diff --git a/.gitignore b/.gitignore index 737b9a4..eeebd0a 100644 --- a/.gitignore +++ b/.gitignore @@ -98,3 +98,5 @@ ENV/ *.yaml *.jsonld *.sqlite3 + +test*.py diff --git a/dev-requirements.txt b/dev-requirements.txt index f0fb91f..aa8a793 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,4 +1,6 @@ flake8 == 7.0.0 +mypy == 1.9.0 pyinstaller == 6.3.0 -pylint == 3.0 watchdog == 4.0.0 + +typing_extensions >= 4.10.0; python_version < '3.11.0' diff --git a/docs/configuration.md b/docs/configuration.md index 2fad0af..9cc8cdb 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1,15 +1,19 @@ # Configuration -## General +## Config File -### Domain +These options are stored in the configuration file (usually relay.yaml) + +### General + +#### Domain Hostname the relay will be hosted on. domain: relay.example.com -### Listener +#### Listener The address and port the relay will listen on. If the reverse proxy (nginx, apache, caddy, etc) is running on the same host, it is recommended to change `listen` to `localhost` if the reverse @@ -19,7 +23,7 @@ proxy is on the same host. port: 8080 -### Push Workers +#### Push Workers The number of processes to spawn for pushing messages to subscribed instances. Leave it at 0 to automatically detect how many processes should be spawned. @@ -27,21 +31,21 @@ automatically detect how many processes should be spawned. workers: 0 -### Database type +#### Database type SQL database backend to use. Valid values are `sqlite` or `postgres`. database_type: sqlite -### Cache type +#### Cache type Cache backend to use. Valid values are `database` or `redis` cache_type: database -### Sqlite File Path +#### Sqlite File Path Path to the sqlite database file. If the path is not absolute, it is relative to the config file. directory. @@ -49,7 +53,7 @@ directory. sqlite_path: relay.jsonld -## Postgresql +### Postgresql In order to use the Postgresql backend, the user and database need to be created first. @@ -57,80 +61,130 @@ In order to use the Postgresql backend, the user and database need to be created sudo -u postgres psql -c "CREATE DATABASE activityrelay OWNER activityrelay" -### Database Name +#### Database Name Name of the database to use. name: activityrelay -### Host +#### Host Hostname, IP address, or unix socket the server is hosted on. host: /var/run/postgresql -### Port +#### Port Port number the server is listening on. port: 5432 -### Username +#### Username User to use when logging into the server. user: null -### Password +#### Password Password for the specified user. pass: null -## Redis +### Redis -### Host +#### Host Hostname, IP address, or unix socket the server is hosted on. host: /var/run/postgresql -### Port +#### Port Port number the server is listening on. port: 5432 -### Username +#### Username User to use when logging into the server. user: null -### Password +#### Password Password for the specified user. pass: null -### Database Number +#### Database Number Number of the database to use. database: 0 -### Prefix +#### Prefix Text to prefix every key with. It cannot contain a `:` character. prefix: activityrelay + +## Database Config + +These options are stored in the database and can be changed via CLI, API, or the web interface. + +### Approval Required + +When enabled, instances that try to follow the relay will have to be manually approved by an admin. + + approval-required: false + + +### Log Level + +Maximum level of messages to log. + +Valid values: `DEBUG`, `VERBOSE`, `INFO`, `WARNING`, `ERROR`, `CRITICAL` + + log-level: INFO + +### Name + +Name of your relay's instance. It will be displayed at the top of web pages and in API endpoints. + + name: ActivityRelay + + +### Note + +Short blurb that will be displayed on the relay's home and in API endpoints if set. Can be in +markdown format. + + note: null + + +### Theme + +Color theme to use for the web pages. + +Valid values: `Default`, `Pink`, `Blue` + + theme: Default + + +### Whitelist Enabled + +When enabled, only instances on the whitelist can join. Any instances currently subscribed and not +in the whitelist when this is enabled can still post. + + whitelist-enabled: False diff --git a/installation/relay.caddy b/installation/relay.caddy index 8cd9b28..94fdd53 100644 --- a/installation/relay.caddy +++ b/installation/relay.caddy @@ -1,6 +1,3 @@ -relay.example.org { - gzip - proxy / 127.0.0.1:8080 { - transparent - } +relay.example.com { + reverse_proxy / http://localhost:8080 } diff --git a/installation/relay.service b/installation/relay.service index 0325316..cf9fa6a 100644 --- a/installation/relay.service +++ b/installation/relay.service @@ -4,6 +4,7 @@ Description=ActivityPub Relay [Service] WorkingDirectory=/home/relay/relay ExecStart=/usr/bin/python3 -m relay run +Environment="IS_SYSTEMD=1" [Install] WantedBy=multi-user.target diff --git a/pyproject.toml b/pyproject.toml index d98eab8..b1bde52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,54 +3,13 @@ requires = ["setuptools","wheel"] build-backend = 'setuptools.build_meta' -[tool.pylint.main] -jobs = 0 -persistent = true -load-plugins = [ - "pylint.extensions.code_style", - "pylint.extensions.comparison_placement", - "pylint.extensions.confusing_elif", - "pylint.extensions.for_any_all", - "pylint.extensions.consider_ternary_expression", - "pylint.extensions.bad_builtin", - "pylint.extensions.dict_init_mutate", - "pylint.extensions.check_elif", - "pylint.extensions.empty_comment", - "pylint.extensions.private_import", - "pylint.extensions.redefined_variable_type", - "pylint.extensions.no_self_use", - "pylint.extensions.overlapping_exceptions", - "pylint.extensions.set_membership", - "pylint.extensions.typing" -] - - -[tool.pylint.design] -max-args = 10 -max-attributes = 100 - - -[tool.pylint.format] -indent-str = "\t" -indent-after-paren = 1 -max-line-length = 100 -single-line-if-stmt = true - - -[tool.pylint.messages_control] -disable = [ - "fixme", - "broad-exception-caught", - "cyclic-import", - "global-statement", - "invalid-name", - "missing-module-docstring", - "too-few-public-methods", - "too-many-public-methods", - "too-many-return-statements", - "wrong-import-order", - "missing-function-docstring", - "missing-class-docstring", - "consider-using-namedtuple-or-dataclass", - "confusing-consecutive-elif" -] +[tool.mypy] +show_traceback = true +install_types = true +pretty = true +disallow_untyped_decorators = true +warn_redundant_casts = true +warn_unreachable = true +warn_unused_ignores = true +ignore_missing_imports = true +follow_imports = "silent" diff --git a/relay/__init__.py b/relay/__init__.py index 0404d81..e1424ed 100644 --- a/relay/__init__.py +++ b/relay/__init__.py @@ -1 +1 @@ -__version__ = '0.3.0' +__version__ = '0.3.1' diff --git a/relay/application.py b/relay/application.py index dfa3861..628d9e5 100644 --- a/relay/application.py +++ b/relay/application.py @@ -8,9 +8,12 @@ import traceback import typing from aiohttp import web +from aiohttp.web import StaticResource from aiohttp_swagger import setup_swagger from aputils.signer import Signer from datetime import datetime, timedelta +from mimetypes import guess_type +from pathlib import Path from queue import Empty from threading import Event, Thread @@ -26,16 +29,31 @@ from .views.api import handle_api_path from .views.frontend import handle_frontend_path if typing.TYPE_CHECKING: - from collections.abc import Coroutine - from tinysql import Database, Row + from collections.abc import Callable + from bsql import Database, Row from .cache import Cache from .misc import Message, Response -# pylint: disable=unsubscriptable-object +def get_csp(request: web.Request) -> str: + data = [ + "default-src 'none'", + f"script-src 'nonce-{request['hash']}'", + f"style-src 'self' 'nonce-{request['hash']}'", + "form-action 'self'", + "connect-src 'self'", + "img-src 'self'", + "object-src 'none'", + "frame-ancestors 'none'", + f"manifest-src 'self' https://{request.app['config'].domain}" + ] + + return '; '.join(data) + ';' + class Application(web.Application): - DEFAULT: Application = None + DEFAULT: Application | None = None + def __init__(self, cfgpath: str | None, dev: bool = False): web.Application.__init__(self, @@ -48,7 +66,7 @@ class Application(web.Application): Application.DEFAULT = self - self['running'] = None + self['running'] = False self['signer'] = None self['start_time'] = None self['cleanup_thread'] = None @@ -64,14 +82,13 @@ class Application(web.Application): self['workers'] = [] self.cache.setup() - - # self.on_response_prepare.append(handle_access_log) - self.on_cleanup.append(handle_cleanup) + self.on_cleanup.append(handle_cleanup) # type: ignore for path, view in VIEWS: self.router.add_view(path, view) - setup_swagger(self, + setup_swagger( + self, ui_version = 3, swagger_from_file = get_resource('data/swagger.yaml') ) @@ -111,6 +128,11 @@ class Application(web.Application): self['signer'] = Signer(value, self.config.keyid) + @property + def template(self) -> Template: + return self['template'] + + @property def uptime(self) -> timedelta: if not self['start_time']: @@ -121,10 +143,20 @@ class Application(web.Application): return timedelta(seconds=uptime.seconds) - def push_message(self, inbox: str, message: Message, instance: Row) -> None: + def push_message(self, inbox: str, message: Message | bytes, instance: Row) -> None: self['push_queue'].put((inbox, message, instance)) + def register_static_routes(self) -> None: + if self['dev']: + static = StaticResource('/static', get_resource('frontend/static')) + + else: + static = CachedStaticResource('/static', get_resource('frontend/static')) + + self.router.register_resource(static) + + def run(self) -> None: if self["running"]: return @@ -137,6 +169,8 @@ class Application(web.Application): logging.error(f'A server is already running on {host}:{port}') return + self.register_static_routes() + logging.info(f'Starting webserver at {domain} ({host}:{port})') asyncio.run(self.handle_run()) @@ -160,6 +194,7 @@ class Application(web.Application): self.set_signal_handler(True) + self['client'].open() self['database'].connect() self['cache'].setup() self['cleanup_thread'] = CacheCleanupThread(self) @@ -174,7 +209,8 @@ class Application(web.Application): runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"') await runner.setup() - site = web.TCPSite(runner, + site = web.TCPSite( + runner, host = self.config.listen, port = self.config.port, reuse_address = True @@ -188,7 +224,7 @@ class Application(web.Application): await site.stop() - for worker in self['workers']: # pylint: disable=not-an-iterable + for worker in self['workers']: worker.stop() self.set_signal_handler(False) @@ -201,6 +237,39 @@ class Application(web.Application): self['cache'].close() +class CachedStaticResource(StaticResource): + def __init__(self, prefix: str, path: Path): + StaticResource.__init__(self, prefix, path) + + self.cache: dict[str, bytes] = {} + + for filename in path.rglob('*'): + if filename.is_dir(): + continue + + rel_path = str(filename.relative_to(path)) + + with filename.open('rb') as fd: + logging.debug('Loading static resource "%s"', rel_path) + self.cache[rel_path] = fd.read() + + + async def _handle(self, request: web.Request) -> web.StreamResponse: + rel_url = request.match_info['filename'] + + if Path(rel_url).anchor: + raise web.HTTPForbidden() + + try: + return web.Response( + body = self.cache[rel_url], + content_type = guess_type(rel_url)[0] + ) + + except KeyError: + raise web.HTTPNotFound() + + class CacheCleanupThread(Thread): def __init__(self, app: Application): Thread.__init__(self) @@ -242,16 +311,17 @@ class PushWorker(multiprocessing.Process): async def handle_queue(self) -> None: client = HttpClient() + client.open() while not self.shutdown.is_set(): try: - inbox, message, instance = self.queue.get(block=True, timeout=0.25) - await client.post(inbox, message, instance) + inbox, message, instance = self.queue.get(block=True, timeout=0.1) + asyncio.create_task(client.post(inbox, message, instance)) except Empty: - pass + await asyncio.sleep(0) - ## make sure an exception doesn't bring down the worker + # make sure an exception doesn't bring down the worker except Exception: traceback.print_exc() @@ -259,10 +329,14 @@ class PushWorker(multiprocessing.Process): @web.middleware -async def handle_response_headers(request: web.Request, handler: Coroutine) -> Response: +async def handle_response_headers(request: web.Request, handler: Callable) -> Response: resp = await handler(request) resp.headers['Server'] = 'ActivityRelay' + # Still have to figure out how csp headers work + if resp.content_type == 'text/html' and not request.path.startswith("/api"): + resp.headers['Content-Security-Policy'] = get_csp(request) + if not request.app['dev'] and request.path.endswith(('.css', '.js')): # cache for 2 weeks resp.headers['Cache-Control'] = 'public,max-age=1209600,immutable' diff --git a/relay/cache.py b/relay/cache.py index 5647106..9ea3d2b 100644 --- a/relay/cache.py +++ b/relay/cache.py @@ -13,15 +13,16 @@ from .database import get_database from .misc import Message, boolean if typing.TYPE_CHECKING: - from typing import Any + from blib import Database from collections.abc import Callable, Iterator + from typing import Any from .application import Application # todo: implement more caching backends -BACKENDS: dict[str, Cache] = {} +BACKENDS: dict[str, type[Cache]] = {} CONVERTERS: dict[str, tuple[Callable, Callable]] = { 'str': (str, str), 'int': (str, int), @@ -71,7 +72,7 @@ class Item: data.value = deserialize_value(data.value, data.value_type) if not isinstance(data.updated, datetime): - data.updated = datetime.fromtimestamp(data.updated, tz = timezone.utc) + data.updated = datetime.fromtimestamp(data.updated, tz = timezone.utc) # type: ignore return data @@ -143,7 +144,7 @@ class Cache(ABC): item.namespace, item.key, item.value, - item.type + item.value_type ) @@ -158,7 +159,7 @@ class SqlCache(Cache): def __init__(self, app: Application): Cache.__init__(self, app) - self._db = None + self._db: Database = None def get(self, namespace: str, key: str) -> Item: @@ -257,7 +258,7 @@ class RedisCache(Cache): def __init__(self, app: Application): Cache.__init__(self, app) - self._rd = None + self._rd: Redis = None # type: ignore @property @@ -275,7 +276,7 @@ class RedisCache(Cache): if not (raw_value := self._rd.get(key_name)): raise KeyError(f'{namespace}:{key}') - value_type, updated, value = raw_value.split(':', 2) + value_type, updated, value = raw_value.split(':', 2) # type: ignore return Item.from_data( namespace, key, @@ -302,7 +303,7 @@ class RedisCache(Cache): yield namespace - def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> None: + def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> Item: date = datetime.now(tz = timezone.utc).timestamp() value = serialize_value(value, value_type) @@ -311,6 +312,8 @@ class RedisCache(Cache): f'{value_type}:{date}:{value}' ) + return self.get(namespace, key) + def delete(self, namespace: str, key: str) -> None: self._rd.delete(self.get_key_name(namespace, key)) @@ -350,7 +353,7 @@ class RedisCache(Cache): options['host'] = self.app.config.rd_host options['port'] = self.app.config.rd_port - self._rd = Redis(**options) + self._rd = Redis(**options) # type: ignore def close(self) -> None: @@ -358,4 +361,4 @@ class RedisCache(Cache): return self._rd.close() - self._rd = None + self._rd = None # type: ignore diff --git a/relay/compat.py b/relay/compat.py index cc19226..9884b25 100644 --- a/relay/compat.py +++ b/relay/compat.py @@ -9,16 +9,12 @@ from functools import cached_property from pathlib import Path from urllib.parse import urlparse -from . import logger as logging -from .misc import Message, boolean +from .misc import boolean if typing.TYPE_CHECKING: - from collections.abc import Iterator from typing import Any -# pylint: disable=duplicate-code - class RelayConfig(dict): def __init__(self, path: str): dict.__init__(self, {}) @@ -46,7 +42,7 @@ class RelayConfig(dict): @property - def db(self) -> RelayDatabase: + def db(self) -> Path: return Path(self['db']).expanduser().resolve() @@ -184,121 +180,3 @@ class RelayDatabase(dict): except json.decoder.JSONDecodeError as e: if self.config.db.stat().st_size > 0: raise e from None - - - def save(self) -> None: - with self.config.db.open('w', encoding = 'UTF-8') as fd: - json.dump(self, fd, indent=4) - - - def get_inbox(self, domain: str, fail: bool = False) -> dict[str, str] | None: - if domain.startswith('http'): - domain = urlparse(domain).hostname - - if (inbox := self['relay-list'].get(domain)): - return inbox - - if fail: - raise KeyError(domain) - - return None - - - def add_inbox(self, - inbox: str, - followid: str | None = None, - software: str | None = None) -> dict[str, str]: - - assert inbox.startswith('https'), 'Inbox must be a url' - domain = urlparse(inbox).hostname - - if (instance := self.get_inbox(domain)): - if followid: - instance['followid'] = followid - - if software: - instance['software'] = software - - return instance - - self['relay-list'][domain] = { - 'domain': domain, - 'inbox': inbox, - 'followid': followid, - 'software': software - } - - logging.verbose('Added inbox to database: %s', inbox) - return self['relay-list'][domain] - - - def del_inbox(self, - domain: str, - followid: str = None, - fail: bool = False) -> bool: - - if not (data := self.get_inbox(domain, fail=False)): - if fail: - raise KeyError(domain) - - return False - - if not data['followid'] or not followid or data['followid'] == followid: - del self['relay-list'][data['domain']] - logging.verbose('Removed inbox from database: %s', data['inbox']) - return True - - if fail: - raise ValueError('Follow IDs do not match') - - logging.debug('Follow ID does not match: db = %s, object = %s', data['followid'], followid) - return False - - - def get_request(self, domain: str, fail: bool = True) -> dict[str, str] | None: - if domain.startswith('http'): - domain = urlparse(domain).hostname - - try: - return self['follow-requests'][domain] - - except KeyError as e: - if fail: - raise e - - return None - - - def add_request(self, actor: str, inbox: str, followid: str) -> None: - domain = urlparse(inbox).hostname - - try: - request = self.get_request(domain) - request['followid'] = followid - - except KeyError: - pass - - self['follow-requests'][domain] = { - 'actor': actor, - 'inbox': inbox, - 'followid': followid - } - - - def del_request(self, domain: str) -> None: - if domain.startswith('http'): - domain = urlparse(domain).hostname - - del self['follow-requests'][domain] - - - def distill_inboxes(self, message: Message) -> Iterator[str]: - src_domains = { - message.domain, - urlparse(message.object_id).netloc - } - - for domain, instance in self['relay-list'].items(): - if domain not in src_domains: - yield instance['inbox'] diff --git a/relay/config.py b/relay/config.py index 84faab1..7e95c29 100644 --- a/relay/config.py +++ b/relay/config.py @@ -6,6 +6,7 @@ import platform import typing import yaml +from dataclasses import asdict, dataclass, fields from pathlib import Path from platformdirs import user_config_dir @@ -14,6 +15,12 @@ from .misc import IS_DOCKER if typing.TYPE_CHECKING: from typing import Any + try: + from typing import Self + + except ImportError: + from typing_extensions import Self + if platform.system() == 'Windows': import multiprocessing @@ -23,61 +30,44 @@ else: CORE_COUNT = len(os.sched_getaffinity(0)) -DEFAULTS: dict[str, Any] = { +DOCKER_VALUES = { 'listen': '0.0.0.0', 'port': 8080, - 'domain': 'relay.example.com', - 'workers': CORE_COUNT, - 'db_type': 'sqlite', - 'ca_type': 'database', - 'sq_path': 'relay.sqlite3', - - 'pg_host': '/var/run/postgresql', - 'pg_port': 5432, - 'pg_user': getpass.getuser(), - 'pg_pass': None, - 'pg_name': 'activityrelay', - - 'rd_host': 'localhost', - 'rd_port': 6379, - 'rd_user': None, - 'rd_pass': None, - 'rd_database': 0, - 'rd_prefix': 'activityrelay' + 'sq_path': '/data/relay.jsonld' } -if IS_DOCKER: - DEFAULTS['sq_path'] = '/data/relay.jsonld' + +class NOVALUE: + pass +@dataclass(init = False) class Config: - def __init__(self, path: str, load: bool = False): - if path: - self.path = Path(path).expanduser().resolve() + listen: str = '0.0.0.0' + port: int = 8080 + domain: str = 'relay.example.com' + workers: int = CORE_COUNT + db_type: str = 'sqlite' + ca_type: str = 'database' + sq_path: str = 'relay.sqlite3' - else: - self.path = Config.get_config_dir() + pg_host: str = '/var/run/postgresql' + pg_port: int = 5432 + pg_user: str = getpass.getuser() + pg_pass: str | None = None + pg_name: str = 'activityrelay' - self.listen = None - self.port = None - self.domain = None - self.workers = None - self.db_type = None - self.ca_type = None - self.sq_path = None + rd_host: str = 'localhost' + rd_port: int = 6470 + rd_user: str | None = None + rd_pass: str | None = None + rd_database: int = 0 + rd_prefix: str = 'activityrelay' - self.pg_host = None - self.pg_port = None - self.pg_user = None - self.pg_pass = None - self.pg_name = None - self.rd_host = None - self.rd_port = None - self.rd_user = None - self.rd_pass = None - self.rd_database = None - self.rd_prefix = None + def __init__(self, path: str | None = None, load: bool = False): + self.path = Config.get_config_dir(path) + self.reset() if load: try: @@ -87,22 +77,36 @@ class Config: self.save() + @classmethod + def KEYS(cls: type[Self]) -> list[str]: + return list(cls.__dataclass_fields__) + + + @classmethod + def DEFAULT(cls: type[Self], key: str) -> str | int | None: + for field in fields(cls): + if field.name == key: + return field.default # type: ignore + + raise KeyError(key) + + @staticmethod def get_config_dir(path: str | None = None) -> Path: if path: return Path(path).expanduser().resolve() - dirs = ( + paths = ( Path("relay.yaml").resolve(), Path(user_config_dir("activityrelay"), "relay.yaml"), Path("/etc/activityrelay/relay.yaml") ) - for directory in dirs: - if directory.exists(): - return directory + for cfgfile in paths: + if cfgfile.exists(): + return cfgfile - return dirs[0] + return paths[0] @property @@ -130,7 +134,6 @@ class Config: def load(self) -> None: self.reset() - options = {} try: @@ -141,95 +144,85 @@ class Config: with self.path.open('r', encoding = 'UTF-8') as fd: config = yaml.load(fd, **options) - pgcfg = config.get('postgresql', {}) - rdcfg = config.get('redis', {}) if not config: raise ValueError('Config is empty') - if IS_DOCKER: - self.listen = '0.0.0.0' - self.port = 8080 - self.sq_path = '/data/relay.jsonld' + pgcfg = config.get('postgresql', {}) + rdcfg = config.get('redis', {}) - else: - self.set('listen', config.get('listen', DEFAULTS['listen'])) - self.set('port', config.get('port', DEFAULTS['port'])) - self.set('sq_path', config.get('sqlite_path', DEFAULTS['sq_path'])) + for key in type(self).KEYS(): + if IS_DOCKER and key in {'listen', 'port', 'sq_path'}: + self.set(key, DOCKER_VALUES[key]) + continue - self.set('workers', config.get('workers', DEFAULTS['workers'])) - self.set('domain', config.get('domain', DEFAULTS['domain'])) - self.set('db_type', config.get('database_type', DEFAULTS['db_type'])) - self.set('ca_type', config.get('cache_type', DEFAULTS['ca_type'])) - - for key in DEFAULTS: if key.startswith('pg'): - try: - self.set(key, pgcfg[key[3:]]) - - except KeyError: - continue + self.set(key, pgcfg.get(key[3:], NOVALUE)) + continue elif key.startswith('rd'): - try: - self.set(key, rdcfg[key[3:]]) + self.set(key, rdcfg.get(key[3:], NOVALUE)) + continue - except KeyError: - continue + cfgkey = key + + if key == 'db_type': + cfgkey = 'database_type' + + elif key == 'ca_type': + cfgkey = 'cache_type' + + elif key == 'sq_path': + cfgkey = 'sqlite_path' + + self.set(key, config.get(cfgkey, NOVALUE)) def reset(self) -> None: - for key, value in DEFAULTS.items(): - setattr(self, key, value) + for field in fields(self): + setattr(self, field.name, field.default) def save(self) -> None: self.path.parent.mkdir(exist_ok = True, parents = True) + data: dict[str, Any] = {} + + for key, value in asdict(self).items(): + if key.startswith('pg_'): + if 'postgres' not in data: + data['postgres'] = {} + + data['postgres'][key[3:]] = value + continue + + if key.startswith('rd_'): + if 'redis' not in data: + data['redis'] = {} + + data['redis'][key[3:]] = value + continue + + if key == 'db_type': + key = 'database_type' + + elif key == 'ca_type': + key = 'cache_type' + + elif key == 'sq_path': + key = 'sqlite_path' + + data[key] = value + with self.path.open('w', encoding = 'utf-8') as fd: - yaml.dump(self.to_dict(), fd, sort_keys = False) + yaml.dump(data, fd, sort_keys = False) def set(self, key: str, value: Any) -> None: - if key not in DEFAULTS: + if key not in type(self).KEYS(): raise KeyError(key) - if key in {'port', 'pg_port', 'workers'} and not isinstance(value, int): - if (value := int(value)) < 1: - if key == 'port': - value = 8080 - - elif key == 'pg_port': - value = 5432 - - elif key == 'workers': - value = len(os.sched_getaffinity(0)) + if value is NOVALUE: + return setattr(self, key, value) - - - def to_dict(self) -> dict[str, Any]: - return { - 'listen': self.listen, - 'port': self.port, - 'domain': self.domain, - 'workers': self.workers, - 'database_type': self.db_type, - 'cache_type': self.ca_type, - 'sqlite_path': self.sq_path, - 'postgres': { - 'host': self.pg_host, - 'port': self.pg_port, - 'user': self.pg_user, - 'pass': self.pg_pass, - 'name': self.pg_name - }, - 'redis': { - 'host': self.rd_host, - 'port': self.rd_port, - 'user': self.rd_user, - 'pass': self.rd_pass, - 'database': self.rd_database, - 'refix': self.rd_prefix - } - } diff --git a/relay/data/statements.sql b/relay/data/statements.sql index bc06d25..f06d4b5 100644 --- a/relay/data/statements.sql +++ b/relay/data/statements.sql @@ -23,17 +23,26 @@ SELECT * FROM inboxes WHERE domain = :value or inbox = :value or actor = :value; -- name: put-inbox -INSERT INTO inboxes (domain, actor, inbox, followid, software, created) -VALUES (:domain, :actor, :inbox, :followid, :software, :created) -ON CONFLICT (domain) DO UPDATE SET followid = :followid +INSERT INTO inboxes (domain, actor, inbox, followid, software, accepted, created) +VALUES (:domain, :actor, :inbox, :followid, :software, :accepted, :created) +ON CONFLICT (domain) DO +UPDATE SET followid = :followid, inbox = :inbox, software = :software, created = :created RETURNING *; +-- name: put-inbox-accept +UPDATE inboxes SET accepted = :accepted WHERE domain = :domain RETURNING *; + + -- name: del-inbox DELETE FROM inboxes WHERE domain = :value or inbox = :value or actor = :value; +-- name: get-request +SELECT * FROM inboxes WHERE accepted = 0 and domain = :domain; + + -- name: get-user SELECT * FROM users WHERE username = :value or handle = :value; diff --git a/relay/data/swagger.yaml b/relay/data/swagger.yaml index 9c313ae..a2a51dc 100644 --- a/relay/data/swagger.yaml +++ b/relay/data/swagger.yaml @@ -13,6 +13,10 @@ schemes: - https securityDefinitions: + Cookie: + type: apiKey + in: cookie + name: user-token Bearer: type: apiKey name: Authorization @@ -285,6 +289,50 @@ paths: schema: $ref: "#/definitions/Error" + /v1/request: + get: + tags: + - Follow Request + description: Get the list of follow requests + produces: + - application/json + responses: + "200": + description: List of instances + schema: + type: array + items: + $ref: "#/definitions/Instance" + + post: + tags: + - Follow Request + description: Approve or deny a follow request + parameters: + - in: formData + name: domain + required: true + type: string + - in: formData + name: accept + required: true + type: boolean + consumes: + - application/json + - multipart/form-data + - application/x-www-form-urlencoded + produces: + - application/json + responses: + "200": + description: Follow request successfully accepted or denied + schema: + $ref: "#/definitions/Message" + "500": + description: Follow request does not exist + schema: + $ref: "#/definitions/Error" + /v1/domain_ban: get: tags: @@ -505,6 +553,104 @@ paths: schema: $ref: "#/definitions/Error" + /v1/user: + get: + tags: + - User + description: Get a list of all local users + produces: + - application/json + responses: + "200": + description: List of users + schema: + type: array + items: + $ref: "#/definitions/User" + + post: + tags: + - User + description: Create a new user + parameters: + - in: formData + name: username + required: true + type: string + - in: formData + name: password + required: true + type: string + format: password + - in: formData + name: handle + required: false + type: string + format: email + produces: + - application/json + responses: + "200": + description: Newly created user + schema: + $ref: "#/definitions/User" + "404": + description: User already exists + schema: + $ref: "#/definitions/Error" + + patch: + tags: + - User + description: Update a user's password or handle + parameters: + - in: formData + name: username + required: true + type: string + - in: formData + name: password + required: false + type: string + format: password + - in: formData + name: handle + required: false + type: string + format: email + produces: + - application/json + responses: + "200": + description: Updated user data + schema: + $ref: "#/definitions/User" + "404": + description: User does not exist + schema: + $ref: "#/definitions/Error" + + delete: + tags: + - User + description: Delete a user + parameters: + - in: formData + name: username + required: true + type: string + produces: + - application/json + responses: + "202": + description: Successfully deleted user + schema: + $ref: "#/definitions/Message" + "404": + description: User not found + schema: + $ref: "#/definitions/Error" + /v1/whitelist: get: tags: @@ -672,6 +818,9 @@ definitions: software: description: Nodeinfo-formatted name of the instance's software type: string + accepted: + description: Whether or not the follow request has been accepted + type: boolean created: description: Date the instance joined or was added type: string @@ -701,6 +850,21 @@ definitions: description: Character string used for authenticating with the api type: string + User: + type: object + properties: + username: + description: Username of the account + type: string + handle: + description: Fediverse handle associated with the account + type: string + format: email + created: + description: Date the account was created + type: string + format: date-time + Whitelist: type: object properties: diff --git a/relay/database/__init__.py b/relay/database/__init__.py index d248713..08dbec6 100644 --- a/relay/database/__init__.py +++ b/relay/database/__init__.py @@ -3,7 +3,7 @@ from __future__ import annotations import bsql import typing -from .config import CONFIG_DEFAULTS, THEMES, get_default_value +from .config import THEMES, ConfigData from .connection import RELAY_SOFTWARE, Connection from .schema import TABLES, VERSIONS, migrate_0 @@ -11,7 +11,7 @@ from .. import logger as logging from ..misc import get_resource if typing.TYPE_CHECKING: - from .config import Config + from ..config import Config def get_database(config: Config, migrate: bool = True) -> bsql.Database: @@ -46,13 +46,14 @@ def get_database(config: Config, migrate: bool = True) -> bsql.Database: migrate_0(conn) return db - if (schema_ver := conn.get_config('schema-version')) < get_default_value('schema-version'): + if (schema_ver := conn.get_config('schema-version')) < ConfigData.DEFAULT('schema-version'): logging.info("Migrating database from version '%i'", schema_ver) for ver, func in VERSIONS.items(): if schema_ver < ver: func(conn) conn.put_config('schema-version', ver) + logging.info("Updated database to %i", ver) if (privkey := conn.get_config('private-key')): conn.app.signer = privkey diff --git a/relay/database/config.py b/relay/database/config.py index 82e2e69..3922f62 100644 --- a/relay/database/config.py +++ b/relay/database/config.py @@ -1,15 +1,23 @@ from __future__ import annotations -import json import typing +from dataclasses import Field, asdict, dataclass, fields + from .. import logger as logging from ..misc import boolean if typing.TYPE_CHECKING: - from collections.abc import Callable + from bsql import Row + from collections.abc import Callable, Sequence from typing import Any + try: + from typing import Self + + except ImportError: + from typing_extensions import Self + THEMES = { 'default': { @@ -59,39 +67,101 @@ THEMES = { } } -CONFIG_DEFAULTS: dict[str, tuple[str, Any]] = { - 'schema-version': ('int', 20240206), - 'private-key': ('str', None), - 'log-level': ('loglevel', logging.LogLevel.INFO), - 'name': ('str', 'ActivityRelay'), - 'note': ('str', 'Make a note about your instance here.'), - 'theme': ('str', 'default'), - 'whitelist-enabled': ('bool', False) -} - # serializer | deserializer -CONFIG_CONVERT: dict[str, tuple[Callable, Callable]] = { +CONFIG_CONVERT: dict[str, tuple[Callable[[Any], str], Callable[[str], Any]]] = { 'str': (str, str), 'int': (str, int), 'bool': (str, boolean), - 'json': (json.dumps, json.loads), - 'loglevel': (lambda x: x.name, logging.LogLevel.parse) + 'logging.LogLevel': (lambda x: x.name, logging.LogLevel.parse) } -def get_default_value(key: str) -> Any: - return CONFIG_DEFAULTS[key][1] +@dataclass() +class ConfigData: + schema_version: int = 20240310 + private_key: str = '' + approval_required: bool = False + log_level: logging.LogLevel = logging.LogLevel.INFO + name: str = 'ActivityRelay' + note: str = '' + theme: str = 'default' + whitelist_enabled: bool = False -def get_default_type(key: str) -> str: - return CONFIG_DEFAULTS[key][0] + def __getitem__(self, key: str) -> Any: + if (value := getattr(self, key.replace('-', '_'), None)) is None: + raise KeyError(key) + + return value -def serialize(key: str, value: Any) -> str: - type_name = get_default_type(key) - return CONFIG_CONVERT[type_name][0](value) + def __setitem__(self, key: str, value: Any) -> None: + self.set(key, value) -def deserialize(key: str, value: str) -> Any: - type_name = get_default_type(key) - return CONFIG_CONVERT[type_name][1](value) + @classmethod + def KEYS(cls: type[Self]) -> Sequence[str]: + return list(cls.__dataclass_fields__) + + + @staticmethod + def SYSTEM_KEYS() -> Sequence[str]: + return ('schema-version', 'schema_version', 'private-key', 'private_key') + + + @classmethod + def USER_KEYS(cls: type[Self]) -> Sequence[str]: + return tuple(key for key in cls.KEYS() if key not in cls.SYSTEM_KEYS()) + + + @classmethod + def DEFAULT(cls: type[Self], key: str) -> str | int | bool: + return cls.FIELD(key.replace('-', '_')).default # type: ignore + + + @classmethod + def FIELD(cls: type[Self], key: str) -> Field: + for field in fields(cls): + if field.name == key.replace('-', '_'): + return field + + raise KeyError(key) + + + @classmethod + def from_rows(cls: type[Self], rows: Sequence[Row]) -> Self: + data = cls() + set_schema_version = False + + for row in rows: + data.set(row['key'], row['value']) + + if row['key'] == 'schema-version': + set_schema_version = True + + if not set_schema_version: + data.schema_version = 0 + + return data + + + def get(self, key: str, default: Any = None, serialize: bool = False) -> Any: + field = type(self).FIELD(key) + value = getattr(self, field.name, None) + + if not serialize: + return value + + converter = CONFIG_CONVERT[str(field.type)][0] + return converter(value) + + + def set(self, key: str, value: Any) -> None: + field = type(self).FIELD(key) + converter = CONFIG_CONVERT[str(field.type)][1] + + setattr(self, field.name, converter(value)) + + + def to_dict(self) -> dict[str, Any]: + return {key.replace('_', '-'): value for key, value in asdict(self).items()} diff --git a/relay/database/connection.py b/relay/database/connection.py index 2792111..f8de1c0 100644 --- a/relay/database/connection.py +++ b/relay/database/connection.py @@ -9,22 +9,18 @@ from urllib.parse import urlparse from uuid import uuid4 from .config import ( - CONFIG_DEFAULTS, THEMES, - get_default_type, - get_default_value, - serialize, - deserialize + ConfigData ) from .. import logger as logging from ..misc import boolean, get_app if typing.TYPE_CHECKING: - from collections.abc import Iterator + from collections.abc import Iterator, Sequence from bsql import Row from typing import Any - from .application import Application + from ..application import Application from ..misc import Message @@ -46,54 +42,37 @@ class Connection(SqlConnection): return get_app() - def distill_inboxes(self, message: Message) -> Iterator[str]: + def distill_inboxes(self, message: Message) -> Iterator[Row]: src_domains = { message.domain, urlparse(message.object_id).netloc } - for inbox in self.execute('SELECT * FROM inboxes'): - if inbox['domain'] not in src_domains: - yield inbox['inbox'] + for instance in self.get_inboxes(): + if instance['domain'] not in src_domains: + yield instance def get_config(self, key: str) -> Any: - if key not in CONFIG_DEFAULTS: - raise KeyError(key) + key = key.replace('_', '-') with self.run('get-config', {'key': key}) as cur: if not (row := cur.one()): - return get_default_value(key) + return ConfigData.DEFAULT(key) - if row['value']: - return deserialize(row['key'], row['value']) - - return None + data = ConfigData() + data.set(row['key'], row['value']) + return data.get(key) - def get_config_all(self) -> dict[str, Any]: + def get_config_all(self) -> ConfigData: with self.run('get-config-all', None) as cur: - db_config = {row['key']: row['value'] for row in cur} - - config = {} - - for key, data in CONFIG_DEFAULTS.items(): - try: - config[key] = deserialize(key, db_config[key]) - - except KeyError: - if key == 'schema-version': - config[key] = 0 - - else: - config[key] = data[1] - - return config + return ConfigData.from_rows(tuple(cur.all())) def put_config(self, key: str, value: Any) -> Any: - if key not in CONFIG_DEFAULTS: - raise KeyError(key) + field = ConfigData.FIELD(key) + key = field.name.replace('_', '-') if key == 'private-key': self.app.signer = value @@ -102,73 +81,70 @@ class Connection(SqlConnection): value = logging.LogLevel.parse(value) logging.set_level(value) - elif key == 'whitelist-enabled': + elif key in {'approval-required', 'whitelist-enabled'}: value = boolean(value) elif key == 'theme': if value not in THEMES: raise ValueError(f'"{value}" is not a valid theme') + data = ConfigData() + data.set(key, value) + params = { 'key': key, - 'value': serialize(key, value) if value is not None else None, - 'type': get_default_type(key) + 'value': data.get(key, serialize = True), + 'type': 'LogLevel' if field.type == 'logging.LogLevel' else field.type } with self.run('put-config', params): - return value + pass + + return data.get(key) def get_inbox(self, value: str) -> Row: with self.run('get-inbox', {'value': value}) as cur: - return cur.one() + return cur.one() # type: ignore + + + def get_inboxes(self) -> Sequence[Row]: + with self.execute("SELECT * FROM inboxes WHERE accepted = 1") as cur: + return tuple(cur.all()) def put_inbox(self, domain: str, - inbox: str, + inbox: str | None = None, actor: str | None = None, followid: str | None = None, - software: str | None = None) -> Row: + software: str | None = None, + accepted: bool = True) -> Row: - params = { - 'domain': domain, + params: dict[str, Any] = { 'inbox': inbox, 'actor': actor, 'followid': followid, 'software': software, - 'created': datetime.now(tz = timezone.utc) + 'accepted': accepted } - with self.run('put-inbox', params) as cur: - return cur.one() + if not self.get_inbox(domain): + if not inbox: + raise ValueError("Missing inbox") + params['domain'] = domain + params['created'] = datetime.now(tz = timezone.utc) - def update_inbox(self, - inbox: str, - actor: str | None = None, - followid: str | None = None, - software: str | None = None) -> Row: + with self.run('put-inbox', params) as cur: + return cur.one() # type: ignore - if not (actor or followid or software): - raise ValueError('Missing "actor", "followid", and/or "software"') + for key, value in tuple(params.items()): + if value is None: + del params[key] - data = {} - - if actor: - data['actor'] = actor - - if followid: - data['followid'] = followid - - if software: - data['software'] = software - - statement = Update('inboxes', data) - statement.set_where("inbox", inbox) - - with self.query(statement): - return self.get_inbox(inbox) + with self.update('inboxes', params, domain = domain) as cur: + return cur.one() # type: ignore def del_inbox(self, value: str) -> bool: @@ -179,17 +155,64 @@ class Connection(SqlConnection): return cur.row_count == 1 + def get_request(self, domain: str) -> Row: + with self.run('get-request', {'domain': domain}) as cur: + if not (row := cur.one()): + raise KeyError(domain) + + return row + + + def get_requests(self) -> Sequence[Row]: + with self.execute('SELECT * FROM inboxes WHERE accepted = 0') as cur: + return tuple(cur.all()) + + + def put_request_response(self, domain: str, accepted: bool) -> Row: + instance = self.get_request(domain) + + if not accepted: + self.del_inbox(domain) + return instance + + params = { + 'domain': domain, + 'accepted': accepted + } + + with self.run('put-inbox-accept', params) as cur: + return cur.one() # type: ignore + + def get_user(self, value: str) -> Row: with self.run('get-user', {'value': value}) as cur: - return cur.one() + return cur.one() # type: ignore def get_user_by_token(self, code: str) -> Row: with self.run('get-user-by-token', {'code': code}) as cur: - return cur.one() + return cur.one() # type: ignore - def put_user(self, username: str, password: str, handle: str | None = None) -> Row: + def put_user(self, username: str, password: str | None, handle: str | None = None) -> Row: + if self.get_user(username): + data: dict[str, str | datetime | None] = {} + + if password: + data['hash'] = self.hasher.hash(password) + + if handle: + data['handle'] = handle + + stmt = Update("users", data) + stmt.set_where("username", username) + + with self.query(stmt) as cur: + return cur.one() # type: ignore + + if password is None: + raise ValueError('Password cannot be empty') + data = { 'username': username, 'hash': self.hasher.hash(password), @@ -198,7 +221,7 @@ class Connection(SqlConnection): } with self.run('put-user', data) as cur: - return cur.one() + return cur.one() # type: ignore def del_user(self, username: str) -> None: @@ -213,7 +236,7 @@ class Connection(SqlConnection): def get_token(self, code: str) -> Row: with self.run('get-token', {'code': code}) as cur: - return cur.one() + return cur.one() # type: ignore def put_token(self, username: str) -> Row: @@ -224,7 +247,7 @@ class Connection(SqlConnection): } with self.run('put-token', data) as cur: - return cur.one() + return cur.one() # type: ignore def del_token(self, code: str) -> None: @@ -237,7 +260,7 @@ class Connection(SqlConnection): domain = urlparse(domain).netloc with self.run('get-domain-ban', {'domain': domain}) as cur: - return cur.one() + return cur.one() # type: ignore def put_domain_ban(self, @@ -253,7 +276,7 @@ class Connection(SqlConnection): } with self.run('put-domain-ban', params) as cur: - return cur.one() + return cur.one() # type: ignore def update_domain_ban(self, @@ -292,7 +315,7 @@ class Connection(SqlConnection): def get_software_ban(self, name: str) -> Row: with self.run('get-software-ban', {'name': name}) as cur: - return cur.one() + return cur.one() # type: ignore def put_software_ban(self, @@ -308,7 +331,7 @@ class Connection(SqlConnection): } with self.run('put-software-ban', params) as cur: - return cur.one() + return cur.one() # type: ignore def update_software_ban(self, @@ -347,7 +370,7 @@ class Connection(SqlConnection): def get_domain_whitelist(self, domain: str) -> Row: with self.run('get-domain-whitelist', {'domain': domain}) as cur: - return cur.one() + return cur.one() # type: ignore def put_domain_whitelist(self, domain: str) -> Row: @@ -357,7 +380,7 @@ class Connection(SqlConnection): } with self.run('put-domain-whitelist', params) as cur: - return cur.one() + return cur.one() # type: ignore def del_domain_whitelist(self, domain: str) -> bool: diff --git a/relay/database/schema.py b/relay/database/schema.py index e3a0303..ba39ed2 100644 --- a/relay/database/schema.py +++ b/relay/database/schema.py @@ -2,12 +2,13 @@ from __future__ import annotations import typing -from bsql import Column, Connection, Table, Tables +from bsql import Column, Table, Tables -from .config import get_default_value +from .config import ConfigData if typing.TYPE_CHECKING: from collections.abc import Callable + from .connection import Connection VERSIONS: dict[int, Callable] = {} @@ -25,6 +26,7 @@ TABLES: Tables = Tables( Column('inbox', 'text', unique = True, nullable = False), Column('followid', 'text'), Column('software', 'text'), + Column('accepted', 'boolean'), Column('created', 'timestamp', nullable = False) ), Table( @@ -70,9 +72,15 @@ def migration(func: Callable) -> Callable: def migrate_0(conn: Connection) -> None: conn.create_tables() - conn.put_config('schema-version', get_default_value('schema-version')) + conn.put_config('schema-version', ConfigData.DEFAULT('schema-version')) @migration def migrate_20240206(conn: Connection) -> None: conn.create_tables() + + +@migration +def migrate_20240310(conn: Connection) -> None: + conn.execute("ALTER TABLE inboxes ADD COLUMN accepted BOOLEAN") + conn.execute("UPDATE inboxes SET accepted = 1") diff --git a/relay/dev.py b/relay/dev.py index 6407068..7517946 100644 --- a/relay/dev.py +++ b/relay/dev.py @@ -4,18 +4,20 @@ import subprocess import sys import time -from datetime import datetime +from datetime import datetime, timedelta from pathlib import Path from tempfile import TemporaryDirectory +from typing import Sequence from . import __version__ +from . import logger as logging try: from watchdog.observers import Observer from watchdog.events import PatternMatchingEventHandler except ImportError: - class PatternMatchingEventHandler: + class PatternMatchingEventHandler: # type: ignore pass @@ -45,9 +47,25 @@ def cli_install(): @cli.command('lint') @click.argument('path', required = False, default = 'relay') -def cli_lint(path): - subprocess.run([sys.executable, '-m', 'flake8', path], check = False) - subprocess.run([sys.executable, '-m', 'pylint', path], check = False) +@click.option('--strict', '-s', is_flag = True, help = 'Enable strict mode for mypy') +@click.option('--watch', '-w', is_flag = True, + help = 'Automatically, re-run the linters on source change') +def cli_lint(path: str, strict: bool, watch: bool) -> None: + flake8 = [sys.executable, '-m', 'flake8', path] + mypy = [sys.executable, '-m', 'mypy', path] + + if strict: + mypy.append('--strict') + + if watch: + handle_run_watcher(mypy, flake8, wait = True) + return + + click.echo('----- flake8 -----') + subprocess.run(flake8) + + click.echo('\n\n----- mypy -----') + subprocess.run(mypy) @cli.command('build') @@ -80,11 +98,21 @@ def cli_build(): @cli.command('run') -def cli_run(): +@click.option('--dev', '-d', is_flag = True) +def cli_run(dev: bool): print('Starting process watcher') - handler = WatchHandler() - handler.run_proc() + cmd = [sys.executable, '-m', 'relay', 'run'] + + if dev: + cmd.append('-d') + + handle_run_watcher(cmd) + + +def handle_run_watcher(*commands: Sequence[str], wait: bool = False): + handler = WatchHandler(*commands, wait = wait) + handler.run_procs() watcher = Observer() watcher.schedule(handler, str(SCRIPT), recursive=True) @@ -92,13 +120,12 @@ def cli_run(): try: while True: - handler.proc.stdin.write(sys.stdin.read().encode('UTF-8')) - handler.proc.stdin.flush() + time.sleep(1) except KeyboardInterrupt: pass - handler.kill_proc() + handler.kill_procs() watcher.stop() watcher.join() @@ -106,58 +133,65 @@ def cli_run(): class WatchHandler(PatternMatchingEventHandler): patterns = ['*.py'] - cmd = [sys.executable, '-m', 'relay', 'run', '-d'] - def __init__(self): + def __init__(self, *commands: Sequence[str], wait: bool = False): PatternMatchingEventHandler.__init__(self) - self.proc = None - self.last_restart = None + self.commands: Sequence[Sequence[str]] = commands + self.wait: bool = wait + self.procs: list[subprocess.Popen] = [] + self.last_restart: datetime = datetime.now() - def kill_proc(self): - if self.proc.poll() is not None: - return + def kill_procs(self): + for proc in self.procs: + if proc.poll() is not None: + continue - print(f'Terminating process {self.proc.pid}') - self.proc.terminate() - sec = 0.0 + logging.info(f'Terminating process {proc.pid}') + proc.terminate() + sec = 0.0 - while self.proc.poll() is None: - time.sleep(0.1) - sec += 0.1 + while proc.poll() is None: + time.sleep(0.1) + sec += 0.1 - if sec >= 5: - print('Failed to terminate. Killing process...') - self.proc.kill() - break + if sec >= 5: + logging.error('Failed to terminate. Killing process...') + proc.kill() + break - print('Process terminated') + logging.info('Process terminated') - def run_proc(self, restart=False): - timestamp = datetime.timestamp(datetime.now()) - self.last_restart = timestamp if not self.last_restart else 0 - - if restart and self.proc.pid != '': - if timestamp - 3 < self.last_restart: + def run_procs(self, restart: bool = False): + if restart: + if datetime.now() - timedelta(seconds = 3) < self.last_restart: return - self.kill_proc() + self.kill_procs() - # pylint: disable=consider-using-with - self.proc = subprocess.Popen(self.cmd, stdin = subprocess.PIPE) - self.last_restart = timestamp + self.last_restart = datetime.now() - print(f'Started process with PID {self.proc.pid}') + if self.wait: + self.procs = [] + + for cmd in self.commands: + logging.info('Running command: %s', ' '.join(cmd)) + subprocess.run(cmd) + + else: + self.procs = list(subprocess.Popen(cmd) for cmd in self.commands) + pids = (str(proc.pid) for proc in self.procs) + logging.info('Started processes with PIDs: %s', ', '.join(pids)) def on_any_event(self, event): if event.event_type not in ['modified', 'created', 'deleted']: return - self.run_proc(restart = True) + self.run_procs(restart = True) if __name__ == '__main__': diff --git a/relay/frontend/base.haml b/relay/frontend/base.haml index e1da33c..5c4bc1b 100644 --- a/relay/frontend/base.haml +++ b/relay/frontend/base.haml @@ -11,8 +11,10 @@ %title << {{config.name}}: {{page}} %meta(charset="UTF-8") %meta(name="viewport" content="width=device-width, initial-scale=1") - %link(rel="stylesheet" type="text/css" href="/theme/{{theme_name}}.css") - %link(rel="stylesheet" type="text/css" href="/style.css") + %link(rel="stylesheet" type="text/css" href="/theme/{{config.theme}}.css" nonce="{{view.request['hash']}}" class="theme") + %link(rel="stylesheet" type="text/css" href="/static/style.css" nonce="{{view.request['hash']}}") + %link(rel="manifest" href="/manifest.json") + %script(type="application/javascript" src="/static/api.js" nonce="{{view.request['hash']}}" defer) -block head %body @@ -35,22 +37,23 @@ -else {{menu_item("Login", "/login")}} + %ul#notifications + #container #header.section %span#menu-open << ⁞ - %span.title-container - %a.title(href="/") -> =config.name - - -if view.request.path not in ["/", "/login"] - .page -> =page - + %a.title(href="/") -> =config.name .empty -if error - .error.section -> =error + %fieldset.error.section + %legend << Error + =error -if message - .message.section -> =message + %fieldset.message.section + %legend << Message + =message #content(class="page-{{page.lower().replace(' ', '_')}}") -block content @@ -69,26 +72,3 @@ .version %a(href="https://git.pleroma.social/pleroma/relay") ActivityRelay/{{version}} - - %script(type="application/javascript") - const body = document.getElementById("container") - const menu = document.getElementById("menu"); - const menu_open = document.getElementById("menu-open"); - const menu_close = document.getElementById("menu-close"); - - menu_open.addEventListener("click", (event) => { - var new_value = menu.attributes.visible.nodeValue === "true" ? "false" : "true"; - menu.attributes.visible.nodeValue = new_value; - }); - - menu_close.addEventListener("click", (event) => { - menu.attributes.visible.nodeValue = "false" - }); - - body.addEventListener("click", (event) => { - if (event.target === menu_open) { - return; - } - - menu.attributes.visible.nodeValue = "false"; - }); diff --git a/relay/frontend/functions.haml b/relay/frontend/functions.haml new file mode 100644 index 0000000..fe44db6 --- /dev/null +++ b/relay/frontend/functions.haml @@ -0,0 +1,16 @@ +-macro new_checkbox(name, checked) + -if checked + %input(id="{{name}}" type="checkbox" checked) + + -else + %input(id="{{name}}" type="checkbox") + + +-macro new_select(name, selected, items) + %select(id="{{name}}") + -for item in items + -if item == selected + %option(value="{{item}}" selected) -> =item.title() + + -else + %option(value="{{item}}") -> =item.title() diff --git a/relay/frontend/page/admin-config.haml b/relay/frontend/page/admin-config.haml index 4028eb1..e5df986 100644 --- a/relay/frontend/page/admin-config.haml +++ b/relay/frontend/page/admin-config.haml @@ -1,37 +1,29 @@ -extends "base.haml" -set page="Config" + +-block head + %script(type="application/javascript" src="/static/config.js" nonce="{{view.request['hash']}}" defer) + +-import "functions.haml" as func -block content - %form.section(action="/admin/config" method="POST") + %fieldset.section + %legend << Config + .grid-2col %label(for="name") << Name - %input(id = "name" name="name" placeholder="Relay Name" value="{{config.name or ''}}") + %input(id = "name" placeholder="Relay Name" value="{{config.name or ''}}") - %label(for="description") << Description - %textarea(id="description" name="note" value="{{config.note}}") << {{config.note}} + %label(for="note") << Description + %textarea(id="note" value="{{config.note or ''}}") << {{config.note}} %label(for="theme") << Color Theme - %select(id="theme" name="theme") - -for theme in themes - -if theme == config.theme - %option(value="{{theme}}" selected) -> =theme.title() - - -else - %option(value="{{theme}}") -> =theme.title() + =func.new_select("theme", config.theme, themes) %label(for="log-level") << Log Level - %select(id="log-level" name="log-level") - -for level in LogLevel - -if level == config["log-level"] - %option(value="{{level.name}}" selected) -> =level.name.title() - - -else - %option(value="{{level.name}}") -> =level.name.title() + =func.new_select("log-level", config.log_level.name, levels) %label(for="whitelist-enabled") << Whitelist - -if config["whitelist-enabled"] - %input(id="whitelist-enabled" name="whitelist-enabled" type="checkbox" checked) + =func.new_checkbox("whitelist-enabled", config.whitelist_enabled) - -else - %input(id="whitelist-enabled" name="whitelist-enabled" type="checkbox") - - %input(type="submit" value="Save") + %label(for="approval-required") << Approval Required + =func.new_checkbox("approval-required", config.approval_required) diff --git a/relay/frontend/page/admin-domain_bans.haml b/relay/frontend/page/admin-domain_bans.haml index fbee683..19ae4ae 100644 --- a/relay/frontend/page/admin-domain_bans.haml +++ b/relay/frontend/page/admin-domain_bans.haml @@ -1,48 +1,53 @@ -extends "base.haml" -set page="Domain Bans" + +-block head + %script(type="application/javascript" src="/static/domain_ban.js" nonce="{{view.request['hash']}}" defer) + -block content %details.section %summary << Ban Domain - %form(action="/admin/domain_bans" method="POST") - #add-item - %label(for="domain") << Domain - %input(type="domain" id="domain" name="domain" placeholder="Domain") + #add-item + %label(for="new-domain") << Domain + %input(type="domain" id="new-domain" placeholder="Domain") - %label(for="reason") << Ban Reason - %textarea(id="reason" name="reason") << {{""}} + %label(for="new-reason") << Ban Reason + %textarea(id="new-reason") << {{""}} - %label(for="note") << Admin Note - %textarea(id="note" name="note") << {{""}} + %label(for="new-note") << Admin Note + %textarea(id="new-note") << {{""}} - %input(type="submit" value="Ban Domain") + %input#new-ban(type="button" value="Ban Domain") - #data-table.section - %table - %thead - %tr - %td.domain << Instance - %td << Date - %td.remove + %fieldset.section + %legend << Domain Bans - %tbody - -for ban in bans + .data-table + %table + %thead %tr - %td.domain - %details - %summary -> =ban.domain - %form(action="/admin/domain_bans" method="POST") - .grid-2col - .reason << Reason - %textarea.reason(id="reason" name="reason") << {{ban.reason or ""}} - - .note << Note - %textarea.note(id="note" name="note") << {{ban.note or ""}} - - %input(type="hidden" name="domain" value="{{ban.domain}}") - %input(type="submit" value="Update") - - %td.date - =ban.created.strftime("%Y-%m-%d") - + %td.domain << Domain + %td << Date %td.remove - %a(href="/admin/domain_bans/delete/{{ban.domain}}" title="Unban domain") << ✖ + + %tbody + -for ban in bans + %tr(id="{{ban.domain}}") + %td.domain + %details + %summary -> =ban.domain + + .grid-2col + %label.reason(for="{{ban.domain}}-reason") << Reason + %textarea.reason(id="{{ban.domain}}-reason") << {{ban.reason or ""}} + + %label.note(for="{{ban.domain}}-note") << Note + %textarea.note(id="{{ban.domain}}-note") << {{ban.note or ""}} + + %input.update-ban(type="button" value="Update") + + %td.date + =ban.created.strftime("%Y-%m-%d") + + %td.remove + %a(href="#" title="Unban domain") << ✖ diff --git a/relay/frontend/page/admin-instances.haml b/relay/frontend/page/admin-instances.haml index 106e31d..2e43f48 100644 --- a/relay/frontend/page/admin-instances.haml +++ b/relay/frontend/page/admin-instances.haml @@ -1,44 +1,81 @@ -extends "base.haml" -set page="Instances" + +-block head + %script(type="application/javascript" src="/static/instance.js" nonce="{{view.request['hash']}}" defer) + -block content %details.section %summary << Add Instance - %form(action="/admin/instances" method="POST") - #add-item - %label(for="domain") << Domain - %input(type="domain" id="domain" name="domain" placeholder="Domain") + #add-item + %label(for="new-actor") << Actor + %input(type="url" id="new-actor" placeholder="Actor URL") - %label(for="actor") << Actor URL - %input(type="url" id="actor" name="actor" placeholder="Actor URL") + %label(for="new-inbox") << Inbox + %input(type="url" id="new-inbox" placeholder="Inbox URL") - %label(for="inbox") << Inbox URL - %input(type="url" id="inbox" name="inbox" placeholder="Inbox URL") + %label(for="new-followid") << Follow ID + %input(type="url" id="new-followid" placeholder="Follow ID URL") - %label(for="software") << Software - %input(name="software" id="software" placeholder="software") + %label(for="new-software") << Software + %input(id="new-software" placeholder="software") - %input(type="submit" value="Add Instance") + %input#add-instance(type="button" value="Add Instance") - #data-table.section - %table - %thead - %tr - %td.instance << Instance - %td.software << Software - %td.date << Joined - %td.remove + -if requests + %fieldset.section.requests + %legend << Follow Requests + .data-table + %table#requests + %thead + %tr + %td.instance << Instance + %td.software << Software + %td.date << Joined + %td.approve + %td.deny - %tbody - -for instance in instances + %tbody + -for request in requests + %tr(id="{{request.domain}}") + %td.instance + %a(href="https://{{request.domain}}" target="_new") -> =request.domain + + %td.software + =request.software or "n/a" + + %td.date + =request.created.strftime("%Y-%m-%d") + + %td.approve + %a(href="#" title="Approve Request") << ✓ + + %td.deny + %a(href="#" title="Deny Request") << ✖ + + %fieldset.section.instances + %legend << Instances + + .data-table + %table#instances + %thead %tr - %td.instance - %a(href="https://{{instance.domain}}/" target="_new") -> =instance.domain - - %td.software - =instance.software or "n/a" - - %td.date - =instance.created.strftime("%Y-%m-%d") - + %td.instance << Instance + %td.software << Software + %td.date << Joined %td.remove - %a(href="/admin/instances/delete/{{instance.domain}}" title="Remove Instance") << ✖ + + %tbody + -for instance in instances + %tr(id="{{instance.domain}}") + %td.instance + %a(href="https://{{instance.domain}}/" target="_new") -> =instance.domain + + %td.software + =instance.software or "n/a" + + %td.date + =instance.created.strftime("%Y-%m-%d") + + %td.remove + %a(href="#" title="Remove Instance") << ✖ diff --git a/relay/frontend/page/admin-software_bans.haml b/relay/frontend/page/admin-software_bans.haml index 9490405..9bda3be 100644 --- a/relay/frontend/page/admin-software_bans.haml +++ b/relay/frontend/page/admin-software_bans.haml @@ -1,48 +1,53 @@ -extends "base.haml" -set page="Software Bans" + +-block head + %script(type="application/javascript" src="/static/software_ban.js" nonce="{{view.request['hash']}}" defer) + -block content %details.section %summary << Ban Software - %form(action="/admin/software_bans" method="POST") - #add-item - %label(for="name") << Name - %input(id="name" name="name" placeholder="Name") + #add-item + %label(for="new-name") << Domain + %input(type="name" id="new-name" placeholder="Domain") - %label(for="reason") << Ban Reason - %textarea(id="reason" name="reason") << {{""}} + %label(for="new-reason") << Ban Reason + %textarea(id="new-reason") << {{""}} - %label(for="note") << Admin Note - %textarea(id="note" name="note") << {{""}} + %label(for="new-note") << Admin Note + %textarea(id="new-note") << {{""}} - %input(type="submit" value="Ban Software") + %input#new-ban(type="button" value="Ban Software") - #data-table.section - %table - %thead - %tr - %td.name << Instance - %td << Date - %td.remove + %fieldset.section + %legend << Software Bans - %tbody - -for ban in bans + .data-table + %table#bans + %thead %tr - %td.name - %details - %summary -> =ban.name - %form(action="/admin/software_bans" method="POST") - .grid-2col - .reason << Reason - %textarea.reason(id="reason" name="reason") << {{ban.reason or ""}} - - .note << Note - %textarea.note(id="note" name="note") << {{ban.note or ""}} - - %input(type="hidden" name="name" value="{{ban.name}}") - %input(type="submit" value="Update") - - %td.date - =ban.created.strftime("%Y-%m-%d") - + %td.name << Name + %td << Date %td.remove - %a(href="/admin/software_bans/delete/{{ban.name}}" title="Unban software") << ✖ + + %tbody + -for ban in bans + %tr(id="{{ban.name}}") + %td.name + %details + %summary -> =ban.name + + .grid-2col + %label.reason(for="{{ban.name}}-reason") << Reason + %textarea.reason(id="{{ban.name}}-reason") << {{ban.reason or ""}} + + %label.note(for="{{ban.name}}-note") << Note + %textarea.note(id="{{ban.name}}-note") << {{ban.note or ""}} + + %input.update-ban(type="button" value="Update") + + %td.date + =ban.created.strftime("%Y-%m-%d") + + %td.remove + %a(href="#" title="Unban name") << ✖ diff --git a/relay/frontend/page/admin-users.haml b/relay/frontend/page/admin-users.haml index 65c268e..50058d7 100644 --- a/relay/frontend/page/admin-users.haml +++ b/relay/frontend/page/admin-users.haml @@ -1,44 +1,50 @@ -extends "base.haml" -set page="Users" + +-block head + %script(type="application/javascript" src="/static/user.js" nonce="{{view.request['hash']}}" defer) + -block content %details.section %summary << Add User - %form(action="/admin/users", method="POST") - #add-item - %label(for="username") << Username - %input(id="username" name="username" placeholder="Username") + #add-item + %label(for="new-username") << Username + %input(id="new-username" name="username" placeholder="Username" autocomplete="off") - %label(for="password") << Password - %input(type="password" id="password" name="password" placeholder="Password") + %label(for="new-password") << Password + %input(id="new-password" type="password" placeholder="Password" autocomplete="off") - %label(for="password2") << Password Again - %input(type="password" id="password2" name="password2" placeholder="Password Again") + %label(for="new-password2") << Password Again + %input(id="new-password2" type="password" placeholder="Password Again" autocomplete="off") - %label(for="handle") << Handle - %input(type="email" name="handle" id="handle" placeholder="handle") + %label(for="new-handle") << Handle + %input(id="new-handle" type="email" placeholder="handle" autocomplete="off") - %input(type="submit" value="Add User") + %input#new-user(type="button" value="Add User") - #data-table.section - %table - %thead - %tr - %td.username << Username - %td.handle << Handle - %td.date << Joined - %td.remove + %fieldset.section + %legend << Users - %tbody - -for user in users + .data-table + %table#users + %thead %tr - %td.username - =user.username - - %td.handle - =user.handle or "n/a" - - %td.date - =user.created.strftime("%Y-%m-%d") - + %td.username << Username + %td.handle << Handle + %td.date << Joined %td.remove - %a(href="/admin/users/delete/{{user.username}}" title="Remove User") << ✖ + + %tbody + -for user in users + %tr(id="{{user.username}}") + %td.username + =user.username + + %td.handle + =user.handle or "n/a" + + %td.date + =user.created.strftime("%Y-%m-%d") + + %td.remove + %a(href="#" title="Remove User") << ✖ diff --git a/relay/frontend/page/admin-whitelist.haml b/relay/frontend/page/admin-whitelist.haml index b294552..f494aef 100644 --- a/relay/frontend/page/admin-whitelist.haml +++ b/relay/frontend/page/admin-whitelist.haml @@ -1,17 +1,22 @@ -extends "base.haml" -set page="Whitelist" + +-block head + %script(type="application/javascript" src="/static/whitelist.js" nonce="{{view.request['hash']}}" defer) + -block content %details.section %summary << Add Domain - %form(action="/admin/whitelist" method="POST") - #add-item - %label(for="domain") << Domain - %input(type="domain" id="domain" name="domain" placeholder="Domain") + #add-item + %label(for="new-domain") << Domain + %input(type="domain" id="new-domain" placeholder="Domain") - %input(type="submit" value="Add Domain") + %input#new-item(type="button" value="Add Domain") - #data-table.section - %table + %fieldset.data-table.section + %legend << Whitelist + + %table#whitelist %thead %tr %td.domain << Domain @@ -20,7 +25,7 @@ %tbody -for item in whitelist - %tr + %tr(id="{{item.domain}}") %td.domain =item.domain @@ -28,4 +33,4 @@ =item.created.strftime("%Y-%m-%d") %td.remove - %a(href="/admin/whitelist/delete/{{item.domain}}" title="Remove whitlisted domain") << ✖ + %a(href="#" title="Remove whitlisted domain") << ✖ diff --git a/relay/frontend/page/home.haml b/relay/frontend/page/home.haml index 7f09644..f9618fc 100644 --- a/relay/frontend/page/home.haml +++ b/relay/frontend/page/home.haml @@ -1,10 +1,9 @@ -extends "base.haml" -set page = "Home" -block content - .section - -for line in config.note.splitlines() - -if line - %p -> =line + -if config.note + .section + -markdown -> =config.note .section %p @@ -14,23 +13,35 @@ You may subscribe to this relay with the address: %a(href="https://{{domain}}/actor") << https://{{domain}}/actor - -if config["whitelist-enabled"] - %p.section.message - Note: The whitelist is enabled on this instance. Ask the admin to add your instance - before joining. + -if config.approval_required + %fieldset.section.message + %legend << Require Approval - #data-table.section - %table - %thead - %tr - %td.instance << Instance - %td.date << Joined + Follow requests require approval. You will need to wait for an admin to accept or deny + your request. - %tbody - -for instance in instances + -elif config.whitelist_enabled + %fieldset.section.message + %legend << Whitelist Enabled + + The whitelist is enabled on this instance. Ask the admin to add your instance before + joining. + + %fieldset.section + %legend << Instances + + .data-table + %table + %thead %tr - %td.instance -> %a(href="https://{{instance.domain}}/" target="_new") - =instance.domain + %td.instance << Instance + %td.date << Joined - %td.date - =instance.created.strftime("%Y-%m-%d") + %tbody + -for instance in instances + %tr + %td.instance -> %a(href="https://{{instance.domain}}/" target="_new") + =instance.domain + + %td.date + =instance.created.strftime("%Y-%m-%d") diff --git a/relay/frontend/page/login.haml b/relay/frontend/page/login.haml index 1e08185..bf1ab1c 100644 --- a/relay/frontend/page/login.haml +++ b/relay/frontend/page/login.haml @@ -1,7 +1,13 @@ -extends "base.haml" -set page="Login" + +-block head + %script(type="application/javascript" src="/static/login.js" nonce="{{view.request['hash']}}" defer) + -block content - %form.section(action="/login" method="POST") + %fieldset.section + %legend << Login + .grid-2col %label(for="username") << Username %input(id="username" name="username" placeholder="Username" value="{{username or ''}}") @@ -9,4 +15,4 @@ %label(for="password") << Password %input(id="password" name="password" placeholder="Password" type="password") - %input(type="submit" value="Login") + %input.submit(type="button" value="Login") diff --git a/relay/frontend/static/api.js b/relay/frontend/static/api.js new file mode 100644 index 0000000..65423ba --- /dev/null +++ b/relay/frontend/static/api.js @@ -0,0 +1,135 @@ +// toast notifications + +const notifications = document.querySelector("#notifications") + + +function remove_toast(toast) { + toast.classList.add("hide"); + + if (toast.timeoutId) { + clearTimeout(toast.timeoutId); + } + + setTimeout(() => toast.remove(), 300); +} + +function toast(text, type="error", timeout=5) { + const toast = document.createElement("li"); + toast.className = `section ${type}` + toast.innerHTML = `${text}✖` + + toast.querySelector("a").addEventListener("click", async (event) => { + event.preventDefault(); + await remove_toast(toast); + }); + + notifications.appendChild(toast); + toast.timeoutId = setTimeout(() => remove_toast(toast), timeout * 1000); +} + + +// menu + +const body = document.getElementById("container") +const menu = document.getElementById("menu"); +const menu_open = document.getElementById("menu-open"); +const menu_close = document.getElementById("menu-close"); + + +menu_open.addEventListener("click", (event) => { + var new_value = menu.attributes.visible.nodeValue === "true" ? "false" : "true"; + menu.attributes.visible.nodeValue = new_value; +}); + +menu_close.addEventListener("click", (event) => { + menu.attributes.visible.nodeValue = "false" +}); + +body.addEventListener("click", (event) => { + if (event.target === menu_open) { + return; + } + + menu.attributes.visible.nodeValue = "false"; +}); + + +// misc + +function get_date_string(date) { + var year = date.getFullYear().toString(); + var month = date.getMonth().toString(); + var day = date.getDay().toString(); + + if (month.length === 1) { + month = "0" + month; + } + + if (day.length === 1) { + day = "0" + day + } + + return `${year}-${month}-${day}`; +} + + +function append_table_row(table, row_name, row) { + var table_row = table.insertRow(-1); + table_row.id = row_name; + + index = 0; + + for (var prop in row) { + if (Object.prototype.hasOwnProperty.call(row, prop)) { + var cell = table_row.insertCell(index); + cell.className = prop; + cell.innerHTML = row[prop]; + + index += 1; + } + } + + return table_row; +} + + +async function request(method, path, body = null) { + var headers = { + "Accept": "application/json" + } + + if (body !== null) { + headers["Content-Type"] = "application/json" + body = JSON.stringify(body) + } + + const response = await fetch("/api/" + path, { + method: method, + mode: "cors", + cache: "no-store", + redirect: "follow", + body: body, + headers: headers + }); + + const message = await response.json(); + + if (Object.hasOwn(message, "error")) { + throw new Error(message.error); + } + + if (Array.isArray(message)) { + message.forEach((msg) => { + if (Object.hasOwn(msg, "created")) { + msg.created = new Date(msg.created); + } + }); + + } else { + if (Object.hasOwn(message, "created")) { + message.created = new Date(message.created); + } + } + + return message; +} diff --git a/relay/frontend/static/config.js b/relay/frontend/static/config.js new file mode 100644 index 0000000..417c48a --- /dev/null +++ b/relay/frontend/static/config.js @@ -0,0 +1,40 @@ +const elems = [ + document.querySelector("#name"), + document.querySelector("#note"), + document.querySelector("#theme"), + document.querySelector("#log-level"), + document.querySelector("#whitelist-enabled"), + document.querySelector("#approval-required") +] + + +async function handle_config_change(event) { + params = { + key: event.target.id, + value: event.target.type === "checkbox" ? event.target.checked : event.target.value + } + + try { + await request("POST", "v1/config", params); + + } catch (error) { + toast(error); + return; + } + + if (params.key === "name") { + document.querySelector("#header .title").innerHTML = params.value; + document.querySelector("title").innerHTML = params.value; + } + + if (params.key === "theme") { + document.querySelector("link.theme").href = `/theme/${params.value}.css`; + } + + toast("Updated config", "message"); +} + + +for (const elem of elems) { + elem.addEventListener("change", handle_config_change); +} diff --git a/relay/frontend/static/domain_ban.js b/relay/frontend/static/domain_ban.js new file mode 100644 index 0000000..4de2ebf --- /dev/null +++ b/relay/frontend/static/domain_ban.js @@ -0,0 +1,123 @@ +function create_ban_object(domain, reason, note) { + var text = '
\n'; + text += `${domain}\n`; + text += '
\n'; + text += `\n`; + text += `\n`; + text += `\n`; + text += `\n`; + text += ``; + text += '
'; + + return text; +} + + +function add_row_listeners(row) { + row.querySelector(".update-ban").addEventListener("click", async (event) => { + await update_ban(row.id); + }); + + row.querySelector(".remove a").addEventListener("click", async (event) => { + event.preventDefault(); + await unban(row.id); + }); +} + + +async function ban() { + var table = document.querySelector("table"); + var elems = { + domain: document.getElementById("new-domain"), + reason: document.getElementById("new-reason"), + note: document.getElementById("new-note") + } + + var values = { + domain: elems.domain.value.trim(), + reason: elems.reason.value.trim(), + note: elems.note.value.trim() + } + + if (values.domain === "") { + toast("Domain is required"); + return; + } + + try { + var ban = await request("POST", "v1/domain_ban", values); + + } catch (err) { + toast(err); + return + } + + var row = append_table_row(document.querySelector("table"), ban.domain, { + domain: create_ban_object(ban.domain, ban.reason, ban.note), + date: get_date_string(ban.created), + remove: `
` + }); + + add_row_listeners(row); + + elems.domain.value = null; + elems.reason.value = null; + elems.note.value = null; + + document.querySelector("details.section").open = false; + toast("Banned domain", "message"); +} + + +async function update_ban(domain) { + var row = document.getElementById(domain); + + var elems = { + "reason": row.querySelector("textarea.reason"), + "note": row.querySelector("textarea.note") + } + + var values = { + "domain": domain, + "reason": elems.reason.value, + "note": elems.note.value + } + + try { + await request("PATCH", "v1/domain_ban", values) + + } catch (error) { + toast(error); + return; + } + + row.querySelector("details").open = false; + toast("Updated baned domain", "message"); +} + + +async function unban(domain) { + try { + await request("DELETE", "v1/domain_ban", {"domain": domain}); + + } catch (error) { + toast(error); + return; + } + + document.getElementById(domain).remove(); + toast("Unbanned domain", "message"); +} + + +document.querySelector("#new-ban").addEventListener("click", async (event) => { + await ban(); +}); + +for (var row of document.querySelector("fieldset.section table").rows) { + if (!row.querySelector(".update-ban")) { + continue; + } + + add_row_listeners(row); +} diff --git a/relay/frontend/static/instance.js b/relay/frontend/static/instance.js new file mode 100644 index 0000000..a07b647 --- /dev/null +++ b/relay/frontend/static/instance.js @@ -0,0 +1,145 @@ +function add_instance_listeners(row) { + row.querySelector(".remove a").addEventListener("click", async (event) => { + event.preventDefault(); + await del_instance(row.id); + }); +} + + +function add_request_listeners(row) { + row.querySelector(".approve a").addEventListener("click", async (event) => { + event.preventDefault(); + await req_response(row.id, true); + }); + + row.querySelector(".deny a").addEventListener("click", async (event) => { + event.preventDefault(); + await req_response(row.id, false); + }); +} + + +async function add_instance() { + var elems = { + actor: document.getElementById("new-actor"), + inbox: document.getElementById("new-inbox"), + followid: document.getElementById("new-followid"), + software: document.getElementById("new-software") + } + + var values = { + actor: elems.actor.value.trim(), + inbox: elems.inbox.value.trim(), + followid: elems.followid.value.trim(), + software: elems.software.value.trim() + } + + if (values.actor === "") { + toast("Actor is required"); + return; + } + + try { + var instance = await request("POST", "v1/instance", values); + + } catch (err) { + toast(err); + return + } + + row = append_table_row(document.getElementById("instances"), instance.domain, { + domain: `${instance.domain}`, + software: instance.software, + date: get_date_string(instance.created), + remove: `` + }); + + add_instance_listeners(row); + + elems.actor.value = null; + elems.inbox.value = null; + elems.followid.value = null; + elems.software.value = null; + + document.querySelector("details.section").open = false; + toast("Added instance", "message"); +} + + +async function del_instance(domain) { + try { + await request("DELETE", "v1/instance", {"domain": domain}); + + } catch (error) { + toast(error); + return; + } + + document.getElementById(domain).remove(); +} + + +async function req_response(domain, accept) { + params = { + "domain": domain, + "accept": accept + } + + try { + await request("POST", "v1/request", params); + + } catch (error) { + toast(error); + return; + } + + document.getElementById(domain).remove(); + + if (document.getElementById("requests").rows.length < 2) { + document.querySelector("fieldset.requests").remove() + } + + if (!accept) { + toast("Denied instance request", "message"); + return; + } + + instances = await request("GET", `v1/instance`, null); + instances.forEach((instance) => { + if (instance.domain === domain) { + row = append_table_row(document.getElementById("instances"), instance.domain, { + domain: `${instance.domain}`, + software: instance.software, + date: get_date_string(instance.created), + remove: `` + }); + + add_instance_listeners(row); + } + }); + + toast("Accepted instance request", "message"); +} + + +document.querySelector("#add-instance").addEventListener("click", async (event) => { + await add_instance(); +}) + +for (var row of document.querySelector("#instances").rows) { + if (!row.querySelector(".remove a")) { + continue; + } + + add_instance_listeners(row); +} + +if (document.querySelector("#requests")) { + for (var row of document.querySelector("#requests").rows) { + if (!row.querySelector(".approve a")) { + continue; + } + + add_request_listeners(row); + } +} diff --git a/relay/frontend/static/login.js b/relay/frontend/static/login.js new file mode 100644 index 0000000..9c68f17 --- /dev/null +++ b/relay/frontend/static/login.js @@ -0,0 +1,29 @@ +async function login(event) { + fields = { + username: document.querySelector("#username"), + password: document.querySelector("#password") + } + + values = { + username: fields.username.value.trim(), + password: fields.password.value.trim() + } + + if (values.username === "" | values.password === "") { + toast("Username and/or password field is blank"); + return; + } + + try { + await request("POST", "v1/token", values); + + } catch (error) { + toast(error); + return; + } + + document.location = "/"; +} + + +document.querySelector(".submit").addEventListener("click", login); diff --git a/relay/frontend/static/software_ban.js b/relay/frontend/static/software_ban.js new file mode 100644 index 0000000..663929a --- /dev/null +++ b/relay/frontend/static/software_ban.js @@ -0,0 +1,122 @@ +function create_ban_object(name, reason, note) { + var text = '
\n'; + text += `${name}\n`; + text += '
\n'; + text += `\n`; + text += `\n`; + text += `\n`; + text += `\n`; + text += ``; + text += '
'; + + return text; +} + + +function add_row_listeners(row) { + row.querySelector(".update-ban").addEventListener("click", async (event) => { + await update_ban(row.id); + }); + + row.querySelector(".remove a").addEventListener("click", async (event) => { + event.preventDefault(); + await unban(row.id); + }); +} + + +async function ban() { + var elems = { + name: document.getElementById("new-name"), + reason: document.getElementById("new-reason"), + note: document.getElementById("new-note") + } + + var values = { + name: elems.name.value.trim(), + reason: elems.reason.value, + note: elems.note.value + } + + if (values.name === "") { + toast("Domain is required"); + return; + } + + try { + var ban = await request("POST", "v1/software_ban", values); + + } catch (err) { + toast(err); + return + } + + var row = append_table_row(document.getElementById("bans"), ban.name, { + name: create_ban_object(ban.name, ban.reason, ban.note), + date: get_date_string(ban.created), + remove: `` + }); + + add_row_listeners(row); + + elems.name.value = null; + elems.reason.value = null; + elems.note.value = null; + + document.querySelector("details.section").open = false; + toast("Banned software", "message"); +} + + +async function update_ban(name) { + var row = document.getElementById(name); + + var elems = { + "reason": row.querySelector("textarea.reason"), + "note": row.querySelector("textarea.note") + } + + var values = { + "name": name, + "reason": elems.reason.value, + "note": elems.note.value + } + + try { + await request("PATCH", "v1/software_ban", values) + + } catch (error) { + toast(error); + return; + } + + row.querySelector("details").open = false; + toast("Updated software ban", "message"); +} + + +async function unban(name) { + try { + await request("DELETE", "v1/software_ban", {"name": name}); + + } catch (error) { + toast(error); + return; + } + + document.getElementById(name).remove(); + toast("Unbanned software", "message"); +} + + +document.querySelector("#new-ban").addEventListener("click", async (event) => { + await ban(); +}); + +for (var row of document.querySelector("#bans").rows) { + if (!row.querySelector(".update-ban")) { + continue; + } + + add_row_listeners(row); +} diff --git a/relay/frontend/style.css b/relay/frontend/static/style.css similarity index 74% rename from relay/frontend/style.css rename to relay/frontend/static/style.css index f2a6fe1..635aa55 100644 --- a/relay/frontend/style.css +++ b/relay/frontend/static/style.css @@ -23,11 +23,29 @@ details summary { cursor: pointer; } +fieldset { + margin-left: 0px; + margin-right: 0px; +} + +fieldset > *:nth-child(2) { + margin-top: 0px !important; +} + form input[type="submit"] { display: block; margin: 0 auto; } +legend { + background-color: var(--table-background); + padding: 5px; + border: 1px solid var(--border); + border-radius: 5px; + font-size: 10pt; + font-weight: bold; +} + p { line-height: 1em; margin: 0px; @@ -91,6 +109,17 @@ textarea { margin: 0px auto; } +#content .title { + font-size: 24px; + text-align: center; + font-weight: bold; + margin-bottom: 10px; +} + +#content .title:not(:first-child) { + margin-top: 10px; +} + #header { display: grid; grid-template-columns: 50px auto 50px; @@ -175,6 +204,37 @@ textarea { text-align: center; } +#notifications { + position: fixed; + top: 40px; + left: 50%; + transform: translateX(-50%); +} + +#notifications li { + position: relative; + overflow: hidden; + list-style: none; + border-radius: 5px; + padding: 5px;; + margin-bottom: var(--spacing); + animation: show_toast 0.3s ease forwards; + display: grid; + grid-template-columns: auto max-content; + grid-gap: 5px; + align-items: center; +} + +#notifications a { + font-size: 1.5em; + line-height: 1em; + text-decoration: none; +} + +#notifications li.hide { + animation: hide_toast 0.3s ease forwards; +} + #footer { display: grid; grid-template-columns: auto auto; @@ -193,15 +253,6 @@ textarea { align-items: center; } -#data-table td:first-child { - width: 100%; -} - -#data-table .date { - width: max-content; - text-align: right; -} - .button { background-color: var(--primary); border: 1px solid var(--primary); @@ -220,6 +271,15 @@ textarea { grid-template-columns: max-content auto; } +.data-table td:first-child { + width: 100%; +} + +.data-table .date { + width: max-content; + text-align: right; +} + .error, .message { text-align: center; } @@ -267,6 +327,44 @@ textarea { } +@keyframes show_toast { + 0% { + transform: translateX(100%); + } + + 40% { + transform: translateX(-5%); + } + + 80% { + transform: translateX(0%); + } + + 100% { + transform: translateX(-10px); + } +} + + +@keyframes hide_toast { + 0% { + transform: translateX(-10px); + } + + 40% { + transform: translateX(0%); + } + + 80% { + transform: translateX(-5%); + } + + 100% { + transform: translateX(calc(100% + 20px)); + } +} + + @media (max-width: 1026px) { body { margin: 0px; diff --git a/relay/frontend/static/user.js b/relay/frontend/static/user.js new file mode 100644 index 0000000..9c74359 --- /dev/null +++ b/relay/frontend/static/user.js @@ -0,0 +1,85 @@ +function add_row_listeners(row) { + row.querySelector(".remove a").addEventListener("click", async (event) => { + event.preventDefault(); + await del_user(row.id); + }); +} + + +async function add_user() { + var elems = { + username: document.getElementById("new-username"), + password: document.getElementById("new-password"), + password2: document.getElementById("new-password2"), + handle: document.getElementById("new-handle") + } + + var values = { + username: elems.username.value.trim(), + password: elems.password.value.trim(), + password2: elems.password2.value.trim(), + handle: elems.handle.value.trim() + } + + if (values.username === "" | values.password === "" | values.password2 === "") { + toast("Username, password, and password2 are required"); + return; + } + + if (values.password !== values.password2) { + toast("Passwords do not match"); + return; + } + + try { + var user = await request("POST", "v1/user", values); + + } catch (err) { + toast(err); + return + } + + var row = append_table_row(document.querySelector("fieldset.section table"), user.username, { + domain: user.username, + handle: user.handle ? self.handle : "n/a", + date: get_date_string(user.created), + remove: `` + }); + + add_row_listeners(row); + + elems.username.value = null; + elems.password.value = null; + elems.password2.value = null; + elems.handle.value = null; + + document.querySelector("details.section").open = false; + toast("Created user", "message"); +} + + +async function del_user(username) { + try { + await request("DELETE", "v1/user", {"username": username}); + + } catch (error) { + toast(error); + return; + } + + document.getElementById(username).remove(); + toast("Deleted user", "message"); +} + + +document.querySelector("#new-user").addEventListener("click", async (event) => { + await add_user(); +}); + +for (var row of document.querySelector("#users").rows) { + if (!row.querySelector(".remove a")) { + continue; + } + + add_row_listeners(row); +} diff --git a/relay/frontend/static/whitelist.js b/relay/frontend/static/whitelist.js new file mode 100644 index 0000000..70d4db1 --- /dev/null +++ b/relay/frontend/static/whitelist.js @@ -0,0 +1,64 @@ +function add_row_listeners(row) { + row.querySelector(".remove a").addEventListener("click", async (event) => { + event.preventDefault(); + await del_whitelist(row.id); + }); +} + + +async function add_whitelist() { + var domain_elem = document.getElementById("new-domain"); + var domain = domain_elem.value.trim(); + + if (domain === "") { + toast("Domain is required"); + return; + } + + try { + var item = await request("POST", "v1/whitelist", {"domain": domain}); + + } catch (err) { + toast(err); + return; + } + + var row = append_table_row(document.getElementById("whitelist"), item.domain, { + domain: item.domain, + date: get_date_string(item.created), + remove: `` + }); + + add_row_listeners(row); + + domain_elem.value = null; + document.querySelector("details.section").open = false; + toast("Added domain", "message"); +} + + +async function del_whitelist(domain) { + try { + await request("DELETE", "v1/whitelist", {"domain": domain}); + + } catch (error) { + toast(error); + return; + } + + document.getElementById(domain).remove(); + toast("Removed domain", "message"); +} + + +document.querySelector("#new-item").addEventListener("click", async (event) => { + await add_whitelist(); +}); + +for (var row of document.querySelector("fieldset.section table").rows) { + if (!row.querySelector(".remove a")) { + continue; + } + + add_row_listeners(row); +} diff --git a/relay/http_client.py b/relay/http_client.py index 7e7bbd9..04533c5 100644 --- a/relay/http_client.py +++ b/relay/http_client.py @@ -7,7 +7,7 @@ import typing from aiohttp import ClientSession, ClientTimeout, TCPConnector from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError from asyncio.exceptions import TimeoutError as AsyncTimeoutError -from aputils.objects import Nodeinfo, WellKnownNodeinfo +from aputils import AlgorithmType, JsonBase, Nodeinfo, ObjectType, WellKnownNodeinfo from json.decoder import JSONDecodeError from urllib.parse import urlparse @@ -17,12 +17,13 @@ from .misc import MIMETYPES, Message, get_app if typing.TYPE_CHECKING: from aputils import Signer - from tinysql import Row + from bsql import Row from typing import Any from .application import Application from .cache import Cache +T = typing.TypeVar('T', bound = JsonBase) HEADERS = { 'Accept': f'{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9', 'User-Agent': f'ActivityRelay/{__version__}' @@ -33,12 +34,12 @@ class HttpClient: def __init__(self, limit: int = 100, timeout: int = 10): self.limit = limit self.timeout = timeout - self._conn = None - self._session = None + self._conn: TCPConnector | None = None + self._session: ClientSession | None = None async def __aenter__(self) -> HttpClient: - await self.open() + self.open() return self @@ -61,7 +62,7 @@ class HttpClient: return self.app.signer - async def open(self) -> None: + def open(self) -> None: if self._session: return @@ -79,23 +80,19 @@ class HttpClient: async def close(self) -> None: - if not self._session: - return + if self._session: + await self._session.close() - await self._session.close() - await self._conn.close() + if self._conn: + await self._conn.close() self._conn = None self._session = None - async def get(self, # pylint: disable=too-many-branches - url: str, - sign_headers: bool = False, - loads: callable = json.loads, - force: bool = False) -> dict | None: - - await self.open() + async def _get(self, url: str, sign_headers: bool, force: bool) -> dict[str, Any] | None: + if not self._session: + raise RuntimeError('Client not open') try: url, _ = url.split('#', 1) @@ -105,10 +102,8 @@ class HttpClient: if not force: try: - item = self.cache.get('request', url) - - if not item.older_than(48): - return loads(item.value) + if not (item := self.cache.get('request', url)).older_than(48): + return json.loads(item.value) except KeyError: logging.verbose('No cached data for url: %s', url) @@ -116,38 +111,39 @@ class HttpClient: headers = {} if sign_headers: - self.signer.sign_headers('GET', url, algorithm = 'original') + headers = self.signer.sign_headers('GET', url, algorithm = AlgorithmType.HS2019) try: logging.debug('Fetching resource: %s', url) - async with self._session.get(url, headers=headers) as resp: - ## Not expecting a response with 202s, so just return + async with self._session.get(url, headers = headers) as resp: + # Not expecting a response with 202s, so just return if resp.status == 202: return None - data = await resp.read() + data = await resp.text() if resp.status != 200: logging.verbose('Received error when requesting %s: %i', url, resp.status) - logging.debug(await resp.read()) + logging.debug(data) return None - message = loads(data) - self.cache.set('request', url, data.decode('utf-8'), 'str') - logging.debug('%s >> resp %s', url, json.dumps(message, indent = 4)) + self.cache.set('request', url, data, 'str') + logging.debug('%s >> resp %s', url, json.dumps(json.loads(data), indent = 4)) - return message + return json.loads(data) except JSONDecodeError: logging.verbose('Failed to parse JSON') return None - except ClientSSLError: + except ClientSSLError as e: logging.verbose('SSL error when connecting to %s', urlparse(url).netloc) + logging.warning(str(e)) - except (AsyncTimeoutError, ClientConnectionError): + except (AsyncTimeoutError, ClientConnectionError) as e: logging.verbose('Failed to connect to %s', urlparse(url).netloc) + logging.warning(str(e)) except Exception: traceback.print_exc() @@ -155,39 +151,74 @@ class HttpClient: return None - async def post(self, url: str, message: Message, instance: Row | None = None) -> None: - await self.open() + async def get(self, + url: str, + sign_headers: bool, + cls: type[T], + force: bool = False) -> T | None: - ## Using the old algo by default is probably a better idea right now - # pylint: disable=consider-ternary-expression + if not issubclass(cls, JsonBase): + raise TypeError('cls must be a sub-class of "aputils.JsonBase"') + + if (data := (await self._get(url, sign_headers, force))) is None: + return None + + return cls.parse(data) + + + async def post(self, url: str, data: Message | bytes, instance: Row | None = None) -> None: + if not self._session: + raise RuntimeError('Client not open') + + # akkoma and pleroma do not support HS2019 and other software still needs to be tested if instance and instance['software'] in {'mastodon'}: - algorithm = 'hs2019' + algorithm = AlgorithmType.HS2019 else: - algorithm = 'original' - # pylint: enable=consider-ternary-expression + algorithm = AlgorithmType.RSASHA256 - headers = {'Content-Type': 'application/activity+json'} - headers.update(get_app().signer.sign_headers('POST', url, message, algorithm=algorithm)) + body: bytes + message: Message + + if isinstance(data, bytes): + body = data + message = Message.parse(data) + + else: + body = data.to_json().encode("utf-8") + message = data + + mtype = message.type.value if isinstance(message.type, ObjectType) else message.type + headers = self.signer.sign_headers( + 'POST', + url, + body, + headers = {'Content-Type': 'application/activity+json'}, + algorithm = algorithm + ) try: - logging.verbose('Sending "%s" to %s', message.type, url) + logging.verbose('Sending "%s" to %s', mtype, url) - async with self._session.post(url, headers=headers, data=message.to_json()) as resp: + async with self._session.post(url, headers = headers, data = body) as resp: # Not expecting a response, so just return if resp.status in {200, 202}: - logging.verbose('Successfully sent "%s" to %s', message.type, url) + logging.verbose('Successfully sent "%s" to %s', mtype, url) return logging.verbose('Received error when pushing to %s: %i', url, resp.status) logging.debug(await resp.read()) + logging.debug("message: %s", body.decode("utf-8")) + logging.debug("headers: %s", json.dumps(headers, indent = 4)) return - except ClientSSLError: + except ClientSSLError as e: logging.warning('SSL error when pushing to %s', urlparse(url).netloc) + logging.warning(str(e)) - except (AsyncTimeoutError, ClientConnectionError): + except (AsyncTimeoutError, ClientConnectionError) as e: logging.warning('Failed to connect to %s for message push', urlparse(url).netloc) + logging.warning(str(e)) # prevent workers from being brought down except Exception: @@ -198,10 +229,11 @@ class HttpClient: nodeinfo_url = None wk_nodeinfo = await self.get( f'https://{domain}/.well-known/nodeinfo', - loads = WellKnownNodeinfo.parse + False, + WellKnownNodeinfo ) - if not wk_nodeinfo: + if wk_nodeinfo is None: logging.verbose('Failed to fetch well-known nodeinfo url for %s', domain) return None @@ -212,14 +244,14 @@ class HttpClient: except KeyError: pass - if not nodeinfo_url: + if nodeinfo_url is None: logging.verbose('Failed to fetch nodeinfo url for %s', domain) return None - return await self.get(nodeinfo_url, loads = Nodeinfo.parse) or None + return await self.get(nodeinfo_url, False, Nodeinfo) -async def get(*args: Any, **kwargs: Any) -> Message | dict | None: +async def get(*args: Any, **kwargs: Any) -> Any: async with HttpClient() as client: return await client.get(*args, **kwargs) diff --git a/relay/logger.py b/relay/logger.py index 8aff62d..916fa71 100644 --- a/relay/logger.py +++ b/relay/logger.py @@ -11,6 +11,12 @@ if typing.TYPE_CHECKING: from collections.abc import Callable from typing import Any + try: + from typing import Self + + except ImportError: + from typing_extensions import Self + class LogLevel(IntEnum): DEBUG = logging.DEBUG @@ -26,7 +32,13 @@ class LogLevel(IntEnum): @classmethod - def parse(cls: type[IntEnum], data: object) -> IntEnum: + def parse(cls: type[Self], data: Any) -> Self: + try: + data = int(data) + + except ValueError: + pass + if isinstance(data, cls): return data @@ -57,10 +69,10 @@ def set_level(level: LogLevel | str) -> None: def verbose(message: str, *args: Any, **kwargs: Any) -> None: - if not logging.root.isEnabledFor(LogLevel['VERBOSE']): + if not logging.root.isEnabledFor(LogLevel.VERBOSE): return - logging.log(LogLevel['VERBOSE'], message, *args, **kwargs) + logging.log(LogLevel.VERBOSE, message, *args, **kwargs) debug: Callable = logging.debug @@ -70,23 +82,27 @@ error: Callable = logging.error critical: Callable = logging.critical -env_log_level = os.environ.get('LOG_LEVEL', 'INFO').upper() - try: - env_log_file = Path(os.environ['LOG_FILE']).expanduser().resolve() + env_log_file: Path | None = Path(os.environ['LOG_FILE']).expanduser().resolve() except KeyError: env_log_file = None -handlers = [logging.StreamHandler()] +handlers: list[Any] = [logging.StreamHandler()] if env_log_file: handlers.append(logging.FileHandler(env_log_file)) -logging.addLevelName(LogLevel['VERBOSE'], 'VERBOSE') +if os.environ.get('IS_SYSTEMD'): + logging_format = '%(levelname)s: %(message)s' + +else: + logging_format = '[%(asctime)s] %(levelname)s: %(message)s' + +logging.addLevelName(LogLevel.VERBOSE, 'VERBOSE') logging.basicConfig( level = LogLevel.INFO, - format = '[%(asctime)s] %(levelname)s: %(message)s', + format = logging_format, datefmt = '%Y-%m-%d %H:%M:%S', handlers = handlers ) diff --git a/relay/manage.py b/relay/manage.py index 796ec0b..d768284 100644 --- a/relay/manage.py +++ b/relay/manage.py @@ -21,19 +21,10 @@ from .database import RELAY_SOFTWARE, get_database from .misc import ACTOR_FORMATS, SOFTWARE, IS_DOCKER, Message if typing.TYPE_CHECKING: - from tinysql import Row + from bsql import Row from typing import Any -# pylint: disable=unsubscriptable-object,unsupported-assignment-operation - - -CONFIG_IGNORE = ( - 'schema-version', - 'private-key' -) - - def check_alphanumeric(text: str) -> str: if not text.isalnum(): raise click.BadParameter('String not alphanumeric') @@ -50,7 +41,7 @@ def cli(ctx: click.Context, config: str | None) -> None: if not ctx.invoked_subcommand: if ctx.obj.config.domain.endswith('example.com'): - cli_setup.callback() + cli_setup.callback() # type: ignore else: click.echo( @@ -58,7 +49,7 @@ def cli(ctx: click.Context, config: str | None) -> None: 'future.' ) - cli_run.callback() + cli_run.callback() # type: ignore @cli.command('setup') @@ -184,7 +175,7 @@ def cli_setup(ctx: click.Context) -> None: conn.put_config(key, value) if not IS_DOCKER and click.confirm('Relay all setup! Would you like to run it now?'): - cli_run.callback() + cli_run.callback() # type: ignore @cli.command('run') @@ -257,7 +248,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None: conn.put_config('note', config['note']) conn.put_config('whitelist-enabled', config['whitelist_enabled']) - with click.progressbar( + with click.progressbar( # type: ignore database['relay-list'].values(), label = 'Inboxes'.ljust(15), width = 0 @@ -281,7 +272,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None: software = inbox['software'] ) - with click.progressbar( + with click.progressbar( # type: ignore config['blocked_software'], label = 'Banned software'.ljust(15), width = 0 @@ -293,7 +284,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None: reason = 'relay' if software in RELAY_SOFTWARE else None ) - with click.progressbar( + with click.progressbar( # type: ignore config['blocked_instances'], label = 'Banned domains'.ljust(15), width = 0 @@ -302,7 +293,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None: for domain in banned_software: conn.put_domain_ban(domain) - with click.progressbar( + with click.progressbar( # type: ignore config['whitelist'], label = 'Whitelist'.ljust(15), width = 0 @@ -339,10 +330,17 @@ def cli_config_list(ctx: click.Context) -> None: click.echo('Relay Config:') with ctx.obj.database.session() as conn: - for key, value in conn.get_config_all().items(): - if key not in CONFIG_IGNORE: - key = f'{key}:'.ljust(20) - click.echo(f'- {key} {value}') + config = conn.get_config_all() + + for key, value in config.to_dict().items(): + if key in type(config).SYSTEM_KEYS(): + continue + + if key == 'log-level': + value = value.name + + key_str = f'{key}:'.ljust(20) + click.echo(f'- {key_str} {repr(value)}') @cli_config.command('set') @@ -477,7 +475,7 @@ def cli_inbox_list(ctx: click.Context) -> None: click.echo('Connected to the following instances or relays:') with ctx.obj.database.session() as conn: - for inbox in conn.execute('SELECT * FROM inboxes'): + for inbox in conn.get_inboxes(): click.echo(f'- {inbox["inbox"]}') @@ -520,7 +518,7 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None: def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None: 'Unfollow an actor (Relay must be running)' - inbox_data: Row = None + inbox_data: Row | None = None with ctx.obj.database.session() as conn: if conn.get_domain_ban(actor): @@ -540,6 +538,11 @@ def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None: actor = f'https://{actor}/actor' actor_data = asyncio.run(http.get(actor, sign_headers = True)) + + if not actor_data: + click.echo("Failed to fetch actor") + return + inbox = actor_data.shared_inbox message = Message.new_unfollow( host = ctx.obj.config.domain, @@ -618,6 +621,80 @@ def cli_inbox_remove(ctx: click.Context, inbox: str) -> None: click.echo(f'Removed inbox from the database: {inbox}') +@cli.group('request') +def cli_request() -> None: + 'Manage follow requests' + + +@cli_request.command('list') +@click.pass_context +def cli_request_list(ctx: click.Context) -> None: + 'List all current follow requests' + + click.echo('Follow requests:') + + with ctx.obj.database.session() as conn: + for instance in conn.get_requests(): + date = instance['created'].strftime('%Y-%m-%d') + click.echo(f'- [{date}] {instance["domain"]}') + + +@cli_request.command('accept') +@click.argument('domain') +@click.pass_context +def cli_request_accept(ctx: click.Context, domain: str) -> None: + 'Accept a follow request' + + try: + with ctx.obj.database.session() as conn: + instance = conn.put_request_response(domain, True) + + except KeyError: + click.echo('Request not found') + return + + message = Message.new_response( + host = ctx.obj.config.domain, + actor = instance['actor'], + followid = instance['followid'], + accept = True + ) + + asyncio.run(http.post(instance['inbox'], message, instance)) + + if instance['software'] != 'mastodon': + message = Message.new_follow( + host = ctx.obj.config.domain, + actor = instance['actor'] + ) + + asyncio.run(http.post(instance['inbox'], message, instance)) + + +@cli_request.command('deny') +@click.argument('domain') +@click.pass_context +def cli_request_deny(ctx: click.Context, domain: str) -> None: + 'Accept a follow request' + + try: + with ctx.obj.database.session() as conn: + instance = conn.put_request_response(domain, False) + + except KeyError: + click.echo('Request not found') + return + + response = Message.new_response( + host = ctx.obj.config.domain, + actor = instance['actor'], + followid = instance['followid'], + accept = False + ) + + asyncio.run(http.post(instance['inbox'], response, instance)) + + @cli.group('instance') def cli_instance() -> None: 'Manage instance bans' @@ -893,7 +970,6 @@ def cli_whitelist_import(ctx: click.Context) -> None: def main() -> None: - # pylint: disable=no-value-for-parameter cli(prog_name='relay') diff --git a/relay/misc.py b/relay/misc.py index 33e7a06..82b1fd2 100644 --- a/relay/misc.py +++ b/relay/misc.py @@ -8,27 +8,44 @@ import typing from aiohttp.web import Response as AiohttpResponse from datetime import datetime +from pathlib import Path from uuid import uuid4 try: from importlib.resources import files as pkgfiles except ImportError: - from importlib_resources import files as pkgfiles + from importlib_resources import files as pkgfiles # type: ignore if typing.TYPE_CHECKING: - from pathlib import Path from typing import Any from .application import Application + try: + from typing import Self + + except ImportError: + from typing_extensions import Self + + +T = typing.TypeVar('T') +ResponseType = typing.TypedDict('ResponseType', { + 'status': int, + 'headers': dict[str, typing.Any] | None, + 'content_type': str, + 'body': bytes | None, + 'text': str | None +}) IS_DOCKER = bool(os.environ.get('DOCKER_RUNNING')) + MIMETYPES = { 'activity': 'application/activity+json', 'css': 'text/css', 'html': 'text/html', 'json': 'application/json', - 'text': 'text/plain' + 'text': 'text/plain', + 'webmanifest': 'application/manifest+json' } NODEINFO_NS = { @@ -92,7 +109,7 @@ def check_open_port(host: str, port: int) -> bool: def get_app() -> Application: - from .application import Application # pylint: disable=import-outside-toplevel + from .application import Application if not Application.DEFAULT: raise ValueError('No default application set') @@ -101,7 +118,7 @@ def get_app() -> Application: def get_resource(path: str) -> Path: - return pkgfiles('relay').joinpath(path) + return Path(str(pkgfiles('relay'))).joinpath(path) class JsonEncoder(json.JSONEncoder): @@ -114,18 +131,18 @@ class JsonEncoder(json.JSONEncoder): class Message(aputils.Message): @classmethod - def new_actor(cls: type[Message], # pylint: disable=arguments-differ + def new_actor(cls: type[Self], # type: ignore host: str, pubkey: str, - description: str | None = None) -> Message: + description: str | None = None, + approves: bool = False) -> Self: - return cls({ - '@context': 'https://www.w3.org/ns/activitystreams', + return cls.new(aputils.ObjectType.APPLICATION, { 'id': f'https://{host}/actor', - 'type': 'Application', 'preferredUsername': 'relay', 'name': 'ActivityRelay', 'summary': description or 'ActivityRelay bot', + 'manuallyApprovesFollowers': approves, 'followers': f'https://{host}/followers', 'following': f'https://{host}/following', 'inbox': f'https://{host}/inbox', @@ -142,11 +159,9 @@ class Message(aputils.Message): @classmethod - def new_announce(cls: type[Message], host: str, obj: str) -> Message: - return cls({ - '@context': 'https://www.w3.org/ns/activitystreams', + def new_announce(cls: type[Self], host: str, obj: str | dict[str, Any]) -> Self: + return cls.new(aputils.ObjectType.ANNOUNCE, { 'id': f'https://{host}/activities/{uuid4()}', - 'type': 'Announce', 'to': [f'https://{host}/followers'], 'actor': f'https://{host}/actor', 'object': obj @@ -154,23 +169,19 @@ class Message(aputils.Message): @classmethod - def new_follow(cls: type[Message], host: str, actor: str) -> Message: - return cls({ - '@context': 'https://www.w3.org/ns/activitystreams', - 'type': 'Follow', + def new_follow(cls: type[Self], host: str, actor: str) -> Self: + return cls.new(aputils.ObjectType.FOLLOW, { + 'id': f'https://{host}/activities/{uuid4()}', 'to': [actor], 'object': actor, - 'id': f'https://{host}/activities/{uuid4()}', 'actor': f'https://{host}/actor' }) @classmethod - def new_unfollow(cls: type[Message], host: str, actor: str, follow: str) -> Message: - return cls({ - '@context': 'https://www.w3.org/ns/activitystreams', + def new_unfollow(cls: type[Self], host: str, actor: str, follow: dict[str, str]) -> Self: + return cls.new(aputils.ObjectType.UNDO, { 'id': f'https://{host}/activities/{uuid4()}', - 'type': 'Undo', 'to': [actor], 'actor': f'https://{host}/actor', 'object': follow @@ -178,16 +189,9 @@ class Message(aputils.Message): @classmethod - def new_response(cls: type[Message], - host: str, - actor: str, - followid: str, - accept: bool) -> Message: - - return cls({ - '@context': 'https://www.w3.org/ns/activitystreams', + def new_response(cls: type[Self], host: str, actor: str, followid: str, accept: bool) -> Self: + return cls.new(aputils.ObjectType.ACCEPT if accept else aputils.ObjectType.REJECT, { 'id': f'https://{host}/activities/{uuid4()}', - 'type': 'Accept' if accept else 'Reject', 'to': [actor], 'actor': f'https://{host}/actor', 'object': { @@ -206,16 +210,18 @@ class Response(AiohttpResponse): @classmethod - def new(cls: type[Response], - body: str | bytes | dict = '', + def new(cls: type[Self], + body: str | bytes | dict | tuple | list | set = '', status: int = 200, headers: dict[str, str] | None = None, - ctype: str = 'text') -> Response: + ctype: str = 'text') -> Self: - kwargs = { + kwargs: ResponseType = { 'status': status, 'headers': headers, - 'content_type': MIMETYPES[ctype] + 'content_type': MIMETYPES[ctype], + 'body': None, + 'text': None } if isinstance(body, bytes): @@ -231,10 +237,10 @@ class Response(AiohttpResponse): @classmethod - def new_error(cls: type[Response], + def new_error(cls: type[Self], status: int, body: str | bytes | dict, - ctype: str = 'text') -> Response: + ctype: str = 'text') -> Self: if ctype == 'json': body = {'error': body} @@ -243,14 +249,14 @@ class Response(AiohttpResponse): @classmethod - def new_redir(cls: type[Response], path: str) -> Response: + def new_redir(cls: type[Self], path: str) -> Self: body = f'Redirect to {path}' return cls.new(body, 302, {'Location': path}) @property def location(self) -> str: - return self.headers.get('Location') + return self.headers.get('Location', '') @location.setter diff --git a/relay/processors.py b/relay/processors.py index 824a975..910ecf3 100644 --- a/relay/processors.py +++ b/relay/processors.py @@ -7,10 +7,10 @@ from .database import Connection from .misc import Message if typing.TYPE_CHECKING: - from .views import ActorView + from .views.activitypub import ActorView -def person_check(actor: str, software: str) -> bool: +def person_check(actor: Message, software: str | None) -> bool: # pleroma and akkoma may use Person for the actor type for some reason # akkoma changed this in 3.6.0 if software in {'akkoma', 'pleroma'} and actor.id == f'https://{actor.domain}/relay': @@ -35,8 +35,8 @@ async def handle_relay(view: ActorView, conn: Connection) -> None: message = Message.new_announce(view.config.domain, view.message.object_id) logging.debug('>> relay: %s', message) - for inbox in conn.distill_inboxes(view.message): - view.app.push_message(inbox, message, view.instance) + for instance in conn.distill_inboxes(view.message): + view.app.push_message(instance["inbox"], message, instance) view.cache.set('handle-relay', view.message.object_id, message.id, 'str') @@ -53,8 +53,8 @@ async def handle_forward(view: ActorView, conn: Connection) -> None: message = Message.new_announce(view.config.domain, view.message) logging.debug('>> forward: %s', message) - for inbox in conn.distill_inboxes(view.message): - view.app.push_message(inbox, message, view.instance) + for instance in conn.distill_inboxes(view.message): + view.app.push_message(instance["inbox"], await view.request.read(), instance) view.cache.set('handle-relay', view.message.id, message.id, 'str') @@ -62,9 +62,12 @@ async def handle_forward(view: ActorView, conn: Connection) -> None: async def handle_follow(view: ActorView, conn: Connection) -> None: nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain) software = nodeinfo.sw_name if nodeinfo else None + config = conn.get_config_all() # reject if software used by actor is banned - if conn.get_software_ban(software): + if software and conn.get_software_ban(software): + logging.verbose('Rejected banned actor: %s', view.actor.id) + view.app.push_message( view.actor.shared_inbox, Message.new_response( @@ -72,7 +75,8 @@ async def handle_follow(view: ActorView, conn: Connection) -> None: actor = view.actor.id, followid = view.message.id, accept = False - ) + ), + view.instance ) logging.verbose( @@ -83,8 +87,10 @@ async def handle_follow(view: ActorView, conn: Connection) -> None: return - ## reject if the actor is not an instance actor + # reject if the actor is not an instance actor if person_check(view.actor, software): + logging.verbose('Non-application actor tried to follow: %s', view.actor.id) + view.app.push_message( view.actor.shared_inbox, Message.new_response( @@ -92,23 +98,54 @@ async def handle_follow(view: ActorView, conn: Connection) -> None: actor = view.actor.id, followid = view.message.id, accept = False - ) + ), + view.instance ) - logging.verbose('Non-application actor tried to follow: %s', view.actor.id) return - with conn.transaction(): - if conn.get_inbox(view.actor.shared_inbox): - view.instance = conn.update_inbox(view.actor.shared_inbox, followid = view.message.id) + if not conn.get_domain_whitelist(view.actor.domain): + # add request if approval-required is enabled + if config.approval_required: + logging.verbose('New follow request fromm actor: %s', view.actor.id) - else: - view.instance = conn.put_inbox( - view.actor.domain, + with conn.transaction(): + view.instance = conn.put_inbox( + domain = view.actor.domain, + inbox = view.actor.shared_inbox, + actor = view.actor.id, + followid = view.message.id, + software = software, + accepted = False + ) + + return + + # reject if the actor isn't whitelisted while the whiltelist is enabled + if config.whitelist_enabled: + logging.verbose('Rejected actor for not being in the whitelist: %s', view.actor.id) + + view.app.push_message( view.actor.shared_inbox, - view.actor.id, - view.message.id, - software + Message.new_response( + host = view.config.domain, + actor = view.actor.id, + followid = view.message.id, + accept = False + ), + view.instance + ) + + return + + with conn.transaction(): + view.instance = conn.put_inbox( + domain = view.actor.domain, + inbox = view.actor.shared_inbox, + actor = view.actor.id, + followid = view.message.id, + software = software, + accepted = True ) view.app.push_message( @@ -136,7 +173,7 @@ async def handle_follow(view: ActorView, conn: Connection) -> None: async def handle_undo(view: ActorView, conn: Connection) -> None: - ## If the object is not a Follow, forward it + # If the object is not a Follow, forward it if view.message.object['type'] != 'Follow': await handle_forward(view, conn) return @@ -150,7 +187,7 @@ async def handle_undo(view: ActorView, conn: Connection) -> None: logging.verbose( 'Failed to delete "%s" with follow ID "%s"', view.actor.id, - view.message.object['id'] + view.message.object_id ) view.app.push_message( @@ -189,15 +226,15 @@ async def run_processor(view: ActorView) -> None: if not view.instance['software']: if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])): with conn.transaction(): - view.instance = conn.update_inbox( - view.instance['inbox'], + view.instance = conn.put_inbox( + domain = view.instance['domain'], software = nodeinfo.sw_name ) if not view.instance['actor']: with conn.transaction(): - view.instance = conn.update_inbox( - view.instance['inbox'], + view.instance = conn.put_inbox( + domain = view.instance['domain'], actor = view.actor.id ) diff --git a/relay/template.py b/relay/template.py index 64738e0..1335fab 100644 --- a/relay/template.py +++ b/relay/template.py @@ -1,15 +1,22 @@ from __future__ import annotations +import textwrap import typing -from hamlish_jinja.extension import HamlishExtension +from collections.abc import Callable +from hamlish_jinja import HamlishExtension from jinja2 import Environment, FileSystemLoader +from jinja2.ext import Extension +from jinja2.nodes import CallBlock +from markdown import Markdown + from . import __version__ -from .database.config import THEMES from .misc import get_resource if typing.TYPE_CHECKING: + from jinja2.nodes import Node + from jinja2.parser import Parser from typing import Any from .application import Application from .views.base import View @@ -22,7 +29,8 @@ class Template(Environment): trim_blocks = True, lstrip_blocks = True, extensions = [ - HamlishExtension + HamlishExtension, + MarkdownExtension ], loader = FileSystemLoader([ get_resource('frontend'), @@ -36,16 +44,52 @@ class Template(Environment): def render(self, path: str, view: View | None = None, **context: Any) -> str: - with self.app.database.session(False) as s: - config = s.get_config_all() + with self.app.database.session(False) as conn: + config = conn.get_config_all() new_context = { 'view': view, 'domain': self.app.config.domain, 'version': __version__, 'config': config, - 'theme_name': config['theme'] or 'Default', **(context or {}) } return self.get_template(path).render(new_context) + + + def render_markdown(self, text: str) -> str: + return self._render_markdown(text) # type: ignore + + +class MarkdownExtension(Extension): + tags = {'markdown'} + extensions = ( + 'attr_list', + 'smarty', + 'tables' + ) + + + def __init__(self, environment: Environment): + Extension.__init__(self, environment) + self._markdown = Markdown(extensions = MarkdownExtension.extensions) + environment.extend( + _render_markdown = self._render_markdown + ) + + + def parse(self, parser: Parser) -> Node | list[Node]: + lineno = next(parser.stream).lineno + body = parser.parse_statements( + ('name:endmarkdown',), + drop_needle = True + ) + + output = CallBlock(self.call_method('_render_markdown'), [], [], body) + return output.set_lineno(lineno) + + + def _render_markdown(self, caller: Callable[[], str] | str) -> str: + text = caller if isinstance(caller, str) else caller() + return self._markdown.convert(textwrap.dedent(text.strip('\n'))) diff --git a/relay/views/__init__.py b/relay/views/__init__.py index 6366592..25a7a62 100644 --- a/relay/views/__init__.py +++ b/relay/views/__init__.py @@ -1,4 +1,4 @@ from __future__ import annotations from . import activitypub, api, frontend, misc -from .base import VIEWS +from .base import VIEWS, View diff --git a/relay/views/activitypub.py b/relay/views/activitypub.py index 31266f6..f2eff48 100644 --- a/relay/views/activitypub.py +++ b/relay/views/activitypub.py @@ -12,27 +12,31 @@ from ..processors import run_processor if typing.TYPE_CHECKING: from aiohttp.web import Request - from tinysql import Row + from bsql import Row -# pylint: disable=unused-argument - @register_route('/actor', '/inbox') class ActorView(View): + signature: aputils.Signature + message: Message + actor: Message + instancce: Row + signer: aputils.Signer + + def __init__(self, request: Request): View.__init__(self, request) - self.signature: aputils.Signature = None - self.message: Message = None - self.actor: Message = None - self.instance: Row = None - self.signer: aputils.Signer = None - async def get(self, request: Request) -> Response: + with self.database.session(False) as conn: + config = conn.get_config_all() + data = Message.new_actor( host = self.config.domain, - pubkey = self.app.signer.pubkey + pubkey = self.app.signer.pubkey, + description = self.app.template.render_markdown(config.note), + approves = config.approval_required ) return Response.new(data, ctype='activity') @@ -44,19 +48,13 @@ class ActorView(View): with self.database.session() as conn: self.instance = conn.get_inbox(self.actor.shared_inbox) - config = conn.get_config_all() - ## reject if the actor isn't whitelisted while the whiltelist is enabled - if config['whitelist-enabled'] and not conn.get_domain_whitelist(self.actor.domain): - logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id) - return Response.new_error(403, 'access denied', 'json') - - ## reject if actor is banned + # reject if actor is banned if conn.get_domain_ban(self.actor.domain): logging.verbose('Ignored request from banned actor: %s', self.actor.id) return Response.new_error(403, 'access denied', 'json') - ## reject if activity type isn't 'Follow' and the actor isn't following + # reject if activity type isn't 'Follow' and the actor isn't following if self.message.type != 'Follow' and not self.instance: logging.verbose( 'Rejected actor for trying to post while not following: %s', @@ -73,35 +71,33 @@ class ActorView(View): async def get_post_data(self) -> Response | None: try: - self.signature = aputils.Signature.new_from_signature(self.request.headers['signature']) + self.signature = aputils.Signature.parse(self.request.headers['signature']) except KeyError: logging.verbose('Missing signature header') return Response.new_error(400, 'missing signature header', 'json') try: - self.message = await self.request.json(loads = Message.parse) + message: Message | None = await self.request.json(loads = Message.parse) except Exception: traceback.print_exc() logging.verbose('Failed to parse inbox message') return Response.new_error(400, 'failed to parse message', 'json') - if self.message is None: + if message is None: logging.verbose('empty message') return Response.new_error(400, 'missing message', 'json') + self.message = message + if 'actor' not in self.message: logging.verbose('actor not in message') return Response.new_error(400, 'no actor in message', 'json') - self.actor = await self.client.get( - self.signature.keyid, - sign_headers = True, - loads = Message.parse - ) + actor: Message | None = await self.client.get(self.signature.keyid, True, Message) - if not self.actor: + if actor is None: # ld signatures aren't handled atm, so just ignore it if self.message.type == 'Delete': logging.verbose('Instance sent a delete which cannot be handled') @@ -110,6 +106,8 @@ class ActorView(View): logging.verbose(f'Failed to fetch actor: {self.signature.keyid}') return Response.new_error(400, 'failed to fetch actor', 'json') + self.actor = actor + try: self.signer = self.actor.signer @@ -118,42 +116,13 @@ class ActorView(View): return Response.new_error(400, 'actor missing public key', 'json') try: - self.validate_signature(await self.request.read()) + await self.signer.validate_request_async(self.request) except aputils.SignatureFailureError as e: logging.verbose('signature validation failed for "%s": %s', self.actor.id, e) return Response.new_error(401, str(e), 'json') - - def validate_signature(self, body: bytes) -> None: - headers = {key.lower(): value for key, value in self.request.headers.items()} - headers["(request-target)"] = " ".join([self.request.method.lower(), self.request.path]) - - if (digest := aputils.Digest.new_from_digest(headers.get("digest"))): - if not body: - raise aputils.SignatureFailureError("Missing body for digest verification") - - if not digest.validate(body): - raise aputils.SignatureFailureError("Body digest does not match") - - if self.signature.algorithm_type == "hs2019": - if "(created)" not in self.signature.headers: - raise aputils.SignatureFailureError("'(created)' header not used") - - current_timestamp = aputils.HttpDate.new_utc().timestamp() - - if self.signature.created > current_timestamp: - raise aputils.SignatureFailureError("Creation date after current date") - - if current_timestamp > self.signature.expires: - raise aputils.SignatureFailureError("Expiration date before current date") - - headers["(created)"] = self.signature.created - headers["(expires)"] = self.signature.expires - - # pylint: disable=protected-access - if not self.signer._validate_signature(headers, self.signature): - raise aputils.SignatureFailureError("Signature does not match") + return None @register_route('/.well-known/webfinger') diff --git a/relay/views/api.py b/relay/views/api.py index 5a32cac..04b9af8 100644 --- a/relay/views/api.py +++ b/relay/views/api.py @@ -9,23 +9,22 @@ from urllib.parse import urlparse from .base import View, register_route from .. import __version__ -from .. import logger as logging -from ..database.config import CONFIG_DEFAULTS -from ..misc import Message, Response +from ..database import ConfigData +from ..misc import Message, Response, boolean, get_app if typing.TYPE_CHECKING: from aiohttp.web import Request - from collections.abc import Coroutine + from collections.abc import Callable, Sequence + from typing import Any -CONFIG_IGNORE = ( - 'schema-version', - 'private-key' -) +ALLOWED_HEADERS = { + 'accept', + 'authorization', + 'content-type' +} -CONFIG_VALID = {key for key in CONFIG_DEFAULTS if key not in CONFIG_IGNORE} - -PUBLIC_API_PATHS: tuple[tuple[str, str]] = ( +PUBLIC_API_PATHS: Sequence[tuple[str, str]] = ( ('GET', '/api/v1/relay'), ('GET', '/api/v1/instance'), ('POST', '/api/v1/token') @@ -40,28 +39,36 @@ def check_api_path(method: str, path: str) -> bool: @web.middleware -async def handle_api_path(request: web.Request, handler: Coroutine) -> web.Response: +async def handle_api_path(request: Request, handler: Callable) -> Response: try: - request['token'] = request.headers['Authorization'].replace('Bearer', '').strip() + if (token := request.cookies.get('user-token')): + request['token'] = token - with request.app.database.session() as conn: + else: + request['token'] = request.headers['Authorization'].replace('Bearer', '').strip() + + with get_app().database.session() as conn: request['user'] = conn.get_user_by_token(request['token']) except (KeyError, ValueError): request['token'] = None request['user'] = None - if check_api_path(request.method, request.path): + if request.method != "OPTIONS" and check_api_path(request.method, request.path): if not request['token']: return Response.new_error(401, 'Missing token', 'json') if not request['user']: return Response.new_error(401, 'Invalid token', 'json') - return await handler(request) + response = await handler(request) + if request.path.startswith('/api'): + response.headers['Access-Control-Allow-Origin'] = '*' + response.headers['Access-Control-Allow-Headers'] = ', '.join(ALLOWED_HEADERS) + + return response -# pylint: disable=no-self-use,unused-argument @register_route('/api/v1/token') class Login(View): @@ -87,7 +94,19 @@ class Login(View): token = conn.put_token(data['username']) - return Response.new({'token': token['code']}, ctype = 'json') + resp = Response.new({'token': token['code']}, ctype = 'json') + resp.set_cookie( + 'user-token', + token['code'], + max_age = 60 * 60 * 24 * 365, + domain = self.config.domain, + path = '/', + secure = True, + httponly = False, + samesite = 'lax' + ) + + return resp async def delete(self, request: Request) -> Response: @@ -102,14 +121,14 @@ class RelayInfo(View): async def get(self, request: Request) -> Response: with self.database.session() as conn: config = conn.get_config_all() - inboxes = [row['domain'] for row in conn.execute('SELECT * FROM inboxes')] + inboxes = [row['domain'] for row in conn.get_inboxes()] data = { 'domain': self.config.domain, - 'name': config['name'], - 'description': config['note'], + 'name': config.name, + 'description': config.note, 'version': __version__, - 'whitelist_enabled': config['whitelist-enabled'], + 'whitelist_enabled': config.whitelist_enabled, 'email': None, 'admin': None, 'icon': None, @@ -122,12 +141,17 @@ class RelayInfo(View): @register_route('/api/v1/config') class Config(View): async def get(self, request: Request) -> Response: - with self.database.session() as conn: - data = conn.get_config_all() - data['log-level'] = data['log-level'].name + data = {} - for key in CONFIG_IGNORE: - del data[key] + with self.database.session() as conn: + for key, value in conn.get_config_all().to_dict().items(): + if key in ConfigData.SYSTEM_KEYS(): + continue + + if key == 'log-level': + value = value.name + + data[key] = value return Response.new(data, ctype = 'json') @@ -138,7 +162,9 @@ class Config(View): if isinstance(data, Response): return data - if data['key'] not in CONFIG_VALID: + data['key'] = data['key'].replace('-', '_') + + if data['key'] not in ConfigData.USER_KEYS(): return Response.new_error(400, 'Invalid key', 'json') with self.database.session() as conn: @@ -153,11 +179,11 @@ class Config(View): if isinstance(data, Response): return data - if data['key'] not in CONFIG_VALID: + if data['key'] not in ConfigData.USER_KEYS(): return Response.new_error(400, 'Invalid key', 'json') with self.database.session() as conn: - conn.put_config(data['key'], CONFIG_DEFAULTS[data['key']][1]) + conn.put_config(data['key'], ConfigData.DEFAULT(data['key'])) return Response.new({'message': 'Updated config'}, ctype = 'json') @@ -166,7 +192,7 @@ class Config(View): class Inbox(View): async def get(self, request: Request) -> Response: with self.database.session() as conn: - data = tuple(conn.execute('SELECT * FROM inboxes').all()) + data = conn.get_inboxes() return Response.new(data, ctype = 'json') @@ -184,19 +210,19 @@ class Inbox(View): return Response.new_error(404, 'Instance already in database', 'json') if not data.get('inbox'): - try: - actor_data = await self.client.get( - data['actor'], - sign_headers = True, - loads = Message.parse - ) + actor_data: Message | None = await self.client.get(data['actor'], True, Message) - data['inbox'] = actor_data.shared_inbox - - except Exception as e: - logging.error('Failed to fetch actor: %s', str(e)) + if actor_data is None: return Response.new_error(500, 'Failed to fetch actor', 'json') + data['inbox'] = actor_data.shared_inbox + + if not data.get('software'): + nodeinfo = await self.client.fetch_nodeinfo(data['domain']) + + if nodeinfo is not None: + data['software'] = nodeinfo.sw_name + row = conn.put_inbox(**data) return Response.new(row, ctype = 'json') @@ -212,12 +238,12 @@ class Inbox(View): if not (instance := conn.get_inbox(data['domain'])): return Response.new_error(404, 'Instance with domain not found', 'json') - instance = conn.update_inbox(instance['inbox'], **data) + instance = conn.put_inbox(instance['domain'], **data) return Response.new(instance, ctype = 'json') - async def delete(self, request: Request, domain: str) -> Response: + async def delete(self, request: Request) -> Response: with self.database.session() as conn: data = await self.get_api_data(['domain'], []) @@ -232,6 +258,47 @@ class Inbox(View): return Response.new({'message': 'Deleted instance'}, ctype = 'json') +@register_route('/api/v1/request') +class RequestView(View): + async def get(self, request: Request) -> Response: + with self.database.session() as conn: + instances = conn.get_requests() + + return Response.new(instances, ctype = 'json') + + + async def post(self, request: Request) -> Response: + data: dict[str, Any] | Response = await self.get_api_data(['domain', 'accept'], []) + data['accept'] = boolean(data['accept']) + + try: + with self.database.session(True) as conn: + instance = conn.put_request_response(data['domain'], data['accept']) + + except KeyError: + return Response.new_error(404, 'Request not found', 'json') + + message = Message.new_response( + host = self.config.domain, + actor = instance['actor'], + followid = instance['followid'], + accept = data['accept'] + ) + + self.app.push_message(instance['inbox'], message, instance) + + if data['accept'] and instance['software'] != 'mastodon': + message = Message.new_follow( + host = self.config.domain, + actor = instance['actor'] + ) + + self.app.push_message(instance['inbox'], message, instance) + + resp_message = {'message': 'Request accepted' if data['accept'] else 'Request denied'} + return Response.new(resp_message, ctype = 'json') + + @register_route('/api/v1/domain_ban') class DomainBan(View): async def get(self, request: Request) -> Response: @@ -269,7 +336,7 @@ class DomainBan(View): if not any([data.get('note'), data.get('reason')]): return Response.new_error(400, 'Must include note and/or reason parameters', 'json') - ban = conn.update_domain_ban(data['domain'], **data) + ban = conn.update_domain_ban(**data) return Response.new(ban, ctype = 'json') @@ -326,7 +393,7 @@ class SoftwareBan(View): if not any([data.get('note'), data.get('reason')]): return Response.new_error(400, 'Must include note and/or reason parameters', 'json') - ban = conn.update_software_ban(data['name'], **data) + ban = conn.update_software_ban(**data) return Response.new(ban, ctype = 'json') @@ -346,6 +413,63 @@ class SoftwareBan(View): return Response.new({'message': 'Unbanned software'}, ctype = 'json') +@register_route('/api/v1/user') +class User(View): + async def get(self, request: Request) -> Response: + with self.database.session() as conn: + items = [] + + for row in conn.execute('SELECT * FROM users'): + del row['hash'] + items.append(row) + + return Response.new(items, ctype = 'json') + + + async def post(self, request: Request) -> Response: + data = await self.get_api_data(['username', 'password'], ['handle']) + + if isinstance(data, Response): + return data + + with self.database.session() as conn: + if conn.get_user(data['username']): + return Response.new_error(404, 'User already exists', 'json') + + user = conn.put_user(**data) + del user['hash'] + + return Response.new(user, ctype = 'json') + + + async def patch(self, request: Request) -> Response: + data = await self.get_api_data(['username'], ['password', 'handle']) + + if isinstance(data, Response): + return data + + with self.database.session(True) as conn: + user = conn.put_user(**data) + del user['hash'] + + return Response.new(user, ctype = 'json') + + + async def delete(self, request: Request) -> Response: + data = await self.get_api_data(['username'], []) + + if isinstance(data, Response): + return data + + with self.database.session(True) as conn: + if not conn.get_user(data['username']): + return Response.new_error(404, 'User does not exist', 'json') + + conn.del_user(data['username']) + + return Response.new({'message': 'Deleted user'}, ctype = 'json') + + @register_route('/api/v1/whitelist') class Whitelist(View): async def get(self, request: Request) -> Response: diff --git a/relay/views/base.py b/relay/views/base.py index f568525..93b3e3b 100644 --- a/relay/views/base.py +++ b/relay/views/base.py @@ -2,40 +2,52 @@ from __future__ import annotations import typing +from Crypto.Random import get_random_bytes from aiohttp.abc import AbstractView from aiohttp.hdrs import METH_ALL as METHODS from aiohttp.web import HTTPMethodNotAllowed +from base64 import b64encode from functools import cached_property from json.decoder import JSONDecodeError -from ..misc import Response +from ..misc import Response, get_app if typing.TYPE_CHECKING: from aiohttp.web import Request - from collections.abc import Callable, Coroutine, Generator + from collections.abc import Callable, Generator, Sequence, Mapping from bsql import Database - from typing import Any, Self + from typing import Any from ..application import Application from ..cache import Cache from ..config import Config from ..http_client import HttpClient from ..template import Template + try: + from typing import Self -VIEWS = [] + except ImportError: + from typing_extensions import Self + + +VIEWS: list[tuple[str, type[View]]] = [] + + +def convert_data(data: Mapping[str, Any]) -> dict[str, str]: + return {key: str(value) for key, value in data.items()} def register_route(*paths: str) -> Callable: - def wrapper(view: View) -> View: + def wrapper(view: type[View]) -> type[View]: for path in paths: - VIEWS.append([path, view]) + VIEWS.append((path, view)) return view return wrapper class View(AbstractView): - def __await__(self) -> Generator[Response]: + def __await__(self) -> Generator[Any, None, Response]: if self.request.method not in METHODS: raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) @@ -46,22 +58,27 @@ class View(AbstractView): @classmethod - async def run(cls: type[Self], method: str, request: Request, **kwargs: Any) -> Self: + async def run(cls: type[Self], method: str, request: Request, **kwargs: Any) -> Response: view = cls(request) return await view.handlers[method](request, **kwargs) - async def _run_handler(self, handler: Coroutine, **kwargs: Any) -> Response: + async def _run_handler(self, handler: Callable[..., Any], **kwargs: Any) -> Response: + self.request['hash'] = b64encode(get_random_bytes(16)).decode('ascii') return await handler(self.request, **self.request.match_info, **kwargs) + async def options(self, request: Request) -> Response: + return Response.new() + + @cached_property - def allowed_methods(self) -> tuple[str]: + def allowed_methods(self) -> Sequence[str]: return tuple(self.handlers.keys()) @cached_property - def handlers(self) -> dict[str, Coroutine]: + def handlers(self) -> dict[str, Callable[..., Any]]: data = {} for method in METHODS: @@ -74,10 +91,9 @@ class View(AbstractView): return data - # app components @property def app(self) -> Application: - return self.request.app + return get_app() @property @@ -110,17 +126,17 @@ class View(AbstractView): optional: list[str]) -> dict[str, str] | Response: if self.request.content_type in {'x-www-form-urlencoded', 'multipart/form-data'}: - post_data = await self.request.post() + post_data = convert_data(await self.request.post()) elif self.request.content_type == 'application/json': try: - post_data = await self.request.json() + post_data = convert_data(await self.request.json()) except JSONDecodeError: return Response.new_error(400, 'Invalid JSON data', 'json') else: - post_data = self.request.query + post_data = convert_data(self.request.query) data = {} @@ -132,6 +148,6 @@ class View(AbstractView): return Response.new_error(400, f'Missing {str(e)} pararmeter', 'json') for key in optional: - data[key] = post_data.get(key) + data[key] = post_data.get(key, '') return data diff --git a/relay/views/frontend.py b/relay/views/frontend.py index bd63417..2b5bec0 100644 --- a/relay/views/frontend.py +++ b/relay/views/frontend.py @@ -3,60 +3,59 @@ from __future__ import annotations import typing from aiohttp import web -from argon2.exceptions import VerifyMismatchError -from urllib.parse import urlparse from .base import View, register_route -from ..database import CONFIG_DEFAULTS, THEMES +from ..database import THEMES from ..logger import LogLevel -from ..misc import ACTOR_FORMATS, Message, Response +from ..misc import Response, get_app if typing.TYPE_CHECKING: from aiohttp.web import Request - from collections.abc import Coroutine + from collections.abc import Callable + from typing import Any -# pylint: disable=no-self-use - UNAUTH_ROUTES = { '/', '/login' } -CONFIG_IGNORE = ( - 'schema-version', - 'private-key' -) - @web.middleware -async def handle_frontend_path(request: web.Request, handler: Coroutine) -> Response: +async def handle_frontend_path(request: web.Request, handler: Callable) -> Response: + app = get_app() + if request.path in UNAUTH_ROUTES or request.path.startswith('/admin'): request['token'] = request.cookies.get('user-token') request['user'] = None if request['token']: - with request.app.database.session(False) as conn: + with app.database.session(False) as conn: request['user'] = conn.get_user_by_token(request['token']) if request['user'] and request.path == '/login': return Response.new('', 302, {'Location': '/'}) if not request['user'] and request.path.startswith('/admin'): - return Response.new('', 302, {'Location': f'/login?redir={request.path}'}) + response = Response.new('', 302, {'Location': f'/login?redir={request.path}'}) + response.del_cookie('user-token') + return response - return await handler(request) + response = await handler(request) + if not request.path.startswith('/api') and not request['user'] and request['token']: + response.del_cookie('user-token') + + return response -# pylint: disable=unused-argument @register_route('/') class HomeView(View): async def get(self, request: Request) -> Response: with self.database.session() as conn: - context = { - 'instances': tuple(conn.execute('SELECT * FROM inboxes').all()) + context: dict[str, Any] = { + 'instances': tuple(conn.get_inboxes()) } data = self.template.render('page/home.haml', self, **context) @@ -70,47 +69,6 @@ class Login(View): return Response.new(data, ctype = 'html') - async def post(self, request: Request) -> Response: - form = await request.post() - params = {} - - with self.database.session(True) as conn: - if not (user := conn.get_user(form['username'])): - params = { - 'username': form['username'], - 'error': 'User not found' - } - - else: - try: - conn.hasher.verify(user['hash'], form['password']) - - except VerifyMismatchError: - params = { - 'username': form['username'], - 'error': 'Invalid password' - } - - if params: - data = self.template.render('page/login.haml', self, **params) - return Response.new(data, ctype = 'html') - - token = conn.put_token(user['username']) - resp = Response.new_redir(request.query.getone('redir', '/')) - resp.set_cookie( - 'user-token', - token['code'], - max_age = 60 * 60 * 24 * 365, - domain = self.config.domain, - path = '/', - secure = True, - httponly = True, - samesite = 'Strict' - ) - - return resp - - @register_route('/logout') class Logout(View): async def get(self, request: Request) -> Response: @@ -136,8 +94,9 @@ class AdminInstances(View): message: str | None = None) -> Response: with self.database.session() as conn: - context = { - 'instances': tuple(conn.execute('SELECT * FROM inboxes').all()) + context: dict[str, Any] = { + 'instances': tuple(conn.get_inboxes()), + 'requests': tuple(conn.get_requests()) } if error: @@ -150,44 +109,6 @@ class AdminInstances(View): return Response.new(data, ctype = 'html') - async def post(self, request: Request) -> Response: - data = await request.post() - - if not data.get('actor') and not data.get('domain'): - return await self.get(request, error = 'Missing actor and/or domain') - - if not data.get('domain'): - data['domain'] = urlparse(data['actor']).netloc - - if not data.get('software'): - nodeinfo = await self.client.fetch_nodeinfo(data['domain']) - data['software'] = nodeinfo.sw_name - - if not data.get('actor') and data['software'] in ACTOR_FORMATS: - data['actor'] = ACTOR_FORMATS[data['software']].format(domain = data['domain']) - - if not data.get('inbox') and data['actor']: - actor = await self.client.get(data['actor'], sign_headers = True, loads = Message.parse) - data['inbox'] = actor.shared_inbox - - with self.database.session(True) as conn: - conn.put_inbox(**data) - - return await self.get(request, message = "Added new inbox") - - -@register_route('/admin/instances/delete/{domain}') -class AdminInstancesDelete(View): - async def get(self, request: Request, domain: str) -> Response: - with self.database.session() as conn: - if not conn.get_inbox(domain): - return await AdminInstances(request).get(request, message = 'Instance not found') - - conn.del_inbox(domain) - - return await AdminInstances(request).get(request, message = 'Removed instance') - - @register_route('/admin/whitelist') class AdminWhitelist(View): async def get(self, @@ -196,8 +117,8 @@ class AdminWhitelist(View): message: str | None = None) -> Response: with self.database.session() as conn: - context = { - 'whitelist': tuple(conn.execute('SELECT * FROM whitelist').all()) + context: dict[str, Any] = { + 'whitelist': tuple(conn.execute('SELECT * FROM whitelist ORDER BY domain ASC')) } if error: @@ -210,34 +131,6 @@ class AdminWhitelist(View): return Response.new(data, ctype = 'html') - async def post(self, request: Request) -> Response: - data = await request.post() - - if not data['domain']: - return await self.get(request, error = 'Missing domain') - - with self.database.session(True) as conn: - if conn.get_domain_whitelist(data['domain']): - return await self.get(request, message = "Domain already in whitelist") - - conn.put_domain_whitelist(data['domain']) - - return await self.get(request, message = "Added/updated domain ban") - - -@register_route('/admin/whitelist/delete/{domain}') -class AdminWhitlistDelete(View): - async def get(self, request: Request, domain: str) -> Response: - with self.database.session() as conn: - if not conn.get_domain_whitelist(domain): - msg = 'Whitelisted domain not found' - return await AdminWhitelist.run("GET", request, message = msg) - - conn.del_domain_whitelist(domain) - - return await AdminWhitelist.run("GET", request, message = 'Removed domain from whitelist') - - @register_route('/admin/domain_bans') class AdminDomainBans(View): async def get(self, @@ -246,8 +139,8 @@ class AdminDomainBans(View): message: str | None = None) -> Response: with self.database.session() as conn: - context = { - 'bans': tuple(conn.execute('SELECT * FROM domain_bans ORDER BY domain ASC').all()) + context: dict[str, Any] = { + 'bans': tuple(conn.execute('SELECT * FROM domain_bans ORDER BY domain ASC')) } if error: @@ -260,42 +153,6 @@ class AdminDomainBans(View): return Response.new(data, ctype = 'html') - async def post(self, request: Request) -> Response: - data = await request.post() - - if not data['domain']: - return await self.get(request, error = 'Missing domain') - - with self.database.session(True) as conn: - if conn.get_domain_ban(data['domain']): - conn.update_domain_ban( - data['domain'], - data.get('reason'), - data.get('note') - ) - - else: - conn.put_domain_ban( - data['domain'], - data.get('reason'), - data.get('note') - ) - - return await self.get(request, message = "Added/updated domain ban") - - -@register_route('/admin/domain_bans/delete/{domain}') -class AdminDomainBansDelete(View): - async def get(self, request: Request, domain: str) -> Response: - with self.database.session() as conn: - if not conn.get_domain_ban(domain): - return await AdminDomainBans.run("GET", request, message = 'Domain ban not found') - - conn.del_domain_ban(domain) - - return await AdminDomainBans.run("GET", request, message = 'Unbanned domain') - - @register_route('/admin/software_bans') class AdminSoftwareBans(View): async def get(self, @@ -304,8 +161,8 @@ class AdminSoftwareBans(View): message: str | None = None) -> Response: with self.database.session() as conn: - context = { - 'bans': tuple(conn.execute('SELECT * FROM software_bans ORDER BY name ASC').all()) + context: dict[str, Any] = { + 'bans': tuple(conn.execute('SELECT * FROM software_bans ORDER BY name ASC')) } if error: @@ -318,42 +175,6 @@ class AdminSoftwareBans(View): return Response.new(data, ctype = 'html') - async def post(self, request: Request) -> Response: - data = await request.post() - - if not data['name']: - return await self.get(request, error = 'Missing name') - - with self.database.session(True) as conn: - if conn.get_software_ban(data['name']): - conn.update_software_ban( - data['name'], - data.get('reason'), - data.get('note') - ) - - else: - conn.put_software_ban( - data['name'], - data.get('reason'), - data.get('note') - ) - - return await self.get(request, message = "Added/updated software ban") - - -@register_route('/admin/software_bans/delete/{name}') -class AdminSoftwareBansDelete(View): - async def get(self, request: Request, name: str) -> Response: - with self.database.session() as conn: - if not conn.get_software_ban(name): - return await AdminSoftwareBans.run("GET", request, message = 'Software ban not found') - - conn.del_software_ban(name) - - return await AdminSoftwareBans.run("GET", request, message = 'Unbanned software') - - @register_route('/admin/users') class AdminUsers(View): async def get(self, @@ -362,8 +183,8 @@ class AdminUsers(View): message: str | None = None) -> Response: with self.database.session() as conn: - context = { - 'users': tuple(conn.execute('SELECT * FROM users').all()) + context: dict[str, Any] = { + 'users': tuple(conn.execute('SELECT * FROM users ORDER BY username ASC')) } if error: @@ -376,82 +197,47 @@ class AdminUsers(View): return Response.new(data, ctype = 'html') - async def post(self, request: Request) -> Response: - data = await request.post() - required_fields = {'username', 'password', 'password2'} - - if not all(data.get(field) for field in required_fields): - return await self.get(request, error = 'Missing username and/or password') - - if data['password'] != data['password2']: - return await self.get(request, error = 'Passwords do not match') - - with self.database.session(True) as conn: - if conn.get_user(data['username']): - return await self.get(request, message = "User already exists") - - conn.put_user(data['username'], data['password'], data['handle']) - - return await self.get(request, message = "Added user") - - -@register_route('/admin/users/delete/{name}') -class AdminUsersDelete(View): - async def get(self, request: Request, name: str) -> Response: - with self.database.session() as conn: - if not conn.get_user(name): - return await AdminUsers.run("GET", request, message = 'User not found') - - conn.del_user(name) - - return await AdminUsers.run("GET", request, message = 'User deleted') - - @register_route('/admin/config') class AdminConfig(View): async def get(self, request: Request, message: str | None = None) -> Response: - context = { + context: dict[str, Any] = { 'themes': tuple(THEMES.keys()), - 'LogLevel': LogLevel, + 'levels': tuple(level.name for level in LogLevel), 'message': message } + data = self.template.render('page/admin-config.haml', self, **context) return Response.new(data, ctype = 'html') - async def post(self, request: Request) -> Response: - form = dict(await request.post()) - - with self.database.session(True) as conn: - for key in CONFIG_DEFAULTS: - value = form.get(key) - - if key == 'whitelist-enabled': - value = bool(value) - - elif key.lower() in CONFIG_IGNORE: - continue - - if value is None: - continue - - conn.put_config(key, value) - - return await self.get(request, message = 'Updated config') - - -@register_route('/style.css') -class StyleCss(View): +@register_route('/manifest.json') +class ManifestJson(View): async def get(self, request: Request) -> Response: - data = self.template.render('style.css', self) - return Response.new(data, ctype = 'css') + with self.database.session(False) as conn: + config = conn.get_config_all() + theme = THEMES[config.theme] + + data = { + 'background_color': theme['background'], + 'categories': ['activitypub'], + 'description': 'Message relay for the ActivityPub network', + 'display': 'standalone', + 'name': config['name'], + 'orientation': 'portrait', + 'scope': f"https://{self.config.domain}/", + 'short_name': 'ActivityRelay', + 'start_url': f"https://{self.config.domain}/", + 'theme_color': theme['primary'] + } + + return Response.new(data, ctype = 'webmanifest') @register_route('/theme/{theme}.css') class ThemeCss(View): async def get(self, request: Request, theme: str) -> Response: try: - context = { + context: dict[str, Any] = { 'theme': THEMES[theme] } diff --git a/relay/views/misc.py b/relay/views/misc.py index 65025e3..f10a877 100644 --- a/relay/views/misc.py +++ b/relay/views/misc.py @@ -27,28 +27,26 @@ if Path(__file__).parent.parent.joinpath('.git').exists(): pass -# pylint: disable=unused-argument - @register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}') class NodeinfoView(View): - # pylint: disable=no-self-use async def get(self, request: Request, niversion: str) -> Response: with self.database.session() as conn: - inboxes = conn.execute('SELECT * FROM inboxes').all() + inboxes = conn.get_inboxes() - data = { - 'name': 'activityrelay', - 'version': VERSION, - 'protocols': ['activitypub'], - 'open_regs': not conn.get_config('whitelist-enabled'), - 'users': 1, - 'metadata': {'peers': [inbox['domain'] for inbox in inboxes]} - } + nodeinfo = aputils.Nodeinfo.new( + name = 'activityrelay', + version = VERSION, + protocols = ['activitypub'], + open_regs = not conn.get_config('whitelist-enabled'), + users = 1, + repo = 'https://git.pleroma.social/pleroma/relay' if niversion == '2.1' else None, + metadata = { + 'approval_required': conn.get_config('approval-required'), + 'peers': [inbox['domain'] for inbox in inboxes] + } + ) - if niversion == '2.1': - data['repo'] = 'https://git.pleroma.social/pleroma/relay' - - return Response.new(aputils.Nodeinfo.new(**data), ctype = 'json') + return Response.new(nodeinfo, ctype = 'json') @register_route('/.well-known/nodeinfo') diff --git a/requirements.txt b/requirements.txt index 4c43b87..5649873 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,14 @@ -aiohttp>=3.9.1 -aiohttp-swagger[performance]==1.0.16 -aputils@https://git.barkshark.xyz/barkshark/aputils/archive/0.1.7.tar.gz -argon2-cffi==23.1.0 -barkshark-sql@https://git.barkshark.xyz/barkshark/bsql/archive/0.1.2.tar.gz -click>=8.1.2 -hamlish-jinja@https://git.barkshark.xyz/barkshark/hamlish-jinja/archive/0.3.5.tar.gz -hiredis==2.3.2 -platformdirs==4.2.0 -pyyaml>=6.0 -redis==5.0.1 +activitypub-utils == 0.2.1 +aiohttp >= 3.9.1 +aiohttp-swagger[performance] == 1.0.16 +argon2-cffi == 23.1.0 +barkshark-sql @ https://git.barkshark.xyz/barkshark/bsql/archive/0.1.2.tar.gz +click >= 8.1.2 +hamlish-jinja @ https://git.barkshark.xyz/barkshark/hamlish-jinja/archive/0.3.5.tar.gz +hiredis == 2.3.2 +markdown == 3.5.2 +platformdirs == 4.2.0 +pyyaml >= 6.0 +redis == 5.0.1 -importlib_resources==6.1.1;python_version<'3.9' +importlib_resources == 6.1.1; python_version < '3.9' diff --git a/setup.cfg b/setup.cfg index 41c2a30..b7d4fdc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,6 +44,8 @@ console_scripts = [flake8] -select = F401 +extend-ignore = E128,E251,E261,E303,W191 +max-line-length = 100 +indent-size = 4 per-file-ignores = __init__.py: F401