From 9bf45a54d1cdc17191aeedb22092ffeca715eab1 Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Sun, 14 Jan 2024 14:13:06 -0500 Subject: [PATCH] add annotations and fix linter warnings --- relay/__init__.py | 2 - relay/application.py | 78 +++++++++++---------- relay/config.py | 69 +++++++++---------- relay/database.py | 60 +++++++++++------ relay/http_client.py | 108 +++++++++++++++-------------- relay/logger.py | 20 ++++-- relay/manage.py | 157 ++++++++++++++++++++++++++----------------- relay/misc.py | 137 ++++++++++++++++++++----------------- relay/processors.py | 28 +++----- relay/views.py | 36 ++++++---- 10 files changed, 391 insertions(+), 304 deletions(-) diff --git a/relay/__init__.py b/relay/__init__.py index 426b03e..a6587ae 100644 --- a/relay/__init__.py +++ b/relay/__init__.py @@ -1,3 +1 @@ __version__ = '0.2.4' - -from . import logger diff --git a/relay/application.py b/relay/application.py index be9b136..a01aaec 100644 --- a/relay/application.py +++ b/relay/application.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import asyncio -import os import queue import signal import threading import traceback +import typing from aiohttp import web from datetime import datetime, timedelta @@ -12,20 +14,29 @@ from . import logger as logging from .config import RelayConfig from .database import RelayDatabase from .http_client import HttpClient -from .misc import DotDict, check_open_port, set_app +from .misc import check_open_port from .views import VIEWS +if typing.TYPE_CHECKING: + from typing import Any + from .misc import Message + + +# pylint: disable=unsubscriptable-object + class Application(web.Application): - def __init__(self, cfgpath): + def __init__(self, cfgpath: str): web.Application.__init__(self) - self['starttime'] = None + self['workers'] = [] + self['last_worker'] = 0 + self['start_time'] = None self['running'] = False self['config'] = RelayConfig(cfgpath) - if not self['config'].load(): - self['config'].save() + if not self.config.load(): + self.config.save() if self.config.is_docker: self.config.update({ @@ -34,13 +45,8 @@ class Application(web.Application): 'port': 8080 }) - self['workers'] = [] - self['last_worker'] = 0 - - set_app(self) - - self['database'] = RelayDatabase(self['config']) - self['database'].load() + self['database'] = RelayDatabase(self.config) + self.database.load() self['client'] = HttpClient( database = self.database, @@ -54,33 +60,34 @@ class Application(web.Application): @property - def client(self): + def client(self) -> HttpClient: return self['client'] @property - def config(self): + def config(self) -> RelayConfig: return self['config'] @property - def database(self): + def database(self) -> RelayDatabase: return self['database'] @property - def uptime(self): - if not self['starttime']: + def uptime(self) -> timedelta: + if not self['start_time']: return timedelta(seconds=0) - uptime = datetime.now() - self['starttime'] + uptime = datetime.now() - self['start_time'] return timedelta(seconds=uptime.seconds) - def push_message(self, inbox, message): + def push_message(self, inbox: str, message: Message) -> None: if self.config.workers <= 0: - return asyncio.ensure_future(self.client.post(inbox, message)) + asyncio.ensure_future(self.client.post(inbox, message)) + return worker = self['workers'][self['last_worker']] worker.queue.put((inbox, message)) @@ -91,8 +98,8 @@ class Application(web.Application): self['last_worker'] = 0 - def set_signal_handler(self, startup): - for sig in {'SIGHUP', 'SIGINT', 'SIGQUIT', 'SIGTERM'}: + def set_signal_handler(self, startup: bool) -> None: + for sig in ('SIGHUP', 'SIGINT', 'SIGQUIT', 'SIGTERM'): try: signal.signal(getattr(signal, sig), self.stop if startup else signal.SIG_DFL) @@ -101,9 +108,10 @@ class Application(web.Application): pass - def run(self): + def run(self) -> None: if not check_open_port(self.config.listen, self.config.port): - return logging.error('A server is already running on port %i', self.config.port) + logging.error('A server is already running on port %i', self.config.port) + return for view in VIEWS: self.router.add_view(*view) @@ -118,17 +126,17 @@ class Application(web.Application): asyncio.run(self.handle_run()) - def stop(self, *_): + def stop(self, *_: Any) -> None: self['running'] = False - async def handle_run(self): + async def handle_run(self) -> None: self['running'] = True self.set_signal_handler(True) if self.config.workers > 0: - for i in range(self.config.workers): + for _ in range(self.config.workers): worker = PushWorker(self) worker.start() @@ -137,14 +145,15 @@ class Application(web.Application): runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"') await runner.setup() - site = web.TCPSite(runner, + site = web.TCPSite( + runner, host = self.config.listen, port = self.config.port, reuse_address = True ) await site.start() - self['starttime'] = datetime.now() + self['start_time'] = datetime.now() while self['running']: await asyncio.sleep(0.25) @@ -152,23 +161,24 @@ class Application(web.Application): await site.stop() await self.client.close() - self['starttime'] = None + self['start_time'] = None self['running'] = False self['workers'].clear() class PushWorker(threading.Thread): - def __init__(self, app): + def __init__(self, app: Application): threading.Thread.__init__(self) self.app = app self.queue = queue.Queue() + self.client = None - def run(self): + def run(self) -> None: asyncio.run(self.handle_queue()) - async def handle_queue(self): + async def handle_queue(self) -> None: self.client = HttpClient( database = self.app.database, limit = self.app.config.push_limit, diff --git a/relay/config.py b/relay/config.py index 996fa9f..e684ead 100644 --- a/relay/config.py +++ b/relay/config.py @@ -1,5 +1,7 @@ -import json +from __future__ import annotations + import os +import typing import yaml from functools import cached_property @@ -8,6 +10,10 @@ from urllib.parse import urlparse from .misc import DotDict, boolean +if typing.TYPE_CHECKING: + from typing import Any + from .database import RelayDatabase + RELAY_SOFTWARE = [ 'activityrelay', # https://git.pleroma.social/pleroma/relay @@ -25,17 +31,19 @@ APKEYS = [ class RelayConfig(DotDict): - def __init__(self, path): + __slots__ = ('path', ) + + def __init__(self, path: str | Path): DotDict.__init__(self, {}) if self.is_docker: path = '/data/config.yaml' - self._path = Path(path).expanduser() + self._path = Path(path).expanduser().resolve() self.reset() - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: Any) -> None: if key in ['blocked_instances', 'blocked_software', 'whitelist']: assert isinstance(value, (list, set, tuple)) @@ -51,36 +59,31 @@ class RelayConfig(DotDict): @property - def db(self): + def db(self) -> RelayDatabase: return Path(self['db']).expanduser().resolve() @property - def path(self): - return self._path - - - @property - def actor(self): + def actor(self) -> str: return f'https://{self.host}/actor' @property - def inbox(self): + def inbox(self) -> str: return f'https://{self.host}/inbox' @property - def keyid(self): + def keyid(self) -> str: return f'{self.actor}#main-key' @cached_property - def is_docker(self): + def is_docker(self) -> bool: return bool(os.environ.get('DOCKER_RUNNING')) - def reset(self): + def reset(self) -> None: self.clear() self.update({ 'db': str(self._path.parent.joinpath(f'{self._path.stem}.jsonld')), @@ -99,7 +102,7 @@ class RelayConfig(DotDict): }) - def ban_instance(self, instance): + def ban_instance(self, instance: str) -> bool: if instance.startswith('http'): instance = urlparse(instance).hostname @@ -110,7 +113,7 @@ class RelayConfig(DotDict): return True - def unban_instance(self, instance): + def unban_instance(self, instance: str) -> bool: if instance.startswith('http'): instance = urlparse(instance).hostname @@ -118,11 +121,11 @@ class RelayConfig(DotDict): self.blocked_instances.remove(instance) return True - except: + except ValueError: return False - def ban_software(self, software): + def ban_software(self, software: str) -> bool: if self.is_banned_software(software): return False @@ -130,16 +133,16 @@ class RelayConfig(DotDict): return True - def unban_software(self, software): + def unban_software(self, software: str) -> bool: try: self.blocked_software.remove(software) return True - except: + except ValueError: return False - def add_whitelist(self, instance): + def add_whitelist(self, instance: str) -> bool: if instance.startswith('http'): instance = urlparse(instance).hostname @@ -150,7 +153,7 @@ class RelayConfig(DotDict): return True - def del_whitelist(self, instance): + def del_whitelist(self, instance: str) -> bool: if instance.startswith('http'): instance = urlparse(instance).hostname @@ -158,32 +161,32 @@ class RelayConfig(DotDict): self.whitelist.remove(instance) return True - except: + except ValueError: return False - def is_banned(self, instance): + def is_banned(self, instance: str) -> bool: if instance.startswith('http'): instance = urlparse(instance).hostname return instance in self.blocked_instances - def is_banned_software(self, software): + def is_banned_software(self, software: str) -> bool: if not software: return False return software.lower() in self.blocked_software - def is_whitelisted(self, instance): + def is_whitelisted(self, instance: str) -> bool: if instance.startswith('http'): instance = urlparse(instance).hostname return instance in self.whitelist - def load(self): + def load(self) -> bool: self.reset() options = {} @@ -195,7 +198,7 @@ class RelayConfig(DotDict): pass try: - with open(self.path) as fd: + with self._path.open('r', encoding = 'UTF-8') as fd: config = yaml.load(fd, **options) except FileNotFoundError: @@ -214,7 +217,7 @@ class RelayConfig(DotDict): continue - elif key not in self: + if key not in self: continue self[key] = value @@ -225,7 +228,7 @@ class RelayConfig(DotDict): return True - def save(self): + def save(self) -> None: config = { # just turning config.db into a string is good enough for now 'db': str(self.db), @@ -239,7 +242,5 @@ class RelayConfig(DotDict): 'ap': {key: self[key] for key in APKEYS} } - with open(self._path, 'w') as fd: + with self._path.open('w', encoding = 'utf-8') as fd: yaml.dump(config, fd, sort_keys=False) - - return config diff --git a/relay/database.py b/relay/database.py index b2c8423..d6c1acc 100644 --- a/relay/database.py +++ b/relay/database.py @@ -1,15 +1,21 @@ +from __future__ import annotations + import aputils -import asyncio import json -import traceback +import typing from urllib.parse import urlparse from . import logger as logging +if typing.TYPE_CHECKING: + from typing import Iterator, Optional + from .config import RelayConfig + from .misc import Message + class RelayDatabase(dict): - def __init__(self, config): + def __init__(self, config: RelayConfig): dict.__init__(self, { 'relay-list': {}, 'private-key': None, @@ -22,16 +28,16 @@ class RelayDatabase(dict): @property - def hostnames(self): + def hostnames(self) -> tuple[str]: return tuple(self['relay-list'].keys()) @property - def inboxes(self): + def inboxes(self) -> tuple[dict[str, str]]: return tuple(data['inbox'] for data in self['relay-list'].values()) - def load(self): + def load(self) -> bool: new_db = True try: @@ -41,7 +47,7 @@ class RelayDatabase(dict): self['version'] = data.get('version', None) self['private-key'] = data.get('private-key') - if self['version'] == None: + if self['version'] is None: self['version'] = 1 if 'actorKeys' in data: @@ -59,7 +65,9 @@ class RelayDatabase(dict): self['relay-list'] = data.get('relay-list', {}) for domain, instance in self['relay-list'].items(): - if self.config.is_banned(domain) or (self.config.whitelist_enabled and not self.config.is_whitelisted(domain)): + if self.config.is_banned(domain) or \ + (self.config.whitelist_enabled and not self.config.is_whitelisted(domain)): + self.del_inbox(domain) continue @@ -87,25 +95,29 @@ class RelayDatabase(dict): return not new_db - def save(self): - with self.config.db.open('w') as fd: + def save(self) -> None: + with self.config.db.open('w', encoding = 'UTF-8') as fd: json.dump(self, fd, indent=4) - def get_inbox(self, domain, fail=False): + def get_inbox(self, domain: str, fail: Optional[bool] = False) -> dict[str, str] | None: if domain.startswith('http'): domain = urlparse(domain).hostname - inbox = self['relay-list'].get(domain) - - if inbox: + if (inbox := self['relay-list'].get(domain)): return inbox if fail: raise KeyError(domain) + return None + + + def add_inbox(self, + inbox: str, + followid: Optional[str] = None, + software: Optional[str] = None) -> dict[str, str]: - def add_inbox(self, inbox, followid=None, software=None): assert inbox.startswith('https'), 'Inbox must be a url' domain = urlparse(inbox).hostname instance = self.get_inbox(domain) @@ -130,7 +142,11 @@ class RelayDatabase(dict): return self['relay-list'][domain] - def del_inbox(self, domain, followid=None, fail=False): + def del_inbox(self, + domain: str, + followid: Optional[str] = None, + fail: Optional[bool] = False) -> bool: + data = self.get_inbox(domain, fail=False) if not data: @@ -151,7 +167,7 @@ class RelayDatabase(dict): return False - def get_request(self, domain, fail=True): + def get_request(self, domain: str, fail: bool = True) -> dict[str, str] | None: if domain.startswith('http'): domain = urlparse(domain).hostname @@ -162,8 +178,10 @@ class RelayDatabase(dict): if fail: raise e + return None - def add_request(self, actor, inbox, followid): + + def add_request(self, actor: str, inbox: str, followid: str) -> None: domain = urlparse(inbox).hostname try: @@ -180,14 +198,14 @@ class RelayDatabase(dict): } - def del_request(self, domain): + def del_request(self, domain: str) -> None: if domain.startswith('http'): - domain = urlparse(inbox).hostname + domain = urlparse(domain).hostname del self['follow-requests'][domain] - def distill_inboxes(self, message): + def distill_inboxes(self, message: Message) -> Iterator[str]: src_domains = { message.domain, urlparse(message.objectid).netloc diff --git a/relay/http_client.py b/relay/http_client.py index 732c9e4..6f2a044 100644 --- a/relay/http_client.py +++ b/relay/http_client.py @@ -1,21 +1,23 @@ +from __future__ import annotations + import traceback +import typing from aiohttp import ClientSession, ClientTimeout, TCPConnector from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError from asyncio.exceptions import TimeoutError as AsyncTimeoutError -from aputils import Nodeinfo, WellKnownNodeinfo -from datetime import datetime +from aputils.objects import Nodeinfo, WellKnownNodeinfo from cachetools import LRUCache from json.decoder import JSONDecodeError from urllib.parse import urlparse from . import __version__ from . import logger as logging -from .misc import ( - MIMETYPES, - DotDict, - Message -) +from .misc import MIMETYPES, Message + +if typing.TYPE_CHECKING: + from typing import Any, Callable, Optional + from .database import RelayDatabase HEADERS = { @@ -24,40 +26,31 @@ HEADERS = { } -class Cache(LRUCache): - def set_maxsize(self, value): - self.__maxsize = int(value) - - class HttpClient: - def __init__(self, database, limit=100, timeout=10, cache_size=1024): + def __init__(self, + database: RelayDatabase, + limit: Optional[int] = 100, + timeout: Optional[int] = 10, + cache_size: Optional[int] = 1024): + self.database = database - self.cache = Cache(cache_size) - self.cfg = {'limit': limit, 'timeout': timeout} + self.cache = LRUCache(cache_size) + self.limit = limit + self.timeout = timeout self._conn = None self._session = None - async def __aenter__(self): + async def __aenter__(self) -> HttpClient: await self.open() return self - async def __aexit__(self, *_): + async def __aexit__(self, *_: Any) -> None: await self.close() - @property - def limit(self): - return self.cfg['limit'] - - - @property - def timeout(self): - return self.cfg['timeout'] - - - async def open(self): + async def open(self) -> None: if self._session: return @@ -74,7 +67,7 @@ class HttpClient: ) - async def close(self): + async def close(self) -> None: if not self._session: return @@ -85,11 +78,19 @@ class HttpClient: self._session = None - async def get(self, url, sign_headers=False, loads=None, force=False): + async def get(self, # pylint: disable=too-many-branches + url: str, + sign_headers: Optional[bool] = False, + loads: Optional[Callable] = None, + force: Optional[bool] = False) -> Message | dict | None: + await self.open() - try: url, _ = url.split('#', 1) - except: pass + try: + url, _ = url.split('#', 1) + + except ValueError: + pass if not force and url in self.cache: return self.cache[url] @@ -105,26 +106,26 @@ class HttpClient: async with self._session.get(url, headers=headers) as resp: ## Not expecting a response with 202s, so just return if resp.status == 202: - return + return None - elif resp.status != 200: + if resp.status != 200: logging.verbose('Received error when requesting %s: %i', url, resp.status) logging.debug(await resp.read()) - return + return None if loads: message = await resp.json(loads=loads) elif resp.content_type == MIMETYPES['activity']: - message = await resp.json(loads=Message.parse) + message = await resp.json(loads = Message.parse) elif resp.content_type == MIMETYPES['json']: - message = await resp.json(loads=DotDict.parse) + message = await resp.json() else: - # todo: raise TypeError or something logging.verbose('Invalid Content-Type for "%s": %s', url, resp.content_type) - return logging.debug('Response: %s', await resp.read()) + logging.debug('Response: %s', await resp.read()) + return None logging.debug('%s >> resp %s', url, message.to_json(4)) @@ -140,11 +141,13 @@ class HttpClient: except (AsyncTimeoutError, ClientConnectionError): logging.verbose('Failed to connect to %s', urlparse(url).netloc) - except Exception as e: + except Exception: traceback.print_exc() + return None - async def post(self, url, message): + + async def post(self, url: str, message: Message) -> None: await self.open() instance = self.database.get_inbox(url) @@ -165,10 +168,12 @@ class HttpClient: async with self._session.post(url, headers=headers, data=message.to_json()) as resp: ## Not expecting a response, so just return if resp.status in {200, 202}: - return logging.verbose('Successfully sent "%s" to %s', message.type, url) + logging.verbose('Successfully sent "%s" to %s', message.type, url) + return logging.verbose('Received error when pushing to %s: %i', url, resp.status) - return logging.verbose(await resp.read()) # change this to debug + logging.debug(await resp.read()) + return except ClientSSLError: logging.warning('SSL error when pushing to %s', urlparse(url).netloc) @@ -177,12 +182,11 @@ class HttpClient: logging.warning('Failed to connect to %s for message push', urlparse(url).netloc) ## prevent workers from being brought down - except Exception as e: + except Exception: traceback.print_exc() - ## Additional methods ## - async def fetch_nodeinfo(self, domain): + async def fetch_nodeinfo(self, domain: str) -> Nodeinfo | None: nodeinfo_url = None wk_nodeinfo = await self.get( f'https://{domain}/.well-known/nodeinfo', @@ -191,7 +195,7 @@ class HttpClient: if not wk_nodeinfo: logging.verbose('Failed to fetch well-known nodeinfo url for %s', domain) - return False + return None for version in ['20', '21']: try: @@ -202,21 +206,21 @@ class HttpClient: if not nodeinfo_url: logging.verbose('Failed to fetch nodeinfo url for %s', domain) - return False + return None - return await self.get(nodeinfo_url, loads=Nodeinfo.parse) or False + return await self.get(nodeinfo_url, loads = Nodeinfo.parse) or None -async def get(database, *args, **kwargs): +async def get(database: RelayDatabase, *args: Any, **kwargs: Any) -> Message | dict | None: async with HttpClient(database) as client: return await client.get(*args, **kwargs) -async def post(database, *args, **kwargs): +async def post(database: RelayDatabase, *args: Any, **kwargs: Any) -> None: async with HttpClient(database) as client: return await client.post(*args, **kwargs) -async def fetch_nodeinfo(database, *args, **kwargs): +async def fetch_nodeinfo(database: RelayDatabase, *args: Any, **kwargs: Any) -> Nodeinfo | None: async with HttpClient(database) as client: return await client.fetch_nodeinfo(*args, **kwargs) diff --git a/relay/logger.py b/relay/logger.py index 9218af9..0d1d451 100644 --- a/relay/logger.py +++ b/relay/logger.py @@ -1,10 +1,16 @@ +from __future__ import annotations + import logging import os +import typing from pathlib import Path +if typing.TYPE_CHECKING: + from typing import Any, Callable -LOG_LEVELS = { + +LOG_LEVELS: dict[str, int] = { 'DEBUG': logging.DEBUG, 'VERBOSE': 15, 'INFO': logging.INFO, @@ -14,14 +20,14 @@ LOG_LEVELS = { } -debug = logging.debug -info = logging.info -warning = logging.warning -error = logging.error -critical = logging.critical +debug: Callable = logging.debug +info: Callable = logging.info +warning: Callable = logging.warning +error: Callable = logging.error +critical: Callable = logging.critical -def verbose(message, *args, **kwargs): +def verbose(message: str, *args: Any, **kwargs: Any) -> None: if not logging.root.isEnabledFor(LOG_LEVELS['VERBOSE']): return diff --git a/relay/manage.py b/relay/manage.py index 5f36a79..b0c5cb3 100644 --- a/relay/manage.py +++ b/relay/manage.py @@ -1,7 +1,10 @@ +from __future__ import annotations + import Crypto import asyncio import click import platform +import typing from urllib.parse import urlparse @@ -10,6 +13,12 @@ from . import http_client as http from .application import Application from .config import RELAY_SOFTWARE +if typing.TYPE_CHECKING: + from typing import Any + + +# pylint: disable=unsubscriptable-object,unsupported-assignment-operation + app = None CONFIG_IGNORE = {'blocked_software', 'blocked_instances', 'whitelist'} @@ -19,7 +28,7 @@ CONFIG_IGNORE = {'blocked_software', 'blocked_instances', 'whitelist'} @click.option('--config', '-c', default='relay.yaml', help='path to the relay\'s config') @click.version_option(version=__version__, prog_name='ActivityRelay') @click.pass_context -def cli(ctx, config): +def cli(ctx: click.Context, config: str) -> None: global app app = Application(config) @@ -32,11 +41,14 @@ def cli(ctx, config): @cli.command('setup') -def cli_setup(): +def cli_setup() -> None: 'Generate a new config' while True: - app.config.host = click.prompt('What domain will the relay be hosted on?', default=app.config.host) + app.config.host = click.prompt( + 'What domain will the relay be hosted on?', + default = app.config.host + ) if not app.config.host.endswith('example.com'): break @@ -44,10 +56,18 @@ def cli_setup(): click.echo('The domain must not be example.com') if not app.config.is_docker: - app.config.listen = click.prompt('Which address should the relay listen on?', default=app.config.listen) + app.config.listen = click.prompt( + 'Which address should the relay listen on?', + default = app.config.listen + ) while True: - app.config.port = click.prompt('What TCP port should the relay listen on?', default=app.config.port, type=int) + app.config.port = click.prompt( + 'What TCP port should the relay listen on?', + default = app.config.port, + type = int + ) + break app.config.save() @@ -57,39 +77,47 @@ def cli_setup(): @cli.command('run') -def cli_run(): +def cli_run() -> None: 'Run the relay' if app.config.host.endswith('example.com'): - return click.echo('Relay is not set up. Please edit your relay config or run "activityrelay setup".') + click.echo( + 'Relay is not set up. Please edit your relay config or run "activityrelay setup".' + ) + + return vers_split = platform.python_version().split('.') pip_command = 'pip3 uninstall pycrypto && pip3 install pycryptodome' if Crypto.__version__ == '2.6.1': if int(vers_split[1]) > 7: - click.echo('Error: PyCrypto is broken on Python 3.8+. Please replace it with pycryptodome before running again. Exiting...') - return click.echo(pip_command) + click.echo( + 'Error: PyCrypto is broken on Python 3.8+. Please replace it with pycryptodome ' + + 'before running again. Exiting...' + ) - else: - click.echo('Warning: PyCrypto is old and should be replaced with pycryptodome') - return click.echo(pip_command) + click.echo(pip_command) + return + + click.echo('Warning: PyCrypto is old and should be replaced with pycryptodome') + click.echo(pip_command) + return if not misc.check_open_port(app.config.listen, app.config.port): - return click.echo(f'Error: A server is already running on port {app.config.port}') + click.echo(f'Error: A server is already running on port {app.config.port}') + return app.run() -# todo: add config default command for resetting config key @cli.group('config') -def cli_config(): +def cli_config() -> None: 'Manage the relay config' - pass @cli_config.command('list') -def cli_config_list(): +def cli_config_list() -> None: 'List the current relay config' click.echo('Relay Config:') @@ -103,7 +131,7 @@ def cli_config_list(): @cli_config.command('set') @click.argument('key') @click.argument('value') -def cli_config_set(key, value): +def cli_config_set(key: str, value: Any) -> None: 'Set a config value' app.config[key] = value @@ -113,13 +141,12 @@ def cli_config_set(key, value): @cli.group('inbox') -def cli_inbox(): +def cli_inbox() -> None: 'Manage the inboxes in the database' - pass @cli_inbox.command('list') -def cli_inbox_list(): +def cli_inbox_list() -> None: 'List the connected instances or relays' click.echo('Connected to the following instances or relays:') @@ -130,11 +157,12 @@ def cli_inbox_list(): @cli_inbox.command('follow') @click.argument('actor') -def cli_inbox_follow(actor): +def cli_inbox_follow(actor: str) -> None: 'Follow an actor (Relay must be running)' if app.config.is_banned(actor): - return click.echo(f'Error: Refusing to follow banned actor: {actor}') + click.echo(f'Error: Refusing to follow banned actor: {actor}') + return if not actor.startswith('http'): domain = actor @@ -151,7 +179,8 @@ def cli_inbox_follow(actor): actor_data = asyncio.run(http.get(app.database, actor, sign_headers=True)) if not actor_data: - return click.echo(f'Failed to fetch actor: {actor}') + click.echo(f'Failed to fetch actor: {actor}') + return inbox = actor_data.shared_inbox @@ -166,7 +195,7 @@ def cli_inbox_follow(actor): @cli_inbox.command('unfollow') @click.argument('actor') -def cli_inbox_unfollow(actor): +def cli_inbox_unfollow(actor: str) -> None: 'Unfollow an actor (Relay must be running)' if not actor.startswith('http'): @@ -204,17 +233,19 @@ def cli_inbox_unfollow(actor): @cli_inbox.command('add') @click.argument('inbox') -def cli_inbox_add(inbox): +def cli_inbox_add(inbox: str) -> None: 'Add an inbox to the database' if not inbox.startswith('http'): inbox = f'https://{inbox}/inbox' if app.config.is_banned(inbox): - return click.echo(f'Error: Refusing to add banned inbox: {inbox}') + click.echo(f'Error: Refusing to add banned inbox: {inbox}') + return if app.database.get_inbox(inbox): - return click.echo(f'Error: Inbox already in database: {inbox}') + click.echo(f'Error: Inbox already in database: {inbox}') + return app.database.add_inbox(inbox) app.database.save() @@ -224,7 +255,7 @@ def cli_inbox_add(inbox): @cli_inbox.command('remove') @click.argument('inbox') -def cli_inbox_remove(inbox): +def cli_inbox_remove(inbox: str) -> None: 'Remove an inbox from the database' try: @@ -241,13 +272,12 @@ def cli_inbox_remove(inbox): @cli.group('instance') -def cli_instance(): +def cli_instance() -> None: 'Manage instance bans' - pass @cli_instance.command('list') -def cli_instance_list(): +def cli_instance_list() -> None: 'List all banned instances' click.echo('Banned instances or relays:') @@ -258,7 +288,7 @@ def cli_instance_list(): @cli_instance.command('ban') @click.argument('target') -def cli_instance_ban(target): +def cli_instance_ban(target: str) -> None: 'Ban an instance and remove the associated inbox if it exists' if target.startswith('http'): @@ -278,7 +308,7 @@ def cli_instance_ban(target): @cli_instance.command('unban') @click.argument('target') -def cli_instance_unban(target): +def cli_instance_unban(target: str) -> None: 'Unban an instance' if app.config.unban_instance(target): @@ -291,13 +321,12 @@ def cli_instance_unban(target): @cli.group('software') -def cli_software(): +def cli_software() -> None: 'Manage banned software' - pass @cli_software.command('list') -def cli_software_list(): +def cli_software_list() -> None: 'List all banned software' click.echo('Banned software:') @@ -307,19 +336,21 @@ def cli_software_list(): @cli_software.command('ban') -@click.option('--fetch-nodeinfo/--ignore-nodeinfo', '-f', 'fetch_nodeinfo', default=False, - help='Treat NAME like a domain and try to fet the software name from nodeinfo' +@click.option( + '--fetch-nodeinfo/--ignore-nodeinfo', '-f', 'fetch_nodeinfo', default = False, + help = 'Treat NAME like a domain and try to fet the software name from nodeinfo' ) @click.argument('name') -def cli_software_ban(name, fetch_nodeinfo): +def cli_software_ban(name: str, fetch_nodeinfo: bool) -> None: 'Ban software. Use RELAYS for NAME to ban relays' if name == 'RELAYS': - for name in RELAY_SOFTWARE: - app.config.ban_software(name) + for software in RELAY_SOFTWARE: + app.config.ban_software(software) app.config.save() - return click.echo('Banned all relay software') + click.echo('Banned all relay software') + return if fetch_nodeinfo: nodeinfo = asyncio.run(http.fetch_nodeinfo(app.database, name)) @@ -331,25 +362,28 @@ def cli_software_ban(name, fetch_nodeinfo): if app.config.ban_software(name): app.config.save() - return click.echo(f'Banned software: {name}') + click.echo(f'Banned software: {name}') + return click.echo(f'Software already banned: {name}') @cli_software.command('unban') -@click.option('--fetch-nodeinfo/--ignore-nodeinfo', '-f', 'fetch_nodeinfo', default=False, - help='Treat NAME like a domain and try to fet the software name from nodeinfo' +@click.option( + '--fetch-nodeinfo/--ignore-nodeinfo', '-f', 'fetch_nodeinfo', default = False, + help = 'Treat NAME like a domain and try to fet the software name from nodeinfo' ) @click.argument('name') -def cli_software_unban(name, fetch_nodeinfo): +def cli_software_unban(name: str, fetch_nodeinfo: bool) -> None: 'Ban software. Use RELAYS for NAME to unban relays' if name == 'RELAYS': - for name in RELAY_SOFTWARE: - app.config.unban_software(name) + for software in RELAY_SOFTWARE: + app.config.unban_software(software) app.config.save() - return click.echo('Unbanned all relay software') + click.echo('Unbanned all relay software') + return if fetch_nodeinfo: nodeinfo = asyncio.run(http.fetch_nodeinfo(app.database, name)) @@ -361,19 +395,19 @@ def cli_software_unban(name, fetch_nodeinfo): if app.config.unban_software(name): app.config.save() - return click.echo(f'Unbanned software: {name}') + click.echo(f'Unbanned software: {name}') + return click.echo(f'Software wasn\'t banned: {name}') @cli.group('whitelist') -def cli_whitelist(): +def cli_whitelist() -> None: 'Manage the instance whitelist' - pass @cli_whitelist.command('list') -def cli_whitelist_list(): +def cli_whitelist_list() -> None: 'List all the instances in the whitelist' click.echo('Current whitelisted domains') @@ -384,11 +418,12 @@ def cli_whitelist_list(): @cli_whitelist.command('add') @click.argument('instance') -def cli_whitelist_add(instance): +def cli_whitelist_add(instance: str) -> None: 'Add an instance to the whitelist' if not app.config.add_whitelist(instance): - return click.echo(f'Instance already in the whitelist: {instance}') + click.echo(f'Instance already in the whitelist: {instance}') + return app.config.save() click.echo(f'Instance added to the whitelist: {instance}') @@ -396,11 +431,12 @@ def cli_whitelist_add(instance): @cli_whitelist.command('remove') @click.argument('instance') -def cli_whitelist_remove(instance): +def cli_whitelist_remove(instance: str) -> None: 'Remove an instance from the whitelist' if not app.config.del_whitelist(instance): - return click.echo(f'Instance not in the whitelist: {instance}') + click.echo(f'Instance not in the whitelist: {instance}') + return app.config.save() @@ -412,14 +448,15 @@ def cli_whitelist_remove(instance): @cli_whitelist.command('import') -def cli_whitelist_import(): +def cli_whitelist_import() -> None: 'Add all current inboxes to the whitelist' for domain in app.database.hostnames: cli_whitelist_add.callback(domain) -def main(): +def main() -> None: + # pylint: disable=no-value-for-parameter cli(prog_name='relay') diff --git a/relay/misc.py b/relay/misc.py index c31514d..9325b40 100644 --- a/relay/misc.py +++ b/relay/misc.py @@ -12,20 +12,21 @@ from aiohttp.web_exceptions import HTTPMethodNotAllowed from aputils.errors import SignatureFailureError from aputils.misc import Digest, HttpDate, Signature from aputils.message import Message as ApMessage -from datetime import datetime from functools import cached_property from json.decoder import JSONDecodeError -from urllib.parse import urlparse from uuid import uuid4 from . import logger as logging if typing.TYPE_CHECKING: - from typing import Coroutine, Generator + from typing import Any, Coroutine, Generator, Optional, Type + from aputils.signer import Signer + from .application import Application + from .config import RelayConfig + from .database import RelayDatabase + from .http_client import HttpClient -app = None - MIMETYPES = { 'activity': 'application/activity+json', 'html': 'text/html', @@ -39,94 +40,87 @@ NODEINFO_NS = { } -def set_app(new_app): - global app - app = new_app - - -def boolean(value): +def boolean(value: Any) -> bool: if isinstance(value, str): if value.lower() in ['on', 'y', 'yes', 'true', 'enable', 'enabled', '1']: return True - elif value.lower() in ['off', 'n', 'no', 'false', 'disable', 'disable', '0']: + if value.lower() in ['off', 'n', 'no', 'false', 'disable', 'disable', '0']: return False - else: - raise TypeError(f'Cannot parse string "{value}" as a boolean') + raise TypeError(f'Cannot parse string "{value}" as a boolean') - elif isinstance(value, int): + if isinstance(value, int): if value == 1: return True - elif value == 0: + if value == 0: return False - else: - raise ValueError('Integer value must be 1 or 0') + raise ValueError('Integer value must be 1 or 0') - elif value == None: + if value is None: return False - try: - return value.__bool__() - - except AttributeError: - raise TypeError(f'Cannot convert object of type "{clsname(value)}"') + return bool(value) -def check_open_port(host, port): +def check_open_port(host: str, port: int) -> bool: if host == '0.0.0.0': host = '127.0.0.1' with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: try: - return s.connect_ex((host , port)) != 0 + return s.connect_ex((host, port)) != 0 - except socket.error as e: + except socket.error: return False class DotDict(dict): - def __init__(self, _data, **kwargs): + def __init__(self, _data: dict[str, Any], **kwargs: Any): dict.__init__(self) self.update(_data, **kwargs) - def __getattr__(self, k): + def __getattr__(self, key: str) -> str: try: - return self[k] + return self[key] except KeyError: - raise AttributeError(f'{self.__class__.__name__} object has no attribute {k}') from None + raise AttributeError( + f'{self.__class__.__name__} object has no attribute {key}' + ) from None - def __setattr__(self, k, v): - if k.startswith('_'): - super().__setattr__(k, v) + def __setattr__(self, key: str, value: Any) -> None: + if key.startswith('_'): + super().__setattr__(key, value) else: - self[k] = v + self[key] = value - def __setitem__(self, k, v): - if type(v) == dict: - v = DotDict(v) + def __setitem__(self, key: str, value: Any) -> None: + if type(value) is dict: # pylint: disable=unidiomatic-typecheck + value = DotDict(value) - super().__setitem__(k, v) + super().__setitem__(key, value) - def __delattr__(self, k): + def __delattr__(self, key: str) -> None: try: - dict.__delitem__(self, k) + dict.__delitem__(self, key) except KeyError: - raise AttributeError(f'{self.__class__.__name__} object has no attribute {k}') from None + raise AttributeError( + f'{self.__class__.__name__} object has no attribute {key}' + ) from None @classmethod - def new_from_json(cls, data): + def new_from_json(cls: Type[DotDict], data: dict[str, Any]) -> DotDict[str, Any]: if not data: raise JSONDecodeError('Empty body', data, 1) @@ -134,11 +128,11 @@ class DotDict(dict): return cls(json.loads(data)) except ValueError: - raise JSONDecodeError('Invalid body', data, 1) + raise JSONDecodeError('Invalid body', data, 1) from None @classmethod - def new_from_signature(cls, sig): + def new_from_signature(cls: Type[DotDict], sig: str) -> DotDict[str, Any]: data = cls({}) for chunk in sig.strip().split(','): @@ -153,11 +147,11 @@ class DotDict(dict): return data - def to_json(self, indent=None): + def to_json(self, indent: Optional[int | str] = None) -> str: return json.dumps(self, indent=indent) - def update(self, _data, **kwargs): + def update(self, _data: dict[str, Any], **kwargs: Any) -> None: if isinstance(_data, dict): for key, value in _data.items(): self[key] = value @@ -172,7 +166,11 @@ class DotDict(dict): class Message(ApMessage): @classmethod - def new_actor(cls, host, pubkey, description=None): + def new_actor(cls: Type[Message], # pylint: disable=arguments-differ + host: str, + pubkey: str, + description: Optional[str] = None) -> Message: + return cls({ '@context': 'https://www.w3.org/ns/activitystreams', 'id': f'https://{host}/actor', @@ -196,19 +194,19 @@ class Message(ApMessage): @classmethod - def new_announce(cls, host, object): + def new_announce(cls: Type[Message], host: str, obj: str) -> Message: return cls({ '@context': 'https://www.w3.org/ns/activitystreams', 'id': f'https://{host}/activities/{uuid4()}', 'type': 'Announce', 'to': [f'https://{host}/followers'], 'actor': f'https://{host}/actor', - 'object': object + 'object': obj }) @classmethod - def new_follow(cls, host, actor): + def new_follow(cls: Type[Message], host: str, actor: str) -> Message: return cls({ '@context': 'https://www.w3.org/ns/activitystreams', 'type': 'Follow', @@ -220,7 +218,7 @@ class Message(ApMessage): @classmethod - def new_unfollow(cls, host, actor, follow): + def new_unfollow(cls: Type[Message], host: str, actor: str, follow: str) -> Message: return cls({ '@context': 'https://www.w3.org/ns/activitystreams', 'id': f'https://{host}/activities/{uuid4()}', @@ -232,7 +230,12 @@ class Message(ApMessage): @classmethod - def new_response(cls, host, actor, followid, accept): + def new_response(cls: Type[Message], + host: str, + actor: str, + followid: str, + accept: bool) -> Message: + return cls({ '@context': 'https://www.w3.org/ns/activitystreams', 'id': f'https://{host}/activities/{uuid4()}', @@ -250,7 +253,12 @@ class Message(ApMessage): class Response(AiohttpResponse): @classmethod - def new(cls, body='', status=200, headers=None, ctype='text'): + def new(cls: Type[Response], + body: Optional[str | bytes | dict] = '', + status: Optional[int] = 200, + headers: Optional[dict[str, str]] = None, + ctype: Optional[str] = 'text') -> Response: + kwargs = { 'status': status, 'headers': headers, @@ -270,7 +278,11 @@ class Response(AiohttpResponse): @classmethod - def new_error(cls, status, body, ctype='text'): + def new_error(cls: Type[Response], + status: int, + body: str | bytes | dict, + ctype: str = 'text') -> Response: + if ctype == 'json': body = json.dumps({'status': status, 'error': body}) @@ -278,12 +290,12 @@ class Response(AiohttpResponse): @property - def location(self): + def location(self) -> str: return self.headers.get('Location') @location.setter - def location(self, value): + def location(self, value: str) -> None: self.headers['Location'] = value @@ -295,6 +307,7 @@ class View(AbstractView): self.message: Message = None self.actor: Message = None self.instance: dict[str, str] = None + self.signer: Signer = None def __await__(self) -> Generator[Response]: @@ -335,7 +348,7 @@ class View(AbstractView): @property - def client(self) -> Client: + def client(self) -> HttpClient: return self.app.client @@ -377,9 +390,9 @@ class View(AbstractView): self.actor = await self.client.get(self.signature.keyid, sign_headers = True) if self.actor is None: - ## ld signatures aren't handled atm, so just ignore it + # ld signatures aren't handled atm, so just ignore it if self.message.type == 'Delete': - logging.verbose(f'Instance sent a delete which cannot be handled') + logging.verbose('Instance sent a delete which cannot be handled') return Response.new(status=202) logging.verbose(f'Failed to fetch actor: {self.signature.keyid}') @@ -409,7 +422,7 @@ class View(AbstractView): if (digest := Digest.new_from_digest(headers.get("digest"))): if not body: raise SignatureFailureError("Missing body for digest verification") - + if not digest.validate(body): raise SignatureFailureError("Body digest does not match") @@ -429,5 +442,5 @@ class View(AbstractView): headers["(expires)"] = self.signature.expires # pylint: disable=protected-access - if not self.actor.signer._validate_signature(headers, self.signature): + if not self.signer._validate_signature(headers, self.signature): raise SignatureFailureError("Signature does not match") diff --git a/relay/processors.py b/relay/processors.py index d315fc0..5276ef6 100644 --- a/relay/processors.py +++ b/relay/processors.py @@ -1,10 +1,8 @@ from __future__ import annotations -import asyncio import typing from cachetools import LRUCache -from uuid import uuid4 from . import logger as logging from .misc import Message @@ -16,8 +14,8 @@ if typing.TYPE_CHECKING: cache = LRUCache(1024) -def person_check(actor, software): - ## pleroma and akkoma may use Person for the actor type for some reason +def person_check(actor: str, software: str) -> bool: + # pleroma and akkoma may use Person for the actor type for some reason if software in {'akkoma', 'pleroma'} and actor.id == f'https://{actor.domain}/relay': return False @@ -25,21 +23,19 @@ def person_check(actor, software): if actor.type != 'Application': return True + return False + async def handle_relay(view: View) -> None: if view.message.objectid in cache: logging.verbose('already relayed %s', view.message.objectid) return - message = Message.new_announce( - host = view.config.host, - object = view.message.objectid - ) - + message = Message.new_announce(view.config.host, view.message.objectid) cache[view.message.objectid] = message.id logging.debug('>> relay: %s', message) - inboxes = view.database.distill_inboxes(message) + inboxes = view.database.distill_inboxes(view.message) for inbox in inboxes: view.app.push_message(inbox, message) @@ -50,15 +46,11 @@ async def handle_forward(view: View) -> None: logging.verbose('already forwarded %s', view.message.id) return - message = Message.new_announce( - host = view.config.host, - object = view.message - ) - + message = Message.new_announce(view.config.host, view.message) cache[view.message.id] = message.id logging.debug('>> forward: %s', message) - inboxes = view.database.distill_inboxes(message.message) + inboxes = view.database.distill_inboxes(view.message) for inbox in inboxes: view.app.push_message(inbox, message) @@ -162,7 +154,7 @@ processors = { } -async def run_processor(view: View): +async def run_processor(view: View) -> None: if view.message.type not in processors: logging.verbose( 'Message type "%s" from actor cannot be handled: %s', @@ -180,4 +172,4 @@ async def run_processor(view: View): view.database.save() logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id) - return await processors[view.message.type](view) + await processors[view.message.type](view) diff --git a/relay/views.py b/relay/views.py index e04688e..e1bed64 100644 --- a/relay/views.py +++ b/relay/views.py @@ -1,15 +1,13 @@ from __future__ import annotations -import aputils import asyncio import subprocess -import traceback import typing from aputils.objects import Nodeinfo, Webfinger, WellKnownNodeinfo from pathlib import Path -from . import __version__, misc +from . import __version__ from . import logger as logging from .misc import Message, Response, View from .processors import run_processor @@ -35,8 +33,16 @@ HOME_TEMPLATE = """

