diff --git a/.gitignore b/.gitignore index ecb6570..737b9a4 100644 --- a/.gitignore +++ b/.gitignore @@ -94,9 +94,7 @@ ENV/ # Rope project settings .ropeproject -viera.yaml -viera.jsonld - -# config file -relay.yaml -relay.jsonld +# config and database +*.yaml +*.jsonld +*.sqlite3 diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..83229d2 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +include data/statements.sql diff --git a/relay.spec b/relay.spec index 57fedc7..535e5fe 100644 --- a/relay.spec +++ b/relay.spec @@ -5,40 +5,43 @@ block_cipher = None a = Analysis( - ['relay/__main__.py'], - pathex=[], - binaries=[], - datas=[], - hiddenimports=[], - hookspath=[], - hooksconfig={}, - runtime_hooks=[], - excludes=[], - win_no_prefer_redirects=False, - win_private_assemblies=False, - cipher=block_cipher, - noarchive=False, + ['relay/__main__.py'], + pathex=[], + binaries=[], + datas=[ + ('relay/data', 'relay/data') + ], + hiddenimports=[], + hookspath=[], + hooksconfig={}, + runtime_hooks=[], + excludes=[], + win_no_prefer_redirects=False, + win_private_assemblies=False, + cipher=block_cipher, + noarchive=False, ) + pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher) exe = EXE( - pyz, - a.scripts, - a.binaries, - a.zipfiles, - a.datas, - [], - name='activityrelay', - debug=False, - bootloader_ignore_signals=False, - strip=False, - upx=True, - upx_exclude=[], - runtime_tmpdir=None, - console=True, - disable_windowed_traceback=False, - argv_emulation=False, - target_arch=None, - codesign_identity=None, - entitlements_file=None, + pyz, + a.scripts, + a.binaries, + a.zipfiles, + a.datas, + [], + name='activityrelay', + debug=False, + bootloader_ignore_signals=False, + strip=False, + upx=True, + upx_exclude=[], + runtime_tmpdir=None, + console=True, + disable_windowed_traceback=False, + argv_emulation=False, + target_arch=None, + codesign_identity=None, + entitlements_file=None, ) diff --git a/relay.yaml.example b/relay.yaml.example index 4e35697..90b9e8f 100644 --- a/relay.yaml.example +++ b/relay.yaml.example @@ -1,43 +1,35 @@ -# this is the path that the object graph will get dumped to (in JSON-LD format), -# you probably shouldn't change it, but you can if you want. -db: relay.jsonld +# [string] Domain the relay will be hosted on +domain: relay.example.com -# Listener +# [string] Address the relay will listen on listen: 0.0.0.0 + +# [integer] Port the relay will listen on port: 8080 -# Note -note: "Make a note about your instance here." +# [integer] Number of push workers to start (will get removed in a future update) +workers: 8 -# Number of worker threads to start. If 0, use asyncio futures instead of threads. -workers: 0 +# [string] Database backend to use. Valid values: sqlite, postgres +database_type: sqlite -# Maximum number of inbox posts to do at once -# If workers is set to 1 or above, this is the max for each worker -push_limit: 512 +# [string] Path to the sqlite database file if the sqlite backend is in use +sqlite_path: relay.sqlite3 -# The amount of json objects to cache from GET requests -json_cache: 1024 +# settings for the postgresql backend +postgres: -ap: - # This is used for generating activitypub messages, as well as instructions for - # linking AP identities. It should be an SSL-enabled domain reachable by https. - host: 'relay.example.com' + # [string] hostname or unix socket to connect to + host: /var/run/postgresql - blocked_instances: - - 'bad-instance.example.com' - - 'another-bad-instance.example.com' + # [integer] port of the server + port: 5432 - whitelist_enabled: false + # [string] username to use when logging into the server (default is the current system username) + user: null - whitelist: - - 'good-instance.example.com' - - 'another.good-instance.example.com' + # [string] password of the user + pass: null - # uncomment the lines below to prevent certain activitypub software from posting - # to the relay (all known relays by default). this uses the software name in nodeinfo - #blocked_software: - #- 'activityrelay' - #- 'aoderelay' - #- 'social.seattle.wa.us-relay' - #- 'unciarelay' + # [string] name of the database to use + name: activityrelay diff --git a/relay/application.py b/relay/application.py index a01aaec..9440098 100644 --- a/relay/application.py +++ b/relay/application.py @@ -8,52 +8,41 @@ import traceback import typing from aiohttp import web +from aputils.signer import Signer from datetime import datetime, timedelta from . import logger as logging -from .config import RelayConfig -from .database import RelayDatabase +from .config import Config +from .database import get_database from .http_client import HttpClient from .misc import check_open_port from .views import VIEWS if typing.TYPE_CHECKING: + from tinysql import Database from typing import Any from .misc import Message # pylint: disable=unsubscriptable-object - class Application(web.Application): + DEFAULT: Application = None + def __init__(self, cfgpath: str): web.Application.__init__(self) + Application.DEFAULT = self + + self['signer'] = None + self['config'] = Config(cfgpath, load = True) + self['database'] = get_database(self.config) + self['client'] = HttpClient() + self['workers'] = [] self['last_worker'] = 0 self['start_time'] = None self['running'] = False - self['config'] = RelayConfig(cfgpath) - - if not self.config.load(): - self.config.save() - - if self.config.is_docker: - self.config.update({ - 'db': '/data/relay.jsonld', - 'listen': '0.0.0.0', - 'port': 8080 - }) - - self['database'] = RelayDatabase(self.config) - self.database.load() - - self['client'] = HttpClient( - database = self.database, - limit = self.config.push_limit, - timeout = self.config.timeout, - cache_size = self.config.json_cache - ) for path, view in VIEWS: self.router.add_view(path, view) @@ -65,15 +54,29 @@ class Application(web.Application): @property - def config(self) -> RelayConfig: + def config(self) -> Config: return self['config'] @property - def database(self) -> RelayDatabase: + def database(self) -> Database: return self['database'] + @property + def signer(self) -> Signer: + return self['signer'] + + + @signer.setter + def signer(self, value: Signer | str) -> None: + if isinstance(value, Signer): + self['signer'] = value + return + + self['signer'] = Signer(value, self.config.keyid) + + @property def uptime(self) -> timedelta: if not self['start_time']: @@ -118,7 +121,7 @@ class Application(web.Application): logging.info( 'Starting webserver at %s (%s:%i)', - self.config.host, + self.config.domain, self.config.listen, self.config.port ) @@ -179,12 +182,7 @@ class PushWorker(threading.Thread): async def handle_queue(self) -> None: - self.client = HttpClient( - database = self.app.database, - limit = self.app.config.push_limit, - timeout = self.app.config.timeout, - cache_size = self.app.config.json_cache - ) + self.client = HttpClient() while self.app['running']: try: diff --git a/relay/database.py b/relay/compat.py similarity index 64% rename from relay/database.py rename to relay/compat.py index 5d059dd..16d6461 100644 --- a/relay/database.py +++ b/relay/compat.py @@ -1,17 +1,128 @@ from __future__ import annotations import json +import os import typing +import yaml -from aputils.signer import Signer +from functools import cached_property +from pathlib import Path from urllib.parse import urlparse from . import logger as logging +from .misc import Message, boolean if typing.TYPE_CHECKING: - from typing import Iterator, Optional - from .config import RelayConfig - from .misc import Message + from typing import Any, Iterator, Optional + + +# pylint: disable=duplicate-code + +class RelayConfig(dict): + def __init__(self, path: str): + dict.__init__(self, {}) + + if self.is_docker: + path = '/data/config.yaml' + + self._path = Path(path).expanduser().resolve() + self.reset() + + + def __setitem__(self, key: str, value: Any) -> None: + if key in ['blocked_instances', 'blocked_software', 'whitelist']: + assert isinstance(value, (list, set, tuple)) + + elif key in ['port', 'workers', 'json_cache', 'timeout']: + if not isinstance(value, int): + value = int(value) + + elif key == 'whitelist_enabled': + if not isinstance(value, bool): + value = boolean(value) + + super().__setitem__(key, value) + + + @property + def db(self) -> RelayDatabase: + return Path(self['db']).expanduser().resolve() + + + @property + def actor(self) -> str: + return f'https://{self["host"]}/actor' + + + @property + def inbox(self) -> str: + return f'https://{self["host"]}/inbox' + + + @property + def keyid(self) -> str: + return f'{self.actor}#main-key' + + + @cached_property + def is_docker(self) -> bool: + return bool(os.environ.get('DOCKER_RUNNING')) + + + def reset(self) -> None: + self.clear() + self.update({ + 'db': str(self._path.parent.joinpath(f'{self._path.stem}.jsonld')), + 'listen': '0.0.0.0', + 'port': 8080, + 'note': 'Make a note about your instance here.', + 'push_limit': 512, + 'json_cache': 1024, + 'timeout': 10, + 'workers': 0, + 'host': 'relay.example.com', + 'whitelist_enabled': False, + 'blocked_software': [], + 'blocked_instances': [], + 'whitelist': [] + }) + + + def load(self) -> None: + self.reset() + + options = {} + + try: + options['Loader'] = yaml.FullLoader + + except AttributeError: + pass + + try: + with self._path.open('r', encoding = 'UTF-8') as fd: + config = yaml.load(fd, **options) + + except FileNotFoundError: + return + + if not config: + return + + for key, value in config.items(): + if key in ['ap']: + for k, v in value.items(): + if k not in self: + continue + + self[k] = v + + continue + + if key not in self: + continue + + self[key] = value class RelayDatabase(dict): @@ -37,9 +148,7 @@ class RelayDatabase(dict): return tuple(data['inbox'] for data in self['relay-list'].values()) - def load(self) -> bool: - new_db = True - + def load(self) -> None: try: with self.config.db.open() as fd: data = json.load(fd) @@ -65,17 +174,9 @@ class RelayDatabase(dict): self['relay-list'] = data.get('relay-list', {}) for domain, instance in self['relay-list'].items(): - if self.config.is_banned(domain) or \ - (self.config.whitelist_enabled and not self.config.is_whitelisted(domain)): - - self.del_inbox(domain) - continue - if not instance.get('domain'): instance['domain'] = domain - new_db = False - except FileNotFoundError: pass @@ -83,17 +184,6 @@ class RelayDatabase(dict): if self.config.db.stat().st_size > 0: raise e from None - if not self['private-key']: - logging.info('No actor keys present, generating 4096-bit RSA keypair.') - self.signer = Signer.new(self.config.keyid, size=4096) - self['private-key'] = self.signer.export() - - else: - self.signer = Signer(self['private-key'], self.config.keyid) - - self.save() - return not new_db - def save(self) -> None: with self.config.db.open('w', encoding = 'UTF-8') as fd: diff --git a/relay/config.py b/relay/config.py index e684ead..937372f 100644 --- a/relay/config.py +++ b/relay/config.py @@ -1,76 +1,76 @@ from __future__ import annotations +import getpass import os import typing import yaml -from functools import cached_property from pathlib import Path -from urllib.parse import urlparse -from .misc import DotDict, boolean +from .misc import IS_DOCKER if typing.TYPE_CHECKING: - from typing import Any - from .database import RelayDatabase + from typing import Any, Optional -RELAY_SOFTWARE = [ - 'activityrelay', # https://git.pleroma.social/pleroma/relay - 'aoderelay', # https://git.asonix.dog/asonix/relay - 'feditools-relay' # https://git.ptzo.gdn/feditools/relay -] +DEFAULTS: dict[str, Any] = { + 'listen': '0.0.0.0', + 'port': 8080, + 'domain': 'relay.example.com', + 'workers': len(os.sched_getaffinity(0)), + 'db_type': 'sqlite', + 'sq_path': 'relay.sqlite3', + 'pg_host': '/var/run/postgresql', + 'pg_port': 5432, + 'pg_user': getpass.getuser(), + 'pg_pass': None, + 'pg_name': 'activityrelay' +} -APKEYS = [ - 'host', - 'whitelist_enabled', - 'blocked_software', - 'blocked_instances', - 'whitelist' -] +if IS_DOCKER: + DEFAULTS['sq_path'] = '/data/relay.jsonld' -class RelayConfig(DotDict): - __slots__ = ('path', ) +class Config: + def __init__(self, path: str, load: Optional[bool] = False): + self.path = Path(path).expanduser().resolve() - def __init__(self, path: str | Path): - DotDict.__init__(self, {}) + self.listen = None + self.port = None + self.domain = None + self.workers = None + self.db_type = None + self.sq_path = None + self.pg_host = None + self.pg_port = None + self.pg_user = None + self.pg_pass = None + self.pg_name = None - if self.is_docker: - path = '/data/config.yaml' + if load: + try: + self.load() - self._path = Path(path).expanduser().resolve() - self.reset() - - - def __setitem__(self, key: str, value: Any) -> None: - if key in ['blocked_instances', 'blocked_software', 'whitelist']: - assert isinstance(value, (list, set, tuple)) - - elif key in ['port', 'workers', 'json_cache', 'timeout']: - if not isinstance(value, int): - value = int(value) - - elif key == 'whitelist_enabled': - if not isinstance(value, bool): - value = boolean(value) - - super().__setitem__(key, value) + except FileNotFoundError: + self.save() @property - def db(self) -> RelayDatabase: - return Path(self['db']).expanduser().resolve() + def sqlite_path(self) -> Path: + if not os.path.isabs(self.sq_path): + return self.path.parent.joinpath(self.sq_path).resolve() + + return Path(self.sq_path).expanduser().resolve() @property def actor(self) -> str: - return f'https://{self.host}/actor' + return f'https://{self.domain}/actor' @property def inbox(self) -> str: - return f'https://{self.host}/inbox' + return f'https://{self.domain}/inbox' @property @@ -78,115 +78,7 @@ class RelayConfig(DotDict): return f'{self.actor}#main-key' - @cached_property - def is_docker(self) -> bool: - return bool(os.environ.get('DOCKER_RUNNING')) - - - def reset(self) -> None: - self.clear() - self.update({ - 'db': str(self._path.parent.joinpath(f'{self._path.stem}.jsonld')), - 'listen': '0.0.0.0', - 'port': 8080, - 'note': 'Make a note about your instance here.', - 'push_limit': 512, - 'json_cache': 1024, - 'timeout': 10, - 'workers': 0, - 'host': 'relay.example.com', - 'whitelist_enabled': False, - 'blocked_software': [], - 'blocked_instances': [], - 'whitelist': [] - }) - - - def ban_instance(self, instance: str) -> bool: - if instance.startswith('http'): - instance = urlparse(instance).hostname - - if self.is_banned(instance): - return False - - self.blocked_instances.append(instance) - return True - - - def unban_instance(self, instance: str) -> bool: - if instance.startswith('http'): - instance = urlparse(instance).hostname - - try: - self.blocked_instances.remove(instance) - return True - - except ValueError: - return False - - - def ban_software(self, software: str) -> bool: - if self.is_banned_software(software): - return False - - self.blocked_software.append(software) - return True - - - def unban_software(self, software: str) -> bool: - try: - self.blocked_software.remove(software) - return True - - except ValueError: - return False - - - def add_whitelist(self, instance: str) -> bool: - if instance.startswith('http'): - instance = urlparse(instance).hostname - - if self.is_whitelisted(instance): - return False - - self.whitelist.append(instance) - return True - - - def del_whitelist(self, instance: str) -> bool: - if instance.startswith('http'): - instance = urlparse(instance).hostname - - try: - self.whitelist.remove(instance) - return True - - except ValueError: - return False - - - def is_banned(self, instance: str) -> bool: - if instance.startswith('http'): - instance = urlparse(instance).hostname - - return instance in self.blocked_instances - - - def is_banned_software(self, software: str) -> bool: - if not software: - return False - - return software.lower() in self.blocked_software - - - def is_whitelisted(self, instance: str) -> bool: - if instance.startswith('http'): - instance = urlparse(instance).hostname - - return instance in self.whitelist - - - def load(self) -> bool: + def load(self) -> None: self.reset() options = {} @@ -197,50 +89,69 @@ class RelayConfig(DotDict): except AttributeError: pass - try: - with self._path.open('r', encoding = 'UTF-8') as fd: - config = yaml.load(fd, **options) - - except FileNotFoundError: - return False + with self.path.open('r', encoding = 'UTF-8') as fd: + config = yaml.load(fd, **options) + pgcfg = config.get('postgresql', {}) if not config: - return False + raise ValueError('Config is empty') - for key, value in config.items(): - if key in ['ap']: - for k, v in value.items(): - if k not in self: - continue + if IS_DOCKER: + self.listen = '0.0.0.0' + self.port = 8080 + self.sq_path = '/data/relay.jsonld' - self[k] = v + else: + self.set('listen', config.get('listen', DEFAULTS['listen'])) + self.set('port', config.get('port', DEFAULTS['port'])) + self.set('sq_path', config.get('sqlite_path', DEFAULTS['sq_path'])) + self.set('domain', config.get('domain', DEFAULTS['domain'])) + self.set('db_type', config.get('database_type', DEFAULTS['db_type'])) + + for key in DEFAULTS: + if not key.startswith('pg'): continue - if key not in self: + try: + self.set(key, pgcfg[key[3:]]) + + except KeyError: continue - self[key] = value - if self.host.endswith('example.com'): - return False - - return True + def reset(self) -> None: + for key, value in DEFAULTS.items(): + setattr(self, key, value) def save(self) -> None: + self.path.parent.mkdir(exist_ok = True, parents = True) + config = { - # just turning config.db into a string is good enough for now - 'db': str(self.db), 'listen': self.listen, 'port': self.port, - 'note': self.note, - 'push_limit': self.push_limit, - 'workers': self.workers, - 'json_cache': self.json_cache, - 'timeout': self.timeout, - 'ap': {key: self[key] for key in APKEYS} + 'domain': self.domain, + 'database_type': self.db_type, + 'sqlite_path': self.sq_path, + 'postgres': { + 'host': self.pg_host, + 'port': self.pg_port, + 'user': self.pg_user, + 'pass': self.pg_pass, + 'name': self.pg_name + } } - with self._path.open('w', encoding = 'utf-8') as fd: - yaml.dump(config, fd, sort_keys=False) + with self.path.open('w', encoding = 'utf-8') as fd: + yaml.dump(config, fd, sort_keys = False) + + + def set(self, key: str, value: Any) -> None: + if key not in DEFAULTS: + raise KeyError(key) + + if key in ('port', 'pg_port', 'workers') and not isinstance(value, int): + value = int(value) + + setattr(self, key, value) diff --git a/relay/data/statements.sql b/relay/data/statements.sql new file mode 100644 index 0000000..a262feb --- /dev/null +++ b/relay/data/statements.sql @@ -0,0 +1,79 @@ +-- name: get-config +SELECT * FROM config WHERE key = :key + + +-- name: get-config-all +SELECT * FROM config + + +-- name: put-config +INSERT INTO config (key, value, type) +VALUES (:key, :value, :type) +ON CONFLICT (key) DO UPDATE SET value = :value +RETURNING * + + +-- name: del-config +DELETE FROM config +WHERE key = :key + + +-- name: get-inbox +SELECT * FROM inboxes WHERE domain = :value or inbox = :value or actor = :value + + +-- name: put-inbox +INSERT INTO inboxes (domain, actor, inbox, followid, software, created) +VALUES (:domain, :actor, :inbox, :followid, :software, :created) +ON CONFLICT (domain) DO UPDATE SET followid = :followid +RETURNING * + + +-- name: del-inbox +DELETE FROM inboxes +WHERE domain = :value or inbox = :value or actor = :value + + +-- name: get-software-ban +SELECT * FROM software_bans WHERE name = :name + + +-- name: put-software-ban +INSERT INTO software_bans (name, reason, note, created) +VALUES (:name, :reason, :note, :created) +RETURNING * + + +-- name: del-software-ban +DELETE FROM software_bans +WHERE name = :name + + +-- name: get-domain-ban +SELECT * FROM domain_bans WHERE domain = :domain + + +-- name: put-domain-ban +INSERT INTO domain_bans (domain, reason, note, created) +VALUES (:domain, :reason, :note, :created) +RETURNING * + + +-- name: del-domain-ban +DELETE FROM domain_bans +WHERE domain = :domain + + +-- name: get-domain-whitelist +SELECT * FROM whitelist WHERE domain = :domain + + +-- name: put-domain-whitelist +INSERT INTO whitelist (domain, created) +VALUES (:domain, :created) +RETURNING * + + +-- name: del-domain-whitelist +DELETE FROM whitelist +WHERE domain = :domain diff --git a/relay/database/__init__.py b/relay/database/__init__.py new file mode 100644 index 0000000..925c5e0 --- /dev/null +++ b/relay/database/__init__.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import tinysql +import typing + +from importlib.resources import files as pkgfiles + +from .config import get_default_value +from .connection import Connection +from .schema import VERSIONS, migrate_0 + +from .. import logger as logging + +if typing.TYPE_CHECKING: + from typing import Optional + from .config import Config + + +def get_database(config: Config, migrate: Optional[bool] = True) -> tinysql.Database: + if config.db_type == "sqlite": + db = tinysql.Database.sqlite(config.sqlite_path, connection_class = Connection) + + elif config.db_type == "postgres": + db = tinysql.Database.postgres( + config.pg_name, + config.pg_host, + config.pg_port, + config.pg_user, + config.pg_pass, + connection_class = Connection + ) + + db.load_prepared_statements(pkgfiles("relay").joinpath("data", "statements.sql")) + + if not migrate: + return db + + with db.connection() as conn: + if 'config' not in conn.get_tables(): + logging.info("Creating database tables") + migrate_0(conn) + return db + + schema_ver = conn.get_config('schema-version') + + if schema_ver < get_default_value('schema-version'): + logging.info("Migrating database from version '%i'", schema_ver) + + for ver, func in VERSIONS: + if schema_ver < ver: + conn.begin() + + func(conn) + + conn.put_config('schema-version', ver) + conn.commit() + + if (privkey := conn.get_config('private-key')): + conn.app.signer = privkey + + logging.set_level(conn.get_config('log-level')) + + return db diff --git a/relay/database/config.py b/relay/database/config.py new file mode 100644 index 0000000..e132647 --- /dev/null +++ b/relay/database/config.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import typing + +from .. import logger as logging +from ..misc import boolean + +if typing.TYPE_CHECKING: + from typing import Any, Callable + + +CONFIG_DEFAULTS: dict[str, tuple[str, Any]] = { + 'schema-version': ('int', 20240119), + 'log-level': ('loglevel', logging.LogLevel.INFO), + 'note': ('str', 'Make a note about your instance here.'), + 'private-key': ('str', None), + 'whitelist-enabled': ('bool', False) +} + +# serializer | deserializer +CONFIG_CONVERT: dict[str, tuple[Callable, Callable]] = { + 'str': (str, str), + 'int': (str, int), + 'bool': (str, boolean), + 'loglevel': (lambda x: x.name, logging.LogLevel.parse) +} + + +def get_default_value(key: str) -> Any: + return CONFIG_DEFAULTS[key][1] + + +def get_default_type(key: str) -> str: + return CONFIG_DEFAULTS[key][0] + + +def serialize(key: str, value: Any) -> str: + type_name = get_default_type(key) + return CONFIG_CONVERT[type_name][0](value) + + +def deserialize(key: str, value: str) -> Any: + type_name = get_default_type(key) + return CONFIG_CONVERT[type_name][1](value) diff --git a/relay/database/connection.py b/relay/database/connection.py new file mode 100644 index 0000000..43bbb7e --- /dev/null +++ b/relay/database/connection.py @@ -0,0 +1,295 @@ +from __future__ import annotations + +import tinysql +import typing + +from datetime import datetime, timezone +from urllib.parse import urlparse + +from .config import CONFIG_DEFAULTS, get_default_type, get_default_value, serialize, deserialize + +from .. import logger as logging +from ..misc import get_app + +if typing.TYPE_CHECKING: + from tinysql import Cursor, Row + from typing import Any, Iterator, Optional + from .application import Application + from ..misc import Message + + +RELAY_SOFTWARE = [ + 'activityrelay', # https://git.pleroma.social/pleroma/relay + 'activity-relay', # https://github.com/yukimochi/Activity-Relay + 'aoderelay', # https://git.asonix.dog/asonix/relay + 'feditools-relay' # https://git.ptzo.gdn/feditools/relay +] + + +class Connection(tinysql.Connection): + @property + def app(self) -> Application: + return get_app() + + + def distill_inboxes(self, message: Message) -> Iterator[str]: + src_domains = { + message.domain, + urlparse(message.object_id).netloc + } + + for inbox in self.execute('SELECT * FROM inboxes'): + if inbox['domain'] not in src_domains: + yield inbox['inbox'] + + + def exec_statement(self, name: str, params: Optional[dict[str, Any]] = None) -> Cursor: + return self.execute(self.database.prepared_statements[name], params) + + + def get_config(self, key: str) -> Any: + if key not in CONFIG_DEFAULTS: + raise KeyError(key) + + with self.exec_statement('get-config', {'key': key}) as cur: + if not (row := cur.one()): + return get_default_value(key) + + if row['value']: + return deserialize(row['key'], row['value']) + + return None + + + def get_config_all(self) -> dict[str, Any]: + with self.exec_statement('get-config-all') as cur: + db_config = {row['key']: row['value'] for row in cur} + + config = {} + + for key, data in CONFIG_DEFAULTS.items(): + try: + config[key] = deserialize(key, db_config[key]) + + except KeyError: + if key == 'schema-version': + config[key] = 0 + + else: + config[key] = data[1] + + return config + + + def put_config(self, key: str, value: Any) -> Any: + if key not in CONFIG_DEFAULTS: + raise KeyError(key) + + if key == 'private-key': + self.app.signer = value + + elif key == 'log-level': + value = logging.LogLevel.parse(value) + logging.set_level(value) + + params = { + 'key': key, + 'value': serialize(key, value) if value is not None else None, + 'type': get_default_type(key) + } + + with self.exec_statement('put-config', params): + return value + + + def get_inbox(self, value: str) -> Row: + with self.exec_statement('get-inbox', {'value': value}) as cur: + return cur.one() + + + def put_inbox(self, + domain: str, + inbox: str, + actor: Optional[str] = None, + followid: Optional[str] = None, + software: Optional[str] = None) -> Row: + + params = { + 'domain': domain, + 'inbox': inbox, + 'actor': actor, + 'followid': followid, + 'software': software, + 'created': datetime.now(tz = timezone.utc) + } + + with self.exec_statement('put-inbox', params) as cur: + return cur.one() + + + def update_inbox(self, + inbox: str, + actor: Optional[str] = None, + followid: Optional[str] = None, + software: Optional[str] = None) -> Row: + + if not (actor or followid or software): + raise ValueError('Missing "actor", "followid", and/or "software"') + + data = {} + + if actor: + data['actor'] = actor + + if followid: + data['followid'] = followid + + if software: + data['software'] = software + + statement = tinysql.Update('inboxes', data, inbox = inbox) + + with self.query(statement): + return self.get_inbox(inbox) + + + def del_inbox(self, value: str) -> bool: + with self.exec_statement('del-inbox', {'value': value}) as cur: + if cur.modified_row_count > 1: + raise ValueError('More than one row was modified') + + return cur.modified_row_count == 1 + + + def get_domain_ban(self, domain: str) -> Row: + if domain.startswith('http'): + domain = urlparse(domain).netloc + + with self.exec_statement('get-domain-ban', {'domain': domain}) as cur: + return cur.one() + + + def put_domain_ban(self, + domain: str, + reason: Optional[str] = None, + note: Optional[str] = None) -> Row: + + params = { + 'domain': domain, + 'reason': reason, + 'note': note, + 'created': datetime.now(tz = timezone.utc) + } + + with self.exec_statement('put-domain-ban', params) as cur: + return cur.one() + + + def update_domain_ban(self, + domain: str, + reason: Optional[str] = None, + note: Optional[str] = None) -> tinysql.Row: + + if not (reason or note): + raise ValueError('"reason" and/or "note" must be specified') + + params = {} + + if reason: + params['reason'] = reason + + if note: + params['note'] = note + + statement = tinysql.Update('domain_bans', params, domain = domain) + + with self.query(statement) as cur: + if cur.modified_row_count > 1: + raise ValueError('More than one row was modified') + + return self.get_domain_ban(domain) + + + def del_domain_ban(self, domain: str) -> bool: + with self.exec_statement('del-domain-ban', {'domain': domain}) as cur: + if cur.modified_row_count > 1: + raise ValueError('More than one row was modified') + + return cur.modified_row_count == 1 + + + def get_software_ban(self, name: str) -> Row: + with self.exec_statement('get-software-ban', {'name': name}) as cur: + return cur.one() + + + def put_software_ban(self, + name: str, + reason: Optional[str] = None, + note: Optional[str] = None) -> Row: + + params = { + 'name': name, + 'reason': reason, + 'note': note, + 'created': datetime.now(tz = timezone.utc) + } + + with self.exec_statement('put-software-ban', params) as cur: + return cur.one() + + + def update_software_ban(self, + name: str, + reason: Optional[str] = None, + note: Optional[str] = None) -> tinysql.Row: + + if not (reason or note): + raise ValueError('"reason" and/or "note" must be specified') + + params = {} + + if reason: + params['reason'] = reason + + if note: + params['note'] = note + + statement = tinysql.Update('software_bans', params, name = name) + + with self.query(statement) as cur: + if cur.modified_row_count > 1: + raise ValueError('More than one row was modified') + + return self.get_software_ban(name) + + + def del_software_ban(self, name: str) -> bool: + with self.exec_statement('del-software-ban', {'name': name}) as cur: + if cur.modified_row_count > 1: + raise ValueError('More than one row was modified') + + return cur.modified_row_count == 1 + + + def get_domain_whitelist(self, domain: str) -> Row: + with self.exec_statement('get-domain-whitelist', {'domain': domain}) as cur: + return cur.one() + + + def put_domain_whitelist(self, domain: str) -> Row: + params = { + 'domain': domain, + 'created': datetime.now(tz = timezone.utc) + } + + with self.exec_statement('put-domain-whitelist', params) as cur: + return cur.one() + + + def del_domain_whitelist(self, domain: str) -> bool: + with self.exec_statement('del-domain-whitelist', {'domain': domain}) as cur: + if cur.modified_row_count > 1: + raise ValueError('More than one row was modified') + + return cur.modified_row_count == 1 diff --git a/relay/database/schema.py b/relay/database/schema.py new file mode 100644 index 0000000..15a1fae --- /dev/null +++ b/relay/database/schema.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import typing + +from tinysql import Column, Connection, Table + +from .config import get_default_value + +if typing.TYPE_CHECKING: + from typing import Callable + + +VERSIONS: list[Callable] = [] +TABLES: list[Table] = [ + Table( + 'config', + Column('key', 'text', primary_key = True, unique = True, nullable = False), + Column('value', 'text'), + Column('type', 'text', default = 'str') + ), + Table( + 'inboxes', + Column('domain', 'text', primary_key = True, unique = True, nullable = False), + Column('actor', 'text', unique = True), + Column('inbox', 'text', unique = True, nullable = False), + Column('followid', 'text'), + Column('software', 'text'), + Column('created', 'timestamp', nullable = False) + ), + Table( + 'whitelist', + Column('domain', 'text', primary_key = True, unique = True, nullable = True), + Column('created', 'timestamp') + ), + Table( + 'instance_bans', + Column('domain', 'text', primary_key = True, unique = True, nullable = True), + Column('reason', 'text'), + Column('note', 'text'), + Column('created', 'timestamp', nullable = False) + ), + Table( + 'software_bans', + Column('name', 'text', primary_key = True, unique = True, nullable = True), + Column('reason', 'text'), + Column('note', 'text'), + Column('created', 'timestamp', nullable = False) + ) +] + + +def version(func: Callable) -> Callable: + ver = int(func.replace('migrate_', '')) + VERSIONS[ver] = func + return func + + +def migrate_0(conn: Connection) -> None: + conn.create_tables(TABLES) + conn.put_config('schema-version', get_default_value('schema-version')) diff --git a/relay/http_client.py b/relay/http_client.py index 6f2a044..52176b7 100644 --- a/relay/http_client.py +++ b/relay/http_client.py @@ -13,11 +13,10 @@ from urllib.parse import urlparse from . import __version__ from . import logger as logging -from .misc import MIMETYPES, Message +from .misc import MIMETYPES, Message, get_app if typing.TYPE_CHECKING: from typing import Any, Callable, Optional - from .database import RelayDatabase HEADERS = { @@ -28,12 +27,10 @@ HEADERS = { class HttpClient: def __init__(self, - database: RelayDatabase, limit: Optional[int] = 100, timeout: Optional[int] = 10, cache_size: Optional[int] = 1024): - self.database = database self.cache = LRUCache(cache_size) self.limit = limit self.timeout = timeout @@ -98,7 +95,7 @@ class HttpClient: headers = {} if sign_headers: - headers.update(self.database.signer.sign_headers('GET', url, algorithm='original')) + get_app().signer.sign_headers('GET', url, algorithm = 'original') try: logging.debug('Fetching resource: %s', url) @@ -150,23 +147,24 @@ class HttpClient: async def post(self, url: str, message: Message) -> None: await self.open() - instance = self.database.get_inbox(url) + with get_app().database.connection() as conn: + instance = conn.get_inbox(url) ## Using the old algo by default is probably a better idea right now - if instance and instance.get('software') in {'mastodon'}: + if instance and instance['software'] in {'mastodon'}: algorithm = 'hs2019' else: algorithm = 'original' headers = {'Content-Type': 'application/activity+json'} - headers.update(self.database.signer.sign_headers('POST', url, message, algorithm=algorithm)) + headers.update(get_app().signer.sign_headers('POST', url, message, algorithm=algorithm)) try: logging.verbose('Sending "%s" to %s', message.type, url) async with self._session.post(url, headers=headers, data=message.to_json()) as resp: - ## Not expecting a response, so just return + # Not expecting a response, so just return if resp.status in {200, 202}: logging.verbose('Successfully sent "%s" to %s', message.type, url) return @@ -181,7 +179,7 @@ class HttpClient: except (AsyncTimeoutError, ClientConnectionError): logging.warning('Failed to connect to %s for message push', urlparse(url).netloc) - ## prevent workers from being brought down + # prevent workers from being brought down except Exception: traceback.print_exc() @@ -211,16 +209,16 @@ class HttpClient: return await self.get(nodeinfo_url, loads = Nodeinfo.parse) or None -async def get(database: RelayDatabase, *args: Any, **kwargs: Any) -> Message | dict | None: - async with HttpClient(database) as client: +async def get(*args: Any, **kwargs: Any) -> Message | dict | None: + async with HttpClient() as client: return await client.get(*args, **kwargs) -async def post(database: RelayDatabase, *args: Any, **kwargs: Any) -> None: - async with HttpClient(database) as client: +async def post(*args: Any, **kwargs: Any) -> None: + async with HttpClient() as client: return await client.post(*args, **kwargs) -async def fetch_nodeinfo(database: RelayDatabase, *args: Any, **kwargs: Any) -> Nodeinfo | None: - async with HttpClient(database) as client: +async def fetch_nodeinfo(*args: Any, **kwargs: Any) -> Nodeinfo | None: + async with HttpClient() as client: return await client.fetch_nodeinfo(*args, **kwargs) diff --git a/relay/logger.py b/relay/logger.py index 0d1d451..e822cb4 100644 --- a/relay/logger.py +++ b/relay/logger.py @@ -4,20 +4,62 @@ import logging import os import typing +from enum import IntEnum from pathlib import Path if typing.TYPE_CHECKING: - from typing import Any, Callable + from typing import Any, Callable, Type -LOG_LEVELS: dict[str, int] = { - 'DEBUG': logging.DEBUG, - 'VERBOSE': 15, - 'INFO': logging.INFO, - 'WARNING': logging.WARNING, - 'ERROR': logging.ERROR, - 'CRITICAL': logging.CRITICAL -} +class LogLevel(IntEnum): + DEBUG = logging.DEBUG + VERBOSE = 15 + INFO = logging.INFO + WARNING = logging.WARNING + ERROR = logging.ERROR + CRITICAL = logging.CRITICAL + + + def __str__(self) -> str: + return self.name + + + @classmethod + def parse(cls: Type[IntEnum], data: object) -> IntEnum: + if isinstance(data, cls): + return data + + if isinstance(data, str): + data = data.upper() + + try: + return cls[data] + + except KeyError: + pass + + try: + return cls(data) + + except ValueError: + pass + + raise AttributeError(f'Invalid enum property for {cls.__name__}: {data}') + + +def get_level() -> LogLevel: + return LogLevel.parse(logging.root.level) + + +def set_level(level: LogLevel | str) -> None: + logging.root.setLevel(LogLevel.parse(level)) + + +def verbose(message: str, *args: Any, **kwargs: Any) -> None: + if not logging.root.isEnabledFor(LogLevel['VERBOSE']): + return + + logging.log(LogLevel['VERBOSE'], message, *args, **kwargs) debug: Callable = logging.debug @@ -27,14 +69,7 @@ error: Callable = logging.error critical: Callable = logging.critical -def verbose(message: str, *args: Any, **kwargs: Any) -> None: - if not logging.root.isEnabledFor(LOG_LEVELS['VERBOSE']): - return - - logging.log(LOG_LEVELS['VERBOSE'], message, *args, **kwargs) - - -logging.addLevelName(LOG_LEVELS['VERBOSE'], 'VERBOSE') +logging.addLevelName(LogLevel['VERBOSE'], 'VERBOSE') env_log_level = os.environ.get('LOG_LEVEL', 'INFO').upper() try: @@ -45,11 +80,11 @@ except KeyError: try: - log_level = LOG_LEVELS[env_log_level] + log_level = LogLevel[env_log_level] except KeyError: - logging.warning('Invalid log level: %s', env_log_level) - log_level = logging.INFO + print('Invalid log level:', env_log_level) + log_level = LogLevel['INFO'] handlers = [logging.StreamHandler()] diff --git a/relay/manage.py b/relay/manage.py index b0c5cb3..c04235f 100644 --- a/relay/manage.py +++ b/relay/manage.py @@ -6,22 +6,49 @@ import click import platform import typing +from aputils.signer import Signer +from pathlib import Path +from shutil import copyfile from urllib.parse import urlparse -from . import misc, __version__ +from . import __version__ from . import http_client as http +from . import logger as logging from .application import Application -from .config import RELAY_SOFTWARE +from .compat import RelayConfig, RelayDatabase +from .database import get_database +from .database.connection import RELAY_SOFTWARE +from .misc import IS_DOCKER, Message, check_open_port if typing.TYPE_CHECKING: - from typing import Any + from tinysql import Row + from typing import Any, Optional # pylint: disable=unsubscriptable-object,unsupported-assignment-operation -app = None -CONFIG_IGNORE = {'blocked_software', 'blocked_instances', 'whitelist'} +CONFIG_IGNORE = ( + 'schema-version', + 'private-key' +) + +ACTOR_FORMATS = { + 'mastodon': 'https://{domain}/actor', + 'akkoma': 'https://{domain}/relay', + 'pleroma': 'https://{domain}/relay' +} + +SOFTWARE = ( + 'mastodon', + 'akkoma', + 'pleroma', + 'misskey', + 'friendica', + 'hubzilla', + 'firefish', + 'gotosocial' +) @click.group('cli', context_settings={'show_default': True}, invoke_without_command=True) @@ -29,11 +56,10 @@ CONFIG_IGNORE = {'blocked_software', 'blocked_instances', 'whitelist'} @click.version_option(version=__version__, prog_name='ActivityRelay') @click.pass_context def cli(ctx: click.Context, config: str) -> None: - global app - app = Application(config) + ctx.obj = Application(config) if not ctx.invoked_subcommand: - if app.config.host.endswith('example.com'): + if ctx.obj.config.domain.endswith('example.com'): cli_setup.callback() else: @@ -41,46 +67,92 @@ def cli(ctx: click.Context, config: str) -> None: @cli.command('setup') -def cli_setup() -> None: +@click.pass_context +def cli_setup(ctx: click.Context) -> None: 'Generate a new config' while True: - app.config.host = click.prompt( + ctx.obj.config.domain = click.prompt( 'What domain will the relay be hosted on?', - default = app.config.host + default = ctx.obj.config.domain ) - if not app.config.host.endswith('example.com'): + if not ctx.obj.config.domain.endswith('example.com'): break - click.echo('The domain must not be example.com') + click.echo('The domain must not end with "example.com"') - if not app.config.is_docker: - app.config.listen = click.prompt( + if not IS_DOCKER: + ctx.obj.config.listen = click.prompt( 'Which address should the relay listen on?', - default = app.config.listen + default = ctx.obj.config.listen ) - while True: - app.config.port = click.prompt( - 'What TCP port should the relay listen on?', - default = app.config.port, - type = int - ) + ctx.obj.config.port = click.prompt( + 'What TCP port should the relay listen on?', + default = ctx.obj.config.port, + type = int + ) - break + ctx.obj.config.db_type = click.prompt( + 'Which database backend will be used?', + default = ctx.obj.config.db_type, + type = click.Choice(['postgres', 'sqlite'], case_sensitive = False) + ) - app.config.save() + if ctx.obj.config.db_type == 'sqlite': + ctx.obj.config.sq_path = click.prompt( + 'Where should the database be stored?', + default = ctx.obj.config.sq_path + ) - if not app.config.is_docker and click.confirm('Relay all setup! Would you like to run it now?'): + elif ctx.obj.config.db_type == 'postgres': + ctx.obj.config.pg_name = click.prompt( + 'What is the name of the database?', + default = ctx.obj.config.pg_name + ) + + ctx.obj.config.pg_host = click.prompt( + 'What IP address or hostname does the server listen on?', + default = ctx.obj.config.pg_host + ) + + ctx.obj.config.pg_port = click.prompt( + 'What port does the server listen on?', + default = ctx.obj.config.pg_port, + type = int + ) + + ctx.obj.config.pg_user = click.prompt( + 'Which user will authenticate with the server?', + default = ctx.obj.config.pg_user + ) + + ctx.obj.config.pg_pass = click.prompt( + 'User password: ', + hide_input = True + ) or None + + ctx.obj.config.save() + + config = { + 'private-key': Signer.new('n/a').export() + } + + with ctx.obj.database.connection() as conn: + for key, value in config.items(): + conn.put_config(key, value) + + if not IS_DOCKER and click.confirm('Relay all setup! Would you like to run it now?'): cli_run.callback() @cli.command('run') -def cli_run() -> None: +@click.pass_context +def cli_run(ctx: click.Context) -> None: 'Run the relay' - if app.config.host.endswith('example.com'): + if ctx.obj.config.domain.endswith('example.com') or not ctx.obj.signer: click.echo( 'Relay is not set up. Please edit your relay config or run "activityrelay setup".' ) @@ -104,40 +176,142 @@ def cli_run() -> None: click.echo(pip_command) return - if not misc.check_open_port(app.config.listen, app.config.port): - click.echo(f'Error: A server is already running on port {app.config.port}') + if not check_open_port(ctx.obj.config.listen, ctx.obj.config.port): + click.echo(f'Error: A server is already running on port {ctx.obj.config.port}') return - app.run() + ctx.obj.run() + + +@cli.command('convert') +@click.option('--old-config', '-o', help = 'Path to the new config file') +@click.pass_context +def cli_convert(ctx: click.Context, old_config: str) -> None: + 'Convert an old config and jsonld database to the new format.' + + old_config = Path(old_config).expanduser().resolve() + backup = ctx.obj.config.path.parent.joinpath(f'{ctx.obj.config.path.stem}.backup.yaml') + + if str(old_config) == str(ctx.obj.config.path) and not backup.exists(): + logging.info('Created backup config @ %s', backup) + copyfile(ctx.obj.config.path, backup) + + config = RelayConfig(old_config) + config.load() + + database = RelayDatabase(config) + database.load() + + ctx.obj.config.set('listen', config['listen']) + ctx.obj.config.set('port', config['port']) + ctx.obj.config.set('workers', config['workers']) + ctx.obj.config.set('sq_path', config['db'].replace('jsonld', 'sqlite3')) + + with get_database(ctx.obj.config) as db: + with db.connection() as conn: + conn.put_config('private-key', database['private-key']) + conn.put_config('note', config['note']) + conn.put_config('whitelist-enabled', config['whitelist_enabled']) + + with click.progressbar( + database['relay-list'].values(), + label = 'Inboxes'.ljust(15), + width = 0 + ) as inboxes: + + for inbox in inboxes: + if inbox['software'] in ('akkoma', 'pleroma'): + actor = f'https://{inbox["domain"]}/relay' + + elif inbox['software'] == 'mastodon': + actor = f'https://{inbox["domain"]}/actor' + + else: + actor = None + + conn.put_inbox( + inbox['domain'], + inbox['inbox'], + actor = actor, + followid = inbox['followid'], + software = inbox['software'] + ) + + with click.progressbar( + config['blocked_software'], + label = 'Banned software'.ljust(15), + width = 0 + ) as banned_software: + + for software in banned_software: + conn.put_software_ban( + software, + reason = 'relay' if software in RELAY_SOFTWARE else None + ) + + with click.progressbar( + config['blocked_instances'], + label = 'Banned domains'.ljust(15), + width = 0 + ) as banned_software: + + for domain in banned_software: + conn.put_domain_ban(domain) + + with click.progressbar( + config['whitelist'], + label = 'Whitelist'.ljust(15), + width = 0 + ) as whitelist: + + for instance in whitelist: + conn.put_domain_whitelist(instance) + + click.echo('Finished converting old config and database :3') + + +@cli.command('edit-config') +@click.option('--editor', '-e', help = 'Text editor to use') +@click.pass_context +def cli_editconfig(ctx: click.Context, editor: str) -> None: + 'Edit the config file' + + click.edit( + editor = editor, + filename = str(ctx.obj.config.path) + ) @cli.group('config') def cli_config() -> None: - 'Manage the relay config' + 'Manage the relay settings stored in the database' @cli_config.command('list') -def cli_config_list() -> None: +@click.pass_context +def cli_config_list(ctx: click.Context) -> None: 'List the current relay config' click.echo('Relay Config:') - for key, value in app.config.items(): - if key not in CONFIG_IGNORE: - key = f'{key}:'.ljust(20) - click.echo(f'- {key} {value}') + with ctx.obj.database.connection() as conn: + for key, value in conn.get_config_all().items(): + if key not in CONFIG_IGNORE: + key = f'{key}:'.ljust(20) + click.echo(f'- {key} {value}') @cli_config.command('set') @click.argument('key') @click.argument('value') -def cli_config_set(key: str, value: Any) -> None: +@click.pass_context +def cli_config_set(ctx: click.Context, key: str, value: Any) -> None: 'Set a config value' - app.config[key] = value - app.config.save() + with ctx.obj.database.connection() as conn: + new_value = conn.put_config(key, value) - print(f'{key}: {app.config[key]}') + print(f'{key}: {repr(new_value)}') @cli.group('inbox') @@ -146,127 +320,145 @@ def cli_inbox() -> None: @cli_inbox.command('list') -def cli_inbox_list() -> None: +@click.pass_context +def cli_inbox_list(ctx: click.Context) -> None: 'List the connected instances or relays' click.echo('Connected to the following instances or relays:') - for inbox in app.database.inboxes: - click.echo(f'- {inbox}') + with ctx.obj.database.connection() as conn: + for inbox in conn.execute('SELECT * FROM inboxes'): + click.echo(f'- {inbox["inbox"]}') @cli_inbox.command('follow') @click.argument('actor') -def cli_inbox_follow(actor: str) -> None: +@click.pass_context +def cli_inbox_follow(ctx: click.Context, actor: str) -> None: 'Follow an actor (Relay must be running)' - if app.config.is_banned(actor): - click.echo(f'Error: Refusing to follow banned actor: {actor}') - return - - if not actor.startswith('http'): - domain = actor - actor = f'https://{actor}/actor' - - else: - domain = urlparse(actor).hostname - - try: - inbox_data = app.database['relay-list'][domain] - inbox = inbox_data['inbox'] - - except KeyError: - actor_data = asyncio.run(http.get(app.database, actor, sign_headers=True)) - - if not actor_data: - click.echo(f'Failed to fetch actor: {actor}') + with ctx.obj.database.connection() as conn: + if conn.get_domain_ban(actor): + click.echo(f'Error: Refusing to follow banned actor: {actor}') return - inbox = actor_data.shared_inbox + if (inbox_data := conn.get_inbox(actor)): + inbox = inbox_data['inbox'] - message = misc.Message.new_follow( - host = app.config.host, + else: + if not actor.startswith('http'): + actor = f'https://{actor}/actor' + + actor_data = asyncio.run(http.get(actor, sign_headers = True)) + + if not actor_data: + click.echo(f'Failed to fetch actor: {actor}') + return + + inbox = actor_data.shared_inbox + + message = Message.new_follow( + host = ctx.obj.config.domain, actor = actor ) - asyncio.run(http.post(app.database, inbox, message)) + asyncio.run(http.post(inbox, message)) click.echo(f'Sent follow message to actor: {actor}') @cli_inbox.command('unfollow') @click.argument('actor') -def cli_inbox_unfollow(actor: str) -> None: +@click.pass_context +def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None: 'Unfollow an actor (Relay must be running)' - if not actor.startswith('http'): - domain = actor - actor = f'https://{actor}/actor' + inbox_data: Row = None - else: - domain = urlparse(actor).hostname + with ctx.obj.database.connection() as conn: + if conn.get_domain_ban(actor): + click.echo(f'Error: Refusing to follow banned actor: {actor}') + return - try: - inbox_data = app.database['relay-list'][domain] - inbox = inbox_data['inbox'] - message = misc.Message.new_unfollow( - host = app.config.host, - actor = actor, - follow = inbox_data['followid'] - ) + if (inbox_data := conn.get_inbox(actor)): + inbox = inbox_data['inbox'] + message = Message.new_unfollow( + host = ctx.obj.config.domain, + actor = actor, + follow = inbox_data['followid'] + ) - except KeyError: - actor_data = asyncio.run(http.get(app.database, actor, sign_headers=True)) - inbox = actor_data.shared_inbox - message = misc.Message.new_unfollow( - host = app.config.host, - actor = actor, - follow = { - 'type': 'Follow', - 'object': actor, - 'actor': f'https://{app.config.host}/actor' - } - ) + else: + if not actor.startswith('http'): + actor = f'https://{actor}/actor' - asyncio.run(http.post(app.database, inbox, message)) + actor_data = asyncio.run(http.get(actor, sign_headers = True)) + inbox = actor_data.shared_inbox + message = Message.new_unfollow( + host = ctx.obj.config.domain, + actor = actor, + follow = { + 'type': 'Follow', + 'object': actor, + 'actor': f'https://{ctx.obj.config.domain}/actor' + } + ) + + asyncio.run(http.post(inbox, message)) click.echo(f'Sent unfollow message to: {actor}') @cli_inbox.command('add') @click.argument('inbox') -def cli_inbox_add(inbox: str) -> None: +@click.option('--actor', '-a', help = 'Actor url for the inbox') +@click.option('--followid', '-f', help = 'Url for the follow activity') +@click.option('--software', '-s', type = click.Choice(SOFTWARE)) +@click.pass_context +def cli_inbox_add( + ctx: click.Context, + inbox: str, + actor: Optional[str] = None, + followid: Optional[str] = None, + software: Optional[str] = None) -> None: 'Add an inbox to the database' if not inbox.startswith('http'): + domain = inbox inbox = f'https://{inbox}/inbox' - if app.config.is_banned(inbox): - click.echo(f'Error: Refusing to add banned inbox: {inbox}') - return + else: + domain = urlparse(inbox).netloc - if app.database.get_inbox(inbox): - click.echo(f'Error: Inbox already in database: {inbox}') - return + if not actor and software: + try: + actor = ACTOR_FORMATS[software].format(domain = domain) - app.database.add_inbox(inbox) - app.database.save() + except KeyError: + pass + + with ctx.obj.database.connection() as conn: + if conn.get_domain_ban(domain): + click.echo(f'Refusing to add banned inbox: {inbox}') + return + + if conn.get_inbox(inbox): + click.echo(f'Error: Inbox already in database: {inbox}') + return + + conn.put_inbox(domain, inbox, actor, followid, software) click.echo(f'Added inbox to the database: {inbox}') @cli_inbox.command('remove') @click.argument('inbox') -def cli_inbox_remove(inbox: str) -> None: +@click.pass_context +def cli_inbox_remove(ctx: click.Context, inbox: str) -> None: 'Remove an inbox from the database' - try: - dbinbox = app.database.get_inbox(inbox, fail=True) - - except KeyError: - click.echo(f'Error: Inbox does not exist: {inbox}') - return - - app.database.del_inbox(dbinbox['domain']) - app.database.save() + with ctx.obj.database.connection() as conn: + if not conn.del_inbox(inbox): + click.echo(f'Inbox not in database: {inbox}') + return click.echo(f'Removed inbox from the database: {inbox}') @@ -277,47 +469,76 @@ def cli_instance() -> None: @cli_instance.command('list') -def cli_instance_list() -> None: +@click.pass_context +def cli_instance_list(ctx: click.Context) -> None: 'List all banned instances' - click.echo('Banned instances or relays:') + click.echo('Banned domains:') - for domain in app.config.blocked_instances: - click.echo(f'- {domain}') + with ctx.obj.database.connection() as conn: + for instance in conn.execute('SELECT * FROM domain_bans'): + if instance['reason']: + click.echo(f'- {instance["domain"]} ({instance["reason"]})') + + else: + click.echo(f'- {instance["domain"]}') @cli_instance.command('ban') -@click.argument('target') -def cli_instance_ban(target: str) -> None: +@click.argument('domain') +@click.option('--reason', '-r', help = 'Public note about why the domain is banned') +@click.option('--note', '-n', help = 'Internal note that will only be seen by admins and mods') +@click.pass_context +def cli_instance_ban(ctx: click.Context, domain: str, reason: str, note: str) -> None: 'Ban an instance and remove the associated inbox if it exists' - if target.startswith('http'): - target = urlparse(target).hostname + with ctx.obj.database.connection() as conn: + if conn.get_domain_ban(domain): + click.echo(f'Domain already banned: {domain}') + return - if app.config.ban_instance(target): - app.config.save() - - if app.database.del_inbox(target): - app.database.save() - - click.echo(f'Banned instance: {target}') - return - - click.echo(f'Instance already banned: {target}') + conn.put_domain_ban(domain, reason, note) + conn.del_inbox(domain) + click.echo(f'Banned instance: {domain}') @cli_instance.command('unban') -@click.argument('target') -def cli_instance_unban(target: str) -> None: +@click.argument('domain') +@click.pass_context +def cli_instance_unban(ctx: click.Context, domain: str) -> None: 'Unban an instance' - if app.config.unban_instance(target): - app.config.save() + with ctx.obj.database.connection() as conn: + if not conn.del_domain_ban(domain): + click.echo(f'Instance wasn\'t banned: {domain}') + return - click.echo(f'Unbanned instance: {target}') - return + click.echo(f'Unbanned instance: {domain}') - click.echo(f'Instance wasn\'t banned: {target}') + +@cli_instance.command('update') +@click.argument('domain') +@click.option('--reason', '-r') +@click.option('--note', '-n') +@click.pass_context +def cli_instance_update(ctx: click.Context, domain: str, reason: str, note: str) -> None: + 'Update the public reason or internal note for a domain ban' + + if not (reason or note): + ctx.fail('Must pass --reason or --note') + + with ctx.obj.database.connection() as conn: + if not (row := conn.update_domain_ban(domain, reason, note)): + click.echo(f'Failed to update domain ban: {domain}') + return + + click.echo(f'Updated domain ban: {domain}') + + if row['reason']: + click.echo(f'- {row["domain"]} ({row["reason"]})') + + else: + click.echo(f'- {row["domain"]}') @cli.group('software') @@ -326,79 +547,131 @@ def cli_software() -> None: @cli_software.command('list') -def cli_software_list() -> None: +@click.pass_context +def cli_software_list(ctx: click.Context) -> None: 'List all banned software' click.echo('Banned software:') - for software in app.config.blocked_software: - click.echo(f'- {software}') + with ctx.obj.database.connection() as conn: + for software in conn.execute('SELECT * FROM software_bans'): + if software['reason']: + click.echo(f'- {software["name"]} ({software["reason"]})') + + else: + click.echo(f'- {software["name"]}') @cli_software.command('ban') -@click.option( - '--fetch-nodeinfo/--ignore-nodeinfo', '-f', 'fetch_nodeinfo', default = False, - help = 'Treat NAME like a domain and try to fet the software name from nodeinfo' -) @click.argument('name') -def cli_software_ban(name: str, fetch_nodeinfo: bool) -> None: +@click.option('--reason', '-r') +@click.option('--note', '-n') +@click.option( + '--fetch-nodeinfo', '-f', + is_flag = True, + help = 'Treat NAME like a domain and try to fetch the software name from nodeinfo' +) +@click.pass_context +def cli_software_ban(ctx: click.Context, + name: str, + reason: str, + note: str, + fetch_nodeinfo: bool) -> None: 'Ban software. Use RELAYS for NAME to ban relays' - if name == 'RELAYS': - for software in RELAY_SOFTWARE: - app.config.ban_software(software) + with ctx.obj.database.connection() as conn: + if name == 'RELAYS': + for software in RELAY_SOFTWARE: + if conn.get_software_ban(software): + click.echo(f'Relay already banned: {software}') + continue - app.config.save() - click.echo('Banned all relay software') - return + conn.put_software_ban(software, reason or 'relay', note) - if fetch_nodeinfo: - nodeinfo = asyncio.run(http.fetch_nodeinfo(app.database, name)) + click.echo('Banned all relay software') + return - if not nodeinfo: - click.echo(f'Failed to fetch software name from domain: {name}') + if fetch_nodeinfo: + nodeinfo = asyncio.run(http.fetch_nodeinfo(name)) - name = nodeinfo.sw_name + if not nodeinfo: + click.echo(f'Failed to fetch software name from domain: {name}') + return + + name = nodeinfo.sw_name + + if conn.get_software_ban(name): + click.echo(f'Software already banned: {name}') + return + + if not conn.put_software_ban(name, reason, note): + click.echo(f'Failed to ban software: {name}') + return - if app.config.ban_software(name): - app.config.save() click.echo(f'Banned software: {name}') - return - - click.echo(f'Software already banned: {name}') @cli_software.command('unban') -@click.option( - '--fetch-nodeinfo/--ignore-nodeinfo', '-f', 'fetch_nodeinfo', default = False, - help = 'Treat NAME like a domain and try to fet the software name from nodeinfo' -) @click.argument('name') -def cli_software_unban(name: str, fetch_nodeinfo: bool) -> None: +@click.option('--reason', '-r') +@click.option('--note', '-n') +@click.option( + '--fetch-nodeinfo', '-f', + is_flag = True, + help = 'Treat NAME like a domain and try to fetch the software name from nodeinfo' +) +@click.pass_context +def cli_software_unban(ctx: click.Context, name: str, fetch_nodeinfo: bool) -> None: 'Ban software. Use RELAYS for NAME to unban relays' - if name == 'RELAYS': - for software in RELAY_SOFTWARE: - app.config.unban_software(software) + with ctx.obj.database.connection() as conn: + if name == 'RELAYS': + for software in RELAY_SOFTWARE: + if not conn.del_software_ban(software): + click.echo(f'Relay was not banned: {software}') - app.config.save() - click.echo('Unbanned all relay software') - return + click.echo('Unbanned all relay software') + return - if fetch_nodeinfo: - nodeinfo = asyncio.run(http.fetch_nodeinfo(app.database, name)) + if fetch_nodeinfo: + nodeinfo = asyncio.run(http.fetch_nodeinfo(name)) - if not nodeinfo: - click.echo(f'Failed to fetch software name from domain: {name}') + if not nodeinfo: + click.echo(f'Failed to fetch software name from domain: {name}') + return - name = nodeinfo.sw_name + name = nodeinfo.sw_name + + if not conn.del_software_ban(name): + click.echo(f'Software was not banned: {name}') + return - if app.config.unban_software(name): - app.config.save() click.echo(f'Unbanned software: {name}') - return - click.echo(f'Software wasn\'t banned: {name}') + +@cli_software.command('update') +@click.argument('name') +@click.option('--reason', '-r') +@click.option('--note', '-n') +@click.pass_context +def cli_software_update(ctx: click.Context, name: str, reason: str, note: str) -> None: + 'Update the public reason or internal note for a software ban' + + if not (reason or note): + ctx.fail('Must pass --reason or --note') + + with ctx.obj.database.connection() as conn: + if not (row := conn.update_software_ban(name, reason, note)): + click.echo(f'Failed to update software ban: {name}') + return + + click.echo(f'Updated software ban: {name}') + + if row['reason']: + click.echo(f'- {row["name"]} ({row["reason"]})') + + else: + click.echo(f'- {row["name"]}') @cli.group('whitelist') @@ -407,52 +680,64 @@ def cli_whitelist() -> None: @cli_whitelist.command('list') -def cli_whitelist_list() -> None: +@click.pass_context +def cli_whitelist_list(ctx: click.Context) -> None: 'List all the instances in the whitelist' - click.echo('Current whitelisted domains') + click.echo('Current whitelisted domains:') - for domain in app.config.whitelist: - click.echo(f'- {domain}') + with ctx.obj.database.connection() as conn: + for domain in conn.execute('SELECT * FROM whitelist'): + click.echo(f'- {domain["domain"]}') @cli_whitelist.command('add') -@click.argument('instance') -def cli_whitelist_add(instance: str) -> None: - 'Add an instance to the whitelist' +@click.argument('domain') +@click.pass_context +def cli_whitelist_add(ctx: click.Context, domain: str) -> None: + 'Add a domain to the whitelist' - if not app.config.add_whitelist(instance): - click.echo(f'Instance already in the whitelist: {instance}') - return + with ctx.obj.database.connection() as conn: + if conn.get_domain_whitelist(domain): + click.echo(f'Instance already in the whitelist: {domain}') + return - app.config.save() - click.echo(f'Instance added to the whitelist: {instance}') + conn.put_domain_whitelist(domain) + click.echo(f'Instance added to the whitelist: {domain}') @cli_whitelist.command('remove') -@click.argument('instance') -def cli_whitelist_remove(instance: str) -> None: +@click.argument('domain') +@click.pass_context +def cli_whitelist_remove(ctx: click.Context, domain: str) -> None: 'Remove an instance from the whitelist' - if not app.config.del_whitelist(instance): - click.echo(f'Instance not in the whitelist: {instance}') - return + with ctx.obj.database.connection() as conn: + if not conn.del_domain_whitelist(domain): + click.echo(f'Domain not in the whitelist: {domain}') + return - app.config.save() + if conn.get_config('whitelist-enabled'): + if conn.del_inbox(domain): + click.echo(f'Removed inbox for domain: {domain}') - if app.config.whitelist_enabled: - if app.database.del_inbox(instance): - app.database.save() - - click.echo(f'Removed instance from the whitelist: {instance}') + click.echo(f'Removed domain from the whitelist: {domain}') @cli_whitelist.command('import') -def cli_whitelist_import() -> None: +@click.pass_context +def cli_whitelist_import(ctx: click.Context) -> None: 'Add all current inboxes to the whitelist' - for domain in app.database.hostnames: - cli_whitelist_add.callback(domain) + with ctx.obj.database.connection() as conn: + for inbox in conn.execute('SELECT * FROM inboxes').all(): + if conn.get_domain_whitelist(inbox['domain']): + click.echo(f'Domain already in whitelist: {inbox["domain"]}') + continue + + conn.put_domain_whitelist(inbox['domain']) + + click.echo('Imported whitelist from inboxes') def main() -> None: diff --git a/relay/misc.py b/relay/misc.py index 7244eaa..2d7117d 100644 --- a/relay/misc.py +++ b/relay/misc.py @@ -1,32 +1,27 @@ from __future__ import annotations import json +import os import socket -import traceback import typing from aiohttp.abc import AbstractView from aiohttp.hdrs import METH_ALL as METHODS -from aiohttp.web import Request as AiohttpRequest, Response as AiohttpResponse +from aiohttp.web import Response as AiohttpResponse from aiohttp.web_exceptions import HTTPMethodNotAllowed -from aputils.errors import SignatureFailureError -from aputils.misc import Digest, HttpDate, Signature from aputils.message import Message as ApMessage from functools import cached_property -from json.decoder import JSONDecodeError from uuid import uuid4 -from . import logger as logging - if typing.TYPE_CHECKING: from typing import Any, Coroutine, Generator, Optional, Type - from aputils.signer import Signer from .application import Application - from .config import RelayConfig - from .database import RelayDatabase + from .config import Config + from .database import Database from .http_client import HttpClient +IS_DOCKER = bool(os.environ.get('DOCKER_RUNNING')) MIMETYPES = { 'activity': 'application/activity+json', 'html': 'text/html', @@ -77,91 +72,13 @@ def check_open_port(host: str, port: int) -> bool: return False -class DotDict(dict): - def __init__(self, _data: dict[str, Any], **kwargs: Any): - dict.__init__(self) +def get_app() -> Application: + from .application import Application # pylint: disable=import-outside-toplevel - self.update(_data, **kwargs) + if not Application.DEFAULT: + raise ValueError('No default application set') - - def __getattr__(self, key: str) -> str: - try: - return self[key] - - except KeyError: - raise AttributeError( - f'{self.__class__.__name__} object has no attribute {key}' - ) from None - - - def __setattr__(self, key: str, value: Any) -> None: - if key.startswith('_'): - super().__setattr__(key, value) - - else: - self[key] = value - - - def __setitem__(self, key: str, value: Any) -> None: - if type(value) is dict: # pylint: disable=unidiomatic-typecheck - value = DotDict(value) - - super().__setitem__(key, value) - - - def __delattr__(self, key: str) -> None: - try: - dict.__delitem__(self, key) - - except KeyError: - raise AttributeError( - f'{self.__class__.__name__} object has no attribute {key}' - ) from None - - - @classmethod - def new_from_json(cls: Type[DotDict], data: dict[str, Any]) -> DotDict[str, Any]: - if not data: - raise JSONDecodeError('Empty body', data, 1) - - try: - return cls(json.loads(data)) - - except ValueError: - raise JSONDecodeError('Invalid body', data, 1) from None - - - @classmethod - def new_from_signature(cls: Type[DotDict], sig: str) -> DotDict[str, Any]: - data = cls({}) - - for chunk in sig.strip().split(','): - key, value = chunk.split('=', 1) - value = value.strip('\"') - - if key == 'headers': - value = value.split() - - data[key.lower()] = value - - return data - - - def to_json(self, indent: Optional[int | str] = None) -> str: - return json.dumps(self, indent=indent) - - - def update(self, _data: dict[str, Any], **kwargs: Any) -> None: - if isinstance(_data, dict): - for key, value in _data.items(): - self[key] = value - - elif isinstance(_data, (list, tuple, set)): - for key, value in _data: - self[key] = value - - for key, value in kwargs.items(): - self[key] = value + return Application.DEFAULT class Message(ApMessage): @@ -181,7 +98,7 @@ class Message(ApMessage): 'followers': f'https://{host}/followers', 'following': f'https://{host}/following', 'inbox': f'https://{host}/inbox', - 'url': f'https://{host}/inbox', + 'url': f'https://{host}/', 'endpoints': { 'sharedInbox': f'https://{host}/inbox' }, @@ -310,16 +227,6 @@ class Response(AiohttpResponse): class View(AbstractView): - def __init__(self, request: AiohttpRequest): - AbstractView.__init__(self, request) - - self.signature: Signature = None - self.message: Message = None - self.actor: Message = None - self.instance: dict[str, str] = None - self.signer: Signer = None - - def __await__(self) -> Generator[Response]: method = self.request.method.upper() @@ -363,94 +270,10 @@ class View(AbstractView): @property - def config(self) -> RelayConfig: + def config(self) -> Config: return self.app.config @property - def database(self) -> RelayDatabase: + def database(self) -> Database: return self.app.database - - - # todo: move to views.ActorView - async def get_post_data(self) -> Response | None: - try: - self.signature = Signature.new_from_signature(self.request.headers['signature']) - - except KeyError: - logging.verbose('Missing signature header') - return Response.new_error(400, 'missing signature header', 'json') - - try: - self.message = await self.request.json(loads = Message.parse) - - except Exception: - traceback.print_exc() - logging.verbose('Failed to parse inbox message') - return Response.new_error(400, 'failed to parse message', 'json') - - if self.message is None: - logging.verbose('empty message') - return Response.new_error(400, 'missing message', 'json') - - if 'actor' not in self.message: - logging.verbose('actor not in message') - return Response.new_error(400, 'no actor in message', 'json') - - self.actor = await self.client.get(self.signature.keyid, sign_headers = True) - - if self.actor is None: - # ld signatures aren't handled atm, so just ignore it - if self.message.type == 'Delete': - logging.verbose('Instance sent a delete which cannot be handled') - return Response.new(status=202) - - logging.verbose(f'Failed to fetch actor: {self.signature.keyid}') - return Response.new_error(400, 'failed to fetch actor', 'json') - - try: - self.signer = self.actor.signer - - except KeyError: - logging.verbose('Actor missing public key: %s', self.signature.keyid) - return Response.new_error(400, 'actor missing public key', 'json') - - try: - self.validate_signature(await self.request.read()) - - except SignatureFailureError as e: - logging.verbose('signature validation failed for "%s": %s', self.actor.id, e) - return Response.new_error(401, str(e), 'json') - - self.instance = self.database.get_inbox(self.actor.inbox) - - - def validate_signature(self, body: bytes) -> None: - headers = {key.lower(): value for key, value in self.request.headers.items()} - headers["(request-target)"] = " ".join([self.request.method.lower(), self.request.path]) - - if (digest := Digest.new_from_digest(headers.get("digest"))): - if not body: - raise SignatureFailureError("Missing body for digest verification") - - if not digest.validate(body): - raise SignatureFailureError("Body digest does not match") - - if self.signature.algorithm_type == "hs2019": - if "(created)" not in self.signature.headers: - raise SignatureFailureError("'(created)' header not used") - - current_timestamp = HttpDate.new_utc().timestamp() - - if self.signature.created > current_timestamp: - raise SignatureFailureError("Creation date after current date") - - if current_timestamp > self.signature.expires: - raise SignatureFailureError("Expiration date before current date") - - headers["(created)"] = self.signature.created - headers["(expires)"] = self.signature.expires - - # pylint: disable=protected-access - if not self.signer._validate_signature(headers, self.signature): - raise SignatureFailureError("Signature does not match") diff --git a/relay/processors.py b/relay/processors.py index b9b32bc..4d85607 100644 --- a/relay/processors.py +++ b/relay/processors.py @@ -1,5 +1,6 @@ from __future__ import annotations +import tinysql import typing from cachetools import LRUCache @@ -8,7 +9,7 @@ from . import logger as logging from .misc import Message if typing.TYPE_CHECKING: - from .misc import View + from .views import ActorView cache = LRUCache(1024) @@ -16,128 +17,141 @@ cache = LRUCache(1024) def person_check(actor: str, software: str) -> bool: # pleroma and akkoma may use Person for the actor type for some reason + # akkoma changed this in a 3.6.0 if software in {'akkoma', 'pleroma'} and actor.id == f'https://{actor.domain}/relay': return False - ## make sure the actor is an application + # make sure the actor is an application if actor.type != 'Application': return True return False -async def handle_relay(view: View) -> None: +async def handle_relay(view: ActorView) -> None: if view.message.object_id in cache: logging.verbose('already relayed %s', view.message.object_id) return - message = Message.new_announce(view.config.host, view.message.object_id) + message = Message.new_announce(view.config.domain, view.message.object_id) cache[view.message.object_id] = message.id logging.debug('>> relay: %s', message) - inboxes = view.database.distill_inboxes(view.message) - - for inbox in inboxes: - view.app.push_message(inbox, message) + with view.database.connection() as conn: + for inbox in conn.distill_inboxes(view.message): + view.app.push_message(inbox, message) -async def handle_forward(view: View) -> None: +async def handle_forward(view: ActorView) -> None: if view.message.id in cache: logging.verbose('already forwarded %s', view.message.id) return - message = Message.new_announce(view.config.host, view.message) + message = Message.new_announce(view.config.domain, view.message) cache[view.message.id] = message.id logging.debug('>> forward: %s', message) - inboxes = view.database.distill_inboxes(view.message) - - for inbox in inboxes: - view.app.push_message(inbox, message) + with view.database.connection() as conn: + for inbox in conn.distill_inboxes(view.message): + view.app.push_message(inbox, message) -async def handle_follow(view: View) -> None: +async def handle_follow(view: ActorView) -> None: nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain) software = nodeinfo.sw_name if nodeinfo else None - ## reject if software used by actor is banned - if view.config.is_banned_software(software): + with view.database.connection() as conn: + # reject if software used by actor is banned + if view.config.is_banned_software(software): + view.app.push_message( + view.actor.shared_inbox, + Message.new_response( + host = view.config.domain, + actor = view.actor.id, + followid = view.message.id, + accept = False + ) + ) + + logging.verbose( + 'Rejected follow from actor for using specific software: actor=%s, software=%s', + view.actor.id, + software + ) + + return + + ## reject if the actor is not an instance actor + if person_check(view.actor, software): + view.app.push_message( + view.actor.shared_inbox, + Message.new_response( + host = view.config.domain, + actor = view.actor.id, + followid = view.message.id, + accept = False + ) + ) + + logging.verbose('Non-application actor tried to follow: %s', view.actor.id) + return + + if conn.get_inbox(view.actor.shared_inbox): + data = {'followid': view.message.id} + statement = tinysql.Update('inboxes', data, inbox = view.actor.shared_inbox) + + with conn.query(statement): + pass + + else: + conn.put_inbox( + view.actor.domain, + view.actor.shared_inbox, + view.actor.id, + view.message.id, + software + ) + view.app.push_message( view.actor.shared_inbox, Message.new_response( - host = view.config.host, + host = view.config.domain, actor = view.actor.id, followid = view.message.id, - accept = False + accept = True ) ) - return logging.verbose( - 'Rejected follow from actor for using specific software: actor=%s, software=%s', - view.actor.id, - software - ) - - ## reject if the actor is not an instance actor - if person_check(view.actor, software): - view.app.push_message( - view.actor.shared_inbox, - Message.new_response( - host = view.config.host, - actor = view.actor.id, - followid = view.message.id, - accept = False + # Are Akkoma and Pleroma the only two that expect a follow back? + # Ignoring only Mastodon for now + if software != 'mastodon': + view.app.push_message( + view.actor.shared_inbox, + Message.new_follow( + host = view.config.domain, + actor = view.actor.id + ) ) - ) - - logging.verbose('Non-application actor tried to follow: %s', view.actor.id) - return - - view.database.add_inbox(view.actor.shared_inbox, view.message.id, software) - view.database.save() - - view.app.push_message( - view.actor.shared_inbox, - Message.new_response( - host = view.config.host, - actor = view.actor.id, - followid = view.message.id, - accept = True - ) - ) - - # Are Akkoma and Pleroma the only two that expect a follow back? - # Ignoring only Mastodon for now - if software != 'mastodon': - view.app.push_message( - view.actor.shared_inbox, - Message.new_follow( - host = view.config.host, - actor = view.actor.id - ) - ) -async def handle_undo(view: View) -> None: +async def handle_undo(view: ActorView) -> None: ## If the object is not a Follow, forward it if view.message.object['type'] != 'Follow': - return await handle_forward(view) - - if not view.database.del_inbox(view.actor.domain, view.message.object['id']): - logging.verbose( - 'Failed to delete "%s" with follow ID "%s"', - view.actor.id, - view.message.object['id'] - ) - + await handle_forward(view) return - view.database.save() + with view.database.connection() as conn: + if not conn.del_inbox(view.actor.inbox): + logging.verbose( + 'Failed to delete "%s" with follow ID "%s"', + view.actor.id, + view.message.object['id'] + ) view.app.push_message( view.actor.shared_inbox, Message.new_unfollow( - host = view.config.host, + host = view.config.domain, actor = view.actor.id, follow = view.message ) @@ -154,7 +168,7 @@ processors = { } -async def run_processor(view: View) -> None: +async def run_processor(view: ActorView) -> None: if view.message.type not in processors: logging.verbose( 'Message type "%s" from actor cannot be handled: %s', @@ -164,12 +178,21 @@ async def run_processor(view: View) -> None: return - if view.instance and not view.instance.get('software'): - nodeinfo = await view.client.fetch_nodeinfo(view.instance['domain']) + if view.instance: + if not view.instance['software']: + if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])): + with view.database.connection() as conn: + view.instance = conn.update_inbox( + view.instance['inbox'], + software = nodeinfo.sw_name + ) - if nodeinfo: - view.instance['software'] = nodeinfo.sw_name - view.database.save() + if not view.instance['actor']: + with view.database.connection() as conn: + view.instance = conn.update_inbox( + view.instance['inbox'], + actor = view.actor.id + ) logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id) await processors[view.message.type](view) diff --git a/relay/views.py b/relay/views.py index e1bed64..df06e81 100644 --- a/relay/views.py +++ b/relay/views.py @@ -2,8 +2,11 @@ from __future__ import annotations import asyncio import subprocess +import traceback import typing +from aputils.errors import SignatureFailureError +from aputils.misc import Digest, HttpDate, Signature from aputils.objects import Nodeinfo, Webfinger, WellKnownNodeinfo from pathlib import Path @@ -14,6 +17,7 @@ from .processors import run_processor if typing.TYPE_CHECKING: from aiohttp.web import Request + from aputils.signer import Signer from typing import Callable @@ -71,12 +75,16 @@ def register_route(*paths: str) -> Callable: @register_route('/') class HomeView(View): async def get(self, request: Request) -> Response: - text = HOME_TEMPLATE.format( - host = self.config.host, - note = self.config.note, - count = len(self.database.hostnames), - targets = '
'.join(self.database.hostnames) - ) + with self.database.connection() as conn: + config = conn.get_config_all() + inboxes = conn.execute('SELECT * FROM inboxes').all() + + text = HOME_TEMPLATE.format( + host = self.config.domain, + note = config['note'], + count = len(inboxes), + targets = '
'.join(inbox['domain'] for inbox in inboxes) + ) return Response.new(text, ctype='html') @@ -84,44 +92,137 @@ class HomeView(View): @register_route('/actor', '/inbox') class ActorView(View): + def __init__(self, request: Request): + View.__init__(self, request) + + self.signature: Signature = None + self.message: Message = None + self.actor: Message = None + self.instance: dict[str, str] = None + self.signer: Signer = None + + async def get(self, request: Request) -> Response: data = Message.new_actor( - host = self.config.host, - pubkey = self.database.signer.pubkey + host = self.config.domain, + pubkey = self.app.signer.pubkey ) return Response.new(data, ctype='activity') async def post(self, request: Request) -> Response: - response = await self.get_post_data() - - if response is not None: + if (response := await self.get_post_data()): return response - ## reject if the actor isn't whitelisted while the whiltelist is enabled - if self.config.whitelist_enabled and not self.config.is_whitelisted(self.actor.domain): - logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id) - return Response.new_error(403, 'access denied', 'json') + with self.database.connection() as conn: + self.instance = conn.get_inbox(self.actor.inbox) + config = conn.get_config_all() - ## reject if actor is banned - if self.config.is_banned(self.actor.domain): - logging.verbose('Ignored request from banned actor: %s', self.actor.id) - return Response.new_error(403, 'access denied', 'json') + ## reject if the actor isn't whitelisted while the whiltelist is enabled + if config['whitelist-enabled'] and not conn.get_domain_whitelist(self.actor.domain): + logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id) + return Response.new_error(403, 'access denied', 'json') - ## reject if activity type isn't 'Follow' and the actor isn't following - if self.message.type != 'Follow' and not self.database.get_inbox(self.actor.domain): - logging.verbose( - 'Rejected actor for trying to post while not following: %s', - self.actor.id - ) + ## reject if actor is banned + if conn.get_domain_ban(self.actor.domain): + logging.verbose('Ignored request from banned actor: %s', self.actor.id) + return Response.new_error(403, 'access denied', 'json') - return Response.new_error(401, 'access denied', 'json') + ## reject if activity type isn't 'Follow' and the actor isn't following + if self.message.type != 'Follow' and not self.instance: + logging.verbose( + 'Rejected actor for trying to post while not following: %s', + self.actor.id + ) - logging.debug('>> payload %s', self.message.to_json(4)) + return Response.new_error(401, 'access denied', 'json') - asyncio.ensure_future(run_processor(self)) - return Response.new(status = 202) + logging.debug('>> payload %s', self.message.to_json(4)) + + asyncio.ensure_future(run_processor(self)) + return Response.new(status = 202) + + + async def get_post_data(self) -> Response | None: + try: + self.signature = Signature.new_from_signature(self.request.headers['signature']) + + except KeyError: + logging.verbose('Missing signature header') + return Response.new_error(400, 'missing signature header', 'json') + + try: + self.message = await self.request.json(loads = Message.parse) + + except Exception: + traceback.print_exc() + logging.verbose('Failed to parse inbox message') + return Response.new_error(400, 'failed to parse message', 'json') + + if self.message is None: + logging.verbose('empty message') + return Response.new_error(400, 'missing message', 'json') + + if 'actor' not in self.message: + logging.verbose('actor not in message') + return Response.new_error(400, 'no actor in message', 'json') + + self.actor = await self.client.get(self.signature.keyid, sign_headers = True) + + if self.actor is None: + # ld signatures aren't handled atm, so just ignore it + if self.message.type == 'Delete': + logging.verbose('Instance sent a delete which cannot be handled') + return Response.new(status=202) + + logging.verbose(f'Failed to fetch actor: {self.signature.keyid}') + return Response.new_error(400, 'failed to fetch actor', 'json') + + try: + self.signer = self.actor.signer + + except KeyError: + logging.verbose('Actor missing public key: %s', self.signature.keyid) + return Response.new_error(400, 'actor missing public key', 'json') + + try: + self.validate_signature(await self.request.read()) + + except SignatureFailureError as e: + logging.verbose('signature validation failed for "%s": %s', self.actor.id, e) + return Response.new_error(401, str(e), 'json') + + + def validate_signature(self, body: bytes) -> None: + headers = {key.lower(): value for key, value in self.request.headers.items()} + headers["(request-target)"] = " ".join([self.request.method.lower(), self.request.path]) + + if (digest := Digest.new_from_digest(headers.get("digest"))): + if not body: + raise SignatureFailureError("Missing body for digest verification") + + if not digest.validate(body): + raise SignatureFailureError("Body digest does not match") + + if self.signature.algorithm_type == "hs2019": + if "(created)" not in self.signature.headers: + raise SignatureFailureError("'(created)' header not used") + + current_timestamp = HttpDate.new_utc().timestamp() + + if self.signature.created > current_timestamp: + raise SignatureFailureError("Creation date after current date") + + if current_timestamp > self.signature.expires: + raise SignatureFailureError("Expiration date before current date") + + headers["(created)"] = self.signature.created + headers["(expires)"] = self.signature.expires + + # pylint: disable=protected-access + if not self.signer._validate_signature(headers, self.signature): + raise SignatureFailureError("Signature does not match") @register_route('/.well-known/webfinger') @@ -133,12 +234,12 @@ class WebfingerView(View): except KeyError: return Response.new_error(400, 'missing "resource" query key', 'json') - if subject != f'acct:relay@{self.config.host}': + if subject != f'acct:relay@{self.config.domain}': return Response.new_error(404, 'user not found', 'json') data = Webfinger.new( handle = 'relay', - domain = self.config.host, + domain = self.config.domain, actor = self.config.actor ) @@ -148,14 +249,17 @@ class WebfingerView(View): @register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}') class NodeinfoView(View): async def get(self, request: Request, niversion: str) -> Response: - data = { - 'name': 'activityrelay', - 'version': VERSION, - 'protocols': ['activitypub'], - 'open_regs': not self.config.whitelist_enabled, - 'users': 1, - 'metadata': {'peers': self.database.hostnames} - } + with self.database.connection() as conn: + inboxes = conn.execute('SELECT * FROM inboxes').all() + + data = { + 'name': 'activityrelay', + 'version': VERSION, + 'protocols': ['activitypub'], + 'open_regs': not conn.get_config('whitelist-enabled'), + 'users': 1, + 'metadata': {'peers': [inbox['domain'] for inbox in inboxes]} + } if niversion == '2.1': data['repo'] = 'https://git.pleroma.social/pleroma/relay' @@ -166,5 +270,5 @@ class NodeinfoView(View): @register_route('/.well-known/nodeinfo') class WellknownNodeinfoView(View): async def get(self, request: Request) -> Response: - data = WellKnownNodeinfo.new_template(self.config.host) + data = WellKnownNodeinfo.new_template(self.config.domain) return Response.new(data, ctype = 'json') diff --git a/requirements.txt b/requirements.txt index 43bf45a..cc6fc4d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ aputils@https://git.barkshark.xyz/barkshark/aputils/archive/0.1.6a.tar.gz cachetools>=5.2.0 click>=8.1.2 pyyaml>=6.0 +tinysql@https://git.barkshark.xyz/barkshark/tinysql/archive/0.2.1.tar.gz diff --git a/setup.cfg b/setup.cfg index 8b807e3..65874ff 100644 --- a/setup.cfg +++ b/setup.cfg @@ -10,26 +10,33 @@ license_file = LICENSE classifiers = Environment :: Console License :: OSI Approved :: AGPLv3 License - Programming Language :: Python :: 3.6 - Programming Language :: Python :: 3.7 Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 + Programming Language :: Python :: 3.12 project_urls = Source = https://git.pleroma.social/pleroma/relay Tracker = https://git.pleroma.social/pleroma/relay/-/issues [options] zip_safe = False -packages = find: +packages = + relay + relay.database +include_package_data = true install_requires = file: requirements.txt python_requires = >=3.8 [options.extras_require] dev = - flake8 = 3.1.0 - pyinstaller = 6.3.0 - pylint = 3.0 + flake8 == 3.1.0 + pyinstaller == 6.3.0 + pylint == 3.0 + +[options.package_data] +relay = + data/statements.sql [options.entry_points] console_scripts =