diff --git a/dev.py b/dev.py index 38499d8..114073f 100755 --- a/dev.py +++ b/dev.py @@ -1,25 +1,38 @@ #!/usr/bin/env python3 -import click import platform import shutil import subprocess import sys import time -import tomllib from datetime import datetime, timedelta +from importlib.util import find_spec from pathlib import Path -from relay import __version__, logger as logging from tempfile import TemporaryDirectory from typing import Any, Sequence try: - from watchdog.observers import Observer - from watchdog.events import FileSystemEvent, PatternMatchingEventHandler + import tomllib except ImportError: - class PatternMatchingEventHandler: # type: ignore - pass + if find_spec("toml") is None: + subprocess.run([sys.executable, "-m", "pip", "install", "toml"]) + + import toml as tomllib # type: ignore[no-redef] + +if None in [find_spec("click"), find_spec("watchdog")]: + CMD = [sys.executable, "-m", "pip", "install", "click >= 8.1.0", "watchdog >= 4.0.0"] + PROC = subprocess.run(CMD, check = False) + + if PROC.returncode != 0: + sys.exit() + + print("Successfully installed dependencies") + +import click + +from watchdog.observers import Observer +from watchdog.events import FileSystemEvent, PatternMatchingEventHandler REPO = Path(__file__).parent @@ -37,13 +50,11 @@ def cli() -> None: @cli.command('install') @click.option('--no-dev', '-d', is_flag = True, help = 'Do not install development dependencies') def cli_install(no_dev: bool) -> None: - with open('pyproject.toml', 'rb') as fd: - data = tomllib.load(fd) + with open('pyproject.toml', 'r', encoding = 'utf-8') as fd: + data = tomllib.loads(fd.read()) deps = data['project']['dependencies'] - - if not no_dev: - deps.extend(data['project']['optional-dependencies']['dev']) + deps.extend(data['project']['optional-dependencies']['dev']) subprocess.run([sys.executable, '-m', 'pip', 'install', '-U', *deps], check = False) @@ -60,7 +71,7 @@ def cli_lint(path: Path, watch: bool) -> None: return flake8 = [sys.executable, '-m', 'flake8', "dev.py", str(path)] - mypy = [sys.executable, '-m', 'mypy', "dev.py", str(path)] + mypy = [sys.executable, '-m', 'mypy', '--python-version', '3.12', 'dev.py', str(path)] click.echo('----- flake8 -----') subprocess.run(flake8) @@ -89,6 +100,8 @@ def cli_clean() -> None: @cli.command('build') def cli_build() -> None: + from relay import __version__ + with TemporaryDirectory() as tmp: arch = 'amd64' if sys.maxsize >= 2**32 else 'i386' cmd = [ @@ -171,7 +184,7 @@ class WatchHandler(PatternMatchingEventHandler): if proc.poll() is not None: continue - logging.info(f'Terminating process {proc.pid}') + print(f'Terminating process {proc.pid}') proc.terminate() sec = 0.0 @@ -180,11 +193,11 @@ class WatchHandler(PatternMatchingEventHandler): sec += 0.1 if sec >= 5: - logging.error('Failed to terminate. Killing process...') + print('Failed to terminate. Killing process...') proc.kill() break - logging.info('Process terminated') + print('Process terminated') def run_procs(self, restart: bool = False) -> None: @@ -200,13 +213,13 @@ class WatchHandler(PatternMatchingEventHandler): self.procs = [] for cmd in self.commands: - logging.info('Running command: %s', ' '.join(cmd)) + print('Running command:', ' '.join(cmd)) subprocess.run(cmd) else: self.procs = list(subprocess.Popen(cmd) for cmd in self.commands) pids = (str(proc.pid) for proc in self.procs) - logging.info('Started processes with PIDs: %s', ', '.join(pids)) + print('Started processes with PIDs:', ', '.join(pids)) def on_any_event(self, event: FileSystemEvent) -> None: diff --git a/pyproject.toml b/pyproject.toml index 6f06de0..3982ec1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,30 +9,27 @@ license = {text = "AGPLv3"} classifiers = [ "Environment :: Console", "License :: OSI Approved :: GNU Affero General Public License v3", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.12" ] dependencies = [ - "activitypub-utils == 0.3.1", + "activitypub-utils >= 0.3.2, < 0.4", "aiohttp >= 3.9.5", "aiohttp-swagger[performance] == 1.0.16", "argon2-cffi == 23.1.0", - "barkshark-lib >= 0.1.3-1", - "barkshark-sql == 0.1.4-1", - "click >= 8.1.2", + "barkshark-lib >= 0.2.3, < 0.3.0", + "barkshark-sql >= 0.2.0, < 0.3.0", + "click == 8.1.2", "hiredis == 2.3.2", "idna == 3.4", "jinja2-haml == 0.3.5", "markdown == 3.6", "platformdirs == 4.2.2", - "pyyaml >= 6.0", - "redis == 5.0.5", - "importlib-resources == 6.4.0; python_version < '3.9'" + "pyyaml == 6.0.1", + "redis == 5.0.7" ] -requires-python = ">=3.8" +requires-python = ">=3.10" dynamic = ["version"] [project.readme] @@ -49,11 +46,10 @@ activityrelay = "relay.manage:main" [project.optional-dependencies] dev = [ - "flake8 == 7.0.0", - "mypy == 1.10.0", - "pyinstaller == 6.8.0", - "watchdog == 4.0.1", - "typing-extensions >= 4.12.2; python_version < '3.11.0'" + "flake8 == 7.1.0", + "mypy == 1.11.1", + "pyinstaller == 6.10.0", + "watchdog == 4.0.2" ] [tool.setuptools] @@ -104,7 +100,3 @@ implicit_reexport = true [[tool.mypy.overrides]] module = "blib" implicit_reexport = true - -[[tool.mypy.overrides]] -module = "bsql" -implicit_reexport = true diff --git a/relay/__init__.py b/relay/__init__.py index 73e3bb4..80eb7f9 100644 --- a/relay/__init__.py +++ b/relay/__init__.py @@ -1 +1 @@ -__version__ = '0.3.2' +__version__ = '0.3.3' diff --git a/relay/application.py b/relay/application.py index b12c64f..5bdc04f 100644 --- a/relay/application.py +++ b/relay/application.py @@ -6,29 +6,33 @@ import signal import time import traceback +from Crypto.Random import get_random_bytes from aiohttp import web -from aiohttp.web import StaticResource +from aiohttp.web import HTTPException, StaticResource from aiohttp_swagger import setup_swagger from aputils.signer import Signer -from bsql import Database, Row +from base64 import b64encode +from blib import HttpError +from bsql import Database from collections.abc import Awaitable, Callable from datetime import datetime, timedelta from mimetypes import guess_type from pathlib import Path -from queue import Empty from threading import Event, Thread -from typing import Any +from typing import Any, cast from . import logger as logging from .cache import Cache, get_cache from .config import Config from .database import Connection, get_database +from .database.schema import Instance from .http_client import HttpClient -from .misc import IS_WINDOWS, Message, Response, check_open_port, get_resource +from .misc import JSON_PATHS, TOKEN_PATHS, Message, Response, check_open_port, get_resource from .template import Template from .views import VIEWS from .views.api import handle_api_path from .views.frontend import handle_frontend_path +from .workers import PushWorkers def get_csp(request: web.Request) -> str: @@ -54,9 +58,9 @@ class Application(web.Application): def __init__(self, cfgpath: Path | None, dev: bool = False): web.Application.__init__(self, middlewares = [ - handle_api_path, # type: ignore[list-item] + handle_response_headers, # type: ignore[list-item] handle_frontend_path, # type: ignore[list-item] - handle_response_headers # type: ignore[list-item] + handle_api_path # type: ignore[list-item] ] ) @@ -75,7 +79,7 @@ class Application(web.Application): self['cache'].setup() self['template'] = Template(self) self['push_queue'] = multiprocessing.Queue() - self['workers'] = [] + self['workers'] = PushWorkers(self.config.workers) self.cache.setup() self.on_cleanup.append(handle_cleanup) # type: ignore @@ -92,27 +96,27 @@ class Application(web.Application): @property def cache(self) -> Cache: - return self['cache'] # type: ignore[no-any-return] + return cast(Cache, self['cache']) @property def client(self) -> HttpClient: - return self['client'] # type: ignore[no-any-return] + return cast(HttpClient, self['client']) @property def config(self) -> Config: - return self['config'] # type: ignore[no-any-return] + return cast(Config, self['config']) @property def database(self) -> Database[Connection]: - return self['database'] # type: ignore[no-any-return] + return cast(Database[Connection], self['database']) @property def signer(self) -> Signer: - return self['signer'] # type: ignore[no-any-return] + return cast(Signer, self['signer']) @signer.setter @@ -126,7 +130,7 @@ class Application(web.Application): @property def template(self) -> Template: - return self['template'] # type: ignore[no-any-return] + return cast(Template, self['template']) @property @@ -139,8 +143,13 @@ class Application(web.Application): return timedelta(seconds=uptime.seconds) - def push_message(self, inbox: str, message: Message, instance: Row) -> None: - self['push_queue'].put((inbox, message, instance)) + @property + def workers(self) -> PushWorkers: + return cast(PushWorkers, self['workers']) + + + def push_message(self, inbox: str, message: Message, instance: Instance) -> None: + self['workers'].push_message(inbox, message, instance) def register_static_routes(self) -> None: @@ -195,12 +204,7 @@ class Application(web.Application): self['cache'].setup() self['cleanup_thread'] = CacheCleanupThread(self) self['cleanup_thread'].start() - - for _ in range(self.config.workers): - worker = PushWorker(self['push_queue']) - worker.start() - - self['workers'].append(worker) + self['workers'].start() runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"') await runner.setup() @@ -220,15 +224,13 @@ class Application(web.Application): await site.stop() - for worker in self['workers']: - worker.stop() + self['workers'].stop() self.set_signal_handler(False) self['starttime'] = None self['running'] = False self['cleanup_thread'].stop() - self['workers'].clear() self['database'].disconnect() self['cache'].close() @@ -290,56 +292,15 @@ class CacheCleanupThread(Thread): self.running.clear() -class PushWorker(multiprocessing.Process): - def __init__(self, queue: multiprocessing.Queue[tuple[str, Message, Row]]) -> None: - if Application.DEFAULT is None: - raise RuntimeError('Application not setup yet') +def format_error(request: web.Request, error: HttpError) -> Response: + app: Application = request.app # type: ignore[assignment] - multiprocessing.Process.__init__(self) + if request.path.startswith(JSON_PATHS) or 'json' in request.headers.get('accept', ''): + return Response.new({'error': error.message}, error.status, ctype = 'json') - self.queue = queue - self.shutdown = multiprocessing.Event() - self.path = Application.DEFAULT.config.path - - - def stop(self) -> None: - self.shutdown.set() - - - def run(self) -> None: - asyncio.run(self.handle_queue()) - - - async def handle_queue(self) -> None: - if IS_WINDOWS: - app = Application(self.path) - client = app.client - - client.open() - app.database.connect() - app.cache.setup() - - else: - client = HttpClient() - client.open() - - while not self.shutdown.is_set(): - try: - inbox, message, instance = self.queue.get(block=True, timeout=0.1) - asyncio.create_task(client.post(inbox, message, instance)) - - except Empty: - await asyncio.sleep(0) - - # make sure an exception doesn't bring down the worker - except Exception: - traceback.print_exc() - - if IS_WINDOWS: - app.database.disconnect() - app.cache.close() - - await client.close() + else: + body = app.template.render('page/error.haml', request, e = error) + return Response.new(body, error.status, ctype = 'html') @web.middleware @@ -347,14 +308,60 @@ async def handle_response_headers( request: web.Request, handler: Callable[[web.Request], Awaitable[Response]]) -> Response: - resp = await handler(request) + request['hash'] = b64encode(get_random_bytes(16)).decode('ascii') + request['token'] = None + request['user'] = None + + app: Application = request.app # type: ignore[assignment] + + if request.path == "/" or request.path.startswith(TOKEN_PATHS): + with app.database.session() as conn: + tokens = ( + request.headers.get('Authorization', '').replace('Bearer', '').strip(), + request.cookies.get('user-token') + ) + + for token in tokens: + if not token: + continue + + request['token'] = conn.get_app_by_token(token) + + if request['token'] is not None: + request['user'] = conn.get_user(request['token'].user) + + break + + try: + resp = await handler(request) + + except HttpError as e: + resp = format_error(request, e) + + except HTTPException as e: + if e.status == 404: + try: + text = (e.text or "").split(":")[1].strip() + + except IndexError: + text = e.text or "" + + resp = format_error(request, HttpError(e.status, text)) + + else: + raise + + except Exception: + resp = format_error(request, HttpError(500, 'Internal server error')) + traceback.print_exc() + resp.headers['Server'] = 'ActivityRelay' # Still have to figure out how csp headers work if resp.content_type == 'text/html' and not request.path.startswith("/api"): resp.headers['Content-Security-Policy'] = get_csp(request) - if not request.app['dev'] and request.path.endswith(('.css', '.js')): + if not request.app['dev'] and request.path.endswith(('.css', '.js', '.woff2')): # cache for 2 weeks resp.headers['Cache-Control'] = 'public,max-age=1209600,immutable' diff --git a/relay/cache.py b/relay/cache.py index e9f261b..0c76b8e 100644 --- a/relay/cache.py +++ b/relay/cache.py @@ -4,12 +4,13 @@ import json import os from abc import ABC, abstractmethod -from bsql import Database +from blib import Date +from bsql import Database, Row from collections.abc import Callable, Iterator from dataclasses import asdict, dataclass -from datetime import datetime, timedelta, timezone +from datetime import timedelta, timezone from redis import Redis -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypedDict from .database import Connection, get_database from .misc import Message, boolean @@ -31,6 +32,14 @@ CONVERTERS: dict[str, tuple[SerializerCallback, DeserializerCallback]] = { } +class RedisConnectType(TypedDict): + client_name: str + decode_responses: bool + username: str | None + password: str | None + db: int + + def get_cache(app: Application) -> Cache: return BACKENDS[app.config.ca_type](app) @@ -57,12 +66,14 @@ class Item: key: str value: Any value_type: str - updated: datetime + updated: Date def __post_init__(self) -> None: - if isinstance(self.updated, str): # type: ignore[unreachable] - self.updated = datetime.fromisoformat(self.updated) # type: ignore[unreachable] + self.updated = Date.parse(self.updated) + + if self.updated.tzinfo is None: + self.updated = self.updated.replace(tzinfo = timezone.utc) @classmethod @@ -70,15 +81,11 @@ class Item: data = cls(*args) data.value = deserialize_value(data.value, data.value_type) - if not isinstance(data.updated, datetime): - data.updated = datetime.fromtimestamp(data.updated, tz = timezone.utc) # type: ignore - return data def older_than(self, hours: int) -> bool: - delta = datetime.now(tz = timezone.utc) - self.updated - return (delta.total_seconds()) > hours * 3600 + return self.updated + timedelta(hours = hours) < Date.new_utc() def to_dict(self) -> dict[str, Any]: @@ -172,7 +179,7 @@ class SqlCache(Cache): with self._db.session(False) as conn: with conn.run('get-cache-item', params) as cur: - if not (row := cur.one()): + if not (row := cur.one(Row)): raise KeyError(f'{namespace}:{key}') row.pop('id', None) @@ -206,14 +213,16 @@ class SqlCache(Cache): 'key': key, 'value': serialize_value(value, value_type), 'type': value_type, - 'date': datetime.now(tz = timezone.utc) + 'date': Date.new_utc() } with self._db.session(True) as conn: with conn.run('set-cache-item', params) as cur: - row = cur.one() - row.pop('id', None) # type: ignore[union-attr] - return Item.from_data(*tuple(row.values())) # type: ignore[union-attr] + if (row := cur.one(Row)) is None: + raise RuntimeError("Cache item not set") + + row.pop('id', None) + return Item.from_data(*tuple(row.values())) def delete(self, namespace: str, key: str) -> None: @@ -234,7 +243,7 @@ class SqlCache(Cache): if self._db is None: raise RuntimeError("Database has not been setup") - limit = datetime.now(tz = timezone.utc) - timedelta(days = days) + limit = Date.new_utc() - timedelta(days = days) params = {"limit": limit.timestamp()} with self._db.session(True) as conn: @@ -278,7 +287,7 @@ class RedisCache(Cache): def __init__(self, app: Application): Cache.__init__(self, app) - self._rd: Redis = None # type: ignore + self._rd: Redis | None = None @property @@ -291,28 +300,38 @@ class RedisCache(Cache): def get(self, namespace: str, key: str) -> Item: + if self._rd is None: + raise ConnectionError("Not connected") + key_name = self.get_key_name(namespace, key) if not (raw_value := self._rd.get(key_name)): raise KeyError(f'{namespace}:{key}') - value_type, updated, value = raw_value.split(':', 2) # type: ignore + value_type, updated, value = raw_value.split(':', 2) # type: ignore[union-attr] + return Item.from_data( namespace, key, value, value_type, - datetime.fromtimestamp(float(updated), tz = timezone.utc) + Date.parse(float(updated)) ) def get_keys(self, namespace: str) -> Iterator[str]: + if self._rd is None: + raise ConnectionError("Not connected") + for key in self._rd.scan_iter(self.get_key_name(namespace, '*')): *_, key_name = key.split(':', 2) yield key_name def get_namespaces(self) -> Iterator[str]: + if self._rd is None: + raise ConnectionError("Not connected") + namespaces = [] for key in self._rd.scan_iter(f'{self.prefix}:*'): @@ -324,7 +343,10 @@ class RedisCache(Cache): def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> Item: - date = datetime.now(tz = timezone.utc).timestamp() + if self._rd is None: + raise ConnectionError("Not connected") + + date = Date.new_utc().timestamp() value = serialize_value(value, value_type) self._rd.set( @@ -336,11 +358,17 @@ class RedisCache(Cache): def delete(self, namespace: str, key: str) -> None: + if self._rd is None: + raise ConnectionError("Not connected") + self._rd.delete(self.get_key_name(namespace, key)) def delete_old(self, days: int = 14) -> None: - limit = datetime.now(tz = timezone.utc) - timedelta(days = days) + if self._rd is None: + raise ConnectionError("Not connected") + + limit = Date.new_utc() - timedelta(days = days) for full_key in self._rd.scan_iter(f'{self.prefix}:*'): _, namespace, key = full_key.split(':', 2) @@ -351,14 +379,17 @@ class RedisCache(Cache): def clear(self) -> None: + if self._rd is None: + raise ConnectionError("Not connected") + self._rd.delete(f"{self.prefix}:*") def setup(self) -> None: - if self._rd: + if self._rd is not None: return - options = { + options: RedisConnectType = { 'client_name': f'ActivityRelay_{self.app.config.domain}', 'decode_responses': True, 'username': self.app.config.rd_user, @@ -367,18 +398,22 @@ class RedisCache(Cache): } if os.path.exists(self.app.config.rd_host): - options['unix_socket_path'] = self.app.config.rd_host + self._rd = Redis( + unix_socket_path = self.app.config.rd_host, + **options + ) + return - else: - options['host'] = self.app.config.rd_host - options['port'] = self.app.config.rd_port - - self._rd = Redis(**options) # type: ignore + self._rd = Redis( + host = self.app.config.rd_host, + port = self.app.config.rd_port, + **options + ) def close(self) -> None: if not self._rd: return - self._rd.close() # type: ignore - self._rd = None # type: ignore + self._rd.close() # type: ignore[no-untyped-call] + self._rd = None diff --git a/relay/config.py b/relay/config.py index ac2bbb6..e40cf70 100644 --- a/relay/config.py +++ b/relay/config.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import getpass import os import platform @@ -6,16 +8,13 @@ import yaml from dataclasses import asdict, dataclass, fields from pathlib import Path from platformdirs import user_config_dir -from typing import Any +from typing import TYPE_CHECKING, Any from .misc import IS_DOCKER -try: +if TYPE_CHECKING: from typing import Self -except ImportError: - from typing_extensions import Self - if platform.system() == 'Windows': import multiprocessing @@ -61,7 +60,7 @@ class Config: def __init__(self, path: Path | None = None, load: bool = False): - self.path = Config.get_config_dir(path) + self.path: Path = Config.get_config_dir(path) self.reset() if load: @@ -81,7 +80,7 @@ class Config: def DEFAULT(cls: type[Self], key: str) -> str | int | None: for field in fields(cls): if field.name == key: - return field.default # type: ignore + return field.default # type: ignore[return-value] raise KeyError(key) @@ -146,7 +145,7 @@ class Config: if not config: raise ValueError('Config is empty') - pgcfg = config.get('postgresql', {}) + pgcfg = config.get('postgres', {}) rdcfg = config.get('redis', {}) for key in type(self).KEYS(): diff --git a/relay/data/statements.sql b/relay/data/statements.sql index f06d4b5..894bb40 100644 --- a/relay/data/statements.sql +++ b/relay/data/statements.sql @@ -40,7 +40,7 @@ WHERE domain = :value or inbox = :value or actor = :value; -- name: get-request -SELECT * FROM inboxes WHERE accepted = 0 and domain = :domain; +SELECT * FROM inboxes WHERE accepted = false and domain = :domain; -- name: get-user @@ -51,8 +51,8 @@ WHERE username = :value or handle = :value; -- name: get-user-by-token SELECT * FROM users WHERE username = ( - SELECT user FROM tokens - WHERE code = :code + SELECT user FROM apps + WHERE token = :token ); @@ -64,28 +64,35 @@ RETURNING *; -- name: del-user DELETE FROM users -WHERE username = :value or handle = :value; +WHERE username = :username or handle = :username; --- name: get-token -SELECT * FROM tokens -WHERE code = :code; +-- name: get-app +SELECT * FROM apps +WHERE client_id = :id and client_secret = :secret; --- name: put-token -INSERT INTO tokens (code, user, created) -VALUES (:code, :user, :created) -RETURNING *; +-- name: get-app-with-token +SELECT * FROM apps +WHERE client_id = :id and client_secret = :secret and token = :token; --- name: del-token -DELETE FROM tokens -WHERE code = :code; +-- name: get-app-by-token +SELECT * FROM apps +WHERE token = :token; + +-- name: del-app +DELETE FROM apps +WHERE client_id = :id and client_secret = :secret; + + +-- name: del-app-with-token +DELETE FROM apps +WHERE client_id = :id and client_secret = :secret and token = :token; -- name: del-token-user -DELETE FROM tokens -WHERE user = :username; +DELETE FROM apps WHERE "user" = :username; -- name: get-software-ban diff --git a/relay/data/swagger.yaml b/relay/data/swagger.yaml index a2a51dc..ac7b728 100644 --- a/relay/data/swagger.yaml +++ b/relay/data/swagger.yaml @@ -18,10 +18,12 @@ securityDefinitions: in: cookie name: user-token Bearer: - type: apiKey + type: oauth2 name: Authorization in: header - description: "Enter the token with the `Bearer ` prefix" + flow: accessCode + authorizationUrl: /oauth/authorize + tokenUrl: /oauth/token paths: /: @@ -35,6 +37,161 @@ paths: schema: $ref: "#/definitions/Error" + /oauth/authorize: + get: + tags: + - OAuth + description: Get an authorization code + parameters: + - in: query + name: response-type + required: true + type: string + - in: query + name: client_id + required: true + type: string + - in: query + name: redirect_uri + required: true + type: string + + /oauth/token: + post: + tags: + - OAuth + description: Get a token for an authorized app + parameters: + - in: formData + name: grant_type + required: true + type: string + - in: formData + name: code + required: true + type: string + - in: formData + name: client_id + required: true + type: string + - in: formData + name: client_secret + required: true + type: string + - in: formData + name: redirect_uri + required: true + type: string + consumes: + - application/x-www-form-urlencoded + - application/json + - multipart/form-data + produces: + - application/json + responses: + "200": + description: Application + schema: + $ref: "#/definitions/Application" + + /oauth/revoke: + post: + tags: + - OAuth + description: Get a token for an authorized app + parameters: + - in: formData + name: client_id + required: true + type: string + - in: formData + name: client_secret + required: true + type: string + - in: formData + name: token + required: true + type: string + consumes: + - application/json + - multipart/form-data + - application/x-www-form-urlencoded + produces: + - application/json + responses: + "200": + description: Message confirming application deletion + schema: + $ref: "#/definitions/Message" + + /v1/app: + get: + tags: + - Applications + description: Verify the token is valid + produces: + - application/json + responses: + "200": + description: Application with the associated token + schema: + $ref: "#/definitions/Application" + + post: + tags: + - Applications + description: Create a new application + parameters: + - in: query + name: name + required: true + type: string + - in: query + name: redirect_uri + required: true + type: string + - in: query + name: website + required: false + type: string + format: url + consumes: + - application/json + - multipart/form-data + - application/x-www-form-urlencoded + produces: + - application/json + responses: + "200": + description: Newly created application + schema: + $ref: "#/definitions/Application" + + delete: + tags: + - Applications + description: Deletes an application + parameters: + - in: formData + name: client_id + required: true + type: string + - in: formData + name: client_secret + required: true + type: string + consumes: + - application/json + - multipart/form-data + - application/x-www-form-urlencoded + produces: + - application/json + responses: + "200": + description: Confirmation of application deletion + schema: + $ref: "#/definitions/Message" + /v1/relay: get: tags: @@ -48,23 +205,11 @@ paths: schema: $ref: "#/definitions/Info" - /v1/token: - get: - tags: - - Token - description: Verify API token - produces: - - application/json - responses: - "200": - description: Valid token - schema: - $ref: "#/definitions/Message" - + /v1/login: post: tags: - - Token - description: Get a new token + - Login + description: Login with a username and password parameters: - in: formData name: username @@ -74,7 +219,6 @@ paths: name: password required: true type: string - format: password consumes: - application/json - multipart/form-data @@ -83,22 +227,9 @@ paths: - application/json responses: "200": - description: Created token + description: A new Application schema: - $ref: "#/definitions/Token" - - - delete: - tags: - - Token - description: Revoke a token - produces: - - application/json - responses: - "200": - description: Revoked token - schema: - $ref: "#/definitions/Message" + $ref: "#/definitions/Application" /v1/config: get: @@ -731,9 +862,43 @@ definitions: description: Human-readable message text type: string + Application: + type: object + properties: + client_id: + description: Identifier for the application + type: string + client_secret: + description: Secret string for the application + type: string + name: + description: Human-readable name of the application + type: string + website: + description: Website for the application + type: string + format: url + redirect_uri: + description: URL to redirect to when authorizing an app + type: string + token: + description: String to use in the Authorization header for client requests + type: string + created: + description: Date the application was created + type: string + format: date-time + accessed: + description: Date the application was last used + type: string + format: date-time + Config: type: object properties: + approval-required: + description: Require instances to be approved when following + type: bool log-level: description: Maximum level of log messages to print to the console type: string @@ -743,6 +908,9 @@ definitions: note: description: Blurb to display on the home page type: string + theme: + description: Name of the color scheme to use for the frontend + type: string whitelist-enabled: description: Only allow specific instances to join the relay when enabled type: boolean @@ -843,13 +1011,6 @@ definitions: type: string format: date-time - Token: - type: object - properties: - token: - description: Character string used for authenticating with the api - type: string - User: type: object properties: diff --git a/relay/database/__init__.py b/relay/database/__init__.py index becd456..03198ab 100644 --- a/relay/database/__init__.py +++ b/relay/database/__init__.py @@ -1,3 +1,6 @@ +import sqlite3 + +from blib import Date from bsql import Database from .config import THEMES, ConfigData @@ -9,6 +12,9 @@ from ..config import Config from ..misc import get_resource +sqlite3.register_adapter(Date, Date.timestamp) + + def get_database(config: Config, migrate: bool = True) -> Database[Connection]: options = { 'connection_class': Connection, @@ -16,6 +22,8 @@ def get_database(config: Config, migrate: bool = True) -> Database[Connection]: 'tables': TABLES } + db: Database[Connection] + if config.db_type == 'sqlite': db = Database.sqlite(config.sqlite_path, **options) diff --git a/relay/database/config.py b/relay/database/config.py index 6effbb9..3f3c7e0 100644 --- a/relay/database/config.py +++ b/relay/database/config.py @@ -5,17 +5,14 @@ from __future__ import annotations from bsql import Row from collections.abc import Callable, Sequence from dataclasses import Field, asdict, dataclass, fields -from typing import Any +from typing import TYPE_CHECKING, Any from .. import logger as logging from ..misc import boolean -try: +if TYPE_CHECKING: from typing import Self -except ImportError: - from typing_extensions import Self - THEMES = { 'default': { @@ -76,7 +73,7 @@ CONFIG_CONVERT: dict[str, tuple[Callable[[Any], str], Callable[[str], Any]]] = { @dataclass() class ConfigData: - schema_version: int = 20240310 + schema_version: int = 20240625 private_key: str = '' approval_required: bool = False log_level: logging.LogLevel = logging.LogLevel.INFO @@ -114,11 +111,11 @@ class ConfigData: @classmethod def DEFAULT(cls: type[Self], key: str) -> str | int | bool: - return cls.FIELD(key.replace('-', '_')).default # type: ignore + return cls.FIELD(key.replace('-', '_')).default # type: ignore[return-value] @classmethod - def FIELD(cls: type[Self], key: str) -> Field[Any]: + def FIELD(cls: type[Self], key: str) -> Field[str | int | bool]: for field in fields(cls): if field.name == key.replace('-', '_'): return field diff --git a/relay/database/connection.py b/relay/database/connection.py index 614f307..14ff60a 100644 --- a/relay/database/connection.py +++ b/relay/database/connection.py @@ -1,13 +1,16 @@ from __future__ import annotations +import secrets + from argon2 import PasswordHasher +from blib import Date from bsql import Connection as SqlConnection, Row, Update -from collections.abc import Iterator, Sequence +from collections.abc import Iterator from datetime import datetime, timezone from typing import TYPE_CHECKING, Any from urllib.parse import urlparse -from uuid import uuid4 +from . import schema from .config import ( THEMES, ConfigData @@ -37,22 +40,52 @@ class Connection(SqlConnection): return get_app() - def distill_inboxes(self, message: Message) -> Iterator[Row]: + def distill_inboxes(self, message: Message) -> Iterator[schema.Instance]: src_domains = { message.domain, urlparse(message.object_id).netloc } for instance in self.get_inboxes(): - if instance['domain'] not in src_domains: + if instance.domain not in src_domains: yield instance + def fix_timestamps(self) -> None: + for app in self.select('apps').all(schema.App): + data = {'created': app.created.timestamp(), 'accessed': app.accessed.timestamp()} + self.update('apps', data, client_id = app.client_id) + + for item in self.select('cache'): + data = {'updated': Date.parse(item['updated']).timestamp()} + self.update('cache', data, id = item['id']) + + for dban in self.select('domain_bans').all(schema.DomainBan): + data = {'created': dban.created.timestamp()} + self.update('domain_bans', data, domain = dban.domain) + + for instance in self.select('inboxes').all(schema.Instance): + data = {'created': instance.created.timestamp()} + self.update('inboxes', data, domain = instance.domain) + + for sban in self.select('software_bans').all(schema.SoftwareBan): + data = {'created': sban.created.timestamp()} + self.update('software_bans', data, name = sban.name) + + for user in self.select('users').all(schema.User): + data = {'created': user.created.timestamp()} + self.update('users', data, username = user.username) + + for wlist in self.select('whitelist').all(schema.Whitelist): + data = {'created': wlist.created.timestamp()} + self.update('whitelist', data, domain = wlist.domain) + + def get_config(self, key: str) -> Any: key = key.replace('_', '-') with self.run('get-config', {'key': key}) as cur: - if not (row := cur.one()): + if (row := cur.one(Row)) is None: return ConfigData.DEFAULT(key) data = ConfigData() @@ -61,8 +94,8 @@ class Connection(SqlConnection): def get_config_all(self) -> ConfigData: - with self.run('get-config-all', None) as cur: - return ConfigData.from_rows(tuple(cur.all())) + rows = tuple(self.run('get-config-all', None).all(schema.Row)) + return ConfigData.from_rows(rows) def put_config(self, key: str, value: Any) -> Any: @@ -75,6 +108,7 @@ class Connection(SqlConnection): elif key == 'log-level': value = logging.LogLevel.parse(value) logging.set_level(value) + self.app['workers'].set_log_level(value) elif key in {'approval-required', 'whitelist-enabled'}: value = boolean(value) @@ -98,23 +132,23 @@ class Connection(SqlConnection): return data.get(key) - def get_inbox(self, value: str) -> Row: + def get_inbox(self, value: str) -> schema.Instance | None: with self.run('get-inbox', {'value': value}) as cur: - return cur.one() # type: ignore + return cur.one(schema.Instance) - def get_inboxes(self) -> Sequence[Row]: - with self.execute("SELECT * FROM inboxes WHERE accepted = 1") as cur: - return tuple(cur.all()) + def get_inboxes(self) -> Iterator[schema.Instance]: + return self.execute("SELECT * FROM inboxes WHERE accepted = true").all(schema.Instance) - def put_inbox(self, + # todo: check if software is different than stored row + def put_inbox(self, # noqa: E301 domain: str, inbox: str | None = None, actor: str | None = None, followid: str | None = None, software: str | None = None, - accepted: bool = True) -> Row: + accepted: bool = True) -> schema.Instance: params: dict[str, Any] = { 'inbox': inbox, @@ -124,7 +158,7 @@ class Connection(SqlConnection): 'accepted': accepted } - if not self.get_inbox(domain): + if self.get_inbox(domain) is None: if not inbox: raise ValueError("Missing inbox") @@ -132,14 +166,20 @@ class Connection(SqlConnection): params['created'] = datetime.now(tz = timezone.utc) with self.run('put-inbox', params) as cur: - return cur.one() # type: ignore + if (row := cur.one(schema.Instance)) is None: + raise RuntimeError(f"Failed to insert instance: {domain}") + + return row for key, value in tuple(params.items()): if value is None: del params[key] with self.update('inboxes', params, domain = domain) as cur: - return cur.one() # type: ignore + if (row := cur.one(schema.Instance)) is None: + raise RuntimeError(f"Failed to update instance: {domain}") + + return row def del_inbox(self, value: str) -> bool: @@ -150,24 +190,23 @@ class Connection(SqlConnection): return cur.row_count == 1 - def get_request(self, domain: str) -> Row: + def get_request(self, domain: str) -> schema.Instance | None: with self.run('get-request', {'domain': domain}) as cur: - if not (row := cur.one()): - raise KeyError(domain) - - return row + return cur.one(schema.Instance) - def get_requests(self) -> Sequence[Row]: - with self.execute('SELECT * FROM inboxes WHERE accepted = 0') as cur: - return tuple(cur.all()) + def get_requests(self) -> Iterator[schema.Instance]: + return self.execute('SELECT * FROM inboxes WHERE accepted = false').all(schema.Instance) - def put_request_response(self, domain: str, accepted: bool) -> Row: - instance = self.get_request(domain) + def put_request_response(self, domain: str, accepted: bool) -> schema.Instance: + if (instance := self.get_request(domain)) is None: + raise KeyError(domain) if not accepted: - self.del_inbox(domain) + if not self.del_inbox(domain): + raise RuntimeError(f'Failed to delete request: {domain}') + return instance params = { @@ -176,21 +215,28 @@ class Connection(SqlConnection): } with self.run('put-inbox-accept', params) as cur: - return cur.one() # type: ignore + if (row := cur.one(schema.Instance)) is None: + raise RuntimeError(f"Failed to insert response for domain: {domain}") + + return row - def get_user(self, value: str) -> Row: + def get_user(self, value: str) -> schema.User | None: with self.run('get-user', {'value': value}) as cur: - return cur.one() # type: ignore + return cur.one(schema.User) - def get_user_by_token(self, code: str) -> Row: - with self.run('get-user-by-token', {'code': code}) as cur: - return cur.one() # type: ignore + def get_user_by_token(self, token: str) -> schema.User | None: + with self.run('get-user-by-token', {'token': token}) as cur: + return cur.one(schema.User) - def put_user(self, username: str, password: str | None, handle: str | None = None) -> Row: - if self.get_user(username): + def get_users(self) -> Iterator[schema.User]: + return self.execute("SELECT * FROM users").all(schema.User) + + + def put_user(self, username: str, password: str | None, handle: str | None = None) -> schema.User: + if self.get_user(username) is not None: data: dict[str, str | datetime | None] = {} if password: @@ -203,7 +249,10 @@ class Connection(SqlConnection): stmt.set_where("username", username) with self.query(stmt) as cur: - return cur.one() # type: ignore + if (row := cur.one(schema.User)) is None: + raise RuntimeError(f"Failed to update user: {username}") + + return row if password is None: raise ValueError('Password cannot be empty') @@ -216,52 +265,149 @@ class Connection(SqlConnection): } with self.run('put-user', data) as cur: - return cur.one() # type: ignore + if (row := cur.one(schema.User)) is None: + raise RuntimeError(f"Failed to insert user: {username}") + + return row def del_user(self, username: str) -> None: - user = self.get_user(username) + if (user := self.get_user(username)) is None: + raise KeyError(username) - with self.run('del-user', {'value': user['username']}): + with self.run('del-token-user', {'username': user.username}): pass - with self.run('del-token-user', {'username': user['username']}): + with self.run('del-user', {'username': user.username}): pass - def get_token(self, code: str) -> Row: - with self.run('get-token', {'code': code}) as cur: - return cur.one() # type: ignore + def get_app(self, + client_id: str, + client_secret: str, + token: str | None = None) -> schema.App | None: - - def put_token(self, username: str) -> Row: - data = { - 'code': uuid4().hex, - 'user': username, - 'created': datetime.now(tz = timezone.utc) + params = { + 'id': client_id, + 'secret': client_secret } - with self.run('put-token', data) as cur: - return cur.one() # type: ignore + if token is not None: + command = 'get-app-with-token' + params['token'] = token + + else: + command = 'get-app' + + with self.run(command, params) as cur: + return cur.one(schema.App) - def del_token(self, code: str) -> None: - with self.run('del-token', {'code': code}): - pass + def get_app_by_token(self, token: str) -> schema.App | None: + with self.run('get-app-by-token', {'token': token}) as cur: + return cur.one(schema.App) - def get_domain_ban(self, domain: str) -> Row: + def put_app(self, name: str, redirect_uri: str, website: str | None = None) -> schema.App: + params = { + 'name': name, + 'redirect_uri': redirect_uri, + 'website': website, + 'client_id': secrets.token_hex(20), + 'client_secret': secrets.token_hex(20), + 'created': Date.new_utc(), + 'accessed': Date.new_utc() + } + + with self.insert('apps', params) as cur: + if (row := cur.one(schema.App)) is None: + raise RuntimeError(f'Failed to insert app: {name}') + + return row + + + def put_app_login(self, user: schema.User) -> schema.App: + params = { + 'name': 'Web', + 'redirect_uri': 'urn:ietf:wg:oauth:2.0:oob', + 'website': None, + 'user': user.username, + 'client_id': secrets.token_hex(20), + 'client_secret': secrets.token_hex(20), + 'auth_code': None, + 'token': secrets.token_hex(20), + 'created': Date.new_utc(), + 'accessed': Date.new_utc() + } + + with self.insert('apps', params) as cur: + if (row := cur.one(schema.App)) is None: + raise RuntimeError(f'Failed to create app for "{user.username}"') + + return row + + + def update_app(self, app: schema.App, user: schema.User | None, set_auth: bool) -> schema.App: + data: dict[str, str | None] = {} + + if user is not None: + data['user'] = user.username + + if set_auth: + data['auth_code'] = secrets.token_hex(20) + + else: + data['token'] = secrets.token_hex(20) + data['auth_code'] = None + + params = { + 'client_id': app.client_id, + 'client_secret': app.client_secret + } + + with self.update('apps', data, **params) as cur: # type: ignore[arg-type] + if (row := cur.one(schema.App)) is None: + raise RuntimeError('Failed to update row') + + return row + + + def del_app(self, client_id: str, client_secret: str, token: str | None = None) -> bool: + params = { + 'id': client_id, + 'secret': client_secret + } + + if token is not None: + command = 'del-app-with-token' + params['token'] = token + + else: + command = 'del-app' + + with self.run(command, params) as cur: + if cur.row_count > 1: + raise RuntimeError('More than 1 row was deleted') + + return cur.row_count == 0 + + + def get_domain_ban(self, domain: str) -> schema.DomainBan | None: if domain.startswith('http'): domain = urlparse(domain).netloc with self.run('get-domain-ban', {'domain': domain}) as cur: - return cur.one() # type: ignore + return cur.one(schema.DomainBan) + + + def get_domain_bans(self) -> Iterator[schema.DomainBan]: + return self.execute("SELECT * FROM domain_bans").all(schema.DomainBan) def put_domain_ban(self, domain: str, reason: str | None = None, - note: str | None = None) -> Row: + note: str | None = None) -> schema.DomainBan: params = { 'domain': domain, @@ -271,13 +417,16 @@ class Connection(SqlConnection): } with self.run('put-domain-ban', params) as cur: - return cur.one() # type: ignore + if (row := cur.one(schema.DomainBan)) is None: + raise RuntimeError(f"Failed to insert domain ban: {domain}") + + return row def update_domain_ban(self, domain: str, reason: str | None = None, - note: str | None = None) -> Row: + note: str | None = None) -> schema.DomainBan: if not (reason or note): raise ValueError('"reason" and/or "note" must be specified') @@ -297,7 +446,10 @@ class Connection(SqlConnection): if cur.row_count > 1: raise ValueError('More than one row was modified') - return self.get_domain_ban(domain) + if (row := cur.one(schema.DomainBan)) is None: + raise RuntimeError(f"Failed to update domain ban: {domain}") + + return row def del_domain_ban(self, domain: str) -> bool: @@ -308,15 +460,19 @@ class Connection(SqlConnection): return cur.row_count == 1 - def get_software_ban(self, name: str) -> Row: + def get_software_ban(self, name: str) -> schema.SoftwareBan | None: with self.run('get-software-ban', {'name': name}) as cur: - return cur.one() # type: ignore + return cur.one(schema.SoftwareBan) + + + def get_software_bans(self) -> Iterator[schema.SoftwareBan,]: + return self.execute('SELECT * FROM software_bans').all(schema.SoftwareBan) def put_software_ban(self, name: str, reason: str | None = None, - note: str | None = None) -> Row: + note: str | None = None) -> schema.SoftwareBan: params = { 'name': name, @@ -326,13 +482,16 @@ class Connection(SqlConnection): } with self.run('put-software-ban', params) as cur: - return cur.one() # type: ignore + if (row := cur.one(schema.SoftwareBan)) is None: + raise RuntimeError(f'Failed to insert software ban: {name}') + + return row def update_software_ban(self, name: str, reason: str | None = None, - note: str | None = None) -> Row: + note: str | None = None) -> schema.SoftwareBan: if not (reason or note): raise ValueError('"reason" and/or "note" must be specified') @@ -352,7 +511,10 @@ class Connection(SqlConnection): if cur.row_count > 1: raise ValueError('More than one row was modified') - return self.get_software_ban(name) + if (row := cur.one(schema.SoftwareBan)) is None: + raise RuntimeError(f'Failed to update software ban: {name}') + + return row def del_software_ban(self, name: str) -> bool: @@ -363,19 +525,26 @@ class Connection(SqlConnection): return cur.row_count == 1 - def get_domain_whitelist(self, domain: str) -> Row: + def get_domain_whitelist(self, domain: str) -> schema.Whitelist | None: with self.run('get-domain-whitelist', {'domain': domain}) as cur: - return cur.one() # type: ignore + return cur.one() - def put_domain_whitelist(self, domain: str) -> Row: + def get_domains_whitelist(self) -> Iterator[schema.Whitelist,]: + return self.execute("SELECT * FROM whitelist").all(schema.Whitelist) + + + def put_domain_whitelist(self, domain: str) -> schema.Whitelist: params = { 'domain': domain, 'created': datetime.now(tz = timezone.utc) } with self.run('put-domain-whitelist', params) as cur: - return cur.one() # type: ignore + if (row := cur.one(schema.Whitelist)) is None: + raise RuntimeError(f'Failed to insert whitelisted domain: {domain}') + + return row def del_domain_whitelist(self, domain: str) -> bool: diff --git a/relay/database/schema.py b/relay/database/schema.py index 409ee57..55ca608 100644 --- a/relay/database/schema.py +++ b/relay/database/schema.py @@ -1,61 +1,133 @@ -from bsql import Column, Table, Tables +from __future__ import annotations + +from blib import Date +from bsql import Column, Row, Tables from collections.abc import Callable +from copy import deepcopy +from datetime import timezone +from typing import TYPE_CHECKING, Any from .config import ConfigData -from .connection import Connection + +if TYPE_CHECKING: + from .connection import Connection VERSIONS: dict[int, Callable[[Connection], None]] = {} -TABLES: Tables = Tables( - Table( - 'config', - Column('key', 'text', primary_key = True, unique = True, nullable = False), - Column('value', 'text'), - Column('type', 'text', default = 'str') - ), - Table( - 'inboxes', - Column('domain', 'text', primary_key = True, unique = True, nullable = False), - Column('actor', 'text', unique = True), - Column('inbox', 'text', unique = True, nullable = False), - Column('followid', 'text'), - Column('software', 'text'), - Column('accepted', 'boolean'), - Column('created', 'timestamp', nullable = False) - ), - Table( - 'whitelist', - Column('domain', 'text', primary_key = True, unique = True, nullable = True), - Column('created', 'timestamp') - ), - Table( - 'domain_bans', - Column('domain', 'text', primary_key = True, unique = True, nullable = True), - Column('reason', 'text'), - Column('note', 'text'), - Column('created', 'timestamp', nullable = False) - ), - Table( - 'software_bans', - Column('name', 'text', primary_key = True, unique = True, nullable = True), - Column('reason', 'text'), - Column('note', 'text'), - Column('created', 'timestamp', nullable = False) - ), - Table( - 'users', - Column('username', 'text', primary_key = True, unique = True, nullable = False), - Column('hash', 'text', nullable = False), - Column('handle', 'text'), - Column('created', 'timestamp', nullable = False) - ), - Table( - 'tokens', - Column('code', 'text', primary_key = True, unique = True, nullable = False), - Column('user', 'text', nullable = False), - Column('created', 'timestmap', nullable = False) - ) -) +TABLES = Tables() + + +def deserialize_timestamp(value: Any) -> Date: + try: + date = Date.parse(value) + + except ValueError: + date = Date.fromisoformat(value) + + if date.tzinfo is None: + date = date.replace(tzinfo = timezone.utc) + + return date + + +@TABLES.add_row +class Config(Row): + key: Column[str] = Column('key', 'text', primary_key = True, unique = True, nullable = False) + value: Column[str] = Column('value', 'text') + type: Column[str] = Column('type', 'text', default = 'str') + + +@TABLES.add_row +class Instance(Row): + table_name: str = 'inboxes' + + + domain: Column[str] = Column( + 'domain', 'text', primary_key = True, unique = True, nullable = False) + actor: Column[str] = Column('actor', 'text', unique = True) + inbox: Column[str] = Column('inbox', 'text', unique = True, nullable = False) + followid: Column[str] = Column('followid', 'text') + software: Column[str] = Column('software', 'text') + accepted: Column[Date] = Column('accepted', 'boolean') + created: Column[Date] = Column( + 'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp) + + +@TABLES.add_row +class Whitelist(Row): + domain: Column[str] = Column( + 'domain', 'text', primary_key = True, unique = True, nullable = True) + created: Column[Date] = Column( + 'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp) + + +@TABLES.add_row +class DomainBan(Row): + table_name: str = 'domain_bans' + + + domain: Column[str] = Column( + 'domain', 'text', primary_key = True, unique = True, nullable = True) + reason: Column[str] = Column('reason', 'text') + note: Column[str] = Column('note', 'text') + created: Column[Date] = Column( + 'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp) + + +@TABLES.add_row +class SoftwareBan(Row): + table_name: str = 'software_bans' + + + name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True) + reason: Column[str] = Column('reason', 'text') + note: Column[str] = Column('note', 'text') + created: Column[Date] = Column( + 'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp) + + +@TABLES.add_row +class User(Row): + table_name: str = 'users' + + + username: Column[str] = Column( + 'username', 'text', primary_key = True, unique = True, nullable = False) + hash: Column[str] = Column('hash', 'text', nullable = False) + handle: Column[str] = Column('handle', 'text') + created: Column[Date] = Column( + 'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp) + + +@TABLES.add_row +class App(Row): + table_name: str = 'apps' + + + client_id: Column[str] = Column( + 'client_id', 'text', primary_key = True, unique = True, nullable = False) + client_secret: Column[str] = Column('client_secret', 'text', nullable = False) + name: Column[str] = Column('name', 'text') + website: Column[str] = Column('website', 'text') + redirect_uri: Column[str] = Column('redirect_uri', 'text', nullable = False) + token: Column[str | None] = Column('token', 'text') + auth_code: Column[str | None] = Column('auth_code', 'text') + user: Column[str | None] = Column('user', 'text') + created: Column[Date] = Column( + 'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp) + accessed: Column[Date] = Column( + 'accessed', 'timestamp', nullable = False, deserializer = deserialize_timestamp) + + + def get_api_data(self, include_token: bool = False) -> dict[str, Any]: + data = deepcopy(self) + data.pop('user') + data.pop('auth_code') + + if not include_token: + data.pop('token') + + return data def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]: @@ -76,5 +148,11 @@ def migrate_20240206(conn: Connection) -> None: @migration def migrate_20240310(conn: Connection) -> None: - conn.execute("ALTER TABLE inboxes ADD COLUMN accepted BOOLEAN") - conn.execute("UPDATE inboxes SET accepted = 1") + conn.execute('ALTER TABLE "inboxes" ADD COLUMN "accepted" BOOLEAN').close() + conn.execute('UPDATE "inboxes" SET "accepted" = true').close() + + +@migration +def migrate_20240625(conn: Connection) -> None: + conn.create_tables() + conn.execute('DROP TABLE "tokens"').close() diff --git a/relay/errors.py b/relay/errors.py new file mode 100644 index 0000000..8074929 --- /dev/null +++ b/relay/errors.py @@ -0,0 +1,2 @@ +class EmptyBodyError(Exception): + pass diff --git a/relay/frontend/base.haml b/relay/frontend/base.haml index 7a14b72..33b8a85 100644 --- a/relay/frontend/base.haml +++ b/relay/frontend/base.haml @@ -1,5 +1,5 @@ -macro menu_item(name, path) - -if view.request.path == path or (path != "/" and view.request.path.startswith(path)) + -if request.path == path or (path != "/" and request.path.startswith(path)) %a.button(href="{{path}}" active="true") -> =name -else @@ -11,11 +11,11 @@ %title << {{config.name}}: {{page}} %meta(charset="UTF-8") %meta(name="viewport" content="width=device-width, initial-scale=1") - %link(rel="stylesheet" type="text/css" href="/theme/{{config.theme}}.css" nonce="{{view.request['hash']}}" class="theme") - %link(rel="stylesheet" type="text/css" href="/static/style.css" nonce="{{view.request['hash']}}") - %link(rel="stylesheet" type="text/css" href="/static/bootstrap-icons.css" nonce="{{view.request['hash']}}") - %link(rel="manifest" href="/manifest.json") - %script(type="application/javascript" src="/static/api.js" nonce="{{view.request['hash']}}" defer) + %link(rel="stylesheet" type="text/css" href="/theme/{{config.theme}}.css?{{version}}" nonce="{{request['hash']}}" class="theme") + %link(rel="stylesheet" type="text/css" href="/static/style.css?{{version}}" nonce="{{request['hash']}}") + %link(rel="stylesheet" type="text/css" href="/static/bootstrap-icons.css?{{version}}" nonce="{{request['hash']}}") + %link(rel="manifest" href="/manifest.json?{{version}}") + %script(type="application/javascript" src="/static/functions.js?{{version}}" nonce="{{request['hash']}}" defer) -block head %body @@ -26,7 +26,7 @@ {{menu_item("Home", "/")}} - -if view.request["user"] + -if request["user"] {{menu_item("Instances", "/admin/instances")}} {{menu_item("Whitelist", "/admin/whitelist")}} {{menu_item("Domain Bans", "/admin/domain_bans")}} @@ -61,11 +61,11 @@ #footer.section .col1 - -if not view.request["user"] + -if not request["user"] %a(href="/login") << Login -else - =view.request["user"]["username"] + =request["user"]["username"] ( %a(href="/logout") << Logout ) diff --git a/relay/frontend/page/admin-config.haml b/relay/frontend/page/admin-config.haml index e5df986..226c052 100644 --- a/relay/frontend/page/admin-config.haml +++ b/relay/frontend/page/admin-config.haml @@ -1,29 +1,32 @@ -extends "base.haml" -set page="Config" - --block head - %script(type="application/javascript" src="/static/config.js" nonce="{{view.request['hash']}}" defer) - -import "functions.haml" as func + -block content %fieldset.section %legend << Config .grid-2col %label(for="name") << Name + %i(class="bi bi-question-circle-fill" title="{{desc.name}}") %input(id = "name" placeholder="Relay Name" value="{{config.name or ''}}") %label(for="note") << Description + %i(class="bi bi-question-circle-fill" title="{{desc.note}}") %textarea(id="note" value="{{config.note or ''}}") << {{config.note}} %label(for="theme") << Color Theme + %i(class="bi bi-question-circle-fill" title="{{desc.theme}}") =func.new_select("theme", config.theme, themes) %label(for="log-level") << Log Level + %i(class="bi bi-question-circle-fill" title="{{desc.log_level}}") =func.new_select("log-level", config.log_level.name, levels) %label(for="whitelist-enabled") << Whitelist + %i(class="bi bi-question-circle-fill" title="{{desc.whitelist_enabled}}") =func.new_checkbox("whitelist-enabled", config.whitelist_enabled) %label(for="approval-required") << Approval Required + %i(class="bi bi-question-circle-fill" title="{{desc.approval_required}}") =func.new_checkbox("approval-required", config.approval_required) diff --git a/relay/frontend/page/admin-domain_bans.haml b/relay/frontend/page/admin-domain_bans.haml index b1f7f57..8aa6728 100644 --- a/relay/frontend/page/admin-domain_bans.haml +++ b/relay/frontend/page/admin-domain_bans.haml @@ -1,9 +1,6 @@ -extends "base.haml" -set page="Domain Bans" --block head - %script(type="application/javascript" src="/static/domain_ban.js" nonce="{{view.request['hash']}}" defer) - -block content %details.section %summary << Ban Domain diff --git a/relay/frontend/page/admin-instances.haml b/relay/frontend/page/admin-instances.haml index c317e30..61e08a0 100644 --- a/relay/frontend/page/admin-instances.haml +++ b/relay/frontend/page/admin-instances.haml @@ -1,9 +1,6 @@ -extends "base.haml" -set page="Instances" --block head - %script(type="application/javascript" src="/static/instance.js" nonce="{{view.request['hash']}}" defer) - -block content %details.section %summary << Add Instance diff --git a/relay/frontend/page/admin-software_bans.haml b/relay/frontend/page/admin-software_bans.haml index 9bda3be..faaa57e 100644 --- a/relay/frontend/page/admin-software_bans.haml +++ b/relay/frontend/page/admin-software_bans.haml @@ -1,9 +1,6 @@ -extends "base.haml" -set page="Software Bans" --block head - %script(type="application/javascript" src="/static/software_ban.js" nonce="{{view.request['hash']}}" defer) - -block content %details.section %summary << Ban Software diff --git a/relay/frontend/page/admin-users.haml b/relay/frontend/page/admin-users.haml index 50058d7..d6715c9 100644 --- a/relay/frontend/page/admin-users.haml +++ b/relay/frontend/page/admin-users.haml @@ -1,9 +1,6 @@ -extends "base.haml" -set page="Users" --block head - %script(type="application/javascript" src="/static/user.js" nonce="{{view.request['hash']}}" defer) - -block content %details.section %summary << Add User diff --git a/relay/frontend/page/admin-whitelist.haml b/relay/frontend/page/admin-whitelist.haml index c8111e5..2fa3b99 100644 --- a/relay/frontend/page/admin-whitelist.haml +++ b/relay/frontend/page/admin-whitelist.haml @@ -1,9 +1,6 @@ -extends "base.haml" -set page="Whitelist" --block head - %script(type="application/javascript" src="/static/whitelist.js" nonce="{{view.request['hash']}}" defer) - -block content %details.section %summary << Add Domain diff --git a/relay/frontend/page/authorize_new.haml b/relay/frontend/page/authorize_new.haml new file mode 100644 index 0000000..4f07df3 --- /dev/null +++ b/relay/frontend/page/authorize_new.haml @@ -0,0 +1,31 @@ +-extends "base.haml" +-set page="App Authorization" + +-block content + %fieldset.section + %legend << App Authorization + + -if application.website + #title << Application "{{application.name}}" wants full API access + + -else + #title << Application "{{application.name}}" wants full API access + + #buttons + .spacer + + %form(action="/oauth/authorize" method="POST") + %input(type="hidden" name="client_id" value="{{application.client_id}}") + %input(type="hidden" name="client_secret" value="{{application.client_secret}}") + %input(type="hidden" name="redirect_uri" value="{{application.redirect_uri}}") + %input(type="hidden" name="response" value="true") + %input.button(type="submit" value="Allow") + + %form(action="/oauth/authorize" method="POST") + %input(type="hidden" name="client_id" value="{{application.client_id}}") + %input(type="hidden" name="client_secret" value="{{application.client_secret}}") + %input(type="hidden" name="redirect_uri" value="{{application.redirect_uri}}") + %input(type="hidden" name="response" value="false") + %input.button(type="submit" value="Deny") + + .spacer diff --git a/relay/frontend/page/authorize_show.haml b/relay/frontend/page/authorize_show.haml new file mode 100644 index 0000000..19cde40 --- /dev/null +++ b/relay/frontend/page/authorize_show.haml @@ -0,0 +1,18 @@ +-extends "base.haml" +-set page="App Authorization" + +-block content + %fieldset.section + %legend << App Authorization Code + + -if application.website + %p + Copy the following code into + %a(href="{{application.website}}" target="_main") -> %code -> =application.name + + -else + %p + Copy the following code info + %code -> =application.name + + %pre#code -> =application.auth_code diff --git a/relay/frontend/page/error.haml b/relay/frontend/page/error.haml new file mode 100644 index 0000000..23a8935 --- /dev/null +++ b/relay/frontend/page/error.haml @@ -0,0 +1,7 @@ +-extends "base.haml" +-set page="Error" + +-block content + .section.error + .title << HTTP Error {{e.status}} + .body -> =e.message diff --git a/relay/frontend/page/home.haml b/relay/frontend/page/home.haml index fa883d6..1de9b14 100644 --- a/relay/frontend/page/home.haml +++ b/relay/frontend/page/home.haml @@ -1,5 +1,6 @@ -extends "base.haml" -set page = "Home" + -block content -if config.note .section @@ -14,9 +15,7 @@ %a(href="https://{{domain}}/actor") << https://{{domain}}/actor -if config.approval_required - %fieldset.section.message - %legend << Require Approval - + %div.section.message Follow requests require approval. You will need to wait for an admin to accept or deny your request. diff --git a/relay/frontend/page/login.haml b/relay/frontend/page/login.haml index bf1ab1c..4f29746 100644 --- a/relay/frontend/page/login.haml +++ b/relay/frontend/page/login.haml @@ -1,9 +1,6 @@ -extends "base.haml" -set page="Login" --block head - %script(type="application/javascript" src="/static/login.js" nonce="{{view.request['hash']}}" defer) - -block content %fieldset.section %legend << Login @@ -15,4 +12,6 @@ %label(for="password") << Password %input(id="password" name="password" placeholder="Password" type="password") + + %input#redir(type="hidden" name="redir" value="{{redir}}") %input.submit(type="button" value="Login") diff --git a/relay/frontend/static/api.js b/relay/frontend/static/api.js deleted file mode 100644 index e7f376a..0000000 --- a/relay/frontend/static/api.js +++ /dev/null @@ -1,132 +0,0 @@ -// toast notifications - -const notifications = document.querySelector("#notifications") - - -function remove_toast(toast) { - toast.classList.add("hide"); - - if (toast.timeoutId) { - clearTimeout(toast.timeoutId); - } - - setTimeout(() => toast.remove(), 300); -} - -function toast(text, type="error", timeout=5) { - const toast = document.createElement("li"); - toast.className = `section ${type}` - toast.innerHTML = `${text}✖` - - toast.querySelector("a").addEventListener("click", async (event) => { - event.preventDefault(); - await remove_toast(toast); - }); - - notifications.appendChild(toast); - toast.timeoutId = setTimeout(() => remove_toast(toast), timeout * 1000); -} - - -// menu - -const body = document.getElementById("container") -const menu = document.getElementById("menu"); -const menu_open = document.querySelector("#menu-open i"); -const menu_close = document.getElementById("menu-close"); - - -function toggle_menu() { - let new_value = menu.attributes.visible.nodeValue === "true" ? "false" : "true"; - menu.attributes.visible.nodeValue = new_value; -} - - -menu_open.addEventListener("click", toggle_menu); -menu_close.addEventListener("click", toggle_menu); - -body.addEventListener("click", (event) => { - if (event.target === menu_open) { - return; - } - - menu.attributes.visible.nodeValue = "false"; -}); - -for (const elem of document.querySelectorAll("#menu-open div")) { - elem.addEventListener("click", toggle_menu); -} - - -// misc - -function get_date_string(date) { - var year = date.getUTCFullYear().toString(); - var month = (date.getUTCMonth() + 1).toString().padStart(2, "0"); - var day = date.getUTCDate().toString().padStart(2, "0"); - - return `${year}-${month}-${day}`; -} - - -function append_table_row(table, row_name, row) { - var table_row = table.insertRow(-1); - table_row.id = row_name; - - index = 0; - - for (var prop in row) { - if (Object.prototype.hasOwnProperty.call(row, prop)) { - var cell = table_row.insertCell(index); - cell.className = prop; - cell.innerHTML = row[prop]; - - index += 1; - } - } - - return table_row; -} - - -async function request(method, path, body = null) { - var headers = { - "Accept": "application/json" - } - - if (body !== null) { - headers["Content-Type"] = "application/json" - body = JSON.stringify(body) - } - - const response = await fetch("/api/" + path, { - method: method, - mode: "cors", - cache: "no-store", - redirect: "follow", - body: body, - headers: headers - }); - - const message = await response.json(); - - if (Object.hasOwn(message, "error")) { - throw new Error(message.error); - } - - if (Array.isArray(message)) { - message.forEach((msg) => { - if (Object.hasOwn(msg, "created")) { - msg.created = new Date(msg.created); - } - }); - - } else { - if (Object.hasOwn(message, "created")) { - console.log(message.created) - message.created = new Date(message.created); - } - } - - return message; -} diff --git a/relay/frontend/static/config.js b/relay/frontend/static/config.js deleted file mode 100644 index 417c48a..0000000 --- a/relay/frontend/static/config.js +++ /dev/null @@ -1,40 +0,0 @@ -const elems = [ - document.querySelector("#name"), - document.querySelector("#note"), - document.querySelector("#theme"), - document.querySelector("#log-level"), - document.querySelector("#whitelist-enabled"), - document.querySelector("#approval-required") -] - - -async function handle_config_change(event) { - params = { - key: event.target.id, - value: event.target.type === "checkbox" ? event.target.checked : event.target.value - } - - try { - await request("POST", "v1/config", params); - - } catch (error) { - toast(error); - return; - } - - if (params.key === "name") { - document.querySelector("#header .title").innerHTML = params.value; - document.querySelector("title").innerHTML = params.value; - } - - if (params.key === "theme") { - document.querySelector("link.theme").href = `/theme/${params.value}.css`; - } - - toast("Updated config", "message"); -} - - -for (const elem of elems) { - elem.addEventListener("change", handle_config_change); -} diff --git a/relay/frontend/static/domain_ban.js b/relay/frontend/static/domain_ban.js deleted file mode 100644 index 4de2ebf..0000000 --- a/relay/frontend/static/domain_ban.js +++ /dev/null @@ -1,123 +0,0 @@ -function create_ban_object(domain, reason, note) { - var text = '
\n'; - text += `${domain}\n`; - text += '
\n'; - text += `\n`; - text += `\n`; - text += `\n`; - text += `\n`; - text += ``; - text += '
'; - - return text; -} - - -function add_row_listeners(row) { - row.querySelector(".update-ban").addEventListener("click", async (event) => { - await update_ban(row.id); - }); - - row.querySelector(".remove a").addEventListener("click", async (event) => { - event.preventDefault(); - await unban(row.id); - }); -} - - -async function ban() { - var table = document.querySelector("table"); - var elems = { - domain: document.getElementById("new-domain"), - reason: document.getElementById("new-reason"), - note: document.getElementById("new-note") - } - - var values = { - domain: elems.domain.value.trim(), - reason: elems.reason.value.trim(), - note: elems.note.value.trim() - } - - if (values.domain === "") { - toast("Domain is required"); - return; - } - - try { - var ban = await request("POST", "v1/domain_ban", values); - - } catch (err) { - toast(err); - return - } - - var row = append_table_row(document.querySelector("table"), ban.domain, { - domain: create_ban_object(ban.domain, ban.reason, ban.note), - date: get_date_string(ban.created), - remove: `
` - }); - - add_row_listeners(row); - - elems.domain.value = null; - elems.reason.value = null; - elems.note.value = null; - - document.querySelector("details.section").open = false; - toast("Banned domain", "message"); -} - - -async function update_ban(domain) { - var row = document.getElementById(domain); - - var elems = { - "reason": row.querySelector("textarea.reason"), - "note": row.querySelector("textarea.note") - } - - var values = { - "domain": domain, - "reason": elems.reason.value, - "note": elems.note.value - } - - try { - await request("PATCH", "v1/domain_ban", values) - - } catch (error) { - toast(error); - return; - } - - row.querySelector("details").open = false; - toast("Updated baned domain", "message"); -} - - -async function unban(domain) { - try { - await request("DELETE", "v1/domain_ban", {"domain": domain}); - - } catch (error) { - toast(error); - return; - } - - document.getElementById(domain).remove(); - toast("Unbanned domain", "message"); -} - - -document.querySelector("#new-ban").addEventListener("click", async (event) => { - await ban(); -}); - -for (var row of document.querySelector("fieldset.section table").rows) { - if (!row.querySelector(".update-ban")) { - continue; - } - - add_row_listeners(row); -} diff --git a/relay/frontend/static/functions.js b/relay/frontend/static/functions.js new file mode 100644 index 0000000..3063223 --- /dev/null +++ b/relay/frontend/static/functions.js @@ -0,0 +1,864 @@ +// toast notifications + +const notifications = document.querySelector("#notifications") + + +function remove_toast(toast) { + toast.classList.add("hide"); + + if (toast.timeoutId) { + clearTimeout(toast.timeoutId); + } + + setTimeout(() => toast.remove(), 300); +} + +function toast(text, type="error", timeout=5) { + const toast = document.createElement("li"); + toast.className = `section ${type}` + toast.innerHTML = `${text}✖` + + toast.querySelector("a").addEventListener("click", async (event) => { + event.preventDefault(); + await remove_toast(toast); + }); + + notifications.appendChild(toast); + toast.timeoutId = setTimeout(() => remove_toast(toast), timeout * 1000); +} + + +// menu + +const body = document.getElementById("container") +const menu = document.getElementById("menu"); +const menu_open = document.querySelector("#menu-open i"); +const menu_close = document.getElementById("menu-close"); + + +function toggle_menu() { + let new_value = menu.attributes.visible.nodeValue === "true" ? "false" : "true"; + menu.attributes.visible.nodeValue = new_value; +} + + +menu_open.addEventListener("click", toggle_menu); +menu_close.addEventListener("click", toggle_menu); + +body.addEventListener("click", (event) => { + if (event.target === menu_open) { + return; + } + + menu.attributes.visible.nodeValue = "false"; +}); + +for (const elem of document.querySelectorAll("#menu-open div")) { + elem.addEventListener("click", toggle_menu); +} + + +// misc + +function get_date_string(date) { + var year = date.getUTCFullYear().toString(); + var month = (date.getUTCMonth() + 1).toString().padStart(2, "0"); + var day = date.getUTCDate().toString().padStart(2, "0"); + + return `${year}-${month}-${day}`; +} + + +function append_table_row(table, row_name, row) { + var table_row = table.insertRow(-1); + table_row.id = row_name; + + index = 0; + + for (var prop in row) { + if (Object.prototype.hasOwnProperty.call(row, prop)) { + var cell = table_row.insertCell(index); + cell.className = prop; + cell.innerHTML = row[prop]; + + index += 1; + } + } + + return table_row; +} + + +async function request(method, path, body = null) { + var headers = { + "Accept": "application/json" + } + + if (body !== null) { + headers["Content-Type"] = "application/json" + body = JSON.stringify(body) + } + + const response = await fetch("/api/" + path, { + method: method, + mode: "cors", + cache: "no-store", + redirect: "follow", + body: body, + headers: headers + }); + + const message = await response.json(); + + if (Object.hasOwn(message, "error")) { + throw new Error(message.error); + } + + if (Array.isArray(message)) { + message.forEach((msg) => { + if (Object.hasOwn(msg, "created")) { + msg.created = new Date(msg.created); + } + }); + + } else { + if (Object.hasOwn(message, "created")) { + message.created = new Date(message.created); + } + } + + return message; +} + +// page functions + +function page_config() { + const elems = [ + document.querySelector("#name"), + document.querySelector("#note"), + document.querySelector("#theme"), + document.querySelector("#log-level"), + document.querySelector("#whitelist-enabled"), + document.querySelector("#approval-required") + ] + + + async function handle_config_change(event) { + params = { + key: event.target.id, + value: event.target.type === "checkbox" ? event.target.checked : event.target.value + } + + try { + await request("POST", "v1/config", params); + + } catch (error) { + toast(error); + return; + } + + if (params.key === "name") { + document.querySelector("#header .title").innerHTML = params.value; + document.querySelector("title").innerHTML = params.value; + } + + if (params.key === "theme") { + document.querySelector("link.theme").href = `/theme/${params.value}.css`; + } + + toast("Updated config", "message"); + } + + + document.querySelector("#name").addEventListener("keydown", async (event) => { + if (event.which === 13) { + await handle_config_change(event); + } + }); + + document.querySelector("#note").addEventListener("keydown", async (event) => { + if (event.which === 13 && event.ctrlKey) { + await handle_config_change(event); + } + }); + + for (const elem of elems) { + elem.addEventListener("change", handle_config_change); + } +} + + +function page_domain_ban() { + function create_ban_object(domain, reason, note) { + var text = '
\n'; + text += `${domain}\n`; + text += '
\n'; + text += `\n`; + text += `\n`; + text += `\n`; + text += `\n`; + text += ``; + text += '
'; + + return text; + } + + + function add_row_listeners(row) { + row.querySelector(".update-ban").addEventListener("click", async (event) => { + await update_ban(row.id); + }); + + row.querySelector(".remove a").addEventListener("click", async (event) => { + event.preventDefault(); + await unban(row.id); + }); + } + + + async function ban() { + var table = document.querySelector("table"); + var elems = { + domain: document.getElementById("new-domain"), + reason: document.getElementById("new-reason"), + note: document.getElementById("new-note") + } + + var values = { + domain: elems.domain.value.trim(), + reason: elems.reason.value.trim(), + note: elems.note.value.trim() + } + + if (values.domain === "") { + toast("Domain is required"); + return; + } + + try { + var ban = await request("POST", "v1/domain_ban", values); + + } catch (err) { + toast(err); + return + } + + var row = append_table_row(document.querySelector("table"), ban.domain, { + domain: create_ban_object(ban.domain, ban.reason, ban.note), + date: get_date_string(ban.created), + remove: `
` + }); + + add_row_listeners(row); + + elems.domain.value = null; + elems.reason.value = null; + elems.note.value = null; + + document.querySelector("details.section").open = false; + toast("Banned domain", "message"); + } + + + async function update_ban(domain) { + var row = document.getElementById(domain); + + var elems = { + "reason": row.querySelector("textarea.reason"), + "note": row.querySelector("textarea.note") + } + + var values = { + "domain": domain, + "reason": elems.reason.value, + "note": elems.note.value + } + + try { + await request("PATCH", "v1/domain_ban", values) + + } catch (error) { + toast(error); + return; + } + + row.querySelector("details").open = false; + toast("Updated baned domain", "message"); + } + + + async function unban(domain) { + try { + await request("DELETE", "v1/domain_ban", {"domain": domain}); + + } catch (error) { + toast(error); + return; + } + + document.getElementById(domain).remove(); + toast("Unbanned domain", "message"); + } + + + document.querySelector("#new-ban").addEventListener("click", async (event) => { + await ban(); + }); + + for (var elem of document.querySelectorAll("#add-item input")) { + elem.addEventListener("keydown", async (event) => { + if (event.which === 13) { + await ban(); + } + }); + } + + for (var row of document.querySelector("fieldset.section table").rows) { + if (!row.querySelector(".update-ban")) { + continue; + } + + add_row_listeners(row); + } +} + + +function page_instance() { + function add_instance_listeners(row) { + row.querySelector(".remove a").addEventListener("click", async (event) => { + event.preventDefault(); + await del_instance(row.id); + }); + } + + + function add_request_listeners(row) { + row.querySelector(".approve a").addEventListener("click", async (event) => { + event.preventDefault(); + await req_response(row.id, true); + }); + + row.querySelector(".deny a").addEventListener("click", async (event) => { + event.preventDefault(); + await req_response(row.id, false); + }); + } + + + async function add_instance() { + var elems = { + actor: document.getElementById("new-actor"), + inbox: document.getElementById("new-inbox"), + followid: document.getElementById("new-followid"), + software: document.getElementById("new-software") + } + + var values = { + actor: elems.actor.value.trim(), + inbox: elems.inbox.value.trim(), + followid: elems.followid.value.trim(), + software: elems.software.value.trim() + } + + if (values.actor === "") { + toast("Actor is required"); + return; + } + + try { + var instance = await request("POST", "v1/instance", values); + + } catch (err) { + toast(err); + return + } + + row = append_table_row(document.getElementById("instances"), instance.domain, { + domain: `${instance.domain}`, + software: instance.software, + date: get_date_string(instance.created), + remove: `` + }); + + add_instance_listeners(row); + + elems.actor.value = null; + elems.inbox.value = null; + elems.followid.value = null; + elems.software.value = null; + + document.querySelector("details.section").open = false; + toast("Added instance", "message"); + } + + + async function del_instance(domain) { + try { + await request("DELETE", "v1/instance", {"domain": domain}); + + } catch (error) { + toast(error); + return; + } + + document.getElementById(domain).remove(); + } + + + async function req_response(domain, accept) { + params = { + "domain": domain, + "accept": accept + } + + try { + await request("POST", "v1/request", params); + + } catch (error) { + toast(error); + return; + } + + document.getElementById(domain).remove(); + + if (document.getElementById("requests").rows.length < 2) { + document.querySelector("fieldset.requests").remove() + } + + if (!accept) { + toast("Denied instance request", "message"); + return; + } + + instances = await request("GET", `v1/instance`, null); + instances.forEach((instance) => { + if (instance.domain === domain) { + row = append_table_row(document.getElementById("instances"), instance.domain, { + domain: `${instance.domain}`, + software: instance.software, + date: get_date_string(instance.created), + remove: `` + }); + + add_instance_listeners(row); + } + }); + + toast("Accepted instance request", "message"); + } + + + document.querySelector("#add-instance").addEventListener("click", async (event) => { + await add_instance(); + }) + + for (var elem of document.querySelectorAll("#add-item input")) { + elem.addEventListener("keydown", async (event) => { + if (event.which === 13) { + await add_instance(); + } + }); + } + + for (var row of document.querySelector("#instances").rows) { + if (!row.querySelector(".remove a")) { + continue; + } + + add_instance_listeners(row); + } + + if (document.querySelector("#requests")) { + for (var row of document.querySelector("#requests").rows) { + if (!row.querySelector(".approve a")) { + continue; + } + + add_request_listeners(row); + } + } +} + + +function page_login() { + const fields = { + username: document.querySelector("#username"), + password: document.querySelector("#password"), + redir: document.querySelector("#redir") + }; + + async function login(event) { + const values = { + username: fields.username.value.trim(), + password: fields.password.value.trim(), + redir: fields.redir.value.trim() + } + + if (values.username === "" | values.password === "") { + toast("Username and/or password field is blank"); + return; + } + + try { + await request("POST", "v1/login", values); + + } catch (error) { + toast(error); + return; + } + + document.location = values.redir; + } + + + document.querySelector("#username").addEventListener("keydown", async (event) => { + if (event.which === 13) { + fields.password.focus(); + fields.password.select(); + } + }); + + document.querySelector("#password").addEventListener("keydown", async (event) => { + if (event.which === 13) { + await login(event); + } + }); + + document.querySelector(".submit").addEventListener("click", login); +} + + +function page_software_ban() { + function create_ban_object(name, reason, note) { + var text = '
\n'; + text += `${name}\n`; + text += '
\n'; + text += `\n`; + text += `\n`; + text += `\n`; + text += `\n`; + text += ``; + text += '
'; + + return text; + } + + + function add_row_listeners(row) { + row.querySelector(".update-ban").addEventListener("click", async (event) => { + await update_ban(row.id); + }); + + row.querySelector(".remove a").addEventListener("click", async (event) => { + event.preventDefault(); + await unban(row.id); + }); + } + + + async function ban() { + var elems = { + name: document.getElementById("new-name"), + reason: document.getElementById("new-reason"), + note: document.getElementById("new-note") + } + + var values = { + name: elems.name.value.trim(), + reason: elems.reason.value, + note: elems.note.value + } + + if (values.name === "") { + toast("Domain is required"); + return; + } + + try { + var ban = await request("POST", "v1/software_ban", values); + + } catch (err) { + toast(err); + return + } + + var row = append_table_row(document.getElementById("bans"), ban.name, { + name: create_ban_object(ban.name, ban.reason, ban.note), + date: get_date_string(ban.created), + remove: `` + }); + + add_row_listeners(row); + + elems.name.value = null; + elems.reason.value = null; + elems.note.value = null; + + document.querySelector("details.section").open = false; + toast("Banned software", "message"); + } + + + async function update_ban(name) { + var row = document.getElementById(name); + + var elems = { + "reason": row.querySelector("textarea.reason"), + "note": row.querySelector("textarea.note") + } + + var values = { + "name": name, + "reason": elems.reason.value, + "note": elems.note.value + } + + try { + await request("PATCH", "v1/software_ban", values) + + } catch (error) { + toast(error); + return; + } + + row.querySelector("details").open = false; + toast("Updated software ban", "message"); + } + + + async function unban(name) { + try { + await request("DELETE", "v1/software_ban", {"name": name}); + + } catch (error) { + toast(error); + return; + } + + document.getElementById(name).remove(); + toast("Unbanned software", "message"); + } + + + document.querySelector("#new-ban").addEventListener("click", async (event) => { + await ban(); + }); + + for (var elem of document.querySelectorAll("#add-item input")) { + elem.addEventListener("keydown", async (event) => { + if (event.which === 13) { + await ban(); + } + }); + } + + for (var elem of document.querySelectorAll("#add-item textarea")) { + elem.addEventListener("keydown", async (event) => { + if (event.which === 13 && event.ctrlKey) { + await ban(); + } + }); + } + + for (var row of document.querySelector("#bans").rows) { + if (!row.querySelector(".update-ban")) { + continue; + } + + add_row_listeners(row); + } +} + + +function page_user() { + function add_row_listeners(row) { + row.querySelector(".remove a").addEventListener("click", async (event) => { + event.preventDefault(); + await del_user(row.id); + }); + } + + + async function add_user() { + var elems = { + username: document.getElementById("new-username"), + password: document.getElementById("new-password"), + password2: document.getElementById("new-password2"), + handle: document.getElementById("new-handle") + } + + var values = { + username: elems.username.value.trim(), + password: elems.password.value.trim(), + password2: elems.password2.value.trim(), + handle: elems.handle.value.trim() + } + + if (values.username === "" | values.password === "" | values.password2 === "") { + toast("Username, password, and password2 are required"); + return; + } + + if (values.password !== values.password2) { + toast("Passwords do not match"); + return; + } + + try { + var user = await request("POST", "v1/user", values); + + } catch (err) { + toast(err); + return + } + + var row = append_table_row(document.querySelector("fieldset.section table"), user.username, { + domain: user.username, + handle: user.handle ? self.handle : "n/a", + date: get_date_string(user.created), + remove: `` + }); + + add_row_listeners(row); + + elems.username.value = null; + elems.password.value = null; + elems.password2.value = null; + elems.handle.value = null; + + document.querySelector("details.section").open = false; + toast("Created user", "message"); + } + + + async function del_user(username) { + try { + await request("DELETE", "v1/user", {"username": username}); + + } catch (error) { + toast(error); + return; + } + + document.getElementById(username).remove(); + toast("Deleted user", "message"); + } + + + document.querySelector("#new-user").addEventListener("click", async (event) => { + await add_user(); + }); + + for (var elem of document.querySelectorAll("#add-item input")) { + elem.addEventListener("keydown", async (event) => { + if (event.which === 13) { + await add_user(); + } + }); + } + + for (var row of document.querySelector("#users").rows) { + if (!row.querySelector(".remove a")) { + continue; + } + + add_row_listeners(row); + } +} + + +function page_whitelist() { + function add_row_listeners(row) { + row.querySelector(".remove a").addEventListener("click", async (event) => { + event.preventDefault(); + await del_whitelist(row.id); + }); + } + + + async function add_whitelist() { + var domain_elem = document.getElementById("new-domain"); + var domain = domain_elem.value.trim(); + + if (domain === "") { + toast("Domain is required"); + return; + } + + try { + var item = await request("POST", "v1/whitelist", {"domain": domain}); + + } catch (err) { + toast(err); + return; + } + + var row = append_table_row(document.getElementById("whitelist"), item.domain, { + domain: item.domain, + date: get_date_string(item.created), + remove: `` + }); + + add_row_listeners(row); + + domain_elem.value = null; + document.querySelector("details.section").open = false; + toast("Added domain", "message"); + } + + + async function del_whitelist(domain) { + try { + await request("DELETE", "v1/whitelist", {"domain": domain}); + + } catch (error) { + toast(error); + return; + } + + document.getElementById(domain).remove(); + toast("Removed domain", "message"); + } + + + document.querySelector("#new-item").addEventListener("click", async (event) => { + await add_whitelist(); + }); + + document.querySelector("#add-item").addEventListener("keydown", async (event) => { + if (event.which === 13) { + await add_whitelist(); + } + }); + + for (var row of document.querySelector("fieldset.section table").rows) { + if (!row.querySelector(".remove a")) { + continue; + } + + add_row_listeners(row); + } +} + + +if (location.pathname.startsWith("/admin/config")) { + page_config(); + +} else if (location.pathname.startsWith("/admin/domain_bans")) { + page_domain_ban(); + +} else if (location.pathname.startsWith("/admin/instances")) { + page_instance(); + +} else if (location.pathname.startsWith("/admin/software_bans")) { + page_software_ban(); + +} else if (location.pathname.startsWith("/admin/users")) { + page_user(); + +} else if (location.pathname.startsWith("/admin/whitelist")) { + page_whitelist(); + +} else if (location.pathname.startsWith("/login")) { + page_login(); +} diff --git a/relay/frontend/static/instance.js b/relay/frontend/static/instance.js deleted file mode 100644 index a07b647..0000000 --- a/relay/frontend/static/instance.js +++ /dev/null @@ -1,145 +0,0 @@ -function add_instance_listeners(row) { - row.querySelector(".remove a").addEventListener("click", async (event) => { - event.preventDefault(); - await del_instance(row.id); - }); -} - - -function add_request_listeners(row) { - row.querySelector(".approve a").addEventListener("click", async (event) => { - event.preventDefault(); - await req_response(row.id, true); - }); - - row.querySelector(".deny a").addEventListener("click", async (event) => { - event.preventDefault(); - await req_response(row.id, false); - }); -} - - -async function add_instance() { - var elems = { - actor: document.getElementById("new-actor"), - inbox: document.getElementById("new-inbox"), - followid: document.getElementById("new-followid"), - software: document.getElementById("new-software") - } - - var values = { - actor: elems.actor.value.trim(), - inbox: elems.inbox.value.trim(), - followid: elems.followid.value.trim(), - software: elems.software.value.trim() - } - - if (values.actor === "") { - toast("Actor is required"); - return; - } - - try { - var instance = await request("POST", "v1/instance", values); - - } catch (err) { - toast(err); - return - } - - row = append_table_row(document.getElementById("instances"), instance.domain, { - domain: `${instance.domain}`, - software: instance.software, - date: get_date_string(instance.created), - remove: `` - }); - - add_instance_listeners(row); - - elems.actor.value = null; - elems.inbox.value = null; - elems.followid.value = null; - elems.software.value = null; - - document.querySelector("details.section").open = false; - toast("Added instance", "message"); -} - - -async function del_instance(domain) { - try { - await request("DELETE", "v1/instance", {"domain": domain}); - - } catch (error) { - toast(error); - return; - } - - document.getElementById(domain).remove(); -} - - -async function req_response(domain, accept) { - params = { - "domain": domain, - "accept": accept - } - - try { - await request("POST", "v1/request", params); - - } catch (error) { - toast(error); - return; - } - - document.getElementById(domain).remove(); - - if (document.getElementById("requests").rows.length < 2) { - document.querySelector("fieldset.requests").remove() - } - - if (!accept) { - toast("Denied instance request", "message"); - return; - } - - instances = await request("GET", `v1/instance`, null); - instances.forEach((instance) => { - if (instance.domain === domain) { - row = append_table_row(document.getElementById("instances"), instance.domain, { - domain: `${instance.domain}`, - software: instance.software, - date: get_date_string(instance.created), - remove: `` - }); - - add_instance_listeners(row); - } - }); - - toast("Accepted instance request", "message"); -} - - -document.querySelector("#add-instance").addEventListener("click", async (event) => { - await add_instance(); -}) - -for (var row of document.querySelector("#instances").rows) { - if (!row.querySelector(".remove a")) { - continue; - } - - add_instance_listeners(row); -} - -if (document.querySelector("#requests")) { - for (var row of document.querySelector("#requests").rows) { - if (!row.querySelector(".approve a")) { - continue; - } - - add_request_listeners(row); - } -} diff --git a/relay/frontend/static/login.js b/relay/frontend/static/login.js deleted file mode 100644 index 9c68f17..0000000 --- a/relay/frontend/static/login.js +++ /dev/null @@ -1,29 +0,0 @@ -async function login(event) { - fields = { - username: document.querySelector("#username"), - password: document.querySelector("#password") - } - - values = { - username: fields.username.value.trim(), - password: fields.password.value.trim() - } - - if (values.username === "" | values.password === "") { - toast("Username and/or password field is blank"); - return; - } - - try { - await request("POST", "v1/token", values); - - } catch (error) { - toast(error); - return; - } - - document.location = "/"; -} - - -document.querySelector(".submit").addEventListener("click", login); diff --git a/relay/frontend/static/software_ban.js b/relay/frontend/static/software_ban.js deleted file mode 100644 index 663929a..0000000 --- a/relay/frontend/static/software_ban.js +++ /dev/null @@ -1,122 +0,0 @@ -function create_ban_object(name, reason, note) { - var text = '
\n'; - text += `${name}\n`; - text += '
\n'; - text += `\n`; - text += `\n`; - text += `\n`; - text += `\n`; - text += ``; - text += '
'; - - return text; -} - - -function add_row_listeners(row) { - row.querySelector(".update-ban").addEventListener("click", async (event) => { - await update_ban(row.id); - }); - - row.querySelector(".remove a").addEventListener("click", async (event) => { - event.preventDefault(); - await unban(row.id); - }); -} - - -async function ban() { - var elems = { - name: document.getElementById("new-name"), - reason: document.getElementById("new-reason"), - note: document.getElementById("new-note") - } - - var values = { - name: elems.name.value.trim(), - reason: elems.reason.value, - note: elems.note.value - } - - if (values.name === "") { - toast("Domain is required"); - return; - } - - try { - var ban = await request("POST", "v1/software_ban", values); - - } catch (err) { - toast(err); - return - } - - var row = append_table_row(document.getElementById("bans"), ban.name, { - name: create_ban_object(ban.name, ban.reason, ban.note), - date: get_date_string(ban.created), - remove: `` - }); - - add_row_listeners(row); - - elems.name.value = null; - elems.reason.value = null; - elems.note.value = null; - - document.querySelector("details.section").open = false; - toast("Banned software", "message"); -} - - -async function update_ban(name) { - var row = document.getElementById(name); - - var elems = { - "reason": row.querySelector("textarea.reason"), - "note": row.querySelector("textarea.note") - } - - var values = { - "name": name, - "reason": elems.reason.value, - "note": elems.note.value - } - - try { - await request("PATCH", "v1/software_ban", values) - - } catch (error) { - toast(error); - return; - } - - row.querySelector("details").open = false; - toast("Updated software ban", "message"); -} - - -async function unban(name) { - try { - await request("DELETE", "v1/software_ban", {"name": name}); - - } catch (error) { - toast(error); - return; - } - - document.getElementById(name).remove(); - toast("Unbanned software", "message"); -} - - -document.querySelector("#new-ban").addEventListener("click", async (event) => { - await ban(); -}); - -for (var row of document.querySelector("#bans").rows) { - if (!row.querySelector(".update-ban")) { - continue; - } - - add_row_listeners(row); -} diff --git a/relay/frontend/static/style.css b/relay/frontend/static/style.css index 6ec9316..ac4eaf5 100644 --- a/relay/frontend/static/style.css +++ b/relay/frontend/static/style.css @@ -12,7 +12,7 @@ body { color: var(--text); background-color: #222; margin: var(--spacing); - font-family: sans serif; + font-family: sans-serif; } details *:nth-child(2) { @@ -88,6 +88,7 @@ tbody tr:last-child td:last-child { table td { padding: 5px; + white-space: nowrap; } table thead td { @@ -282,8 +283,11 @@ textarea { width: 100%; } -.data-table .date { +.data-table td:not(:first-child) { width: max-content; +} + +.data-table .date { text-align: right; } @@ -297,13 +301,13 @@ textarea { border: 1px solid var(--error-border) !important; } +/* create .grid base class and .2col and 3col classes */ .grid-2col { display: grid; grid-template-columns: max-content auto; grid-gap: var(--spacing); margin-bottom: var(--spacing); align-items: center; - } .message { @@ -333,6 +337,48 @@ textarea { justify-self: left; } +#content.page-config .grid-2col { + grid-template-columns: max-content max-content auto; +} + + +/* error */ +#content.page-error { + text-align: center; +} + +#content.page-error .title { + font-size: 24px; + font-weight: bold; +} + + +/* auth */ +#content.page-app_authorization { + text-align: center; +} + +#content.page-app_authorization #code { + background: var(--background); + border: 1px solid var(--border); + font-size: 18px; + margin: 0 auto; + width: max-content; + padding: 5px; +} + +#content.page-app_authorization #title { + font-size: 24px; +} + +#content.page-app_authorization #buttons { + display: grid; + grid-template-columns: auto max-content max-content auto; + grid-gap: var(--spacing); + justify-items: center; + margin: var(--spacing) 0; +} + @keyframes show_toast { 0% { diff --git a/relay/frontend/static/user.js b/relay/frontend/static/user.js deleted file mode 100644 index 9c74359..0000000 --- a/relay/frontend/static/user.js +++ /dev/null @@ -1,85 +0,0 @@ -function add_row_listeners(row) { - row.querySelector(".remove a").addEventListener("click", async (event) => { - event.preventDefault(); - await del_user(row.id); - }); -} - - -async function add_user() { - var elems = { - username: document.getElementById("new-username"), - password: document.getElementById("new-password"), - password2: document.getElementById("new-password2"), - handle: document.getElementById("new-handle") - } - - var values = { - username: elems.username.value.trim(), - password: elems.password.value.trim(), - password2: elems.password2.value.trim(), - handle: elems.handle.value.trim() - } - - if (values.username === "" | values.password === "" | values.password2 === "") { - toast("Username, password, and password2 are required"); - return; - } - - if (values.password !== values.password2) { - toast("Passwords do not match"); - return; - } - - try { - var user = await request("POST", "v1/user", values); - - } catch (err) { - toast(err); - return - } - - var row = append_table_row(document.querySelector("fieldset.section table"), user.username, { - domain: user.username, - handle: user.handle ? self.handle : "n/a", - date: get_date_string(user.created), - remove: `` - }); - - add_row_listeners(row); - - elems.username.value = null; - elems.password.value = null; - elems.password2.value = null; - elems.handle.value = null; - - document.querySelector("details.section").open = false; - toast("Created user", "message"); -} - - -async function del_user(username) { - try { - await request("DELETE", "v1/user", {"username": username}); - - } catch (error) { - toast(error); - return; - } - - document.getElementById(username).remove(); - toast("Deleted user", "message"); -} - - -document.querySelector("#new-user").addEventListener("click", async (event) => { - await add_user(); -}); - -for (var row of document.querySelector("#users").rows) { - if (!row.querySelector(".remove a")) { - continue; - } - - add_row_listeners(row); -} diff --git a/relay/frontend/static/whitelist.js b/relay/frontend/static/whitelist.js deleted file mode 100644 index 70d4db1..0000000 --- a/relay/frontend/static/whitelist.js +++ /dev/null @@ -1,64 +0,0 @@ -function add_row_listeners(row) { - row.querySelector(".remove a").addEventListener("click", async (event) => { - event.preventDefault(); - await del_whitelist(row.id); - }); -} - - -async function add_whitelist() { - var domain_elem = document.getElementById("new-domain"); - var domain = domain_elem.value.trim(); - - if (domain === "") { - toast("Domain is required"); - return; - } - - try { - var item = await request("POST", "v1/whitelist", {"domain": domain}); - - } catch (err) { - toast(err); - return; - } - - var row = append_table_row(document.getElementById("whitelist"), item.domain, { - domain: item.domain, - date: get_date_string(item.created), - remove: `` - }); - - add_row_listeners(row); - - domain_elem.value = null; - document.querySelector("details.section").open = false; - toast("Added domain", "message"); -} - - -async function del_whitelist(domain) { - try { - await request("DELETE", "v1/whitelist", {"domain": domain}); - - } catch (error) { - toast(error); - return; - } - - document.getElementById(domain).remove(); - toast("Removed domain", "message"); -} - - -document.querySelector("#new-item").addEventListener("click", async (event) => { - await add_whitelist(); -}); - -for (var row of document.querySelector("fieldset.section table").rows) { - if (!row.querySelector(".remove a")) { - continue; - } - - add_row_listeners(row); -} diff --git a/relay/http_client.py b/relay/http_client.py index 54cea3c..ef25881 100644 --- a/relay/http_client.py +++ b/relay/http_client.py @@ -1,20 +1,16 @@ from __future__ import annotations import json -import traceback from aiohttp import ClientSession, ClientTimeout, TCPConnector -from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError from aputils import AlgorithmType, Nodeinfo, ObjectType, Signer, WellKnownNodeinfo -from asyncio.exceptions import TimeoutError as AsyncTimeoutError -from blib import JsonBase -from bsql import Row -from json.decoder import JSONDecodeError -from typing import TYPE_CHECKING, Any, TypeVar -from urllib.parse import urlparse +from blib import HttpError, JsonBase +from typing import TYPE_CHECKING, Any, TypeVar, overload from . import __version__, logger as logging from .cache import Cache +from .database.schema import Instance +from .errors import EmptyBodyError from .misc import MIMETYPES, Message, get_app if TYPE_CHECKING: @@ -36,7 +32,7 @@ SUPPORTS_HS2019 = { 'sharkey' } -T = TypeVar('T', bound = JsonBase) +T = TypeVar('T', bound = JsonBase[Any]) HEADERS = { 'Accept': f'{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9', 'User-Agent': f'ActivityRelay/{__version__}' @@ -107,21 +103,17 @@ class HttpClient: url: str, sign_headers: bool, force: bool, - old_algo: bool) -> dict[str, Any] | None: + old_algo: bool) -> str | None: if not self._session: raise RuntimeError('Client not open') - try: - url, _ = url.split('#', 1) - - except ValueError: - pass + url = url.split("#", 1)[0] if not force: try: if not (item := self.cache.get('request', url)).older_than(48): - return json.loads(item.value) # type: ignore[no-any-return] + return item.value # type: ignore [no-any-return] except KeyError: logging.verbose('No cached data for url: %s', url) @@ -132,67 +124,77 @@ class HttpClient: algo = AlgorithmType.RSASHA256 if old_algo else AlgorithmType.HS2019 headers = self.signer.sign_headers('GET', url, algorithm = algo) - try: - logging.debug('Fetching resource: %s', url) + logging.debug('Fetching resource: %s', url) - async with self._session.get(url, headers = headers) as resp: - # Not expecting a response with 202s, so just return - if resp.status == 202: - return None - - data = await resp.text() - - if resp.status != 200: - logging.verbose('Received error when requesting %s: %i', url, resp.status) - logging.debug(data) + async with self._session.get(url, headers = headers) as resp: + # Not expecting a response with 202s, so just return + if resp.status == 202: return None - self.cache.set('request', url, data, 'str') - logging.debug('%s >> resp %s', url, json.dumps(json.loads(data), indent = 4)) + data = await resp.text() - return json.loads(data) # type: ignore [no-any-return] - - except JSONDecodeError: - logging.verbose('Failed to parse JSON') + if resp.status not in (200, 202): + logging.verbose('Received error when requesting %s: %i', url, resp.status) logging.debug(data) - return None - except ClientSSLError as e: - logging.verbose('SSL error when connecting to %s', urlparse(url).netloc) - logging.warning(str(e)) + try: + error = json.loads(data)["error"] - except (AsyncTimeoutError, ClientConnectionError) as e: - logging.verbose('Failed to connect to %s', urlparse(url).netloc) - logging.warning(str(e)) + except Exception: + error = data - except Exception: - traceback.print_exc() + raise HttpError(resp.status, error) - return None + self.cache.set('request', url, data, 'str') + return data + + + @overload + async def get(self, + url: str, + sign_headers: bool, + cls: None = None, + force: bool = False, + old_algo: bool = True) -> str | None: ... + + + @overload + async def get(self, + url: str, + sign_headers: bool, + cls: type[T] = JsonBase, # type: ignore[assignment] + force: bool = False, + old_algo: bool = True) -> T: ... async def get(self, url: str, sign_headers: bool, - cls: type[T], + cls: type[T] | None = None, force: bool = False, - old_algo: bool = True) -> T | None: + old_algo: bool = True) -> T | str | None: - if not issubclass(cls, JsonBase): + if cls is not None and not issubclass(cls, JsonBase): raise TypeError('cls must be a sub-class of "blib.JsonBase"') - if (data := (await self._get(url, sign_headers, force, old_algo))) is None: - return None + data = await self._get(url, sign_headers, force, old_algo) - return cls.parse(data) + if cls is not None: + if data is None: + # this shouldn't actually get raised, but keeping just in case + raise EmptyBodyError(f"GET {url}") + + return cls.parse(data) + + return data - async def post(self, url: str, data: Message | bytes, instance: Row | None = None) -> None: + async def post(self, url: str, data: Message | bytes, instance: Instance | None = None) -> None: if not self._session: raise RuntimeError('Client not open') # akkoma and pleroma do not support HS2019 and other software still needs to be tested - if instance and instance['software'] in SUPPORTS_HS2019: + if instance is not None and instance.software in SUPPORTS_HS2019: algorithm = AlgorithmType.HS2019 else: @@ -218,46 +220,27 @@ class HttpClient: algorithm = algorithm ) - try: - logging.verbose('Sending "%s" to %s', mtype, url) + logging.verbose('Sending "%s" to %s', mtype, url) - async with self._session.post(url, headers = headers, data = body) as resp: - # Not expecting a response, so just return - if resp.status in {200, 202}: - logging.verbose('Successfully sent "%s" to %s', mtype, url) - return - - logging.verbose('Received error when pushing to %s: %i', url, resp.status) - logging.debug(await resp.read()) - logging.debug("message: %s", body.decode("utf-8")) - logging.debug("headers: %s", json.dumps(headers, indent = 4)) + async with self._session.post(url, headers = headers, data = body) as resp: + # Not expecting a response, so just return + if resp.status in {200, 202}: + logging.verbose('Successfully sent "%s" to %s', mtype, url) return - except ClientSSLError as e: - logging.warning('SSL error when pushing to %s', urlparse(url).netloc) - logging.warning(str(e)) - - except (AsyncTimeoutError, ClientConnectionError) as e: - logging.warning('Failed to connect to %s for message push', urlparse(url).netloc) - logging.warning(str(e)) - - # prevent workers from being brought down - except Exception: - traceback.print_exc() + logging.error('Received error when pushing to %s: %i', url, resp.status) + logging.debug(await resp.read()) + logging.debug("message: %s", body.decode("utf-8")) + logging.debug("headers: %s", json.dumps(headers, indent = 4)) + return - async def fetch_nodeinfo(self, domain: str) -> Nodeinfo | None: + async def fetch_nodeinfo(self, domain: str, force: bool = False) -> Nodeinfo: nodeinfo_url = None wk_nodeinfo = await self.get( - f'https://{domain}/.well-known/nodeinfo', - False, - WellKnownNodeinfo + f'https://{domain}/.well-known/nodeinfo', False, WellKnownNodeinfo, force ) - if wk_nodeinfo is None: - logging.verbose('Failed to fetch well-known nodeinfo url for %s', domain) - return None - for version in ('20', '21'): try: nodeinfo_url = wk_nodeinfo.get_url(version) @@ -266,10 +249,9 @@ class HttpClient: pass if nodeinfo_url is None: - logging.verbose('Failed to fetch nodeinfo url for %s', domain) - return None + raise ValueError(f'Failed to fetch nodeinfo url for {domain}') - return await self.get(nodeinfo_url, False, Nodeinfo) + return await self.get(nodeinfo_url, False, Nodeinfo, force) async def get(*args: Any, **kwargs: Any) -> Any: diff --git a/relay/logger.py b/relay/logger.py index f1a1bd7..7caac9f 100644 --- a/relay/logger.py +++ b/relay/logger.py @@ -1,16 +1,15 @@ +from __future__ import annotations + import logging import os from enum import IntEnum from pathlib import Path -from typing import Any, Protocol +from typing import TYPE_CHECKING, Any, Protocol -try: +if TYPE_CHECKING: from typing import Self -except ImportError: - from typing_extensions import Self - class LoggingMethod(Protocol): def __call__(self, msg: Any, *args: Any, **kwargs: Any) -> None: ... diff --git a/relay/manage.py b/relay/manage.py index cb2b099..b76443d 100644 --- a/relay/manage.py +++ b/relay/manage.py @@ -6,7 +6,6 @@ import click import json import os -from bsql import Row from pathlib import Path from shutil import copyfile from typing import Any @@ -17,7 +16,8 @@ from . import http_client as http from . import logger as logging from .application import Application from .compat import RelayConfig, RelayDatabase -from .database import RELAY_SOFTWARE, get_database +from .config import Config +from .database import RELAY_SOFTWARE, get_database, schema from .misc import ACTOR_FORMATS, SOFTWARE, IS_DOCKER, Message @@ -213,6 +213,24 @@ def cli_run(ctx: click.Context, dev: bool = False) -> None: os._exit(0) +@cli.command('db-maintenance') +@click.option('--fix-timestamps', '-t', is_flag = True, + help = 'Make sure timestamps in the database are float values') +@click.pass_context +def cli_db_maintenance(ctx: click.Context, fix_timestamps: bool) -> None: + 'Perform maintenance tasks on the database' + + if fix_timestamps: + with ctx.obj.database.session(True) as conn: + conn.fix_timestamps() + + if ctx.obj.config.db_type == "postgres": + return + + with ctx.obj.database.session(False) as conn: + with conn.execute("VACUUM"): + pass + @cli.command('convert') @click.option('--old-config', '-o', help = 'Path to the config file to convert from') @@ -240,18 +258,18 @@ def cli_convert(ctx: click.Context, old_config: str) -> None: ctx.obj.config.set('domain', config['host']) ctx.obj.config.save() + # fix: mypy complains about the types returned by click.progressbar when updating click to 8.1.7 with get_database(ctx.obj.config) as db: with db.session(True) as conn: conn.put_config('private-key', database['private-key']) conn.put_config('note', config['note']) conn.put_config('whitelist-enabled', config['whitelist_enabled']) - with click.progressbar( # type: ignore + with click.progressbar( database['relay-list'].values(), label = 'Inboxes'.ljust(15), width = 0 ) as inboxes: - for inbox in inboxes: if inbox['software'] in {'akkoma', 'pleroma'}: actor = f'https://{inbox["domain"]}/relay' @@ -270,7 +288,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None: software = inbox['software'] ) - with click.progressbar( # type: ignore + with click.progressbar( config['blocked_software'], label = 'Banned software'.ljust(15), width = 0 @@ -282,7 +300,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None: reason = 'relay' if software in RELAY_SOFTWARE else None ) - with click.progressbar( # type: ignore + with click.progressbar( config['blocked_instances'], label = 'Banned domains'.ljust(15), width = 0 @@ -291,7 +309,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None: for domain in banned_software: conn.put_domain_ban(domain) - with click.progressbar( # type: ignore + with click.progressbar( config['whitelist'], label = 'Whitelist'.ljust(15), width = 0 @@ -315,6 +333,33 @@ def cli_editconfig(ctx: click.Context, editor: str) -> None: ) +@cli.command('switch-backend') +@click.pass_context +def cli_switchbackend(ctx: click.Context) -> None: + """ + Copy the database from one backend to the other + + Be sure to set the database type to the backend you want to convert from. For instance, set + the database type to `sqlite`, fill out the connection details for postgresql, and the + data from the sqlite database will be copied to the postgresql database. This only works if + the database in postgresql already exists. + """ + + config = Config(ctx.obj.config.path, load = True) + config.db_type = "sqlite" if config.db_type == "postgres" else "postgres" + database = get_database(config, migrate = False) + + with database.session(True) as new, ctx.obj.database.session(False) as old: + new.create_tables() + + for table in schema.TABLES.keys(): + for row in old.execute(f"SELECT * FROM {table}"): + new.insert(table, row).close() + + config.save() + click.echo(f"Converted database to {repr(config.db_type)}") + + @cli.group('config') def cli_config() -> None: 'Manage the relay settings stored in the database' @@ -348,10 +393,15 @@ def cli_config_list(ctx: click.Context) -> None: def cli_config_set(ctx: click.Context, key: str, value: Any) -> None: 'Set a config value' - with ctx.obj.database.session() as conn: - new_value = conn.put_config(key, value) + try: + with ctx.obj.database.session() as conn: + new_value = conn.put_config(key, value) - print(f'{key}: {repr(new_value)}') + except Exception: + click.echo(f'Invalid config name: {key}') + return + + click.echo(f'{key}: {repr(new_value)}') @cli.group('user') @@ -367,8 +417,8 @@ def cli_user_list(ctx: click.Context) -> None: click.echo('Users:') with ctx.obj.database.session() as conn: - for user in conn.execute('SELECT * FROM users'): - click.echo(f'- {user["username"]}') + for row in conn.get_users(): + click.echo(f'- {row.username}') @cli_user.command('create') @@ -379,7 +429,7 @@ def cli_user_create(ctx: click.Context, username: str, handle: str) -> None: 'Create a new local user' with ctx.obj.database.session() as conn: - if conn.get_user(username): + if conn.get_user(username) is not None: click.echo(f'User already exists: {username}') return @@ -406,7 +456,7 @@ def cli_user_delete(ctx: click.Context, username: str) -> None: 'Delete a local user' with ctx.obj.database.session() as conn: - if not conn.get_user(username): + if conn.get_user(username) is None: click.echo(f'User does not exist: {username}') return @@ -424,8 +474,8 @@ def cli_user_list_tokens(ctx: click.Context, username: str) -> None: click.echo(f'Tokens for "{username}":') with ctx.obj.database.session() as conn: - for token in conn.execute('SELECT * FROM tokens WHERE user = :user', {'user': username}): - click.echo(f'- {token["code"]}') + for row in conn.get_tokens(username): + click.echo(f'- {row.code}') @cli_user.command('create-token') @@ -435,13 +485,13 @@ def cli_user_create_token(ctx: click.Context, username: str) -> None: 'Create a new API token for a user' with ctx.obj.database.session() as conn: - if not (user := conn.get_user(username)): + if (user := conn.get_user(username)) is None: click.echo(f'User does not exist: {username}') return - token = conn.put_token(user['username']) + token = conn.put_token(user.username) - click.echo(f'New token for "{username}": {token["code"]}') + click.echo(f'New token for "{username}": {token.code}') @cli_user.command('delete-token') @@ -451,7 +501,7 @@ def cli_user_delete_token(ctx: click.Context, code: str) -> None: 'Delete an API token' with ctx.obj.database.session() as conn: - if not conn.get_token(code): + if conn.get_token(code) is None: click.echo('Token does not exist') return @@ -473,8 +523,8 @@ def cli_inbox_list(ctx: click.Context) -> None: click.echo('Connected to the following instances or relays:') with ctx.obj.database.session() as conn: - for inbox in conn.get_inboxes(): - click.echo(f'- {inbox["inbox"]}') + for row in conn.get_inboxes(): + click.echo(f'- {row.inbox}') @cli_inbox.command('follow') @@ -483,19 +533,21 @@ def cli_inbox_list(ctx: click.Context) -> None: def cli_inbox_follow(ctx: click.Context, actor: str) -> None: 'Follow an actor (Relay must be running)' + instance: schema.Instance | None = None + with ctx.obj.database.session() as conn: if conn.get_domain_ban(actor): click.echo(f'Error: Refusing to follow banned actor: {actor}') return - if (inbox_data := conn.get_inbox(actor)): - inbox = inbox_data['inbox'] + if (instance := conn.get_inbox(actor)) is not None: + inbox = instance.inbox else: if not actor.startswith('http'): actor = f'https://{actor}/actor' - if not (actor_data := asyncio.run(http.get(actor, sign_headers = True))): + if (actor_data := asyncio.run(http.get(actor, sign_headers = True))) is None: click.echo(f'Failed to fetch actor: {actor}') return @@ -506,7 +558,7 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None: actor = actor ) - asyncio.run(http.post(inbox, message, inbox_data)) + asyncio.run(http.post(inbox, message, instance)) click.echo(f'Sent follow message to actor: {actor}') @@ -516,19 +568,19 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None: def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None: 'Unfollow an actor (Relay must be running)' - inbox_data: Row | None = None + instance: schema.Instance | None = None with ctx.obj.database.session() as conn: if conn.get_domain_ban(actor): click.echo(f'Error: Refusing to follow banned actor: {actor}') return - if (inbox_data := conn.get_inbox(actor)): - inbox = inbox_data['inbox'] + if (instance := conn.get_inbox(actor)): + inbox = instance.inbox message = Message.new_unfollow( host = ctx.obj.config.domain, actor = actor, - follow = inbox_data['followid'] + follow = instance.followid ) else: @@ -552,7 +604,7 @@ def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None: } ) - asyncio.run(http.post(inbox, message, inbox_data)) + asyncio.run(http.post(inbox, message, instance)) click.echo(f'Sent unfollow message to: {actor}') @@ -632,9 +684,9 @@ def cli_request_list(ctx: click.Context) -> None: click.echo('Follow requests:') with ctx.obj.database.session() as conn: - for instance in conn.get_requests(): - date = instance['created'].strftime('%Y-%m-%d') - click.echo(f'- [{date}] {instance["domain"]}') + for row in conn.get_requests(): + date = row.created.strftime('%Y-%m-%d') + click.echo(f'- [{date}] {row.domain}') @cli_request.command('accept') @@ -653,20 +705,20 @@ def cli_request_accept(ctx: click.Context, domain: str) -> None: message = Message.new_response( host = ctx.obj.config.domain, - actor = instance['actor'], - followid = instance['followid'], + actor = instance.actor, + followid = instance.followid, accept = True ) - asyncio.run(http.post(instance['inbox'], message, instance)) + asyncio.run(http.post(instance.inbox, message, instance)) - if instance['software'] != 'mastodon': + if instance.software != 'mastodon': message = Message.new_follow( host = ctx.obj.config.domain, - actor = instance['actor'] + actor = instance.actor ) - asyncio.run(http.post(instance['inbox'], message, instance)) + asyncio.run(http.post(instance.inbox, message, instance)) @cli_request.command('deny') @@ -685,12 +737,12 @@ def cli_request_deny(ctx: click.Context, domain: str) -> None: response = Message.new_response( host = ctx.obj.config.domain, - actor = instance['actor'], - followid = instance['followid'], + actor = instance.actor, + followid = instance.followid, accept = False ) - asyncio.run(http.post(instance['inbox'], response, instance)) + asyncio.run(http.post(instance.inbox, response, instance)) @cli.group('instance') @@ -706,12 +758,12 @@ def cli_instance_list(ctx: click.Context) -> None: click.echo('Banned domains:') with ctx.obj.database.session() as conn: - for instance in conn.execute('SELECT * FROM domain_bans'): - if instance['reason']: - click.echo(f'- {instance["domain"]} ({instance["reason"]})') + for row in conn.get_domain_bans(): + if row.reason is not None: + click.echo(f'- {row.domain} ({row.reason})') else: - click.echo(f'- {instance["domain"]}') + click.echo(f'- {row.domain}') @cli_instance.command('ban') @@ -723,7 +775,7 @@ def cli_instance_ban(ctx: click.Context, domain: str, reason: str, note: str) -> 'Ban an instance and remove the associated inbox if it exists' with ctx.obj.database.session() as conn: - if conn.get_domain_ban(domain): + if conn.get_domain_ban(domain) is not None: click.echo(f'Domain already banned: {domain}') return @@ -739,7 +791,7 @@ def cli_instance_unban(ctx: click.Context, domain: str) -> None: 'Unban an instance' with ctx.obj.database.session() as conn: - if not conn.del_domain_ban(domain): + if conn.del_domain_ban(domain) is None: click.echo(f'Instance wasn\'t banned: {domain}') return @@ -764,11 +816,11 @@ def cli_instance_update(ctx: click.Context, domain: str, reason: str, note: str) click.echo(f'Updated domain ban: {domain}') - if row['reason']: - click.echo(f'- {row["domain"]} ({row["reason"]})') + if row.reason: + click.echo(f'- {row.domain} ({row.reason})') else: - click.echo(f'- {row["domain"]}') + click.echo(f'- {row.domain}') @cli.group('software') @@ -784,12 +836,12 @@ def cli_software_list(ctx: click.Context) -> None: click.echo('Banned software:') with ctx.obj.database.session() as conn: - for software in conn.execute('SELECT * FROM software_bans'): - if software['reason']: - click.echo(f'- {software["name"]} ({software["reason"]})') + for row in conn.get_software_bans(): + if row.reason: + click.echo(f'- {row.name} ({row.reason})') else: - click.echo(f'- {software["name"]}') + click.echo(f'- {row.name}') @cli_software.command('ban') @@ -811,12 +863,12 @@ def cli_software_ban(ctx: click.Context, with ctx.obj.database.session() as conn: if name == 'RELAYS': - for software in RELAY_SOFTWARE: - if conn.get_software_ban(software): - click.echo(f'Relay already banned: {software}') + for item in RELAY_SOFTWARE: + if conn.get_software_ban(item): + click.echo(f'Relay already banned: {item}') continue - conn.put_software_ban(software, reason or 'relay', note) + conn.put_software_ban(item, reason or 'relay', note) click.echo('Banned all relay software') return @@ -893,11 +945,11 @@ def cli_software_update(ctx: click.Context, name: str, reason: str, note: str) - click.echo(f'Updated software ban: {name}') - if row['reason']: - click.echo(f'- {row["name"]} ({row["reason"]})') + if row.reason: + click.echo(f'- {row.name} ({row.reason})') else: - click.echo(f'- {row["name"]}') + click.echo(f'- {row.name}') @cli.group('whitelist') @@ -913,8 +965,8 @@ def cli_whitelist_list(ctx: click.Context) -> None: click.echo('Current whitelisted domains:') with ctx.obj.database.session() as conn: - for domain in conn.execute('SELECT * FROM whitelist'): - click.echo(f'- {domain["domain"]}') + for row in conn.get_domain_whitelist(): + click.echo(f'- {row.domain}') @cli_whitelist.command('add') @@ -953,23 +1005,19 @@ def cli_whitelist_remove(ctx: click.Context, domain: str) -> None: @cli_whitelist.command('import') @click.pass_context def cli_whitelist_import(ctx: click.Context) -> None: - 'Add all current inboxes to the whitelist' + 'Add all current instances to the whitelist' with ctx.obj.database.session() as conn: - for inbox in conn.execute('SELECT * FROM inboxes').all(): - if conn.get_domain_whitelist(inbox['domain']): - click.echo(f'Domain already in whitelist: {inbox["domain"]}') + for row in conn.get_inboxes(): + if conn.get_domain_whitelist(row.domain) is not None: + click.echo(f'Domain already in whitelist: {row.domain}') continue - conn.put_domain_whitelist(inbox['domain']) + conn.put_domain_whitelist(row.domain) click.echo('Imported whitelist from inboxes') def main() -> None: - cli(prog_name='relay') - - -if __name__ == '__main__': - click.echo('Running relay.manage is depreciated. Run `activityrelay [command]` instead.') + cli(prog_name='activityrelay') diff --git a/relay/misc.py b/relay/misc.py index 9e8f035..aa44956 100644 --- a/relay/misc.py +++ b/relay/misc.py @@ -9,23 +9,13 @@ import socket from aiohttp.web import Response as AiohttpResponse from collections.abc import Sequence from datetime import datetime +from importlib.resources import files as pkgfiles from pathlib import Path from typing import TYPE_CHECKING, Any, TypedDict, TypeVar from uuid import uuid4 -try: - from importlib.resources import files as pkgfiles - -except ImportError: - from importlib_resources import files as pkgfiles # type: ignore - -try: - from typing import Self - -except ImportError: - from typing_extensions import Self - if TYPE_CHECKING: + from typing import Self from .application import Application @@ -72,6 +62,27 @@ SOFTWARE = ( 'gotosocial' ) +JSON_PATHS: tuple[str, ...] = ( + '/api/v1', + '/actor', + '/inbox', + '/outbox', + '/following', + '/followers', + '/.well-known', + '/nodeinfo', + '/oauth/token', + '/oauth/revoke' +) + +TOKEN_PATHS: tuple[str, ...] = ( + '/logout', + '/admin', + '/api', + '/oauth/authorize', + '/oauth/revoke' +) + def boolean(value: Any) -> bool: if isinstance(value, str): @@ -252,9 +263,9 @@ class Response(AiohttpResponse): @classmethod - def new_redir(cls: type[Self], path: str) -> Self: + def new_redir(cls: type[Self], path: str, status: int = 307) -> Self: body = f'Redirect to {path}' - return cls.new(body, 302, {'Location': path}) + return cls.new(body, status, {'Location': path}, ctype = 'html') @property diff --git a/relay/processors.py b/relay/processors.py index cd742ec..57e9222 100644 --- a/relay/processors.py +++ b/relay/processors.py @@ -34,7 +34,7 @@ async def handle_relay(view: ActorView, conn: Connection) -> None: logging.debug('>> relay: %s', message) for instance in conn.distill_inboxes(view.message): - view.app.push_message(instance["inbox"], message, instance) + view.app.push_message(instance.inbox, message, instance) view.cache.set('handle-relay', view.message.object_id, message.id, 'str') @@ -52,13 +52,13 @@ async def handle_forward(view: ActorView, conn: Connection) -> None: logging.debug('>> forward: %s', message) for instance in conn.distill_inboxes(view.message): - view.app.push_message(instance["inbox"], view.message, instance) + view.app.push_message(instance.inbox, view.message, instance) view.cache.set('handle-relay', view.message.id, message.id, 'str') async def handle_follow(view: ActorView, conn: Connection) -> None: - nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain) + nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain, force = True) software = nodeinfo.sw_name if nodeinfo else None config = conn.get_config_all() @@ -177,7 +177,7 @@ async def handle_undo(view: ActorView, conn: Connection) -> None: return # prevent past unfollows from removing an instance - if view.instance['followid'] and view.instance['followid'] != view.message.object_id: + if view.instance.followid and view.instance.followid != view.message.object_id: return with conn.transaction(): @@ -221,18 +221,18 @@ async def run_processor(view: ActorView) -> None: with view.database.session() as conn: if view.instance: - if not view.instance['software']: - if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])): + if not view.instance.software: + if (nodeinfo := await view.client.fetch_nodeinfo(view.instance.domain)): with conn.transaction(): view.instance = conn.put_inbox( - domain = view.instance['domain'], + domain = view.instance.domain, software = nodeinfo.sw_name ) - if not view.instance['actor']: + if not view.instance.actor: with conn.transaction(): view.instance = conn.put_inbox( - domain = view.instance['domain'], + domain = view.instance.domain, actor = view.actor.id ) diff --git a/relay/template.py b/relay/template.py index ef25f92..3ee2855 100644 --- a/relay/template.py +++ b/relay/template.py @@ -2,6 +2,7 @@ from __future__ import annotations import textwrap +from aiohttp.web import Request from collections.abc import Callable from hamlish_jinja import HamlishExtension from jinja2 import Environment, FileSystemLoader @@ -13,13 +14,15 @@ from typing import TYPE_CHECKING, Any from . import __version__ from .misc import get_resource -from .views.base import View if TYPE_CHECKING: from .application import Application class Template(Environment): + _render_markdown: Callable[[str], str] + + def __init__(self, app: Application): Environment.__init__(self, autoescape = True, @@ -40,12 +43,12 @@ class Template(Environment): self.hamlish_mode = 'indented' - def render(self, path: str, view: View | None = None, **context: Any) -> str: + def render(self, path: str, request: Request, **context: Any) -> str: with self.app.database.session(False) as conn: config = conn.get_config_all() new_context = { - 'view': view, + 'request': request, 'domain': self.app.config.domain, 'version': __version__, 'config': config, @@ -56,7 +59,7 @@ class Template(Environment): def render_markdown(self, text: str) -> str: - return self._render_markdown(text) # type: ignore + return self._render_markdown(text) class MarkdownExtension(Extension): diff --git a/relay/views/activitypub.py b/relay/views/activitypub.py index b19b7e1..4551c88 100644 --- a/relay/views/activitypub.py +++ b/relay/views/activitypub.py @@ -1,26 +1,23 @@ -from __future__ import annotations - import aputils import traceback -import typing + +from aiohttp.web import Request +from blib import HttpError from .base import View, register_route from .. import logger as logging +from ..database import schema from ..misc import Message, Response from ..processors import run_processor -if typing.TYPE_CHECKING: - from aiohttp.web import Request - from bsql import Row - @register_route('/actor', '/inbox') class ActorView(View): signature: aputils.Signature message: Message actor: Message - instancce: Row + instance: schema.Instance signer: aputils.Signer @@ -43,11 +40,10 @@ class ActorView(View): async def post(self, request: Request) -> Response: - if response := await self.get_post_data(): - return response + await self.get_post_data() with self.database.session() as conn: - self.instance = conn.get_inbox(self.actor.shared_inbox) + self.instance = conn.get_inbox(self.actor.shared_inbox) # type: ignore[assignment] # reject if actor is banned if conn.get_domain_ban(self.actor.domain): @@ -69,13 +65,13 @@ class ActorView(View): return Response.new(status = 202) - async def get_post_data(self) -> Response | None: + async def get_post_data(self) -> None: try: self.signature = aputils.Signature.parse(self.request.headers['signature']) except KeyError: logging.verbose('Missing signature header') - return Response.new_error(400, 'missing signature header', 'json') + raise HttpError(400, 'missing signature header') try: message: Message | None = await self.request.json(loads = Message.parse) @@ -83,46 +79,47 @@ class ActorView(View): except Exception: traceback.print_exc() logging.verbose('Failed to parse inbox message') - return Response.new_error(400, 'failed to parse message', 'json') + raise HttpError(400, 'failed to parse message') if message is None: logging.verbose('empty message') - return Response.new_error(400, 'missing message', 'json') + raise HttpError(400, 'missing message') self.message = message if 'actor' not in self.message: logging.verbose('actor not in message') - return Response.new_error(400, 'no actor in message', 'json') + raise HttpError(400, 'no actor in message') - actor: Message | None = await self.client.get(self.signature.keyid, True, Message) + try: + self.actor = await self.client.get(self.signature.keyid, True, Message) - if actor is None: + except HttpError: # ld signatures aren't handled atm, so just ignore it if self.message.type == 'Delete': logging.verbose('Instance sent a delete which cannot be handled') - return Response.new(status=202) + raise HttpError(202, '') - logging.verbose(f'Failed to fetch actor: {self.signature.keyid}') - return Response.new_error(400, 'failed to fetch actor', 'json') + logging.verbose('Failed to fetch actor: %s', self.signature.keyid) + raise HttpError(400, 'failed to fetch actor') - self.actor = actor + except Exception: + traceback.print_exc() + raise HttpError(500, 'unexpected error when fetching actor') try: self.signer = self.actor.signer except KeyError: logging.verbose('Actor missing public key: %s', self.signature.keyid) - return Response.new_error(400, 'actor missing public key', 'json') + raise HttpError(400, 'actor missing public key') try: await self.signer.validate_request_async(self.request) except aputils.SignatureFailureError as e: logging.verbose('signature validation failed for "%s": %s', self.actor.id, e) - return Response.new_error(401, str(e), 'json') - - return None + raise HttpError(401, str(e)) @register_route('/outbox') diff --git a/relay/views/api.py b/relay/views/api.py index 70a9f0e..e7cb5fb 100644 --- a/relay/views/api.py +++ b/relay/views/api.py @@ -1,17 +1,20 @@ +import traceback + from aiohttp.web import Request, middleware from argon2.exceptions import VerifyMismatchError +from blib import HttpError, convert_to_boolean from collections.abc import Awaitable, Callable, Sequence -from typing import Any from urllib.parse import urlparse from .base import View, register_route from .. import __version__ -from ..database import ConfigData -from ..misc import Message, Response, boolean, get_app +from ..database import ConfigData, schema +from ..misc import Message, Response, boolean -ALLOWED_HEADERS = { +DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob' +ALLOWED_HEADERS: set[str] = { 'accept', 'authorization', 'content-type' @@ -19,6 +22,8 @@ ALLOWED_HEADERS = { PUBLIC_API_PATHS: Sequence[tuple[str, str]] = ( ('GET', '/api/v1/relay'), + ('POST', '/api/v1/app'), + ('POST', '/api/v1/login'), ('POST', '/api/v1/token') ) @@ -34,64 +39,184 @@ def check_api_path(method: str, path: str) -> bool: async def handle_api_path( request: Request, handler: Callable[[Request], Awaitable[Response]]) -> Response: - try: - if (token := request.cookies.get('user-token')): - request['token'] = token - else: - request['token'] = request.headers['Authorization'].replace('Bearer', '').strip() - - with get_app().database.session() as conn: - request['user'] = conn.get_user_by_token(request['token']) - - except (KeyError, ValueError): - request['token'] = None - request['user'] = None + if not request.path.startswith('/api') or request.path == '/api/doc': + return await handler(request) if request.method != "OPTIONS" and check_api_path(request.method, request.path): - if not request['token']: - return Response.new_error(401, 'Missing token', 'json') + if request['token'] is None: + raise HttpError(401, 'Missing token') - if not request['user']: - return Response.new_error(401, 'Invalid token', 'json') + if request['user'] is None: + raise HttpError(401, 'Invalid token') response = await handler(request) - - if request.path.startswith('/api'): - response.headers['Access-Control-Allow-Origin'] = '*' - response.headers['Access-Control-Allow-Headers'] = ', '.join(ALLOWED_HEADERS) + response.headers['Access-Control-Allow-Origin'] = '*' + response.headers['Access-Control-Allow-Headers'] = ', '.join(ALLOWED_HEADERS) return response -@register_route('/api/v1/token') -class Login(View): +@register_route('/oauth/authorize') +@register_route('/api/oauth/authorize') +class OauthAuthorize(View): async def get(self, request: Request) -> Response: - return Response.new({'message': 'Token valid'}, ctype = 'json') + data = await self.get_api_data(['response_type', 'client_id', 'redirect_uri'], []) + + if data['response_type'] != 'code': + raise HttpError(400, 'Response type is not "code"') + + with self.database.session(True) as conn: + with conn.select('apps', client_id = data['client_id']) as cur: + if (app := cur.one(schema.App)) is None: + raise HttpError(404, 'Could not find app') + + if app.token is not None: + raise HttpError(400, 'Application has already been authorized') + + if app.auth_code is not None: + context = {'application': app} + html = self.template.render( + 'page/authorize_show.haml', self.request, **context + ) + + return Response.new(html, ctype = 'html') + + if data['redirect_uri'] != app.redirect_uri: + raise HttpError(400, 'redirect_uri does not match application') + + context = {'application': app} + html = self.template.render('page/authorize_new.haml', self.request, **context) + return Response.new(html, ctype = 'html') async def post(self, request: Request) -> Response: - data = await self.get_api_data(['username', 'password'], []) + data = await self.get_api_data( + ['client_id', 'client_secret', 'redirect_uri', 'response'], [] + ) - if isinstance(data, Response): - return data + with self.database.session(True) as conn: + if (app := conn.get_app(data['client_id'], data['client_secret'])) is None: + return Response.new_error(404, 'Could not find app', 'json') + + if convert_to_boolean(data['response']): + if app.token is not None: + raise HttpError(400, 'Application has already been authorized') + + if app.auth_code is None: + app = conn.update_app(app, request['user'], True) + + if app.redirect_uri == DEFAULT_REDIRECT: + context = {'application': app} + html = self.template.render( + 'page/authorize_show.haml', self.request, **context + ) + + return Response.new(html, ctype = 'html') + + return Response.new_redir(f'{app.redirect_uri}?code={app.auth_code}') + + if not conn.del_app(app.client_id, app.client_secret): + raise HttpError(404, 'App not found') + + return Response.new_redir('/') + + +@register_route('/oauth/token') +@register_route('/api/oauth/token') +class OauthToken(View): + async def post(self, request: Request) -> Response: + data = await self.get_api_data( + ['grant_type', 'code', 'client_id', 'client_secret', 'redirect_uri'], [] + ) + + if data['grant_type'] != 'authorization_code': + raise HttpError(400, 'Invalid grant type') + + with self.database.session(True) as conn: + if (app := conn.get_app(data['client_id'], data['client_secret'])) is None: + raise HttpError(404, 'Application not found') + + if app.auth_code != data['code']: + raise HttpError(400, 'Invalid authentication code') + + if app.redirect_uri != data['redirect_uri']: + raise HttpError(400, 'Invalid redirect uri') + + app = conn.update_app(app, request['user'], False) + + return Response.new(app.get_api_data(True), ctype = 'json') + + +@register_route('/oauth/revoke') +@register_route('/api/oauth/revoke') +class OauthRevoke(View): + async def post(self, request: Request) -> Response: + data = await self.get_api_data(['client_id', 'client_secret', 'token'], []) + + with self.database.session(True) as conn: + if (app := conn.get_app(**data)) is None: + raise HttpError(404, 'Could not find token') + + if app.user != request['token'].username: + raise HttpError(403, 'Invalid token') + + if not conn.del_app(**data): + raise HttpError(400, 'Token not removed') + + return Response.new({'msg': 'Token deleted'}, ctype = 'json') + + +@register_route('/api/v1/app') +class App(View): + async def get(self, request: Request) -> Response: + return Response.new(request['token'].get_api_data(), ctype = 'json') + + + async def post(self, request: Request) -> Response: + data = await self.get_api_data(['name', 'redirect_uri'], ['website']) + + with self.database.session(True) as conn: + app = conn.put_app( + name = data['name'], + redirect_uri = data['redirect_uri'], + website = data.get('website') + ) + + return Response.new(app.get_api_data(), ctype = 'json') + + + async def delete(self, request: Request) -> Response: + data = await self.get_api_data(['client_id', 'client_secret'], []) + + with self.database.session(True) as conn: + if not conn.del_app(data['client_id'], data['client_secret'], request['token'].code): + raise HttpError(400, 'Token not removed') + + return Response.new({'msg': 'Token deleted'}, ctype = 'json') + + +@register_route('/api/v1/login') +class Login(View): + async def post(self, request: Request) -> Response: + data = await self.get_api_data(['username', 'password'], []) with self.database.session(True) as conn: if not (user := conn.get_user(data['username'])): - return Response.new_error(401, 'User not found', 'json') + raise HttpError(401, 'User not found') try: conn.hasher.verify(user['hash'], data['password']) except VerifyMismatchError: - return Response.new_error(401, 'Invalid password', 'json') + raise HttpError(401, 'Invalid password') - token = conn.put_token(data['username']) + app = conn.put_app_login(user) - resp = Response.new({'token': token['code']}, ctype = 'json') + resp = Response.new(app.get_api_data(True), ctype = 'json') resp.set_cookie( 'user-token', - token['code'], + app.token, # type: ignore[arg-type] max_age = 60 * 60 * 24 * 365, domain = self.config.domain, path = '/', @@ -103,19 +228,12 @@ class Login(View): return resp - async def delete(self, request: Request) -> Response: - with self.database.session() as conn: - conn.del_token(request['token']) - - return Response.new({'message': 'Token revoked'}, ctype = 'json') - - @register_route('/api/v1/relay') class RelayInfo(View): async def get(self, request: Request) -> Response: with self.database.session() as conn: config = conn.get_config_all() - inboxes = [row['domain'] for row in conn.get_inboxes()] + inboxes = [row.domain for row in conn.get_inboxes()] data = { 'domain': self.config.domain, @@ -152,17 +270,16 @@ class Config(View): async def post(self, request: Request) -> Response: data = await self.get_api_data(['key', 'value'], []) - - if isinstance(data, Response): - return data - data['key'] = data['key'].replace('-', '_') if data['key'] not in ConfigData.USER_KEYS(): - return Response.new_error(400, 'Invalid key', 'json') + raise HttpError(400, 'Invalid key') with self.database.session() as conn: - conn.put_config(data['key'], data['value']) + value = conn.put_config(data['key'], data['value']) + + if data['key'] == 'log-level': + self.app.workers.set_log_level(value) return Response.new({'message': 'Updated config'}, ctype = 'json') @@ -170,14 +287,14 @@ class Config(View): async def delete(self, request: Request) -> Response: data = await self.get_api_data(['key'], []) - if isinstance(data, Response): - return data - if data['key'] not in ConfigData.USER_KEYS(): - return Response.new_error(400, 'Invalid key', 'json') + raise HttpError(400, 'Invalid key') with self.database.session() as conn: - conn.put_config(data['key'], ConfigData.DEFAULT(data['key'])) + value = conn.put_config(data['key'], ConfigData.DEFAULT(data['key'])) + + if data['key'] == 'log-level': + self.app.workers.set_log_level(value) return Response.new({'message': 'Updated config'}, ctype = 'json') @@ -186,40 +303,46 @@ class Config(View): class Inbox(View): async def get(self, request: Request) -> Response: with self.database.session() as conn: - data = conn.get_inboxes() + data = tuple(conn.get_inboxes()) return Response.new(data, ctype = 'json') async def post(self, request: Request) -> Response: data = await self.get_api_data(['actor'], ['inbox', 'software', 'followid']) - - if isinstance(data, Response): - return data - data['domain'] = urlparse(data["actor"]).netloc with self.database.session() as conn: - if conn.get_inbox(data['domain']): - return Response.new_error(404, 'Instance already in database', 'json') + if conn.get_inbox(data['domain']) is not None: + raise HttpError(404, 'Instance already in database') data['domain'] = data['domain'].encode('idna').decode() if not data.get('inbox'): - actor_data: Message | None = await self.client.get(data['actor'], True, Message) + try: + actor_data = await self.client.get(data['actor'], True, Message) - if actor_data is None: - return Response.new_error(500, 'Failed to fetch actor', 'json') + except Exception: + traceback.print_exc() + raise HttpError(500, 'Failed to fetch actor') from None data['inbox'] = actor_data.shared_inbox if not data.get('software'): - nodeinfo = await self.client.fetch_nodeinfo(data['domain']) - - if nodeinfo is not None: + try: + nodeinfo = await self.client.fetch_nodeinfo(data['domain']) data['software'] = nodeinfo.sw_name - row = conn.put_inbox(**data) # type: ignore[arg-type] + except Exception: + pass + + row = conn.put_inbox( + domain = data['domain'], + actor = data['actor'], + inbox = data.get('inbox'), + software = data.get('software'), + followid = data.get('followid') + ) return Response.new(row, ctype = 'json') @@ -227,16 +350,17 @@ class Inbox(View): async def patch(self, request: Request) -> Response: with self.database.session() as conn: data = await self.get_api_data(['domain'], ['actor', 'software', 'followid']) - - if isinstance(data, Response): - return data - data['domain'] = data['domain'].encode('idna').decode() - if not (instance := conn.get_inbox(data['domain'])): - return Response.new_error(404, 'Instance with domain not found', 'json') + if (instance := conn.get_inbox(data['domain'])) is None: + raise HttpError(404, 'Instance with domain not found') - instance = conn.put_inbox(instance['domain'], **data) # type: ignore[arg-type] + instance = conn.put_inbox( + instance.domain, + actor = data.get('actor'), + software = data.get('software'), + followid = data.get('followid') + ) return Response.new(instance, ctype = 'json') @@ -244,14 +368,10 @@ class Inbox(View): async def delete(self, request: Request) -> Response: with self.database.session() as conn: data = await self.get_api_data(['domain'], []) - - if isinstance(data, Response): - return data - data['domain'] = data['domain'].encode('idna').decode() if not conn.get_inbox(data['domain']): - return Response.new_error(404, 'Instance with domain not found', 'json') + raise HttpError(404, 'Instance with domain not found') conn.del_inbox(data['domain']) @@ -262,43 +382,38 @@ class Inbox(View): class RequestView(View): async def get(self, request: Request) -> Response: with self.database.session() as conn: - instances = conn.get_requests() + instances = tuple(conn.get_requests()) return Response.new(instances, ctype = 'json') async def post(self, request: Request) -> Response: - data: dict[str, Any] | Response = await self.get_api_data(['domain', 'accept'], []) - - if isinstance(data, Response): - return data - - data['accept'] = boolean(data['accept']) + data = await self.get_api_data(['domain', 'accept'], []) data['domain'] = data['domain'].encode('idna').decode() try: with self.database.session(True) as conn: - instance = conn.put_request_response(data['domain'], data['accept']) + instance = conn.put_request_response(data['domain'], boolean(data['accept'])) except KeyError: - return Response.new_error(404, 'Request not found', 'json') + raise HttpError(404, 'Request not found') from None message = Message.new_response( host = self.config.domain, - actor = instance['actor'], - followid = instance['followid'], - accept = data['accept'] + actor = instance.actor, + followid = instance.followid, + accept = boolean(data['accept']) ) - self.app.push_message(instance['inbox'], message, instance) + self.app.push_message(instance.inbox, message, instance) - if data['accept'] and instance['software'] != 'mastodon': + if data['accept'] and instance.software != 'mastodon': message = Message.new_follow( host = self.config.domain, - actor = instance['actor'] + actor = instance.actor ) - self.app.push_message(instance['inbox'], message, instance) + self.app.push_message(instance.inbox, message, instance) resp_message = {'message': 'Request accepted' if data['accept'] else 'Request denied'} return Response.new(resp_message, ctype = 'json') @@ -308,24 +423,24 @@ class RequestView(View): class DomainBan(View): async def get(self, request: Request) -> Response: with self.database.session() as conn: - bans = tuple(conn.execute('SELECT * FROM domain_bans').all()) + bans = tuple(conn.get_domain_bans()) return Response.new(bans, ctype = 'json') async def post(self, request: Request) -> Response: data = await self.get_api_data(['domain'], ['note', 'reason']) - - if isinstance(data, Response): - return data - data['domain'] = data['domain'].encode('idna').decode() with self.database.session() as conn: - if conn.get_domain_ban(data['domain']): - return Response.new_error(400, 'Domain already banned', 'json') + if conn.get_domain_ban(data['domain']) is not None: + raise HttpError(400, 'Domain already banned') - ban = conn.put_domain_ban(**data) + ban = conn.put_domain_ban( + domain = data['domain'], + reason = data.get('reason'), + note = data.get('note') + ) return Response.new(ban, ctype = 'json') @@ -334,18 +449,19 @@ class DomainBan(View): with self.database.session() as conn: data = await self.get_api_data(['domain'], ['note', 'reason']) - if isinstance(data, Response): - return data + if not any([data.get('note'), data.get('reason')]): + raise HttpError(400, 'Must include note and/or reason parameters') data['domain'] = data['domain'].encode('idna').decode() - if not conn.get_domain_ban(data['domain']): - return Response.new_error(404, 'Domain not banned', 'json') + if conn.get_domain_ban(data['domain']) is None: + raise HttpError(404, 'Domain not banned') - if not any([data.get('note'), data.get('reason')]): - return Response.new_error(400, 'Must include note and/or reason parameters', 'json') - - ban = conn.update_domain_ban(**data) + ban = conn.update_domain_ban( + domain = data['domain'], + reason = data.get('reason'), + note = data.get('note') + ) return Response.new(ban, ctype = 'json') @@ -353,14 +469,10 @@ class DomainBan(View): async def delete(self, request: Request) -> Response: with self.database.session() as conn: data = await self.get_api_data(['domain'], []) - - if isinstance(data, Response): - return data - data['domain'] = data['domain'].encode('idna').decode() - if not conn.get_domain_ban(data['domain']): - return Response.new_error(404, 'Domain not banned', 'json') + if conn.get_domain_ban(data['domain']) is None: + raise HttpError(404, 'Domain not banned') conn.del_domain_ban(data['domain']) @@ -371,7 +483,7 @@ class DomainBan(View): class SoftwareBan(View): async def get(self, request: Request) -> Response: with self.database.session() as conn: - bans = tuple(conn.execute('SELECT * FROM software_bans').all()) + bans = tuple(conn.get_software_bans()) return Response.new(bans, ctype = 'json') @@ -379,14 +491,15 @@ class SoftwareBan(View): async def post(self, request: Request) -> Response: data = await self.get_api_data(['name'], ['note', 'reason']) - if isinstance(data, Response): - return data - with self.database.session() as conn: - if conn.get_software_ban(data['name']): - return Response.new_error(400, 'Domain already banned', 'json') + if conn.get_software_ban(data['name']) is not None: + raise HttpError(400, 'Domain already banned') - ban = conn.put_software_ban(**data) + ban = conn.put_software_ban( + name = data['name'], + reason = data.get('reason'), + note = data.get('note') + ) return Response.new(ban, ctype = 'json') @@ -394,17 +507,18 @@ class SoftwareBan(View): async def patch(self, request: Request) -> Response: data = await self.get_api_data(['name'], ['note', 'reason']) - if isinstance(data, Response): - return data + if not any([data.get('note'), data.get('reason')]): + raise HttpError(400, 'Must include note and/or reason parameters') with self.database.session() as conn: - if not conn.get_software_ban(data['name']): - return Response.new_error(404, 'Software not banned', 'json') + if conn.get_software_ban(data['name']) is None: + raise HttpError(404, 'Software not banned') - if not any([data.get('note'), data.get('reason')]): - return Response.new_error(400, 'Must include note and/or reason parameters', 'json') - - ban = conn.update_software_ban(**data) + ban = conn.update_software_ban( + name = data['name'], + reason = data.get('reason'), + note = data.get('note') + ) return Response.new(ban, ctype = 'json') @@ -412,12 +526,9 @@ class SoftwareBan(View): async def delete(self, request: Request) -> Response: data = await self.get_api_data(['name'], []) - if isinstance(data, Response): - return data - with self.database.session() as conn: - if not conn.get_software_ban(data['name']): - return Response.new_error(404, 'Software not banned', 'json') + if conn.get_software_ban(data['name']) is None: + raise HttpError(404, 'Software not banned') conn.del_software_ban(data['name']) @@ -430,7 +541,7 @@ class User(View): with self.database.session() as conn: items = [] - for row in conn.execute('SELECT * FROM users'): + for row in conn.get_users(): del row['hash'] items.append(row) @@ -440,41 +551,40 @@ class User(View): async def post(self, request: Request) -> Response: data = await self.get_api_data(['username', 'password'], ['handle']) - if isinstance(data, Response): - return data - with self.database.session() as conn: - if conn.get_user(data['username']): - return Response.new_error(404, 'User already exists', 'json') + if conn.get_user(data['username']) is not None: + raise HttpError(404, 'User already exists') - user = conn.put_user(**data) - del user['hash'] + user = conn.put_user( + username = data['username'], + password = data['password'], + handle = data.get('handle') + ) + del user['hash'] return Response.new(user, ctype = 'json') async def patch(self, request: Request) -> Response: data = await self.get_api_data(['username'], ['password', 'handle']) - if isinstance(data, Response): - return data - with self.database.session(True) as conn: - user = conn.put_user(**data) - del user['hash'] + user = conn.put_user( + username = data['username'], + password = data['password'], + handle = data.get('handle') + ) + del user['hash'] return Response.new(user, ctype = 'json') async def delete(self, request: Request) -> Response: data = await self.get_api_data(['username'], []) - if isinstance(data, Response): - return data - with self.database.session(True) as conn: - if not conn.get_user(data['username']): - return Response.new_error(404, 'User does not exist', 'json') + if conn.get_user(data['username']) is None: + raise HttpError(404, 'User does not exist') conn.del_user(data['username']) @@ -485,7 +595,7 @@ class User(View): class Whitelist(View): async def get(self, request: Request) -> Response: with self.database.session() as conn: - items = tuple(conn.execute('SELECT * FROM whitelist').all()) + items = tuple(conn.get_domains_whitelist()) return Response.new(items, ctype = 'json') @@ -493,16 +603,13 @@ class Whitelist(View): async def post(self, request: Request) -> Response: data = await self.get_api_data(['domain'], []) - if isinstance(data, Response): - return data - - data['domain'] = data['domain'].encode('idna').decode() + domain = data['domain'].encode('idna').decode() with self.database.session() as conn: - if conn.get_domain_whitelist(data['domain']): - return Response.new_error(400, 'Domain already added to whitelist', 'json') + if conn.get_domain_whitelist(domain) is not None: + raise HttpError(400, 'Domain already added to whitelist') - item = conn.put_domain_whitelist(**data) + item = conn.put_domain_whitelist(domain) return Response.new(item, ctype = 'json') @@ -510,15 +617,12 @@ class Whitelist(View): async def delete(self, request: Request) -> Response: data = await self.get_api_data(['domain'], []) - if isinstance(data, Response): - return data - - data['domain'] = data['domain'].encode('idna').decode() + domain = data['domain'].encode('idna').decode() with self.database.session() as conn: - if not conn.get_domain_whitelist(data['domain']): - return Response.new_error(404, 'Domain not in whitelist', 'json') + if conn.get_domain_whitelist(domain) is None: + raise HttpError(404, 'Domain not in whitelist') - conn.del_domain_whitelist(data['domain']) + conn.del_domain_whitelist(domain) return Response.new({'message': 'Removed domain from whitelist'}, ctype = 'json') diff --git a/relay/views/base.py b/relay/views/base.py index 350016c..624ed9d 100644 --- a/relay/views/base.py +++ b/relay/views/base.py @@ -1,10 +1,9 @@ from __future__ import annotations -from Crypto.Random import get_random_bytes from aiohttp.abc import AbstractView from aiohttp.hdrs import METH_ALL as METHODS -from aiohttp.web import HTTPMethodNotAllowed, Request -from base64 import b64encode +from aiohttp.web import Request +from blib import HttpError from bsql import Database from collections.abc import Awaitable, Callable, Generator, Sequence, Mapping from functools import cached_property @@ -18,18 +17,12 @@ from ..http_client import HttpClient from ..misc import Response, get_app if TYPE_CHECKING: + from typing import Self from ..application import Application from ..template import Template -try: - from typing import Self - -except ImportError: - from typing_extensions import Self HandlerCallback = Callable[[Request], Awaitable[Response]] - - VIEWS: list[tuple[str, type[View]]] = [] @@ -49,10 +42,10 @@ def register_route(*paths: str) -> Callable[[type[View]], type[View]]: class View(AbstractView): def __await__(self) -> Generator[Any, None, Response]: if self.request.method not in METHODS: - raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) + raise HttpError(405, f'"{self.request.method}" method not allowed') if not (handler := self.handlers.get(self.request.method)): - raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) + raise HttpError(405, f'"{self.request.method}" method not allowed') return self._run_handler(handler).__await__() @@ -64,7 +57,6 @@ class View(AbstractView): async def _run_handler(self, handler: HandlerCallback, **kwargs: Any) -> Response: - self.request['hash'] = b64encode(get_random_bytes(16)).decode('ascii') return await handler(self.request, **self.request.match_info, **kwargs) @@ -123,17 +115,18 @@ class View(AbstractView): async def get_api_data(self, required: list[str], - optional: list[str]) -> dict[str, str] | Response: + optional: list[str]) -> dict[str, str]: - if self.request.content_type in {'x-www-form-urlencoded', 'multipart/form-data'}: + if self.request.content_type in {'application/x-www-form-urlencoded', 'multipart/form-data'}: post_data = convert_data(await self.request.post()) + # post_data = {key: value for key, value in parse_qsl(await self.request.text())} elif self.request.content_type == 'application/json': try: post_data = convert_data(await self.request.json()) except JSONDecodeError: - return Response.new_error(400, 'Invalid JSON data', 'json') + raise HttpError(400, 'Invalid JSON data') else: post_data = convert_data(self.request.query) @@ -145,9 +138,9 @@ class View(AbstractView): data[key] = post_data[key] except KeyError as e: - return Response.new_error(400, f'Missing {str(e)} pararmeter', 'json') + raise HttpError(400, f'Missing {str(e)} pararmeter') from None for key in optional: - data[key] = post_data.get(key, '') + data[key] = post_data.get(key) # type: ignore[assignment] return data diff --git a/relay/views/frontend.py b/relay/views/frontend.py index 5dfb43a..b6dba7b 100644 --- a/relay/views/frontend.py +++ b/relay/views/frontend.py @@ -1,18 +1,13 @@ from aiohttp import web from collections.abc import Awaitable, Callable from typing import Any +from urllib.parse import unquote from .base import View, register_route from ..database import THEMES from ..logger import LogLevel -from ..misc import Response, get_app - - -UNAUTH_ROUTES = { - '/', - '/login' -} +from ..misc import TOKEN_PATHS, Response @web.middleware @@ -20,28 +15,25 @@ async def handle_frontend_path( request: web.Request, handler: Callable[[web.Request], Awaitable[Response]]) -> Response: - app = get_app() + if request['user'] is not None and request.path == '/login': + return Response.new_redir('/') - if request.path in UNAUTH_ROUTES or request.path.startswith('/admin'): - request['token'] = request.cookies.get('user-token') - request['user'] = None + if request.path.startswith(TOKEN_PATHS[:2]) and request['user'] is None: + if request.path == '/logout': + return Response.new_redir('/') - if request['token']: - with app.database.session(False) as conn: - request['user'] = conn.get_user_by_token(request['token']) + response = Response.new_redir(f'/login?redir={request.path}') - if request['user'] and request.path == '/login': - return Response.new('', 302, {'Location': '/'}) - - if not request['user'] and request.path.startswith('/admin'): - response = Response.new('', 302, {'Location': f'/login?redir={request.path}'}) + if request['token'] is not None: response.del_cookie('user-token') - return response + + return response response = await handler(request) - if not request.path.startswith('/api') and not request['user'] and request['token']: - response.del_cookie('user-token') + if not request.path.startswith('/api'): + if request['user'] is None and request['token'] is not None: + response.del_cookie('user-token') return response @@ -54,14 +46,15 @@ class HomeView(View): 'instances': tuple(conn.get_inboxes()) } - data = self.template.render('page/home.haml', self, **context) + data = self.template.render('page/home.haml', self.request, **context) return Response.new(data, ctype='html') @register_route('/login') class Login(View): async def get(self, request: web.Request) -> Response: - data = self.template.render('page/login.haml', self) + redir = unquote(request.query.get('redir', '/')) + data = self.template.render('page/login.haml', self.request, redir = redir) return Response.new(data, ctype = 'html') @@ -69,7 +62,7 @@ class Login(View): class Logout(View): async def get(self, request: web.Request) -> Response: with self.database.session(True) as conn: - conn.del_token(request['token']) + conn.del_app(request['token'].client_id, request['token'].client_secret) resp = Response.new_redir('/') resp.del_cookie('user-token', domain = self.config.domain, path = '/') @@ -79,7 +72,7 @@ class Logout(View): @register_route('/admin') class Admin(View): async def get(self, request: web.Request) -> Response: - return Response.new('', 302, {'Location': '/admin/instances'}) + return Response.new_redir(f'/login?redir={request.path}', 301) @register_route('/admin/instances') @@ -101,7 +94,7 @@ class AdminInstances(View): if message: context['message'] = message - data = self.template.render('page/admin-instances.haml', self, **context) + data = self.template.render('page/admin-instances.haml', self.request, **context) return Response.new(data, ctype = 'html') @@ -123,7 +116,7 @@ class AdminWhitelist(View): if message: context['message'] = message - data = self.template.render('page/admin-whitelist.haml', self, **context) + data = self.template.render('page/admin-whitelist.haml', self.request, **context) return Response.new(data, ctype = 'html') @@ -145,7 +138,7 @@ class AdminDomainBans(View): if message: context['message'] = message - data = self.template.render('page/admin-domain_bans.haml', self, **context) + data = self.template.render('page/admin-domain_bans.haml', self.request, **context) return Response.new(data, ctype = 'html') @@ -167,7 +160,7 @@ class AdminSoftwareBans(View): if message: context['message'] = message - data = self.template.render('page/admin-software_bans.haml', self, **context) + data = self.template.render('page/admin-software_bans.haml', self.request, **context) return Response.new(data, ctype = 'html') @@ -189,7 +182,7 @@ class AdminUsers(View): if message: context['message'] = message - data = self.template.render('page/admin-users.haml', self, **context) + data = self.template.render('page/admin-users.haml', self.request, **context) return Response.new(data, ctype = 'html') @@ -199,10 +192,21 @@ class AdminConfig(View): context: dict[str, Any] = { 'themes': tuple(THEMES.keys()), 'levels': tuple(level.name for level in LogLevel), - 'message': message + 'message': message, + 'desc': { + "name": "Name of the relay to be displayed in the header of the pages and in " + + "the actor endpoint.", # noqa: E131 + "note": "Description of the relay to be displayed on the front page and as the " + + "bio in the actor endpoint.", + "theme": "Color theme to use on the web pages.", + "log_level": "Minimum level of logging messages to print to the console.", + "whitelist_enabled": "Only allow instances in the whitelist to be able to follow.", + "approval_required": "Require instances not on the whitelist to be approved by " + + "and admin. The `whitelist-enabled` setting is ignored when this is enabled." + } } - data = self.template.render('page/admin-config.haml', self, **context) + data = self.template.render('page/admin-config.haml', self.request, **context) return Response.new(data, ctype = 'html') @@ -240,5 +244,5 @@ class ThemeCss(View): except KeyError: return Response.new('Invalid theme', 404) - data = self.template.render('variables.css', self, **context) + data = self.template.render('variables.css', self.request, **context) return Response.new(data, ctype = 'css') diff --git a/relay/workers.py b/relay/workers.py new file mode 100644 index 0000000..31cf4c3 --- /dev/null +++ b/relay/workers.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +import asyncio +import traceback + +from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError +from asyncio.exceptions import TimeoutError as AsyncTimeoutError +from dataclasses import dataclass +from multiprocessing import Event, Process, Queue, Value +from multiprocessing.queues import Queue as QueueType +from multiprocessing.sharedctypes import Synchronized +from multiprocessing.synchronize import Event as EventType +from pathlib import Path +from queue import Empty +from urllib.parse import urlparse + +from . import application, logger as logging +from .database.schema import Instance +from .http_client import HttpClient +from .misc import IS_WINDOWS, Message, get_app + + +@dataclass +class PostItem: + inbox: str + message: Message + instance: Instance | None + + @property + def domain(self) -> str: + return urlparse(self.inbox).netloc + + +class PushWorker(Process): + client: HttpClient + + + def __init__(self, queue: QueueType[PostItem], log_level: Synchronized[int]) -> None: + Process.__init__(self) + + self.queue: QueueType[PostItem] = queue + self.shutdown: EventType = Event() + self.path: Path = get_app().config.path + self.log_level: Synchronized[int] = log_level + self._log_level_changed: EventType = Event() + + + def stop(self) -> None: + self.shutdown.set() + + + def run(self) -> None: + asyncio.run(self.handle_queue()) + + + async def handle_queue(self) -> None: + if IS_WINDOWS: + app = application.Application(self.path) + self.client = app.client + + self.client.open() + app.database.connect() + app.cache.setup() + + else: + self.client = HttpClient() + self.client.open() + + logging.verbose("[%i] Starting worker", self.pid) + + while not self.shutdown.is_set(): + try: + if self._log_level_changed.is_set(): + logging.set_level(logging.LogLevel.parse(self.log_level.value)) + self._log_level_changed.clear() + + item = self.queue.get(block=True, timeout=0.1) + asyncio.create_task(self.handle_post(item)) + + except Empty: + await asyncio.sleep(0) + + except Exception: + traceback.print_exc() + + if IS_WINDOWS: + app.database.disconnect() + app.cache.close() + + await self.client.close() + + + async def handle_post(self, item: PostItem) -> None: + try: + await self.client.post(item.inbox, item.message, item.instance) + + except AsyncTimeoutError: + logging.error('Timeout when pushing to %s', item.domain) + + except ClientConnectionError as e: + logging.error('Failed to connect to %s for message push: %s', item.domain, str(e)) + + except ClientSSLError as e: + logging.error('SSL error when pushing to %s: %s', item.domain, str(e)) + + +class PushWorkers(list[PushWorker]): + def __init__(self, count: int) -> None: + self.queue: QueueType[PostItem] = Queue() + self._log_level: Synchronized[int] = Value("i", logging.get_level()) + self._count: int = count + + + def push_message(self, inbox: str, message: Message, instance: Instance) -> None: + self.queue.put(PostItem(inbox, message, instance)) + + + def set_log_level(self, value: logging.LogLevel) -> None: + self._log_level.value = value + + for worker in self: + worker._log_level_changed.set() + + + def start(self) -> None: + if len(self) > 0: + return + + for _ in range(self._count): + worker = PushWorker(self.queue, self._log_level) + worker.start() + self.append(worker) + + + def stop(self) -> None: + for worker in self: + worker.stop() + + self.clear()