This is an Activity Relay for fediverse instances.

{note}

-

You may subscribe to this relay with the address: https://{host}/actor

-

To host your own relay, you may download the code at this address: https://git.pleroma.social/pleroma/relay

+

+ You may subscribe to this relay with the address: + https://{host}/actor +

+

+ To host your own relay, you may download the code at this address: + + https://git.pleroma.social/pleroma/relay + +


List of {count} registered instances:
{targets}

""" @@ -60,6 +66,8 @@ def register_route(*paths: str) -> Callable: return wrapper +# pylint: disable=unused-argument + @register_route('/') class HomeView(View): async def get(self, request: Request) -> Response: @@ -78,7 +86,7 @@ class HomeView(View): class ActorView(View): async def get(self, request: Request) -> Response: data = Message.new_actor( - host = self.config.host, + host = self.config.host, pubkey = self.database.signer.pubkey ) @@ -140,14 +148,14 @@ class WebfingerView(View): @register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}') class NodeinfoView(View): async def get(self, request: Request, niversion: str) -> Response: - data = dict( - name = 'activityrelay', - version = VERSION, - protocols = ['activitypub'], - open_regs = not self.config.whitelist_enabled, - users = 1, - metadata = {'peers': self.database.hostnames} - ) + data = { + 'name': 'activityrelay', + 'version': VERSION, + 'protocols': ['activitypub'], + 'open_regs': not self.config.whitelist_enabled, + 'users': 1, + 'metadata': {'peers': self.database.hostnames} + } if niversion == '2.1': data['repo'] = 'https://git.pleroma.social/pleroma/relay'