diff --git a/.gitignore b/.gitignore index fc8aedd..ecb6570 100644 --- a/.gitignore +++ b/.gitignore @@ -99,3 +99,4 @@ viera.jsonld # config file relay.yaml +relay.jsonld diff --git a/docs/installation.md b/docs/installation.md index 8363faa..a391389 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -15,7 +15,7 @@ the [official pipx docs](https://pypa.github.io/pipx/installation/) for more in- Now simply install ActivityRelay directly from git - pipx install git+https://git.pleroma.social/pleroma/relay@0.2.4 + pipx install git+https://git.pleroma.social/pleroma/relay@0.2.5 Or from a cloned git repo. @@ -39,7 +39,7 @@ be installed via [pyenv](https://github.com/pyenv/pyenv). The instructions for installation via pip are very similar to pipx. Installation can be done from git - python3 -m pip install git+https://git.pleroma.social/pleroma/relay@0.2.4 + python3 -m pip install git+https://git.pleroma.social/pleroma/relay@0.2.5 or a cloned git repo. diff --git a/pyproject.toml b/pyproject.toml index 596d494..2c137b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,37 @@ [build-system] requires = ["setuptools","wheel"] build-backend = 'setuptools.build_meta' + + +[tool.pylint.main] +jobs = 0 +persistent = true + + +[tool.pylint.design] +max-args = 10 +max-attributes = 100 + + +[tool.pylint.format] +indent-str = "\t" +indent-after-paren = 1 +max-line-length = 100 +single-line-if-stmt = true + + +[tool.pylint.messages_control] +disable = [ + "broad-exception-caught", + "cyclic-import", + "global-statement", + "invalid-name", + "missing-module-docstring", + "too-few-public-methods", + "too-many-public-methods", + "too-many-return-statements", + "wrong-import-order", + "wrong-import-position", + "missing-function-docstring", + "missing-class-docstring" +] diff --git a/relay.spec b/relay.spec index c21a829..57fedc7 100644 --- a/relay.spec +++ b/relay.spec @@ -9,13 +9,7 @@ a = Analysis( pathex=[], binaries=[], datas=[], - hiddenimports=[ - 'aputils.enums', - 'aputils.errors', - 'aputils.misc', - 'aputils.objects', - 'aputils.signer' - ], + hiddenimports=[], hookspath=[], hooksconfig={}, runtime_hooks=[], diff --git a/relay/__init__.py b/relay/__init__.py index 426b03e..13a85f7 100644 --- a/relay/__init__.py +++ b/relay/__init__.py @@ -1,3 +1 @@ -__version__ = '0.2.4' - -from . import logger +__version__ = '0.2.5' diff --git a/relay/application.py b/relay/application.py index dbe464f..a01aaec 100644 --- a/relay/application.py +++ b/relay/application.py @@ -1,31 +1,42 @@ +from __future__ import annotations + import asyncio -import logging -import os import queue import signal import threading import traceback +import typing from aiohttp import web from datetime import datetime, timedelta +from . import logger as logging from .config import RelayConfig from .database import RelayDatabase from .http_client import HttpClient -from .misc import DotDict, check_open_port, set_app -from .views import routes +from .misc import check_open_port +from .views import VIEWS + +if typing.TYPE_CHECKING: + from typing import Any + from .misc import Message + + +# pylint: disable=unsubscriptable-object class Application(web.Application): - def __init__(self, cfgpath): + def __init__(self, cfgpath: str): web.Application.__init__(self) - self['starttime'] = None + self['workers'] = [] + self['last_worker'] = 0 + self['start_time'] = None self['running'] = False self['config'] = RelayConfig(cfgpath) - if not self['config'].load(): - self['config'].save() + if not self.config.load(): + self.config.save() if self.config.is_docker: self.config.update({ @@ -34,13 +45,8 @@ class Application(web.Application): 'port': 8080 }) - self['workers'] = [] - self['last_worker'] = 0 - - set_app(self) - - self['database'] = RelayDatabase(self['config']) - self['database'].load() + self['database'] = RelayDatabase(self.config) + self.database.load() self['client'] = HttpClient( database = self.database, @@ -49,37 +55,39 @@ class Application(web.Application): cache_size = self.config.json_cache ) - self.set_signal_handler() + for path, view in VIEWS: + self.router.add_view(path, view) @property - def client(self): + def client(self) -> HttpClient: return self['client'] @property - def config(self): + def config(self) -> RelayConfig: return self['config'] @property - def database(self): + def database(self) -> RelayDatabase: return self['database'] @property - def uptime(self): - if not self['starttime']: + def uptime(self) -> timedelta: + if not self['start_time']: return timedelta(seconds=0) - uptime = datetime.now() - self['starttime'] + uptime = datetime.now() - self['start_time'] return timedelta(seconds=uptime.seconds) - def push_message(self, inbox, message): + def push_message(self, inbox: str, message: Message) -> None: if self.config.workers <= 0: - return asyncio.ensure_future(self.client.post(inbox, message)) + asyncio.ensure_future(self.client.post(inbox, message)) + return worker = self['workers'][self['last_worker']] worker.queue.put((inbox, message)) @@ -90,36 +98,45 @@ class Application(web.Application): self['last_worker'] = 0 - def set_signal_handler(self): - for sig in {'SIGHUP', 'SIGINT', 'SIGQUIT', 'SIGTERM'}: + def set_signal_handler(self, startup: bool) -> None: + for sig in ('SIGHUP', 'SIGINT', 'SIGQUIT', 'SIGTERM'): try: - signal.signal(getattr(signal, sig), self.stop) + signal.signal(getattr(signal, sig), self.stop if startup else signal.SIG_DFL) # some signals don't exist in windows, so skip them except AttributeError: pass - def run(self): + def run(self) -> None: if not check_open_port(self.config.listen, self.config.port): - return logging.error(f'A server is already running on port {self.config.port}') + logging.error('A server is already running on port %i', self.config.port) + return - for route in routes: - self.router.add_route(*route) + for view in VIEWS: + self.router.add_view(*view) + + logging.info( + 'Starting webserver at %s (%s:%i)', + self.config.host, + self.config.listen, + self.config.port + ) - logging.info(f'Starting webserver at {self.config.host} ({self.config.listen}:{self.config.port})') asyncio.run(self.handle_run()) - def stop(self, *_): + def stop(self, *_: Any) -> None: self['running'] = False - async def handle_run(self): + async def handle_run(self) -> None: self['running'] = True + self.set_signal_handler(True) + if self.config.workers > 0: - for i in range(self.config.workers): + for _ in range(self.config.workers): worker = PushWorker(self) worker.start() @@ -128,33 +145,40 @@ class Application(web.Application): runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"') await runner.setup() - site = web.TCPSite(runner, + site = web.TCPSite( + runner, host = self.config.listen, port = self.config.port, reuse_address = True ) await site.start() - self['starttime'] = datetime.now() + self['start_time'] = datetime.now() while self['running']: await asyncio.sleep(0.25) await site.stop() + await self.client.close() - self['starttime'] = None + self['start_time'] = None self['running'] = False self['workers'].clear() class PushWorker(threading.Thread): - def __init__(self, app): + def __init__(self, app: Application): threading.Thread.__init__(self) self.app = app self.queue = queue.Queue() + self.client = None - def run(self): + def run(self) -> None: + asyncio.run(self.handle_queue()) + + + async def handle_queue(self) -> None: self.client = HttpClient( database = self.app.database, limit = self.app.config.push_limit, @@ -162,15 +186,11 @@ class PushWorker(threading.Thread): cache_size = self.app.config.json_cache ) - asyncio.run(self.handle_queue()) - - - async def handle_queue(self): while self.app['running']: try: inbox, message = self.queue.get(block=True, timeout=0.25) self.queue.task_done() - logging.verbose(f'New push from Thread-{threading.get_ident()}') + logging.verbose('New push from Thread-%i', threading.get_ident()) await self.client.post(inbox, message) except queue.Empty: @@ -181,36 +201,3 @@ class PushWorker(threading.Thread): traceback.print_exc() await self.client.close() - - -## Can't sub-class web.Request, so let's just add some properties -def request_actor(self): - try: return self['actor'] - except KeyError: pass - - -def request_instance(self): - try: return self['instance'] - except KeyError: pass - - -def request_message(self): - try: return self['message'] - except KeyError: pass - - -def request_signature(self): - if 'signature' not in self._state: - try: self['signature'] = DotDict.new_from_signature(self.headers['signature']) - except KeyError: return - - return self['signature'] - - -setattr(web.Request, 'actor', property(request_actor)) -setattr(web.Request, 'instance', property(request_instance)) -setattr(web.Request, 'message', property(request_message)) -setattr(web.Request, 'signature', property(request_signature)) - -setattr(web.Request, 'config', property(lambda self: self.app.config)) -setattr(web.Request, 'database', property(lambda self: self.app.database)) diff --git a/relay/config.py b/relay/config.py index 996fa9f..e684ead 100644 --- a/relay/config.py +++ b/relay/config.py @@ -1,5 +1,7 @@ -import json +from __future__ import annotations + import os +import typing import yaml from functools import cached_property @@ -8,6 +10,10 @@ from urllib.parse import urlparse from .misc import DotDict, boolean +if typing.TYPE_CHECKING: + from typing import Any + from .database import RelayDatabase + RELAY_SOFTWARE = [ 'activityrelay', # https://git.pleroma.social/pleroma/relay @@ -25,17 +31,19 @@ APKEYS = [ class RelayConfig(DotDict): - def __init__(self, path): + __slots__ = ('path', ) + + def __init__(self, path: str | Path): DotDict.__init__(self, {}) if self.is_docker: path = '/data/config.yaml' - self._path = Path(path).expanduser() + self._path = Path(path).expanduser().resolve() self.reset() - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: Any) -> None: if key in ['blocked_instances', 'blocked_software', 'whitelist']: assert isinstance(value, (list, set, tuple)) @@ -51,36 +59,31 @@ class RelayConfig(DotDict): @property - def db(self): + def db(self) -> RelayDatabase: return Path(self['db']).expanduser().resolve() @property - def path(self): - return self._path - - - @property - def actor(self): + def actor(self) -> str: return f'https://{self.host}/actor' @property - def inbox(self): + def inbox(self) -> str: return f'https://{self.host}/inbox' @property - def keyid(self): + def keyid(self) -> str: return f'{self.actor}#main-key' @cached_property - def is_docker(self): + def is_docker(self) -> bool: return bool(os.environ.get('DOCKER_RUNNING')) - def reset(self): + def reset(self) -> None: self.clear() self.update({ 'db': str(self._path.parent.joinpath(f'{self._path.stem}.jsonld')), @@ -99,7 +102,7 @@ class RelayConfig(DotDict): }) - def ban_instance(self, instance): + def ban_instance(self, instance: str) -> bool: if instance.startswith('http'): instance = urlparse(instance).hostname @@ -110,7 +113,7 @@ class RelayConfig(DotDict): return True - def unban_instance(self, instance): + def unban_instance(self, instance: str) -> bool: if instance.startswith('http'): instance = urlparse(instance).hostname @@ -118,11 +121,11 @@ class RelayConfig(DotDict): self.blocked_instances.remove(instance) return True - except: + except ValueError: return False - def ban_software(self, software): + def ban_software(self, software: str) -> bool: if self.is_banned_software(software): return False @@ -130,16 +133,16 @@ class RelayConfig(DotDict): return True - def unban_software(self, software): + def unban_software(self, software: str) -> bool: try: self.blocked_software.remove(software) return True - except: + except ValueError: return False - def add_whitelist(self, instance): + def add_whitelist(self, instance: str) -> bool: if instance.startswith('http'): instance = urlparse(instance).hostname @@ -150,7 +153,7 @@ class RelayConfig(DotDict): return True - def del_whitelist(self, instance): + def del_whitelist(self, instance: str) -> bool: if instance.startswith('http'): instance = urlparse(instance).hostname @@ -158,32 +161,32 @@ class RelayConfig(DotDict): self.whitelist.remove(instance) return True - except: + except ValueError: return False - def is_banned(self, instance): + def is_banned(self, instance: str) -> bool: if instance.startswith('http'): instance = urlparse(instance).hostname return instance in self.blocked_instances - def is_banned_software(self, software): + def is_banned_software(self, software: str) -> bool: if not software: return False return software.lower() in self.blocked_software - def is_whitelisted(self, instance): + def is_whitelisted(self, instance: str) -> bool: if instance.startswith('http'): instance = urlparse(instance).hostname return instance in self.whitelist - def load(self): + def load(self) -> bool: self.reset() options = {} @@ -195,7 +198,7 @@ class RelayConfig(DotDict): pass try: - with open(self.path) as fd: + with self._path.open('r', encoding = 'UTF-8') as fd: config = yaml.load(fd, **options) except FileNotFoundError: @@ -214,7 +217,7 @@ class RelayConfig(DotDict): continue - elif key not in self: + if key not in self: continue self[key] = value @@ -225,7 +228,7 @@ class RelayConfig(DotDict): return True - def save(self): + def save(self) -> None: config = { # just turning config.db into a string is good enough for now 'db': str(self.db), @@ -239,7 +242,5 @@ class RelayConfig(DotDict): 'ap': {key: self[key] for key in APKEYS} } - with open(self._path, 'w') as fd: + with self._path.open('w', encoding = 'utf-8') as fd: yaml.dump(config, fd, sort_keys=False) - - return config diff --git a/relay/database.py b/relay/database.py index ad093cd..5d059dd 100644 --- a/relay/database.py +++ b/relay/database.py @@ -1,14 +1,21 @@ -import aputils -import asyncio -import json -import logging -import traceback +from __future__ import annotations +import json +import typing + +from aputils.signer import Signer from urllib.parse import urlparse +from . import logger as logging + +if typing.TYPE_CHECKING: + from typing import Iterator, Optional + from .config import RelayConfig + from .misc import Message + class RelayDatabase(dict): - def __init__(self, config): + def __init__(self, config: RelayConfig): dict.__init__(self, { 'relay-list': {}, 'private-key': None, @@ -21,16 +28,16 @@ class RelayDatabase(dict): @property - def hostnames(self): + def hostnames(self) -> tuple[str]: return tuple(self['relay-list'].keys()) @property - def inboxes(self): + def inboxes(self) -> tuple[dict[str, str]]: return tuple(data['inbox'] for data in self['relay-list'].values()) - def load(self): + def load(self) -> bool: new_db = True try: @@ -40,7 +47,7 @@ class RelayDatabase(dict): self['version'] = data.get('version', None) self['private-key'] = data.get('private-key') - if self['version'] == None: + if self['version'] is None: self['version'] = 1 if 'actorKeys' in data: @@ -58,7 +65,9 @@ class RelayDatabase(dict): self['relay-list'] = data.get('relay-list', {}) for domain, instance in self['relay-list'].items(): - if self.config.is_banned(domain) or (self.config.whitelist_enabled and not self.config.is_whitelisted(domain)): + if self.config.is_banned(domain) or \ + (self.config.whitelist_enabled and not self.config.is_whitelisted(domain)): + self.del_inbox(domain) continue @@ -75,36 +84,40 @@ class RelayDatabase(dict): raise e from None if not self['private-key']: - logging.info("No actor keys present, generating 4096-bit RSA keypair.") - self.signer = aputils.Signer.new(self.config.keyid, size=4096) + logging.info('No actor keys present, generating 4096-bit RSA keypair.') + self.signer = Signer.new(self.config.keyid, size=4096) self['private-key'] = self.signer.export() else: - self.signer = aputils.Signer(self['private-key'], self.config.keyid) + self.signer = Signer(self['private-key'], self.config.keyid) self.save() return not new_db - def save(self): - with self.config.db.open('w') as fd: + def save(self) -> None: + with self.config.db.open('w', encoding = 'UTF-8') as fd: json.dump(self, fd, indent=4) - def get_inbox(self, domain, fail=False): + def get_inbox(self, domain: str, fail: Optional[bool] = False) -> dict[str, str] | None: if domain.startswith('http'): domain = urlparse(domain).hostname - inbox = self['relay-list'].get(domain) - - if inbox: + if (inbox := self['relay-list'].get(domain)): return inbox if fail: raise KeyError(domain) + return None + + + def add_inbox(self, + inbox: str, + followid: Optional[str] = None, + software: Optional[str] = None) -> dict[str, str]: - def add_inbox(self, inbox, followid=None, software=None): assert inbox.startswith('https'), 'Inbox must be a url' domain = urlparse(inbox).hostname instance = self.get_inbox(domain) @@ -125,11 +138,15 @@ class RelayDatabase(dict): 'software': software } - logging.verbose(f'Added inbox to database: {inbox}') + logging.verbose('Added inbox to database: %s', inbox) return self['relay-list'][domain] - def del_inbox(self, domain, followid=None, fail=False): + def del_inbox(self, + domain: str, + followid: Optional[str] = None, + fail: Optional[bool] = False) -> bool: + data = self.get_inbox(domain, fail=False) if not data: @@ -140,17 +157,17 @@ class RelayDatabase(dict): if not data['followid'] or not followid or data['followid'] == followid: del self['relay-list'][data['domain']] - logging.verbose(f'Removed inbox from database: {data["inbox"]}') + logging.verbose('Removed inbox from database: %s', data['inbox']) return True if fail: raise ValueError('Follow IDs do not match') - logging.debug(f'Follow ID does not match: db = {data["followid"]}, object = {followid}') + logging.debug('Follow ID does not match: db = %s, object = %s', data['followid'], followid) return False - def get_request(self, domain, fail=True): + def get_request(self, domain: str, fail: bool = True) -> dict[str, str] | None: if domain.startswith('http'): domain = urlparse(domain).hostname @@ -161,8 +178,10 @@ class RelayDatabase(dict): if fail: raise e + return None - def add_request(self, actor, inbox, followid): + + def add_request(self, actor: str, inbox: str, followid: str) -> None: domain = urlparse(inbox).hostname try: @@ -179,17 +198,17 @@ class RelayDatabase(dict): } - def del_request(self, domain): + def del_request(self, domain: str) -> None: if domain.startswith('http'): - domain = urlparse(inbox).hostname + domain = urlparse(domain).hostname del self['follow-requests'][domain] - def distill_inboxes(self, message): + def distill_inboxes(self, message: Message) -> Iterator[str]: src_domains = { message.domain, - urlparse(message.objectid).netloc + urlparse(message.object_id).netloc } for domain, instance in self['relay-list'].items(): diff --git a/relay/http_client.py b/relay/http_client.py index 81fcd46..6f2a044 100644 --- a/relay/http_client.py +++ b/relay/http_client.py @@ -1,21 +1,23 @@ -import logging +from __future__ import annotations + import traceback +import typing from aiohttp import ClientSession, ClientTimeout, TCPConnector from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError from asyncio.exceptions import TimeoutError as AsyncTimeoutError -from aputils import Nodeinfo, WellKnownNodeinfo -from datetime import datetime +from aputils.objects import Nodeinfo, WellKnownNodeinfo from cachetools import LRUCache from json.decoder import JSONDecodeError from urllib.parse import urlparse from . import __version__ -from .misc import ( - MIMETYPES, - DotDict, - Message -) +from . import logger as logging +from .misc import MIMETYPES, Message + +if typing.TYPE_CHECKING: + from typing import Any, Callable, Optional + from .database import RelayDatabase HEADERS = { @@ -24,40 +26,31 @@ HEADERS = { } -class Cache(LRUCache): - def set_maxsize(self, value): - self.__maxsize = int(value) - - class HttpClient: - def __init__(self, database, limit=100, timeout=10, cache_size=1024): + def __init__(self, + database: RelayDatabase, + limit: Optional[int] = 100, + timeout: Optional[int] = 10, + cache_size: Optional[int] = 1024): + self.database = database - self.cache = Cache(cache_size) - self.cfg = {'limit': limit, 'timeout': timeout} + self.cache = LRUCache(cache_size) + self.limit = limit + self.timeout = timeout self._conn = None self._session = None - async def __aenter__(self): + async def __aenter__(self) -> HttpClient: await self.open() return self - async def __aexit__(self, *_): + async def __aexit__(self, *_: Any) -> None: await self.close() - @property - def limit(self): - return self.cfg['limit'] - - - @property - def timeout(self): - return self.cfg['timeout'] - - - async def open(self): + async def open(self) -> None: if self._session: return @@ -74,7 +67,7 @@ class HttpClient: ) - async def close(self): + async def close(self) -> None: if not self._session: return @@ -85,11 +78,19 @@ class HttpClient: self._session = None - async def get(self, url, sign_headers=False, loads=None, force=False): + async def get(self, # pylint: disable=too-many-branches + url: str, + sign_headers: Optional[bool] = False, + loads: Optional[Callable] = None, + force: Optional[bool] = False) -> Message | dict | None: + await self.open() - try: url, _ = url.split('#', 1) - except: pass + try: + url, _ = url.split('#', 1) + + except ValueError: + pass if not force and url in self.cache: return self.cache[url] @@ -100,51 +101,53 @@ class HttpClient: headers.update(self.database.signer.sign_headers('GET', url, algorithm='original')) try: - logging.verbose(f'Fetching resource: {url}') + logging.debug('Fetching resource: %s', url) async with self._session.get(url, headers=headers) as resp: ## Not expecting a response with 202s, so just return if resp.status == 202: - return + return None - elif resp.status != 200: - logging.verbose(f'Received error when requesting {url}: {resp.status}') - logging.verbose(await resp.read()) # change this to debug - return + if resp.status != 200: + logging.verbose('Received error when requesting %s: %i', url, resp.status) + logging.debug(await resp.read()) + return None if loads: message = await resp.json(loads=loads) elif resp.content_type == MIMETYPES['activity']: - message = await resp.json(loads=Message.new_from_json) + message = await resp.json(loads = Message.parse) elif resp.content_type == MIMETYPES['json']: - message = await resp.json(loads=DotDict.new_from_json) + message = await resp.json() else: - # todo: raise TypeError or something - logging.verbose(f'Invalid Content-Type for "{url}": {resp.content_type}') - return logging.debug(f'Response: {resp.read()}') + logging.verbose('Invalid Content-Type for "%s": %s', url, resp.content_type) + logging.debug('Response: %s', await resp.read()) + return None - logging.debug(f'{url} >> resp {message.to_json(4)}') + logging.debug('%s >> resp %s', url, message.to_json(4)) self.cache[url] = message return message except JSONDecodeError: - logging.verbose(f'Failed to parse JSON') + logging.verbose('Failed to parse JSON') except ClientSSLError: - logging.verbose(f'SSL error when connecting to {urlparse(url).netloc}') + logging.verbose('SSL error when connecting to %s', urlparse(url).netloc) except (AsyncTimeoutError, ClientConnectionError): - logging.verbose(f'Failed to connect to {urlparse(url).netloc}') + logging.verbose('Failed to connect to %s', urlparse(url).netloc) - except Exception as e: + except Exception: traceback.print_exc() + return None - async def post(self, url, message): + + async def post(self, url: str, message: Message) -> None: await self.open() instance = self.database.get_inbox(url) @@ -160,38 +163,39 @@ class HttpClient: headers.update(self.database.signer.sign_headers('POST', url, message, algorithm=algorithm)) try: - logging.verbose(f'Sending "{message.type}" to {url}') + logging.verbose('Sending "%s" to %s', message.type, url) async with self._session.post(url, headers=headers, data=message.to_json()) as resp: ## Not expecting a response, so just return if resp.status in {200, 202}: - return logging.verbose(f'Successfully sent "{message.type}" to {url}') + logging.verbose('Successfully sent "%s" to %s', message.type, url) + return - logging.verbose(f'Received error when pushing to {url}: {resp.status}') - return logging.verbose(await resp.read()) # change this to debug + logging.verbose('Received error when pushing to %s: %i', url, resp.status) + logging.debug(await resp.read()) + return except ClientSSLError: - logging.warning(f'SSL error when pushing to {urlparse(url).netloc}') + logging.warning('SSL error when pushing to %s', urlparse(url).netloc) except (AsyncTimeoutError, ClientConnectionError): - logging.warning(f'Failed to connect to {urlparse(url).netloc} for message push') + logging.warning('Failed to connect to %s for message push', urlparse(url).netloc) ## prevent workers from being brought down - except Exception as e: + except Exception: traceback.print_exc() - ## Additional methods ## - async def fetch_nodeinfo(self, domain): + async def fetch_nodeinfo(self, domain: str) -> Nodeinfo | None: nodeinfo_url = None wk_nodeinfo = await self.get( f'https://{domain}/.well-known/nodeinfo', - loads = WellKnownNodeinfo.new_from_json + loads = WellKnownNodeinfo.parse ) if not wk_nodeinfo: - logging.verbose(f'Failed to fetch well-known nodeinfo url for domain: {domain}') - return False + logging.verbose('Failed to fetch well-known nodeinfo url for %s', domain) + return None for version in ['20', '21']: try: @@ -201,22 +205,22 @@ class HttpClient: pass if not nodeinfo_url: - logging.verbose(f'Failed to fetch nodeinfo url for domain: {domain}') - return False + logging.verbose('Failed to fetch nodeinfo url for %s', domain) + return None - return await self.get(nodeinfo_url, loads=Nodeinfo.new_from_json) or False + return await self.get(nodeinfo_url, loads = Nodeinfo.parse) or None -async def get(database, *args, **kwargs): +async def get(database: RelayDatabase, *args: Any, **kwargs: Any) -> Message | dict | None: async with HttpClient(database) as client: return await client.get(*args, **kwargs) -async def post(database, *args, **kwargs): +async def post(database: RelayDatabase, *args: Any, **kwargs: Any) -> None: async with HttpClient(database) as client: return await client.post(*args, **kwargs) -async def fetch_nodeinfo(database, *args, **kwargs): +async def fetch_nodeinfo(database: RelayDatabase, *args: Any, **kwargs: Any) -> Nodeinfo | None: async with HttpClient(database) as client: return await client.fetch_nodeinfo(*args, **kwargs) diff --git a/relay/logger.py b/relay/logger.py index 166cbf2..0d1d451 100644 --- a/relay/logger.py +++ b/relay/logger.py @@ -1,40 +1,57 @@ +from __future__ import annotations + import logging import os +import typing from pathlib import Path +if typing.TYPE_CHECKING: + from typing import Any, Callable -## Add the verbose logging level -def verbose(message, *args, **kwargs): - if not logging.root.isEnabledFor(logging.VERBOSE): + +LOG_LEVELS: dict[str, int] = { + 'DEBUG': logging.DEBUG, + 'VERBOSE': 15, + 'INFO': logging.INFO, + 'WARNING': logging.WARNING, + 'ERROR': logging.ERROR, + 'CRITICAL': logging.CRITICAL +} + + +debug: Callable = logging.debug +info: Callable = logging.info +warning: Callable = logging.warning +error: Callable = logging.error +critical: Callable = logging.critical + + +def verbose(message: str, *args: Any, **kwargs: Any) -> None: + if not logging.root.isEnabledFor(LOG_LEVELS['VERBOSE']): return - logging.log(logging.VERBOSE, message, *args, **kwargs) - -setattr(logging, 'verbose', verbose) -setattr(logging, 'VERBOSE', 15) -logging.addLevelName(15, 'VERBOSE') + logging.log(LOG_LEVELS['VERBOSE'], message, *args, **kwargs) -## Get log level and file from environment if possible +logging.addLevelName(LOG_LEVELS['VERBOSE'], 'VERBOSE') env_log_level = os.environ.get('LOG_LEVEL', 'INFO').upper() try: - env_log_file = Path(os.environ.get('LOG_FILE')).expanduser().resolve() + env_log_file = Path(os.environ['LOG_FILE']).expanduser().resolve() -except TypeError: +except KeyError: env_log_file = None -## Make sure the level from the environment is valid try: - log_level = getattr(logging, env_log_level) + log_level = LOG_LEVELS[env_log_level] -except AttributeError: +except KeyError: + logging.warning('Invalid log level: %s', env_log_level) log_level = logging.INFO -## Set logging config handlers = [logging.StreamHandler()] if env_log_file: @@ -42,6 +59,6 @@ if env_log_file: logging.basicConfig( level = log_level, - format = "[%(asctime)s] %(levelname)s: %(message)s", + format = '[%(asctime)s] %(levelname)s: %(message)s', handlers = handlers ) diff --git a/relay/manage.py b/relay/manage.py index c36f876..b0c5cb3 100644 --- a/relay/manage.py +++ b/relay/manage.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import Crypto import asyncio import click -import logging import platform +import typing from urllib.parse import urlparse @@ -11,6 +13,12 @@ from . import http_client as http from .application import Application from .config import RELAY_SOFTWARE +if typing.TYPE_CHECKING: + from typing import Any + + +# pylint: disable=unsubscriptable-object,unsupported-assignment-operation + app = None CONFIG_IGNORE = {'blocked_software', 'blocked_instances', 'whitelist'} @@ -20,7 +28,7 @@ CONFIG_IGNORE = {'blocked_software', 'blocked_instances', 'whitelist'} @click.option('--config', '-c', default='relay.yaml', help='path to the relay\'s config') @click.version_option(version=__version__, prog_name='ActivityRelay') @click.pass_context -def cli(ctx, config): +def cli(ctx: click.Context, config: str) -> None: global app app = Application(config) @@ -33,11 +41,14 @@ def cli(ctx, config): @cli.command('setup') -def cli_setup(): +def cli_setup() -> None: 'Generate a new config' while True: - app.config.host = click.prompt('What domain will the relay be hosted on?', default=app.config.host) + app.config.host = click.prompt( + 'What domain will the relay be hosted on?', + default = app.config.host + ) if not app.config.host.endswith('example.com'): break @@ -45,10 +56,18 @@ def cli_setup(): click.echo('The domain must not be example.com') if not app.config.is_docker: - app.config.listen = click.prompt('Which address should the relay listen on?', default=app.config.listen) + app.config.listen = click.prompt( + 'Which address should the relay listen on?', + default = app.config.listen + ) while True: - app.config.port = click.prompt('What TCP port should the relay listen on?', default=app.config.port, type=int) + app.config.port = click.prompt( + 'What TCP port should the relay listen on?', + default = app.config.port, + type = int + ) + break app.config.save() @@ -58,39 +77,47 @@ def cli_setup(): @cli.command('run') -def cli_run(): +def cli_run() -> None: 'Run the relay' if app.config.host.endswith('example.com'): - return click.echo('Relay is not set up. Please edit your relay config or run "activityrelay setup".') + click.echo( + 'Relay is not set up. Please edit your relay config or run "activityrelay setup".' + ) + + return vers_split = platform.python_version().split('.') pip_command = 'pip3 uninstall pycrypto && pip3 install pycryptodome' if Crypto.__version__ == '2.6.1': if int(vers_split[1]) > 7: - click.echo('Error: PyCrypto is broken on Python 3.8+. Please replace it with pycryptodome before running again. Exiting...') - return click.echo(pip_command) + click.echo( + 'Error: PyCrypto is broken on Python 3.8+. Please replace it with pycryptodome ' + + 'before running again. Exiting...' + ) - else: - click.echo('Warning: PyCrypto is old and should be replaced with pycryptodome') - return click.echo(pip_command) + click.echo(pip_command) + return + + click.echo('Warning: PyCrypto is old and should be replaced with pycryptodome') + click.echo(pip_command) + return if not misc.check_open_port(app.config.listen, app.config.port): - return click.echo(f'Error: A server is already running on port {app.config.port}') + click.echo(f'Error: A server is already running on port {app.config.port}') + return app.run() -# todo: add config default command for resetting config key @cli.group('config') -def cli_config(): +def cli_config() -> None: 'Manage the relay config' - pass @cli_config.command('list') -def cli_config_list(): +def cli_config_list() -> None: 'List the current relay config' click.echo('Relay Config:') @@ -104,7 +131,7 @@ def cli_config_list(): @cli_config.command('set') @click.argument('key') @click.argument('value') -def cli_config_set(key, value): +def cli_config_set(key: str, value: Any) -> None: 'Set a config value' app.config[key] = value @@ -114,13 +141,12 @@ def cli_config_set(key, value): @cli.group('inbox') -def cli_inbox(): +def cli_inbox() -> None: 'Manage the inboxes in the database' - pass @cli_inbox.command('list') -def cli_inbox_list(): +def cli_inbox_list() -> None: 'List the connected instances or relays' click.echo('Connected to the following instances or relays:') @@ -131,11 +157,12 @@ def cli_inbox_list(): @cli_inbox.command('follow') @click.argument('actor') -def cli_inbox_follow(actor): +def cli_inbox_follow(actor: str) -> None: 'Follow an actor (Relay must be running)' if app.config.is_banned(actor): - return click.echo(f'Error: Refusing to follow banned actor: {actor}') + click.echo(f'Error: Refusing to follow banned actor: {actor}') + return if not actor.startswith('http'): domain = actor @@ -152,7 +179,8 @@ def cli_inbox_follow(actor): actor_data = asyncio.run(http.get(app.database, actor, sign_headers=True)) if not actor_data: - return click.echo(f'Failed to fetch actor: {actor}') + click.echo(f'Failed to fetch actor: {actor}') + return inbox = actor_data.shared_inbox @@ -167,7 +195,7 @@ def cli_inbox_follow(actor): @cli_inbox.command('unfollow') @click.argument('actor') -def cli_inbox_unfollow(actor): +def cli_inbox_unfollow(actor: str) -> None: 'Unfollow an actor (Relay must be running)' if not actor.startswith('http'): @@ -205,17 +233,19 @@ def cli_inbox_unfollow(actor): @cli_inbox.command('add') @click.argument('inbox') -def cli_inbox_add(inbox): +def cli_inbox_add(inbox: str) -> None: 'Add an inbox to the database' if not inbox.startswith('http'): inbox = f'https://{inbox}/inbox' if app.config.is_banned(inbox): - return click.echo(f'Error: Refusing to add banned inbox: {inbox}') + click.echo(f'Error: Refusing to add banned inbox: {inbox}') + return if app.database.get_inbox(inbox): - return click.echo(f'Error: Inbox already in database: {inbox}') + click.echo(f'Error: Inbox already in database: {inbox}') + return app.database.add_inbox(inbox) app.database.save() @@ -225,7 +255,7 @@ def cli_inbox_add(inbox): @cli_inbox.command('remove') @click.argument('inbox') -def cli_inbox_remove(inbox): +def cli_inbox_remove(inbox: str) -> None: 'Remove an inbox from the database' try: @@ -242,13 +272,12 @@ def cli_inbox_remove(inbox): @cli.group('instance') -def cli_instance(): +def cli_instance() -> None: 'Manage instance bans' - pass @cli_instance.command('list') -def cli_instance_list(): +def cli_instance_list() -> None: 'List all banned instances' click.echo('Banned instances or relays:') @@ -259,7 +288,7 @@ def cli_instance_list(): @cli_instance.command('ban') @click.argument('target') -def cli_instance_ban(target): +def cli_instance_ban(target: str) -> None: 'Ban an instance and remove the associated inbox if it exists' if target.startswith('http'): @@ -279,7 +308,7 @@ def cli_instance_ban(target): @cli_instance.command('unban') @click.argument('target') -def cli_instance_unban(target): +def cli_instance_unban(target: str) -> None: 'Unban an instance' if app.config.unban_instance(target): @@ -292,13 +321,12 @@ def cli_instance_unban(target): @cli.group('software') -def cli_software(): +def cli_software() -> None: 'Manage banned software' - pass @cli_software.command('list') -def cli_software_list(): +def cli_software_list() -> None: 'List all banned software' click.echo('Banned software:') @@ -308,19 +336,21 @@ def cli_software_list(): @cli_software.command('ban') -@click.option('--fetch-nodeinfo/--ignore-nodeinfo', '-f', 'fetch_nodeinfo', default=False, - help='Treat NAME like a domain and try to fet the software name from nodeinfo' +@click.option( + '--fetch-nodeinfo/--ignore-nodeinfo', '-f', 'fetch_nodeinfo', default = False, + help = 'Treat NAME like a domain and try to fet the software name from nodeinfo' ) @click.argument('name') -def cli_software_ban(name, fetch_nodeinfo): +def cli_software_ban(name: str, fetch_nodeinfo: bool) -> None: 'Ban software. Use RELAYS for NAME to ban relays' if name == 'RELAYS': - for name in RELAY_SOFTWARE: - app.config.ban_software(name) + for software in RELAY_SOFTWARE: + app.config.ban_software(software) app.config.save() - return click.echo('Banned all relay software') + click.echo('Banned all relay software') + return if fetch_nodeinfo: nodeinfo = asyncio.run(http.fetch_nodeinfo(app.database, name)) @@ -332,25 +362,28 @@ def cli_software_ban(name, fetch_nodeinfo): if app.config.ban_software(name): app.config.save() - return click.echo(f'Banned software: {name}') + click.echo(f'Banned software: {name}') + return click.echo(f'Software already banned: {name}') @cli_software.command('unban') -@click.option('--fetch-nodeinfo/--ignore-nodeinfo', '-f', 'fetch_nodeinfo', default=False, - help='Treat NAME like a domain and try to fet the software name from nodeinfo' +@click.option( + '--fetch-nodeinfo/--ignore-nodeinfo', '-f', 'fetch_nodeinfo', default = False, + help = 'Treat NAME like a domain and try to fet the software name from nodeinfo' ) @click.argument('name') -def cli_software_unban(name, fetch_nodeinfo): +def cli_software_unban(name: str, fetch_nodeinfo: bool) -> None: 'Ban software. Use RELAYS for NAME to unban relays' if name == 'RELAYS': - for name in RELAY_SOFTWARE: - app.config.unban_software(name) + for software in RELAY_SOFTWARE: + app.config.unban_software(software) app.config.save() - return click.echo('Unbanned all relay software') + click.echo('Unbanned all relay software') + return if fetch_nodeinfo: nodeinfo = asyncio.run(http.fetch_nodeinfo(app.database, name)) @@ -362,19 +395,19 @@ def cli_software_unban(name, fetch_nodeinfo): if app.config.unban_software(name): app.config.save() - return click.echo(f'Unbanned software: {name}') + click.echo(f'Unbanned software: {name}') + return click.echo(f'Software wasn\'t banned: {name}') @cli.group('whitelist') -def cli_whitelist(): +def cli_whitelist() -> None: 'Manage the instance whitelist' - pass @cli_whitelist.command('list') -def cli_whitelist_list(): +def cli_whitelist_list() -> None: 'List all the instances in the whitelist' click.echo('Current whitelisted domains') @@ -385,11 +418,12 @@ def cli_whitelist_list(): @cli_whitelist.command('add') @click.argument('instance') -def cli_whitelist_add(instance): +def cli_whitelist_add(instance: str) -> None: 'Add an instance to the whitelist' if not app.config.add_whitelist(instance): - return click.echo(f'Instance already in the whitelist: {instance}') + click.echo(f'Instance already in the whitelist: {instance}') + return app.config.save() click.echo(f'Instance added to the whitelist: {instance}') @@ -397,11 +431,12 @@ def cli_whitelist_add(instance): @cli_whitelist.command('remove') @click.argument('instance') -def cli_whitelist_remove(instance): +def cli_whitelist_remove(instance: str) -> None: 'Remove an instance from the whitelist' if not app.config.del_whitelist(instance): - return click.echo(f'Instance not in the whitelist: {instance}') + click.echo(f'Instance not in the whitelist: {instance}') + return app.config.save() @@ -413,14 +448,15 @@ def cli_whitelist_remove(instance): @cli_whitelist.command('import') -def cli_whitelist_import(): +def cli_whitelist_import() -> None: 'Add all current inboxes to the whitelist' for domain in app.database.hostnames: cli_whitelist_add.callback(domain) -def main(): +def main() -> None: + # pylint: disable=no-value-for-parameter cli(prog_name='relay') diff --git a/relay/misc.py b/relay/misc.py index a98088f..7244eaa 100644 --- a/relay/misc.py +++ b/relay/misc.py @@ -1,21 +1,31 @@ -import aputils -import asyncio -import base64 +from __future__ import annotations + import json -import logging import socket import traceback -import uuid +import typing +from aiohttp.abc import AbstractView from aiohttp.hdrs import METH_ALL as METHODS -from aiohttp.web import Response as AiohttpResponse, View as AiohttpView -from datetime import datetime +from aiohttp.web import Request as AiohttpRequest, Response as AiohttpResponse +from aiohttp.web_exceptions import HTTPMethodNotAllowed +from aputils.errors import SignatureFailureError +from aputils.misc import Digest, HttpDate, Signature +from aputils.message import Message as ApMessage +from functools import cached_property from json.decoder import JSONDecodeError -from urllib.parse import urlparse from uuid import uuid4 +from . import logger as logging + +if typing.TYPE_CHECKING: + from typing import Any, Coroutine, Generator, Optional, Type + from aputils.signer import Signer + from .application import Application + from .config import RelayConfig + from .database import RelayDatabase + from .http_client import HttpClient -app = None MIMETYPES = { 'activity': 'application/activity+json', @@ -30,94 +40,87 @@ NODEINFO_NS = { } -def set_app(new_app): - global app - app = new_app - - -def boolean(value): +def boolean(value: Any) -> bool: if isinstance(value, str): if value.lower() in ['on', 'y', 'yes', 'true', 'enable', 'enabled', '1']: return True - elif value.lower() in ['off', 'n', 'no', 'false', 'disable', 'disable', '0']: + if value.lower() in ['off', 'n', 'no', 'false', 'disable', 'disable', '0']: return False - else: - raise TypeError(f'Cannot parse string "{value}" as a boolean') + raise TypeError(f'Cannot parse string "{value}" as a boolean') - elif isinstance(value, int): + if isinstance(value, int): if value == 1: return True - elif value == 0: + if value == 0: return False - else: - raise ValueError('Integer value must be 1 or 0') + raise ValueError('Integer value must be 1 or 0') - elif value == None: + if value is None: return False - try: - return value.__bool__() - - except AttributeError: - raise TypeError(f'Cannot convert object of type "{clsname(value)}"') + return bool(value) -def check_open_port(host, port): +def check_open_port(host: str, port: int) -> bool: if host == '0.0.0.0': host = '127.0.0.1' with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: try: - return s.connect_ex((host , port)) != 0 + return s.connect_ex((host, port)) != 0 - except socket.error as e: + except socket.error: return False class DotDict(dict): - def __init__(self, _data, **kwargs): + def __init__(self, _data: dict[str, Any], **kwargs: Any): dict.__init__(self) self.update(_data, **kwargs) - def __getattr__(self, k): + def __getattr__(self, key: str) -> str: try: - return self[k] + return self[key] except KeyError: - raise AttributeError(f'{self.__class__.__name__} object has no attribute {k}') from None + raise AttributeError( + f'{self.__class__.__name__} object has no attribute {key}' + ) from None - def __setattr__(self, k, v): - if k.startswith('_'): - super().__setattr__(k, v) + def __setattr__(self, key: str, value: Any) -> None: + if key.startswith('_'): + super().__setattr__(key, value) else: - self[k] = v + self[key] = value - def __setitem__(self, k, v): - if type(v) == dict: - v = DotDict(v) + def __setitem__(self, key: str, value: Any) -> None: + if type(value) is dict: # pylint: disable=unidiomatic-typecheck + value = DotDict(value) - super().__setitem__(k, v) + super().__setitem__(key, value) - def __delattr__(self, k): + def __delattr__(self, key: str) -> None: try: - dict.__delitem__(self, k) + dict.__delitem__(self, key) except KeyError: - raise AttributeError(f'{self.__class__.__name__} object has no attribute {k}') from None + raise AttributeError( + f'{self.__class__.__name__} object has no attribute {key}' + ) from None @classmethod - def new_from_json(cls, data): + def new_from_json(cls: Type[DotDict], data: dict[str, Any]) -> DotDict[str, Any]: if not data: raise JSONDecodeError('Empty body', data, 1) @@ -125,11 +128,11 @@ class DotDict(dict): return cls(json.loads(data)) except ValueError: - raise JSONDecodeError('Invalid body', data, 1) + raise JSONDecodeError('Invalid body', data, 1) from None @classmethod - def new_from_signature(cls, sig): + def new_from_signature(cls: Type[DotDict], sig: str) -> DotDict[str, Any]: data = cls({}) for chunk in sig.strip().split(','): @@ -144,11 +147,11 @@ class DotDict(dict): return data - def to_json(self, indent=None): + def to_json(self, indent: Optional[int | str] = None) -> str: return json.dumps(self, indent=indent) - def update(self, _data, **kwargs): + def update(self, _data: dict[str, Any], **kwargs: Any) -> None: if isinstance(_data, dict): for key, value in _data.items(): self[key] = value @@ -161,9 +164,13 @@ class DotDict(dict): self[key] = value -class Message(DotDict): +class Message(ApMessage): @classmethod - def new_actor(cls, host, pubkey, description=None): + def new_actor(cls: Type[Message], # pylint: disable=arguments-differ + host: str, + pubkey: str, + description: Optional[str] = None) -> Message: + return cls({ '@context': 'https://www.w3.org/ns/activitystreams', 'id': f'https://{host}/actor', @@ -187,34 +194,34 @@ class Message(DotDict): @classmethod - def new_announce(cls, host, object): + def new_announce(cls: Type[Message], host: str, obj: str) -> Message: return cls({ '@context': 'https://www.w3.org/ns/activitystreams', - 'id': f'https://{host}/activities/{uuid.uuid4()}', + 'id': f'https://{host}/activities/{uuid4()}', 'type': 'Announce', 'to': [f'https://{host}/followers'], 'actor': f'https://{host}/actor', - 'object': object + 'object': obj }) @classmethod - def new_follow(cls, host, actor): + def new_follow(cls: Type[Message], host: str, actor: str) -> Message: return cls({ '@context': 'https://www.w3.org/ns/activitystreams', 'type': 'Follow', 'to': [actor], 'object': actor, - 'id': f'https://{host}/activities/{uuid.uuid4()}', + 'id': f'https://{host}/activities/{uuid4()}', 'actor': f'https://{host}/actor' }) @classmethod - def new_unfollow(cls, host, actor, follow): + def new_unfollow(cls: Type[Message], host: str, actor: str, follow: str) -> Message: return cls({ '@context': 'https://www.w3.org/ns/activitystreams', - 'id': f'https://{host}/activities/{uuid.uuid4()}', + 'id': f'https://{host}/activities/{uuid4()}', 'type': 'Undo', 'to': [actor], 'actor': f'https://{host}/actor', @@ -223,10 +230,15 @@ class Message(DotDict): @classmethod - def new_response(cls, host, actor, followid, accept): + def new_response(cls: Type[Message], + host: str, + actor: str, + followid: str, + accept: bool) -> Message: + return cls({ '@context': 'https://www.w3.org/ns/activitystreams', - 'id': f'https://{host}/activities/{uuid.uuid4()}', + 'id': f'https://{host}/activities/{uuid4()}', 'type': 'Accept' if accept else 'Reject', 'to': [actor], 'actor': f'https://{host}/actor', @@ -239,43 +251,24 @@ class Message(DotDict): }) - # misc properties + # todo: remove when fixed in aputils @property - def domain(self): - return urlparse(self.id).hostname + def object_id(self) -> str: + try: + return self["object"]["id"] - - # actor properties - @property - def shared_inbox(self): - return self.get('endpoints', {}).get('sharedInbox', self.inbox) - - - # activity properties - @property - def actorid(self): - if isinstance(self.actor, dict): - return self.actor.id - - return self.actor - - - @property - def objectid(self): - if isinstance(self.object, dict): - return self.object.id - - return self.object - - - @property - def signer(self): - return aputils.Signer.new_from_actor(self) + except (KeyError, TypeError): + return self["object"] class Response(AiohttpResponse): @classmethod - def new(cls, body='', status=200, headers=None, ctype='text'): + def new(cls: Type[Response], + body: Optional[str | bytes | dict] = '', + status: Optional[int] = 200, + headers: Optional[dict[str, str]] = None, + ctype: Optional[str] = 'text') -> Response: + kwargs = { 'status': status, 'headers': headers, @@ -295,7 +288,11 @@ class Response(AiohttpResponse): @classmethod - def new_error(cls, status, body, ctype='text'): + def new_error(cls: Type[Response], + status: int, + body: str | bytes | dict, + ctype: str = 'text') -> Response: + if ctype == 'json': body = json.dumps({'status': status, 'error': body}) @@ -303,38 +300,157 @@ class Response(AiohttpResponse): @property - def location(self): + def location(self) -> str: return self.headers.get('Location') @location.setter - def location(self, value): + def location(self, value: str) -> None: self.headers['Location'] = value -class View(AiohttpView): - async def _iter(self): - if self.request.method not in METHODS: - self._raise_allowed_methods() +class View(AbstractView): + def __init__(self, request: AiohttpRequest): + AbstractView.__init__(self, request) - method = getattr(self, self.request.method.lower(), None) + self.signature: Signature = None + self.message: Message = None + self.actor: Message = None + self.instance: dict[str, str] = None + self.signer: Signer = None - if method is None: - self._raise_allowed_methods() - return await method(**self.request.match_info) + def __await__(self) -> Generator[Response]: + method = self.request.method.upper() + + if method not in METHODS: + raise HTTPMethodNotAllowed(method, self.allowed_methods) + + if not (handler := self.handlers.get(method)): + raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) from None + + return handler(self.request, **self.request.match_info).__await__() + + + @cached_property + def allowed_methods(self) -> tuple[str]: + return tuple(self.handlers.keys()) + + + @cached_property + def handlers(self) -> dict[str, Coroutine]: + data = {} + + for method in METHODS: + try: + data[method] = getattr(self, method.lower()) + + except AttributeError: + continue + + return data + + + # app components + @property + def app(self) -> Application: + return self.request.app @property - def app(self): - return self._request.app + def client(self) -> HttpClient: + return self.app.client @property - def config(self): + def config(self) -> RelayConfig: return self.app.config @property - def database(self): + def database(self) -> RelayDatabase: return self.app.database + + + # todo: move to views.ActorView + async def get_post_data(self) -> Response | None: + try: + self.signature = Signature.new_from_signature(self.request.headers['signature']) + + except KeyError: + logging.verbose('Missing signature header') + return Response.new_error(400, 'missing signature header', 'json') + + try: + self.message = await self.request.json(loads = Message.parse) + + except Exception: + traceback.print_exc() + logging.verbose('Failed to parse inbox message') + return Response.new_error(400, 'failed to parse message', 'json') + + if self.message is None: + logging.verbose('empty message') + return Response.new_error(400, 'missing message', 'json') + + if 'actor' not in self.message: + logging.verbose('actor not in message') + return Response.new_error(400, 'no actor in message', 'json') + + self.actor = await self.client.get(self.signature.keyid, sign_headers = True) + + if self.actor is None: + # ld signatures aren't handled atm, so just ignore it + if self.message.type == 'Delete': + logging.verbose('Instance sent a delete which cannot be handled') + return Response.new(status=202) + + logging.verbose(f'Failed to fetch actor: {self.signature.keyid}') + return Response.new_error(400, 'failed to fetch actor', 'json') + + try: + self.signer = self.actor.signer + + except KeyError: + logging.verbose('Actor missing public key: %s', self.signature.keyid) + return Response.new_error(400, 'actor missing public key', 'json') + + try: + self.validate_signature(await self.request.read()) + + except SignatureFailureError as e: + logging.verbose('signature validation failed for "%s": %s', self.actor.id, e) + return Response.new_error(401, str(e), 'json') + + self.instance = self.database.get_inbox(self.actor.inbox) + + + def validate_signature(self, body: bytes) -> None: + headers = {key.lower(): value for key, value in self.request.headers.items()} + headers["(request-target)"] = " ".join([self.request.method.lower(), self.request.path]) + + if (digest := Digest.new_from_digest(headers.get("digest"))): + if not body: + raise SignatureFailureError("Missing body for digest verification") + + if not digest.validate(body): + raise SignatureFailureError("Body digest does not match") + + if self.signature.algorithm_type == "hs2019": + if "(created)" not in self.signature.headers: + raise SignatureFailureError("'(created)' header not used") + + current_timestamp = HttpDate.new_utc().timestamp() + + if self.signature.created > current_timestamp: + raise SignatureFailureError("Creation date after current date") + + if current_timestamp > self.signature.expires: + raise SignatureFailureError("Expiration date before current date") + + headers["(created)"] = self.signature.created + headers["(expires)"] = self.signature.expires + + # pylint: disable=protected-access + if not self.signer._validate_signature(headers, self.signature): + raise SignatureFailureError("Signature does not match") diff --git a/relay/processors.py b/relay/processors.py index 1dca6c6..b9b32bc 100644 --- a/relay/processors.py +++ b/relay/processors.py @@ -1,17 +1,21 @@ -import asyncio -import logging +from __future__ import annotations + +import typing from cachetools import LRUCache -from uuid import uuid4 +from . import logger as logging from .misc import Message +if typing.TYPE_CHECKING: + from .misc import View + cache = LRUCache(1024) -def person_check(actor, software): - ## pleroma and akkoma may use Person for the actor type for some reason +def person_check(actor: str, software: str) -> bool: + # pleroma and akkoma may use Person for the actor type for some reason if software in {'akkoma', 'pleroma'} and actor.id == f'https://{actor.domain}/relay': return False @@ -19,86 +23,85 @@ def person_check(actor, software): if actor.type != 'Application': return True + return False -async def handle_relay(request): - if request.message.objectid in cache: - logging.verbose(f'already relayed {request.message.objectid}') + +async def handle_relay(view: View) -> None: + if view.message.object_id in cache: + logging.verbose('already relayed %s', view.message.object_id) return - message = Message.new_announce( - host = request.config.host, - object = request.message.objectid - ) + message = Message.new_announce(view.config.host, view.message.object_id) + cache[view.message.object_id] = message.id + logging.debug('>> relay: %s', message) - cache[request.message.objectid] = message.id - logging.debug(f'>> relay: {message}') - - inboxes = request.database.distill_inboxes(request.message) + inboxes = view.database.distill_inboxes(view.message) for inbox in inboxes: - request.app.push_message(inbox, message) + view.app.push_message(inbox, message) -async def handle_forward(request): - if request.message.id in cache: - logging.verbose(f'already forwarded {request.message.id}') +async def handle_forward(view: View) -> None: + if view.message.id in cache: + logging.verbose('already forwarded %s', view.message.id) return - message = Message.new_announce( - host = request.config.host, - object = request.message - ) + message = Message.new_announce(view.config.host, view.message) + cache[view.message.id] = message.id + logging.debug('>> forward: %s', message) - cache[request.message.id] = message.id - logging.debug(f'>> forward: {message}') - - inboxes = request.database.distill_inboxes(request.message) + inboxes = view.database.distill_inboxes(view.message) for inbox in inboxes: - request.app.push_message(inbox, message) + view.app.push_message(inbox, message) -async def handle_follow(request): - nodeinfo = await request.app.client.fetch_nodeinfo(request.actor.domain) +async def handle_follow(view: View) -> None: + nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain) software = nodeinfo.sw_name if nodeinfo else None ## reject if software used by actor is banned - if request.config.is_banned_software(software): - request.app.push_message( - request.actor.shared_inbox, + if view.config.is_banned_software(software): + view.app.push_message( + view.actor.shared_inbox, Message.new_response( - host = request.config.host, - actor = request.actor.id, - followid = request.message.id, + host = view.config.host, + actor = view.actor.id, + followid = view.message.id, accept = False ) ) - return logging.verbose(f'Rejected follow from actor for using specific software: actor={request.actor.id}, software={software}') + return logging.verbose( + 'Rejected follow from actor for using specific software: actor=%s, software=%s', + view.actor.id, + software + ) ## reject if the actor is not an instance actor - if person_check(request.actor, software): - request.app.push_message( - request.actor.shared_inbox, + if person_check(view.actor, software): + view.app.push_message( + view.actor.shared_inbox, Message.new_response( - host = request.config.host, - actor = request.actor.id, - followid = request.message.id, + host = view.config.host, + actor = view.actor.id, + followid = view.message.id, accept = False ) ) - return logging.verbose(f'Non-application actor tried to follow: {request.actor.id}') + logging.verbose('Non-application actor tried to follow: %s', view.actor.id) + return - request.database.add_inbox(request.actor.shared_inbox, request.message.id, software) - request.database.save() + view.database.add_inbox(view.actor.shared_inbox, view.message.id, software) + view.database.save() - request.app.push_message( - request.actor.shared_inbox, + view.app.push_message( + view.actor.shared_inbox, Message.new_response( - host = request.config.host, - actor = request.actor.id, - followid = request.message.id, + host = view.config.host, + actor = view.actor.id, + followid = view.message.id, accept = True ) ) @@ -106,31 +109,37 @@ async def handle_follow(request): # Are Akkoma and Pleroma the only two that expect a follow back? # Ignoring only Mastodon for now if software != 'mastodon': - request.app.push_message( - request.actor.shared_inbox, + view.app.push_message( + view.actor.shared_inbox, Message.new_follow( - host = request.config.host, - actor = request.actor.id + host = view.config.host, + actor = view.actor.id ) ) -async def handle_undo(request): +async def handle_undo(view: View) -> None: ## If the object is not a Follow, forward it - if request.message.object.type != 'Follow': - return await handle_forward(request) + if view.message.object['type'] != 'Follow': + return await handle_forward(view) + + if not view.database.del_inbox(view.actor.domain, view.message.object['id']): + logging.verbose( + 'Failed to delete "%s" with follow ID "%s"', + view.actor.id, + view.message.object['id'] + ) - if not request.database.del_inbox(request.actor.domain, request.message.id): return - request.database.save() + view.database.save() - request.app.push_message( - request.actor.shared_inbox, + view.app.push_message( + view.actor.shared_inbox, Message.new_unfollow( - host = request.config.host, - actor = request.actor.id, - follow = request.message + host = view.config.host, + actor = view.actor.id, + follow = view.message ) ) @@ -145,16 +154,22 @@ processors = { } -async def run_processor(request): - if request.message.type not in processors: +async def run_processor(view: View) -> None: + if view.message.type not in processors: + logging.verbose( + 'Message type "%s" from actor cannot be handled: %s', + view.message.type, + view.actor.id + ) + return - if request.instance and not request.instance.get('software'): - nodeinfo = await request.app.client.fetch_nodeinfo(request.instance['domain']) + if view.instance and not view.instance.get('software'): + nodeinfo = await view.client.fetch_nodeinfo(view.instance['domain']) if nodeinfo: - request.instance['software'] = nodeinfo.sw_name - request.database.save() + view.instance['software'] = nodeinfo.sw_name + view.database.save() - logging.verbose(f'New "{request.message.type}" from actor: {request.actor.id}') - return await processors[request.message.type](request) + logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id) + await processors[view.message.type](view) diff --git a/relay/views.py b/relay/views.py index 9cea1ef..e1bed64 100644 --- a/relay/views.py +++ b/relay/views.py @@ -1,191 +1,170 @@ -import aputils -import asyncio -import logging -import subprocess -import traceback +from __future__ import annotations +import asyncio +import subprocess +import typing + +from aputils.objects import Nodeinfo, Webfinger, WellKnownNodeinfo from pathlib import Path -from . import __version__, misc -from .misc import DotDict, Message, Response +from . import __version__ +from . import logger as logging +from .misc import Message, Response, View from .processors import run_processor +if typing.TYPE_CHECKING: + from aiohttp.web import Request + from typing import Callable -routes = [] -version = __version__ + +VIEWS = [] +VERSION = __version__ +HOME_TEMPLATE = """ +
+This is an Activity Relay for fediverse instances.
+{note}
++ You may subscribe to this relay with the address: + https://{host}/actor +
++ To host your own relay, you may download the code at this address: + + https://git.pleroma.social/pleroma/relay + +
+List of {count} registered instances:
{targets}
This is an Activity Relay for fediverse instances.
-{note}
-You may subscribe to this relay with the address: https://{host}/actor
-To host your own relay, you may download the code at this address: https://git.pleroma.social/pleroma/relay
-List of {count} registered instances:
{targets}