From c508257981fa562be40e276f5f6f57f72588b75b Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Tue, 18 Jun 2024 23:14:21 -0400 Subject: [PATCH 01/21] raise exceptions instead of returning None from HttpClient methods --- relay/application.py | 12 ++++ relay/http_client.py | 119 +++++++++++++++---------------------- relay/views/activitypub.py | 7 +-- relay/views/api.py | 16 +++-- 4 files changed, 75 insertions(+), 79 deletions(-) diff --git a/relay/application.py b/relay/application.py index b12c64f..6c8c1e7 100644 --- a/relay/application.py +++ b/relay/application.py @@ -7,9 +7,11 @@ import time import traceback from aiohttp import web +from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError from aiohttp.web import StaticResource from aiohttp_swagger import setup_swagger from aputils.signer import Signer +from asyncio.exceptions import TimeoutError as AsyncTimeoutError from bsql import Database, Row from collections.abc import Awaitable, Callable from datetime import datetime, timedelta @@ -18,6 +20,7 @@ from pathlib import Path from queue import Empty from threading import Event, Thread from typing import Any +from urllib.parse import urlparse from . import logger as logging from .cache import Cache, get_cache @@ -331,6 +334,15 @@ class PushWorker(multiprocessing.Process): except Empty: await asyncio.sleep(0) + except ClientSSLError as e: + logging.error('SSL error when pushing to %s: %s', urlparse(inbox).netloc, str(e)) + + except (AsyncTimeoutError, ClientConnectionError) as e: + logging.error( + 'Failed to connect to %s for message push: %s', + urlparse(inbox).netloc, str(e) + ) + # make sure an exception doesn't bring down the worker except Exception: traceback.print_exc() diff --git a/relay/http_client.py b/relay/http_client.py index 54cea3c..610b8a9 100644 --- a/relay/http_client.py +++ b/relay/http_client.py @@ -1,17 +1,12 @@ from __future__ import annotations import json -import traceback from aiohttp import ClientSession, ClientTimeout, TCPConnector -from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError from aputils import AlgorithmType, Nodeinfo, ObjectType, Signer, WellKnownNodeinfo -from asyncio.exceptions import TimeoutError as AsyncTimeoutError from blib import JsonBase from bsql import Row -from json.decoder import JSONDecodeError -from typing import TYPE_CHECKING, Any, TypeVar -from urllib.parse import urlparse +from typing import TYPE_CHECKING, Any, TypeVar, overload from . import __version__, logger as logging from .cache import Cache @@ -107,7 +102,7 @@ class HttpClient: url: str, sign_headers: bool, force: bool, - old_algo: bool) -> dict[str, Any] | None: + old_algo: bool) -> str | None: if not self._session: raise RuntimeError('Client not open') @@ -121,7 +116,7 @@ class HttpClient: if not force: try: if not (item := self.cache.get('request', url)).older_than(48): - return json.loads(item.value) # type: ignore[no-any-return] + return item.value # type: ignore [no-any-return] except KeyError: logging.verbose('No cached data for url: %s', url) @@ -132,59 +127,61 @@ class HttpClient: algo = AlgorithmType.RSASHA256 if old_algo else AlgorithmType.HS2019 headers = self.signer.sign_headers('GET', url, algorithm = algo) - try: - logging.debug('Fetching resource: %s', url) + logging.debug('Fetching resource: %s', url) - async with self._session.get(url, headers = headers) as resp: - # Not expecting a response with 202s, so just return - if resp.status == 202: - return None - - data = await resp.text() - - if resp.status != 200: - logging.verbose('Received error when requesting %s: %i', url, resp.status) - logging.debug(data) + async with self._session.get(url, headers = headers) as resp: + # Not expecting a response with 202s, so just return + if resp.status == 202: return None - self.cache.set('request', url, data, 'str') - logging.debug('%s >> resp %s', url, json.dumps(json.loads(data), indent = 4)) + data = await resp.text() - return json.loads(data) # type: ignore [no-any-return] - - except JSONDecodeError: - logging.verbose('Failed to parse JSON') + if resp.status != 200: + logging.verbose('Received error when requesting %s: %i', url, resp.status) logging.debug(data) return None - except ClientSSLError as e: - logging.verbose('SSL error when connecting to %s', urlparse(url).netloc) - logging.warning(str(e)) + self.cache.set('request', url, data, 'str') + return data - except (AsyncTimeoutError, ClientConnectionError) as e: - logging.verbose('Failed to connect to %s', urlparse(url).netloc) - logging.warning(str(e)) - except Exception: - traceback.print_exc() + @overload + async def get(self, # type: ignore[overload-overlap] + url: str, + sign_headers: bool, + cls: None = None, + force: bool = False, + old_algo: bool = True) -> None: ... - return None + + @overload + async def get(self, + url: str, + sign_headers: bool, + cls: type[T] = JsonBase, # type: ignore[assignment] + force: bool = False, + old_algo: bool = True) -> T: ... async def get(self, url: str, sign_headers: bool, - cls: type[T], + cls: type[T] | None = None, force: bool = False, old_algo: bool = True) -> T | None: - if not issubclass(cls, JsonBase): + if cls is not None and not issubclass(cls, JsonBase): raise TypeError('cls must be a sub-class of "blib.JsonBase"') - if (data := (await self._get(url, sign_headers, force, old_algo))) is None: - return None + data = await self._get(url, sign_headers, force, old_algo) - return cls.parse(data) + if cls is not None: + if data is None: + raise ValueError("Empty response") + + return cls.parse(data) + + return None async def post(self, url: str, data: Message | bytes, instance: Row | None = None) -> None: @@ -218,35 +215,22 @@ class HttpClient: algorithm = algorithm ) - try: - logging.verbose('Sending "%s" to %s', mtype, url) + logging.verbose('Sending "%s" to %s', mtype, url) - async with self._session.post(url, headers = headers, data = body) as resp: - # Not expecting a response, so just return - if resp.status in {200, 202}: - logging.verbose('Successfully sent "%s" to %s', mtype, url) - return - - logging.verbose('Received error when pushing to %s: %i', url, resp.status) - logging.debug(await resp.read()) - logging.debug("message: %s", body.decode("utf-8")) - logging.debug("headers: %s", json.dumps(headers, indent = 4)) + async with self._session.post(url, headers = headers, data = body) as resp: + # Not expecting a response, so just return + if resp.status in {200, 202}: + logging.verbose('Successfully sent "%s" to %s', mtype, url) return - except ClientSSLError as e: - logging.warning('SSL error when pushing to %s', urlparse(url).netloc) - logging.warning(str(e)) - - except (AsyncTimeoutError, ClientConnectionError) as e: - logging.warning('Failed to connect to %s for message push', urlparse(url).netloc) - logging.warning(str(e)) - - # prevent workers from being brought down - except Exception: - traceback.print_exc() + logging.error('Received error when pushing to %s: %i', url, resp.status) + logging.debug(await resp.read()) + logging.debug("message: %s", body.decode("utf-8")) + logging.debug("headers: %s", json.dumps(headers, indent = 4)) + return - async def fetch_nodeinfo(self, domain: str) -> Nodeinfo | None: + async def fetch_nodeinfo(self, domain: str) -> Nodeinfo: nodeinfo_url = None wk_nodeinfo = await self.get( f'https://{domain}/.well-known/nodeinfo', @@ -254,10 +238,6 @@ class HttpClient: WellKnownNodeinfo ) - if wk_nodeinfo is None: - logging.verbose('Failed to fetch well-known nodeinfo url for %s', domain) - return None - for version in ('20', '21'): try: nodeinfo_url = wk_nodeinfo.get_url(version) @@ -266,8 +246,7 @@ class HttpClient: pass if nodeinfo_url is None: - logging.verbose('Failed to fetch nodeinfo url for %s', domain) - return None + raise ValueError(f'Failed to fetch nodeinfo url for {domain}') return await self.get(nodeinfo_url, False, Nodeinfo) diff --git a/relay/views/activitypub.py b/relay/views/activitypub.py index b19b7e1..f568d17 100644 --- a/relay/views/activitypub.py +++ b/relay/views/activitypub.py @@ -95,9 +95,10 @@ class ActorView(View): logging.verbose('actor not in message') return Response.new_error(400, 'no actor in message', 'json') - actor: Message | None = await self.client.get(self.signature.keyid, True, Message) + try: + self.actor = await self.client.get(self.signature.keyid, True, Message) - if actor is None: + except Exception: # ld signatures aren't handled atm, so just ignore it if self.message.type == 'Delete': logging.verbose('Instance sent a delete which cannot be handled') @@ -106,8 +107,6 @@ class ActorView(View): logging.verbose(f'Failed to fetch actor: {self.signature.keyid}') return Response.new_error(400, 'failed to fetch actor', 'json') - self.actor = actor - try: self.signer = self.actor.signer diff --git a/relay/views/api.py b/relay/views/api.py index 70a9f0e..074dc04 100644 --- a/relay/views/api.py +++ b/relay/views/api.py @@ -1,3 +1,5 @@ +import traceback + from aiohttp.web import Request, middleware from argon2.exceptions import VerifyMismatchError from collections.abc import Awaitable, Callable, Sequence @@ -206,19 +208,23 @@ class Inbox(View): data['domain'] = data['domain'].encode('idna').decode() if not data.get('inbox'): - actor_data: Message | None = await self.client.get(data['actor'], True, Message) + try: + actor_data = await self.client.get(data['actor'], True, Message) - if actor_data is None: + except Exception: + traceback.print_exc() return Response.new_error(500, 'Failed to fetch actor', 'json') data['inbox'] = actor_data.shared_inbox if not data.get('software'): - nodeinfo = await self.client.fetch_nodeinfo(data['domain']) - - if nodeinfo is not None: + try: + nodeinfo = await self.client.fetch_nodeinfo(data['domain']) data['software'] = nodeinfo.sw_name + except Exception: + pass + row = conn.put_inbox(**data) # type: ignore[arg-type] return Response.new(row, ctype = 'json') From 9a3e3768e75ac8e756bb7157a4429982bcf18087 Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Wed, 19 Jun 2024 12:59:08 -0400 Subject: [PATCH 02/21] modify workers * move all worker-related classes and functions to workers.py * change the log level in worker processes * create QueueItem and PostItem classes --- relay/application.py | 17 ++-- relay/config.py | 2 +- relay/database/connection.py | 1 + relay/workers.py | 150 +++++++++++++++++++++++++++++++++++ 4 files changed, 157 insertions(+), 13 deletions(-) create mode 100644 relay/workers.py diff --git a/relay/application.py b/relay/application.py index 6c8c1e7..a3d9925 100644 --- a/relay/application.py +++ b/relay/application.py @@ -22,7 +22,7 @@ from threading import Event, Thread from typing import Any from urllib.parse import urlparse -from . import logger as logging +from . import logger as logging, workers from .cache import Cache, get_cache from .config import Config from .database import Connection, get_database @@ -78,7 +78,7 @@ class Application(web.Application): self['cache'].setup() self['template'] = Template(self) self['push_queue'] = multiprocessing.Queue() - self['workers'] = [] + self['workers'] = workers.PushWorkers(self.config.workers) self.cache.setup() self.on_cleanup.append(handle_cleanup) # type: ignore @@ -143,7 +143,7 @@ class Application(web.Application): def push_message(self, inbox: str, message: Message, instance: Row) -> None: - self['push_queue'].put((inbox, message, instance)) + self['workers'].push_message(inbox, message, instance) def register_static_routes(self) -> None: @@ -198,12 +198,7 @@ class Application(web.Application): self['cache'].setup() self['cleanup_thread'] = CacheCleanupThread(self) self['cleanup_thread'].start() - - for _ in range(self.config.workers): - worker = PushWorker(self['push_queue']) - worker.start() - - self['workers'].append(worker) + self['workers'].start() runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"') await runner.setup() @@ -223,15 +218,13 @@ class Application(web.Application): await site.stop() - for worker in self['workers']: - worker.stop() + self['workers'].stop() self.set_signal_handler(False) self['starttime'] = None self['running'] = False self['cleanup_thread'].stop() - self['workers'].clear() self['database'].disconnect() self['cache'].close() diff --git a/relay/config.py b/relay/config.py index ac2bbb6..7e86ef7 100644 --- a/relay/config.py +++ b/relay/config.py @@ -61,7 +61,7 @@ class Config: def __init__(self, path: Path | None = None, load: bool = False): - self.path = Config.get_config_dir(path) + self.path: Path = Config.get_config_dir(path) self.reset() if load: diff --git a/relay/database/connection.py b/relay/database/connection.py index 614f307..864ad27 100644 --- a/relay/database/connection.py +++ b/relay/database/connection.py @@ -75,6 +75,7 @@ class Connection(SqlConnection): elif key == 'log-level': value = logging.LogLevel.parse(value) logging.set_level(value) + self.app['workers'].set_log_level(value) elif key in {'approval-required', 'whitelist-enabled'}: value = boolean(value) diff --git a/relay/workers.py b/relay/workers.py new file mode 100644 index 0000000..8d88ad7 --- /dev/null +++ b/relay/workers.py @@ -0,0 +1,150 @@ +import asyncio +import traceback +import typing + +from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError +from asyncio.exceptions import TimeoutError as AsyncTimeoutError +from bsql import Row +from dataclasses import dataclass +from multiprocessing import Event, Process, Queue, Value +from multiprocessing.synchronize import Event as EventType +from pathlib import Path +from queue import Empty, Queue as QueueType +from urllib.parse import urlparse + +from . import application, logger as logging +from .http_client import HttpClient +from .misc import IS_WINDOWS, Message, get_app + +if typing.TYPE_CHECKING: + from .multiprocessing.synchronize import Syncronized + + +@dataclass +class QueueItem: + pass + + +@dataclass +class PostItem(QueueItem): + inbox: str + message: Message + instance: Row | None + + @property + def domain(self) -> str: + return urlparse(self.inbox).netloc + + +class PushWorker(Process): + client: HttpClient + + + def __init__(self, queue: QueueType[QueueItem], log_level: "Syncronized[str]") -> None: + Process.__init__(self) + + self.queue: QueueType[QueueItem] = queue + self.shutdown: EventType = Event() + self.path: Path = get_app().config.path + self.log_level: "Syncronized[str]" = log_level + self._log_level_changed: EventType = Event() + + + def stop(self) -> None: + self.shutdown.set() + + + def run(self) -> None: + asyncio.run(self.handle_queue()) + + + async def handle_queue(self) -> None: + if IS_WINDOWS: + app = application.Application(self.path) + self.client = app.client + + self.client.open() + app.database.connect() + app.cache.setup() + + else: + self.client = HttpClient() + self.client.open() + + logging.verbose("[%i] Starting worker", self.pid) + + while not self.shutdown.is_set(): + try: + if self._log_level_changed.is_set(): + logging.set_level(logging.LogLevel.parse(self.log_level.value)) + self._log_level_changed.clear() + + item = self.queue.get(block=True, timeout=0.1) + + if isinstance(item, PostItem): + asyncio.create_task(self.handle_post(item)) + + except Empty: + await asyncio.sleep(0) + + except Exception: + traceback.print_exc() + + if IS_WINDOWS: + app.database.disconnect() + app.cache.close() + + await self.client.close() + + + async def handle_post(self, item: PostItem) -> None: + try: + await self.client.post(item.inbox, item.message, item.instance) + + except AsyncTimeoutError: + logging.error('Timeout when pushing to %s', item.domain) + + except ClientConnectionError as e: + logging.error('Failed to connect to %s for message push: %s', item.domain, str(e)) + + except ClientSSLError as e: + logging.error('SSL error when pushing to %s: %s', item.domain, str(e)) + + +class PushWorkers(list[PushWorker]): + def __init__(self, count: int) -> None: + self.queue: QueueType[QueueItem] = Queue() # type: ignore[assignment] + self._log_level: "Syncronized[str]" = Value("i", logging.get_level()) + self._count: int = count + + + def push_item(self, item: QueueItem) -> None: + self.queue.put(item) + + + def push_message(self, inbox: str, message: Message, instance: Row) -> None: + self.queue.put(PostItem(inbox, message, instance)) + + + def set_log_level(self, value: logging.LogLevel) -> None: + self._log_level.value = value + + for worker in self: + worker._log_level_changed.set() + + + def start(self) -> None: + if len(self) > 0: + return + + for _ in range(self._count): + worker = PushWorker(self.queue, self._log_level) + worker.start() + self.append(worker) + + + def stop(self) -> None: + for worker in self: + worker.stop() + + self.clear() From e67ebd75ed4d754dfce0fe0c43ace7fa6c80d6b5 Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Wed, 19 Jun 2024 17:51:13 -0400 Subject: [PATCH 03/21] force frontend resources to refresh on new version --- relay/frontend/base.haml | 10 +++++----- relay/frontend/page/admin-config.haml | 2 +- relay/frontend/page/admin-domain_bans.haml | 2 +- relay/frontend/page/admin-instances.haml | 2 +- relay/frontend/page/admin-software_bans.haml | 2 +- relay/frontend/page/admin-users.haml | 2 +- relay/frontend/page/admin-whitelist.haml | 2 +- relay/frontend/page/login.haml | 2 +- 8 files changed, 12 insertions(+), 12 deletions(-) diff --git a/relay/frontend/base.haml b/relay/frontend/base.haml index 7a14b72..d7551b8 100644 --- a/relay/frontend/base.haml +++ b/relay/frontend/base.haml @@ -11,11 +11,11 @@ %title << {{config.name}}: {{page}} %meta(charset="UTF-8") %meta(name="viewport" content="width=device-width, initial-scale=1") - %link(rel="stylesheet" type="text/css" href="/theme/{{config.theme}}.css" nonce="{{view.request['hash']}}" class="theme") - %link(rel="stylesheet" type="text/css" href="/static/style.css" nonce="{{view.request['hash']}}") - %link(rel="stylesheet" type="text/css" href="/static/bootstrap-icons.css" nonce="{{view.request['hash']}}") - %link(rel="manifest" href="/manifest.json") - %script(type="application/javascript" src="/static/api.js" nonce="{{view.request['hash']}}" defer) + %link(rel="stylesheet" type="text/css" href="/theme/{{config.theme}}.css?{{version}}" nonce="{{view.request['hash']}}" class="theme") + %link(rel="stylesheet" type="text/css" href="/static/style.css?{{version}}" nonce="{{view.request['hash']}}") + %link(rel="stylesheet" type="text/css" href="/static/bootstrap-icons.css?{{version}}" nonce="{{view.request['hash']}}") + %link(rel="manifest" href="/manifest.json?{{version}}") + %script(type="application/javascript" src="/static/api.js?{{version}}" nonce="{{view.request['hash']}}" defer) -block head %body diff --git a/relay/frontend/page/admin-config.haml b/relay/frontend/page/admin-config.haml index e5df986..57fde84 100644 --- a/relay/frontend/page/admin-config.haml +++ b/relay/frontend/page/admin-config.haml @@ -2,7 +2,7 @@ -set page="Config" -block head - %script(type="application/javascript" src="/static/config.js" nonce="{{view.request['hash']}}" defer) + %script(type="application/javascript" src="/static/config.js?{{version}}" nonce="{{view.request['hash']}}" defer) -import "functions.haml" as func -block content diff --git a/relay/frontend/page/admin-domain_bans.haml b/relay/frontend/page/admin-domain_bans.haml index b1f7f57..66cec3d 100644 --- a/relay/frontend/page/admin-domain_bans.haml +++ b/relay/frontend/page/admin-domain_bans.haml @@ -2,7 +2,7 @@ -set page="Domain Bans" -block head - %script(type="application/javascript" src="/static/domain_ban.js" nonce="{{view.request['hash']}}" defer) + %script(type="application/javascript" src="/static/domain_ban.js?{{version}}" nonce="{{view.request['hash']}}" defer) -block content %details.section diff --git a/relay/frontend/page/admin-instances.haml b/relay/frontend/page/admin-instances.haml index c317e30..1490fdd 100644 --- a/relay/frontend/page/admin-instances.haml +++ b/relay/frontend/page/admin-instances.haml @@ -2,7 +2,7 @@ -set page="Instances" -block head - %script(type="application/javascript" src="/static/instance.js" nonce="{{view.request['hash']}}" defer) + %script(type="application/javascript" src="/static/instance.js?{{version}}" nonce="{{view.request['hash']}}" defer) -block content %details.section diff --git a/relay/frontend/page/admin-software_bans.haml b/relay/frontend/page/admin-software_bans.haml index 9bda3be..3bc4648 100644 --- a/relay/frontend/page/admin-software_bans.haml +++ b/relay/frontend/page/admin-software_bans.haml @@ -2,7 +2,7 @@ -set page="Software Bans" -block head - %script(type="application/javascript" src="/static/software_ban.js" nonce="{{view.request['hash']}}" defer) + %script(type="application/javascript" src="/static/software_ban.js?{{version}}" nonce="{{view.request['hash']}}" defer) -block content %details.section diff --git a/relay/frontend/page/admin-users.haml b/relay/frontend/page/admin-users.haml index 50058d7..caa9dc2 100644 --- a/relay/frontend/page/admin-users.haml +++ b/relay/frontend/page/admin-users.haml @@ -2,7 +2,7 @@ -set page="Users" -block head - %script(type="application/javascript" src="/static/user.js" nonce="{{view.request['hash']}}" defer) + %script(type="application/javascript" src="/static/user.js?{{version}}" nonce="{{view.request['hash']}}" defer) -block content %details.section diff --git a/relay/frontend/page/admin-whitelist.haml b/relay/frontend/page/admin-whitelist.haml index c8111e5..0a300dd 100644 --- a/relay/frontend/page/admin-whitelist.haml +++ b/relay/frontend/page/admin-whitelist.haml @@ -2,7 +2,7 @@ -set page="Whitelist" -block head - %script(type="application/javascript" src="/static/whitelist.js" nonce="{{view.request['hash']}}" defer) + %script(type="application/javascript" src="/static/whitelist.js?{{version}}" nonce="{{view.request['hash']}}" defer) -block content %details.section diff --git a/relay/frontend/page/login.haml b/relay/frontend/page/login.haml index bf1ab1c..a56973a 100644 --- a/relay/frontend/page/login.haml +++ b/relay/frontend/page/login.haml @@ -2,7 +2,7 @@ -set page="Login" -block head - %script(type="application/javascript" src="/static/login.js" nonce="{{view.request['hash']}}" defer) + %script(type="application/javascript" src="/static/login.js?{{version}}" nonce="{{view.request['hash']}}" defer) -block content %fieldset.section From 7e08e187853c4bed00a73f0a7cb079ee551c60f6 Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Fri, 21 Jun 2024 02:37:13 -0400 Subject: [PATCH 04/21] add help tooltips to config page --- relay/frontend/page/admin-config.haml | 6 ++++++ relay/frontend/static/style.css | 6 +++++- relay/views/frontend.py | 13 ++++++++++++- 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/relay/frontend/page/admin-config.haml b/relay/frontend/page/admin-config.haml index 57fde84..8dc0ad9 100644 --- a/relay/frontend/page/admin-config.haml +++ b/relay/frontend/page/admin-config.haml @@ -11,19 +11,25 @@ .grid-2col %label(for="name") << Name + %i(class="bi bi-question-circle-fill" title="{{desc.name}}") %input(id = "name" placeholder="Relay Name" value="{{config.name or ''}}") %label(for="note") << Description + %i(class="bi bi-question-circle-fill" title="{{desc.note}}") %textarea(id="note" value="{{config.note or ''}}") << {{config.note}} %label(for="theme") << Color Theme + %i(class="bi bi-question-circle-fill" title="{{desc.theme}}") =func.new_select("theme", config.theme, themes) %label(for="log-level") << Log Level + %i(class="bi bi-question-circle-fill" title="{{desc.log_level}}") =func.new_select("log-level", config.log_level.name, levels) %label(for="whitelist-enabled") << Whitelist + %i(class="bi bi-question-circle-fill" title="{{desc.whitelist_enabled}}") =func.new_checkbox("whitelist-enabled", config.whitelist_enabled) %label(for="approval-required") << Approval Required + %i(class="bi bi-question-circle-fill" title="{{desc.approval_required}}") =func.new_checkbox("approval-required", config.approval_required) diff --git a/relay/frontend/static/style.css b/relay/frontend/static/style.css index 6ec9316..f0d72f5 100644 --- a/relay/frontend/static/style.css +++ b/relay/frontend/static/style.css @@ -297,13 +297,13 @@ textarea { border: 1px solid var(--error-border) !important; } +/* create .grid base class and .2col and 3col classes */ .grid-2col { display: grid; grid-template-columns: max-content auto; grid-gap: var(--spacing); margin-bottom: var(--spacing); align-items: center; - } .message { @@ -333,6 +333,10 @@ textarea { justify-self: left; } +#content.page-config .grid-2col { + grid-template-columns: max-content max-content auto; +} + @keyframes show_toast { 0% { diff --git a/relay/views/frontend.py b/relay/views/frontend.py index 5dfb43a..5ec16fc 100644 --- a/relay/views/frontend.py +++ b/relay/views/frontend.py @@ -199,7 +199,18 @@ class AdminConfig(View): context: dict[str, Any] = { 'themes': tuple(THEMES.keys()), 'levels': tuple(level.name for level in LogLevel), - 'message': message + 'message': message, + 'desc': { + "name": "Name of the relay to be displayed in the header of the pages and in " + + "the actor endpoint.", + "note": "Description of the relay to be displayed on the front page and as the " + + "bio in the actor endpoint.", + "theme": "Color theme to use on the web pages.", + "log_level": "Minimum level of logging messages to print to the console.", + "whitelist_enabled": "Only allow instances in the whitelist to be able to follow.", + "approval_required": "Require instances not on the whitelist to be approved by " + + "and admin. The `whitelist-enabled` setting is ignored when this is enabled." + } } data = self.template.render('page/admin-config.haml', self, **context) From 45b0de26c78725720f03fa5d4c8b5c5676d0fd1b Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Fri, 21 Jun 2024 04:44:13 -0400 Subject: [PATCH 05/21] add ability to submit forms with enter key --- relay/frontend/static/api.js | 1 - relay/frontend/static/config.js | 12 ++++++++++++ relay/frontend/static/domain_ban.js | 8 ++++++++ relay/frontend/static/instance.js | 8 ++++++++ relay/frontend/static/login.js | 25 +++++++++++++++++++------ relay/frontend/static/software_ban.js | 16 ++++++++++++++++ relay/frontend/static/user.js | 8 ++++++++ relay/frontend/static/whitelist.js | 6 ++++++ 8 files changed, 77 insertions(+), 7 deletions(-) diff --git a/relay/frontend/static/api.js b/relay/frontend/static/api.js index e7f376a..6aaefd7 100644 --- a/relay/frontend/static/api.js +++ b/relay/frontend/static/api.js @@ -123,7 +123,6 @@ async function request(method, path, body = null) { } else { if (Object.hasOwn(message, "created")) { - console.log(message.created) message.created = new Date(message.created); } } diff --git a/relay/frontend/static/config.js b/relay/frontend/static/config.js index 417c48a..612f4f3 100644 --- a/relay/frontend/static/config.js +++ b/relay/frontend/static/config.js @@ -35,6 +35,18 @@ async function handle_config_change(event) { } +document.querySelector("#name").addEventListener("keydown", async (event) => { + if (event.which === 13) { + await handle_config_change(event); + } +}); + +document.querySelector("#note").addEventListener("keydown", async (event) => { + if (event.which === 13 && event.ctrlKey) { + await handle_config_change(event); + } +}); + for (const elem of elems) { elem.addEventListener("change", handle_config_change); } diff --git a/relay/frontend/static/domain_ban.js b/relay/frontend/static/domain_ban.js index 4de2ebf..cdffbd3 100644 --- a/relay/frontend/static/domain_ban.js +++ b/relay/frontend/static/domain_ban.js @@ -114,6 +114,14 @@ document.querySelector("#new-ban").addEventListener("click", async (event) => { await ban(); }); +for (var elem of document.querySelectorAll("#add-item input")) { + elem.addEventListener("keydown", async (event) => { + if (event.which === 13) { + await ban(); + } + }); +} + for (var row of document.querySelector("fieldset.section table").rows) { if (!row.querySelector(".update-ban")) { continue; diff --git a/relay/frontend/static/instance.js b/relay/frontend/static/instance.js index a07b647..9519ebc 100644 --- a/relay/frontend/static/instance.js +++ b/relay/frontend/static/instance.js @@ -126,6 +126,14 @@ document.querySelector("#add-instance").addEventListener("click", async (event) await add_instance(); }) +for (var elem of document.querySelectorAll("#add-item input")) { + elem.addEventListener("keydown", async (event) => { + if (event.which === 13) { + await add_instance(); + } + }); +} + for (var row of document.querySelector("#instances").rows) { if (!row.querySelector(".remove a")) { continue; diff --git a/relay/frontend/static/login.js b/relay/frontend/static/login.js index 9c68f17..c61a7f4 100644 --- a/relay/frontend/static/login.js +++ b/relay/frontend/static/login.js @@ -1,10 +1,10 @@ -async function login(event) { - fields = { - username: document.querySelector("#username"), - password: document.querySelector("#password") - } +const fields = { + username: document.querySelector("#username"), + password: document.querySelector("#password") +} - values = { +async function login(event) { + const values = { username: fields.username.value.trim(), password: fields.password.value.trim() } @@ -26,4 +26,17 @@ async function login(event) { } +document.querySelector("#username").addEventListener("keydown", async (event) => { + if (event.which === 13) { + fields.password.focus(); + fields.password.select(); + } +}); + +document.querySelector("#password").addEventListener("keydown", async (event) => { + if (event.which === 13) { + await login(event); + } +}); + document.querySelector(".submit").addEventListener("click", login); diff --git a/relay/frontend/static/software_ban.js b/relay/frontend/static/software_ban.js index 663929a..bb54dbe 100644 --- a/relay/frontend/static/software_ban.js +++ b/relay/frontend/static/software_ban.js @@ -113,6 +113,22 @@ document.querySelector("#new-ban").addEventListener("click", async (event) => { await ban(); }); +for (var elem of document.querySelectorAll("#add-item input")) { + elem.addEventListener("keydown", async (event) => { + if (event.which === 13) { + await ban(); + } + }); +} + +for (var elem of document.querySelectorAll("#add-item textarea")) { + elem.addEventListener("keydown", async (event) => { + if (event.which === 13 && event.ctrlKey) { + await ban(); + } + }); +} + for (var row of document.querySelector("#bans").rows) { if (!row.querySelector(".update-ban")) { continue; diff --git a/relay/frontend/static/user.js b/relay/frontend/static/user.js index 9c74359..5bc16ec 100644 --- a/relay/frontend/static/user.js +++ b/relay/frontend/static/user.js @@ -76,6 +76,14 @@ document.querySelector("#new-user").addEventListener("click", async (event) => { await add_user(); }); +for (var elem of document.querySelectorAll("#add-item input")) { + elem.addEventListener("keydown", async (event) => { + if (event.which === 13) { + await add_user(); + } + }); +} + for (var row of document.querySelector("#users").rows) { if (!row.querySelector(".remove a")) { continue; diff --git a/relay/frontend/static/whitelist.js b/relay/frontend/static/whitelist.js index 70d4db1..c2b31e4 100644 --- a/relay/frontend/static/whitelist.js +++ b/relay/frontend/static/whitelist.js @@ -55,6 +55,12 @@ document.querySelector("#new-item").addEventListener("click", async (event) => { await add_whitelist(); }); +document.querySelector("#add-item").addEventListener("keydown", async (event) => { + if (event.which === 13) { + await add_whitelist(); + } +}); + for (var row of document.querySelector("fieldset.section table").rows) { if (!row.querySelector(".remove a")) { continue; From bdc7d41d7a61001cc7ac1d26dce7240582ea7a7f Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Tue, 25 Jun 2024 16:02:49 -0400 Subject: [PATCH 06/21] update barkshark-sql to 0.2.0-rc1 and create row classes --- pyproject.toml | 6 +- relay/__init__.py | 2 +- relay/application.py | 73 +-------------- relay/cache.py | 12 ++- relay/database/__init__.py | 2 + relay/database/connection.py | 172 +++++++++++++++++++++++------------ relay/database/schema.py | 131 +++++++++++++++----------- relay/http_client.py | 6 +- relay/manage.py | 131 +++++++++++++------------- relay/processors.py | 16 ++-- relay/views/activitypub.py | 14 +-- relay/views/api.py | 130 ++++++++++++++++---------- relay/views/frontend.py | 2 +- relay/workers.py | 6 +- 14 files changed, 374 insertions(+), 329 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6f06de0..a3c9410 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ "aiohttp >= 3.9.5", "aiohttp-swagger[performance] == 1.0.16", "argon2-cffi == 23.1.0", - "barkshark-lib >= 0.1.3-1", + "barkshark-lib >= 0.2.0-rc1", "barkshark-sql == 0.1.4-1", "click >= 8.1.2", "hiredis == 2.3.2", @@ -104,7 +104,3 @@ implicit_reexport = true [[tool.mypy.overrides]] module = "blib" implicit_reexport = true - -[[tool.mypy.overrides]] -module = "bsql" -implicit_reexport = true diff --git a/relay/__init__.py b/relay/__init__.py index 73e3bb4..80eb7f9 100644 --- a/relay/__init__.py +++ b/relay/__init__.py @@ -1 +1 @@ -__version__ = '0.3.2' +__version__ = '0.3.3' diff --git a/relay/application.py b/relay/application.py index a3d9925..d852f29 100644 --- a/relay/application.py +++ b/relay/application.py @@ -4,30 +4,26 @@ import asyncio import multiprocessing import signal import time -import traceback from aiohttp import web -from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError from aiohttp.web import StaticResource from aiohttp_swagger import setup_swagger from aputils.signer import Signer -from asyncio.exceptions import TimeoutError as AsyncTimeoutError -from bsql import Database, Row +from bsql import Database from collections.abc import Awaitable, Callable from datetime import datetime, timedelta from mimetypes import guess_type from pathlib import Path -from queue import Empty from threading import Event, Thread from typing import Any -from urllib.parse import urlparse from . import logger as logging, workers from .cache import Cache, get_cache from .config import Config from .database import Connection, get_database +from .database.schema import Instance from .http_client import HttpClient -from .misc import IS_WINDOWS, Message, Response, check_open_port, get_resource +from .misc import Message, Response, check_open_port, get_resource from .template import Template from .views import VIEWS from .views.api import handle_api_path @@ -142,7 +138,7 @@ class Application(web.Application): return timedelta(seconds=uptime.seconds) - def push_message(self, inbox: str, message: Message, instance: Row) -> None: + def push_message(self, inbox: str, message: Message, instance: Instance) -> None: self['workers'].push_message(inbox, message, instance) @@ -286,67 +282,6 @@ class CacheCleanupThread(Thread): self.running.clear() -class PushWorker(multiprocessing.Process): - def __init__(self, queue: multiprocessing.Queue[tuple[str, Message, Row]]) -> None: - if Application.DEFAULT is None: - raise RuntimeError('Application not setup yet') - - multiprocessing.Process.__init__(self) - - self.queue = queue - self.shutdown = multiprocessing.Event() - self.path = Application.DEFAULT.config.path - - - def stop(self) -> None: - self.shutdown.set() - - - def run(self) -> None: - asyncio.run(self.handle_queue()) - - - async def handle_queue(self) -> None: - if IS_WINDOWS: - app = Application(self.path) - client = app.client - - client.open() - app.database.connect() - app.cache.setup() - - else: - client = HttpClient() - client.open() - - while not self.shutdown.is_set(): - try: - inbox, message, instance = self.queue.get(block=True, timeout=0.1) - asyncio.create_task(client.post(inbox, message, instance)) - - except Empty: - await asyncio.sleep(0) - - except ClientSSLError as e: - logging.error('SSL error when pushing to %s: %s', urlparse(inbox).netloc, str(e)) - - except (AsyncTimeoutError, ClientConnectionError) as e: - logging.error( - 'Failed to connect to %s for message push: %s', - urlparse(inbox).netloc, str(e) - ) - - # make sure an exception doesn't bring down the worker - except Exception: - traceback.print_exc() - - if IS_WINDOWS: - app.database.disconnect() - app.cache.close() - - await client.close() - - @web.middleware async def handle_response_headers( request: web.Request, diff --git a/relay/cache.py b/relay/cache.py index e9f261b..da87cc5 100644 --- a/relay/cache.py +++ b/relay/cache.py @@ -4,7 +4,7 @@ import json import os from abc import ABC, abstractmethod -from bsql import Database +from bsql import Database, Row from collections.abc import Callable, Iterator from dataclasses import asdict, dataclass from datetime import datetime, timedelta, timezone @@ -172,7 +172,7 @@ class SqlCache(Cache): with self._db.session(False) as conn: with conn.run('get-cache-item', params) as cur: - if not (row := cur.one()): + if not (row := cur.one(Row)): raise KeyError(f'{namespace}:{key}') row.pop('id', None) @@ -211,9 +211,11 @@ class SqlCache(Cache): with self._db.session(True) as conn: with conn.run('set-cache-item', params) as cur: - row = cur.one() - row.pop('id', None) # type: ignore[union-attr] - return Item.from_data(*tuple(row.values())) # type: ignore[union-attr] + if (row := cur.one(Row)) is None: + raise RuntimeError("Cache item not set") + + row.pop('id', None) + return Item.from_data(*tuple(row.values())) def delete(self, namespace: str, key: str) -> None: diff --git a/relay/database/__init__.py b/relay/database/__init__.py index becd456..545f822 100644 --- a/relay/database/__init__.py +++ b/relay/database/__init__.py @@ -16,6 +16,8 @@ def get_database(config: Config, migrate: bool = True) -> Database[Connection]: 'tables': TABLES } + db: Database[Connection] + if config.db_type == 'sqlite': db = Database.sqlite(config.sqlite_path, **options) diff --git a/relay/database/connection.py b/relay/database/connection.py index 864ad27..006a907 100644 --- a/relay/database/connection.py +++ b/relay/database/connection.py @@ -2,12 +2,13 @@ from __future__ import annotations from argon2 import PasswordHasher from bsql import Connection as SqlConnection, Row, Update -from collections.abc import Iterator, Sequence +from collections.abc import Iterator from datetime import datetime, timezone from typing import TYPE_CHECKING, Any from urllib.parse import urlparse from uuid import uuid4 +from . import schema from .config import ( THEMES, ConfigData @@ -37,14 +38,14 @@ class Connection(SqlConnection): return get_app() - def distill_inboxes(self, message: Message) -> Iterator[Row]: + def distill_inboxes(self, message: Message) -> Iterator[schema.Instance]: src_domains = { message.domain, urlparse(message.object_id).netloc } for instance in self.get_inboxes(): - if instance['domain'] not in src_domains: + if instance.domain not in src_domains: yield instance @@ -52,7 +53,7 @@ class Connection(SqlConnection): key = key.replace('_', '-') with self.run('get-config', {'key': key}) as cur: - if not (row := cur.one()): + if (row := cur.one(Row)) is None: return ConfigData.DEFAULT(key) data = ConfigData() @@ -61,8 +62,8 @@ class Connection(SqlConnection): def get_config_all(self) -> ConfigData: - with self.run('get-config-all', None) as cur: - return ConfigData.from_rows(tuple(cur.all())) + rows = tuple(self.run('get-config-all', None).all(schema.Row)) + return ConfigData.from_rows(rows) def put_config(self, key: str, value: Any) -> Any: @@ -99,14 +100,13 @@ class Connection(SqlConnection): return data.get(key) - def get_inbox(self, value: str) -> Row: + def get_inbox(self, value: str) -> schema.Instance | None: with self.run('get-inbox', {'value': value}) as cur: - return cur.one() # type: ignore + return cur.one(schema.Instance) - def get_inboxes(self) -> Sequence[Row]: - with self.execute("SELECT * FROM inboxes WHERE accepted = 1") as cur: - return tuple(cur.all()) + def get_inboxes(self) -> Iterator[schema.Instance]: + return self.execute("SELECT * FROM inboxes WHERE accepted = 1").all(schema.Instance) def put_inbox(self, @@ -115,7 +115,7 @@ class Connection(SqlConnection): actor: str | None = None, followid: str | None = None, software: str | None = None, - accepted: bool = True) -> Row: + accepted: bool = True) -> schema.Instance: params: dict[str, Any] = { 'inbox': inbox, @@ -125,7 +125,7 @@ class Connection(SqlConnection): 'accepted': accepted } - if not self.get_inbox(domain): + if self.get_inbox(domain) is None: if not inbox: raise ValueError("Missing inbox") @@ -133,14 +133,20 @@ class Connection(SqlConnection): params['created'] = datetime.now(tz = timezone.utc) with self.run('put-inbox', params) as cur: - return cur.one() # type: ignore + if (row := cur.one(schema.Instance)) is None: + raise RuntimeError(f"Failed to insert instance: {domain}") + + return row for key, value in tuple(params.items()): if value is None: del params[key] with self.update('inboxes', params, domain = domain) as cur: - return cur.one() # type: ignore + if (row := cur.one(schema.Instance)) is None: + raise RuntimeError(f"Failed to update instance: {domain}") + + return row def del_inbox(self, value: str) -> bool: @@ -151,24 +157,23 @@ class Connection(SqlConnection): return cur.row_count == 1 - def get_request(self, domain: str) -> Row: + def get_request(self, domain: str) -> schema.Instance | None: with self.run('get-request', {'domain': domain}) as cur: - if not (row := cur.one()): - raise KeyError(domain) - - return row + return cur.one(schema.Instance) - def get_requests(self) -> Sequence[Row]: - with self.execute('SELECT * FROM inboxes WHERE accepted = 0') as cur: - return tuple(cur.all()) + def get_requests(self) -> Iterator[schema.Instance]: + return self.execute('SELECT * FROM inboxes WHERE accepted = 0').all(schema.Instance) - def put_request_response(self, domain: str, accepted: bool) -> Row: - instance = self.get_request(domain) + def put_request_response(self, domain: str, accepted: bool) -> schema.Instance: + if (instance := self.get_request(domain)) is None: + raise KeyError(domain) if not accepted: - self.del_inbox(domain) + if not self.del_inbox(domain): + raise RuntimeError(f'Failed to delete request: {domain}') + return instance params = { @@ -177,21 +182,28 @@ class Connection(SqlConnection): } with self.run('put-inbox-accept', params) as cur: - return cur.one() # type: ignore + if (row := cur.one(schema.Instance)) is None: + raise RuntimeError(f"Failed to insert response for domain: {domain}") + + return row - def get_user(self, value: str) -> Row: + def get_user(self, value: str) -> schema.User | None: with self.run('get-user', {'value': value}) as cur: - return cur.one() # type: ignore + return cur.one(schema.User) - def get_user_by_token(self, code: str) -> Row: + def get_user_by_token(self, code: str) -> schema.User | None: with self.run('get-user-by-token', {'code': code}) as cur: - return cur.one() # type: ignore + return cur.one(schema.User) - def put_user(self, username: str, password: str | None, handle: str | None = None) -> Row: - if self.get_user(username): + def get_users(self) -> Iterator[schema.User]: + return self.execute("SELECT * FROM users").all(schema.User) + + + def put_user(self, username: str, password: str | None, handle: str | None = None) -> schema.User: + if self.get_user(username) is not None: data: dict[str, str | datetime | None] = {} if password: @@ -204,7 +216,10 @@ class Connection(SqlConnection): stmt.set_where("username", username) with self.query(stmt) as cur: - return cur.one() # type: ignore + if (row := cur.one(schema.User)) is None: + raise RuntimeError(f"Failed to update user: {username}") + + return row if password is None: raise ValueError('Password cannot be empty') @@ -217,25 +232,36 @@ class Connection(SqlConnection): } with self.run('put-user', data) as cur: - return cur.one() # type: ignore + if (row := cur.one(schema.User)) is None: + raise RuntimeError(f"Failed to insert user: {username}") + + return row def del_user(self, username: str) -> None: - user = self.get_user(username) + if (user := self.get_user(username)) is None: + raise KeyError(username) - with self.run('del-user', {'value': user['username']}): + with self.run('del-user', {'value': user.username}): pass - with self.run('del-token-user', {'username': user['username']}): + with self.run('del-token-user', {'username': user.username}): pass - def get_token(self, code: str) -> Row: + def get_token(self, code: str) -> schema.Token | None: with self.run('get-token', {'code': code}) as cur: - return cur.one() # type: ignore + return cur.one(schema.Token) - def put_token(self, username: str) -> Row: + def get_tokens(self, username: str | None = None) -> Iterator[schema.Token]: + if username is not None: + return self.select('tokens').all(schema.Token) + + return self.select('tokens', username = username).all(schema.Token) + + + def put_token(self, username: str) -> schema.Token: data = { 'code': uuid4().hex, 'user': username, @@ -243,7 +269,10 @@ class Connection(SqlConnection): } with self.run('put-token', data) as cur: - return cur.one() # type: ignore + if (row := cur.one(schema.Token)) is None: + raise RuntimeError(f"Failed to insert token for user: {username}") + + return row def del_token(self, code: str) -> None: @@ -251,18 +280,22 @@ class Connection(SqlConnection): pass - def get_domain_ban(self, domain: str) -> Row: + def get_domain_ban(self, domain: str) -> schema.DomainBan | None: if domain.startswith('http'): domain = urlparse(domain).netloc with self.run('get-domain-ban', {'domain': domain}) as cur: - return cur.one() # type: ignore + return cur.one(schema.DomainBan) + + + def get_domain_bans(self) -> Iterator[schema.DomainBan]: + return self.execute("SELECT * FROM domain_bans").all(schema.DomainBan) def put_domain_ban(self, domain: str, reason: str | None = None, - note: str | None = None) -> Row: + note: str | None = None) -> schema.DomainBan: params = { 'domain': domain, @@ -272,13 +305,16 @@ class Connection(SqlConnection): } with self.run('put-domain-ban', params) as cur: - return cur.one() # type: ignore + if (row := cur.one(schema.DomainBan)) is None: + raise RuntimeError(f"Failed to insert domain ban: {domain}") + + return row def update_domain_ban(self, domain: str, reason: str | None = None, - note: str | None = None) -> Row: + note: str | None = None) -> schema.DomainBan: if not (reason or note): raise ValueError('"reason" and/or "note" must be specified') @@ -298,7 +334,10 @@ class Connection(SqlConnection): if cur.row_count > 1: raise ValueError('More than one row was modified') - return self.get_domain_ban(domain) + if (row := cur.one(schema.DomainBan)) is None: + raise RuntimeError(f"Failed to update domain ban: {domain}") + + return row def del_domain_ban(self, domain: str) -> bool: @@ -309,15 +348,19 @@ class Connection(SqlConnection): return cur.row_count == 1 - def get_software_ban(self, name: str) -> Row: + def get_software_ban(self, name: str) -> schema.SoftwareBan | None: with self.run('get-software-ban', {'name': name}) as cur: - return cur.one() # type: ignore + return cur.one(schema.SoftwareBan) + + + def get_software_bans(self) -> Iterator[schema.SoftwareBan,]: + return self.execute('SELECT * FROM software_bans').all(schema.SoftwareBan) def put_software_ban(self, name: str, reason: str | None = None, - note: str | None = None) -> Row: + note: str | None = None) -> schema.SoftwareBan: params = { 'name': name, @@ -327,13 +370,16 @@ class Connection(SqlConnection): } with self.run('put-software-ban', params) as cur: - return cur.one() # type: ignore + if (row := cur.one(schema.SoftwareBan)) is None: + raise RuntimeError(f'Failed to insert software ban: {name}') + + return row def update_software_ban(self, name: str, reason: str | None = None, - note: str | None = None) -> Row: + note: str | None = None) -> schema.SoftwareBan: if not (reason or note): raise ValueError('"reason" and/or "note" must be specified') @@ -353,7 +399,10 @@ class Connection(SqlConnection): if cur.row_count > 1: raise ValueError('More than one row was modified') - return self.get_software_ban(name) + if (row := cur.one(schema.SoftwareBan)) is None: + raise RuntimeError(f'Failed to update software ban: {name}') + + return row def del_software_ban(self, name: str) -> bool: @@ -364,19 +413,26 @@ class Connection(SqlConnection): return cur.row_count == 1 - def get_domain_whitelist(self, domain: str) -> Row: + def get_domain_whitelist(self, domain: str) -> schema.Whitelist | None: with self.run('get-domain-whitelist', {'domain': domain}) as cur: - return cur.one() # type: ignore + return cur.one() - def put_domain_whitelist(self, domain: str) -> Row: + def get_domains_whitelist(self) -> Iterator[schema.Whitelist,]: + return self.execute("SELECT * FROM whitelist").all(schema.Whitelist) + + + def put_domain_whitelist(self, domain: str) -> schema.Whitelist: params = { 'domain': domain, 'created': datetime.now(tz = timezone.utc) } with self.run('put-domain-whitelist', params) as cur: - return cur.one() # type: ignore + if (row := cur.one(schema.Whitelist)) is None: + raise RuntimeError(f'Failed to insert whitelisted domain: {domain}') + + return row def del_domain_whitelist(self, domain: str) -> bool: diff --git a/relay/database/schema.py b/relay/database/schema.py index 409ee57..1fd7003 100644 --- a/relay/database/schema.py +++ b/relay/database/schema.py @@ -1,61 +1,88 @@ -from bsql import Column, Table, Tables +from __future__ import annotations + +import typing + +from bsql import Column, Row, Tables from collections.abc import Callable +from datetime import datetime from .config import ConfigData -from .connection import Connection + +if typing.TYPE_CHECKING: + from .connection import Connection VERSIONS: dict[int, Callable[[Connection], None]] = {} -TABLES: Tables = Tables( - Table( - 'config', - Column('key', 'text', primary_key = True, unique = True, nullable = False), - Column('value', 'text'), - Column('type', 'text', default = 'str') - ), - Table( - 'inboxes', - Column('domain', 'text', primary_key = True, unique = True, nullable = False), - Column('actor', 'text', unique = True), - Column('inbox', 'text', unique = True, nullable = False), - Column('followid', 'text'), - Column('software', 'text'), - Column('accepted', 'boolean'), - Column('created', 'timestamp', nullable = False) - ), - Table( - 'whitelist', - Column('domain', 'text', primary_key = True, unique = True, nullable = True), - Column('created', 'timestamp') - ), - Table( - 'domain_bans', - Column('domain', 'text', primary_key = True, unique = True, nullable = True), - Column('reason', 'text'), - Column('note', 'text'), - Column('created', 'timestamp', nullable = False) - ), - Table( - 'software_bans', - Column('name', 'text', primary_key = True, unique = True, nullable = True), - Column('reason', 'text'), - Column('note', 'text'), - Column('created', 'timestamp', nullable = False) - ), - Table( - 'users', - Column('username', 'text', primary_key = True, unique = True, nullable = False), - Column('hash', 'text', nullable = False), - Column('handle', 'text'), - Column('created', 'timestamp', nullable = False) - ), - Table( - 'tokens', - Column('code', 'text', primary_key = True, unique = True, nullable = False), - Column('user', 'text', nullable = False), - Column('created', 'timestmap', nullable = False) - ) -) +TABLES = Tables() + + +@TABLES.add_row +class Config(Row): + key: Column[str] = Column('key', 'text', primary_key = True, unique = True, nullable = False) + value: Column[str] = Column('value', 'text') + type: Column[str] = Column('type', 'text', default = 'str') + + +@TABLES.add_row +class Instance(Row): + table_name: str = 'inboxes' + + domain: Column[str] = Column( + 'domain', 'text', primary_key = True, unique = True, nullable = False) + actor: Column[str] = Column('actor', 'text', unique = True) + inbox: Column[str] = Column('inbox', 'text', unique = True, nullable = False) + followid: Column[str] = Column('followid', 'text') + software: Column[str] = Column('software', 'text') + accepted: Column[datetime] = Column('accepted', 'boolean') + created: Column[datetime] = Column('created', 'timestamp', nullable = False) + + +@TABLES.add_row +class Whitelist(Row): + domain: Column[str] = Column( + 'domain', 'text', primary_key = True, unique = True, nullable = True) + created: Column[datetime] = Column('created', 'timestamp') + + +@TABLES.add_row +class DomainBan(Row): + table_name: str = 'domain_bans' + + domain: Column[str] = Column( + 'domain', 'text', primary_key = True, unique = True, nullable = True) + reason: Column[str] = Column('reason', 'text') + note: Column[str] = Column('note', 'text') + created: Column[datetime] = Column('created', 'timestamp') + + +@TABLES.add_row +class SoftwareBan(Row): + table_name: str = 'software_bans' + + name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True) + reason: Column[str] = Column('reason', 'text') + note: Column[str] = Column('note', 'text') + created: Column[datetime] = Column('created', 'timestamp') + + +@TABLES.add_row +class User(Row): + table_name: str = 'users' + + username: Column[str] = Column( + 'username', 'text', primary_key = True, unique = True, nullable = False) + hash: Column[str] = Column('hash', 'text', nullable = False) + handle: Column[str] = Column('handle', 'text') + created: Column[datetime] = Column('created', 'timestamp') + + +@TABLES.add_row +class Token(Row): + table_name: str = 'tokens' + + code: Column[str] = Column('code', 'text', primary_key = True, unique = True, nullable = False) + user: Column[str] = Column('user', 'text', nullable = False) + created: Column[datetime] = Column('created', 'timestamp') def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]: diff --git a/relay/http_client.py b/relay/http_client.py index 610b8a9..05a6565 100644 --- a/relay/http_client.py +++ b/relay/http_client.py @@ -5,11 +5,11 @@ import json from aiohttp import ClientSession, ClientTimeout, TCPConnector from aputils import AlgorithmType, Nodeinfo, ObjectType, Signer, WellKnownNodeinfo from blib import JsonBase -from bsql import Row from typing import TYPE_CHECKING, Any, TypeVar, overload from . import __version__, logger as logging from .cache import Cache +from .database.schema import Instance from .misc import MIMETYPES, Message, get_app if TYPE_CHECKING: @@ -184,12 +184,12 @@ class HttpClient: return None - async def post(self, url: str, data: Message | bytes, instance: Row | None = None) -> None: + async def post(self, url: str, data: Message | bytes, instance: Instance | None = None) -> None: if not self._session: raise RuntimeError('Client not open') # akkoma and pleroma do not support HS2019 and other software still needs to be tested - if instance and instance['software'] in SUPPORTS_HS2019: + if instance is not None and instance.software in SUPPORTS_HS2019: algorithm = AlgorithmType.HS2019 else: diff --git a/relay/manage.py b/relay/manage.py index cb2b099..81f546e 100644 --- a/relay/manage.py +++ b/relay/manage.py @@ -6,7 +6,6 @@ import click import json import os -from bsql import Row from pathlib import Path from shutil import copyfile from typing import Any @@ -17,7 +16,7 @@ from . import http_client as http from . import logger as logging from .application import Application from .compat import RelayConfig, RelayDatabase -from .database import RELAY_SOFTWARE, get_database +from .database import RELAY_SOFTWARE, get_database, schema from .misc import ACTOR_FORMATS, SOFTWARE, IS_DOCKER, Message @@ -367,8 +366,8 @@ def cli_user_list(ctx: click.Context) -> None: click.echo('Users:') with ctx.obj.database.session() as conn: - for user in conn.execute('SELECT * FROM users'): - click.echo(f'- {user["username"]}') + for row in conn.get_users(): + click.echo(f'- {row.username}') @cli_user.command('create') @@ -379,7 +378,7 @@ def cli_user_create(ctx: click.Context, username: str, handle: str) -> None: 'Create a new local user' with ctx.obj.database.session() as conn: - if conn.get_user(username): + if conn.get_user(username) is not None: click.echo(f'User already exists: {username}') return @@ -406,7 +405,7 @@ def cli_user_delete(ctx: click.Context, username: str) -> None: 'Delete a local user' with ctx.obj.database.session() as conn: - if not conn.get_user(username): + if conn.get_user(username) is None: click.echo(f'User does not exist: {username}') return @@ -424,8 +423,8 @@ def cli_user_list_tokens(ctx: click.Context, username: str) -> None: click.echo(f'Tokens for "{username}":') with ctx.obj.database.session() as conn: - for token in conn.execute('SELECT * FROM tokens WHERE user = :user', {'user': username}): - click.echo(f'- {token["code"]}') + for row in conn.get_tokens(username): + click.echo(f'- {row.code}') @cli_user.command('create-token') @@ -435,13 +434,13 @@ def cli_user_create_token(ctx: click.Context, username: str) -> None: 'Create a new API token for a user' with ctx.obj.database.session() as conn: - if not (user := conn.get_user(username)): + if (user := conn.get_user(username)) is None: click.echo(f'User does not exist: {username}') return - token = conn.put_token(user['username']) + token = conn.put_token(user.username) - click.echo(f'New token for "{username}": {token["code"]}') + click.echo(f'New token for "{username}": {token.code}') @cli_user.command('delete-token') @@ -451,7 +450,7 @@ def cli_user_delete_token(ctx: click.Context, code: str) -> None: 'Delete an API token' with ctx.obj.database.session() as conn: - if not conn.get_token(code): + if conn.get_token(code) is None: click.echo('Token does not exist') return @@ -473,8 +472,8 @@ def cli_inbox_list(ctx: click.Context) -> None: click.echo('Connected to the following instances or relays:') with ctx.obj.database.session() as conn: - for inbox in conn.get_inboxes(): - click.echo(f'- {inbox["inbox"]}') + for row in conn.get_inboxes(): + click.echo(f'- {row.inbox}') @cli_inbox.command('follow') @@ -483,19 +482,21 @@ def cli_inbox_list(ctx: click.Context) -> None: def cli_inbox_follow(ctx: click.Context, actor: str) -> None: 'Follow an actor (Relay must be running)' + instance: schema.Instance | None = None + with ctx.obj.database.session() as conn: if conn.get_domain_ban(actor): click.echo(f'Error: Refusing to follow banned actor: {actor}') return - if (inbox_data := conn.get_inbox(actor)): - inbox = inbox_data['inbox'] + if (instance := conn.get_inbox(actor)) is not None: + inbox = instance.inbox else: if not actor.startswith('http'): actor = f'https://{actor}/actor' - if not (actor_data := asyncio.run(http.get(actor, sign_headers = True))): + if (actor_data := asyncio.run(http.get(actor, sign_headers = True))) is None: click.echo(f'Failed to fetch actor: {actor}') return @@ -506,7 +507,7 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None: actor = actor ) - asyncio.run(http.post(inbox, message, inbox_data)) + asyncio.run(http.post(inbox, message, instance)) click.echo(f'Sent follow message to actor: {actor}') @@ -516,19 +517,19 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None: def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None: 'Unfollow an actor (Relay must be running)' - inbox_data: Row | None = None + instance: schema.Instance | None = None with ctx.obj.database.session() as conn: if conn.get_domain_ban(actor): click.echo(f'Error: Refusing to follow banned actor: {actor}') return - if (inbox_data := conn.get_inbox(actor)): - inbox = inbox_data['inbox'] + if (instance := conn.get_inbox(actor)): + inbox = instance.inbox message = Message.new_unfollow( host = ctx.obj.config.domain, actor = actor, - follow = inbox_data['followid'] + follow = instance.followid ) else: @@ -552,7 +553,7 @@ def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None: } ) - asyncio.run(http.post(inbox, message, inbox_data)) + asyncio.run(http.post(inbox, message, instance)) click.echo(f'Sent unfollow message to: {actor}') @@ -632,9 +633,9 @@ def cli_request_list(ctx: click.Context) -> None: click.echo('Follow requests:') with ctx.obj.database.session() as conn: - for instance in conn.get_requests(): - date = instance['created'].strftime('%Y-%m-%d') - click.echo(f'- [{date}] {instance["domain"]}') + for row in conn.get_requests(): + date = row.created.strftime('%Y-%m-%d') + click.echo(f'- [{date}] {row.domain}') @cli_request.command('accept') @@ -653,20 +654,20 @@ def cli_request_accept(ctx: click.Context, domain: str) -> None: message = Message.new_response( host = ctx.obj.config.domain, - actor = instance['actor'], - followid = instance['followid'], + actor = instance.actor, + followid = instance.followid, accept = True ) - asyncio.run(http.post(instance['inbox'], message, instance)) + asyncio.run(http.post(instance.inbox, message, instance)) - if instance['software'] != 'mastodon': + if instance.software != 'mastodon': message = Message.new_follow( host = ctx.obj.config.domain, - actor = instance['actor'] + actor = instance.actor ) - asyncio.run(http.post(instance['inbox'], message, instance)) + asyncio.run(http.post(instance.inbox, message, instance)) @cli_request.command('deny') @@ -685,12 +686,12 @@ def cli_request_deny(ctx: click.Context, domain: str) -> None: response = Message.new_response( host = ctx.obj.config.domain, - actor = instance['actor'], - followid = instance['followid'], + actor = instance.actor, + followid = instance.followid, accept = False ) - asyncio.run(http.post(instance['inbox'], response, instance)) + asyncio.run(http.post(instance.inbox, response, instance)) @cli.group('instance') @@ -706,12 +707,12 @@ def cli_instance_list(ctx: click.Context) -> None: click.echo('Banned domains:') with ctx.obj.database.session() as conn: - for instance in conn.execute('SELECT * FROM domain_bans'): - if instance['reason']: - click.echo(f'- {instance["domain"]} ({instance["reason"]})') + for row in conn.get_domain_bans(): + if row.reason is not None: + click.echo(f'- {row.domain} ({row.reason})') else: - click.echo(f'- {instance["domain"]}') + click.echo(f'- {row.domain}') @cli_instance.command('ban') @@ -723,7 +724,7 @@ def cli_instance_ban(ctx: click.Context, domain: str, reason: str, note: str) -> 'Ban an instance and remove the associated inbox if it exists' with ctx.obj.database.session() as conn: - if conn.get_domain_ban(domain): + if conn.get_domain_ban(domain) is not None: click.echo(f'Domain already banned: {domain}') return @@ -739,7 +740,7 @@ def cli_instance_unban(ctx: click.Context, domain: str) -> None: 'Unban an instance' with ctx.obj.database.session() as conn: - if not conn.del_domain_ban(domain): + if conn.del_domain_ban(domain) is None: click.echo(f'Instance wasn\'t banned: {domain}') return @@ -764,11 +765,11 @@ def cli_instance_update(ctx: click.Context, domain: str, reason: str, note: str) click.echo(f'Updated domain ban: {domain}') - if row['reason']: - click.echo(f'- {row["domain"]} ({row["reason"]})') + if row.reason: + click.echo(f'- {row.domain} ({row.reason})') else: - click.echo(f'- {row["domain"]}') + click.echo(f'- {row.domain}') @cli.group('software') @@ -784,12 +785,12 @@ def cli_software_list(ctx: click.Context) -> None: click.echo('Banned software:') with ctx.obj.database.session() as conn: - for software in conn.execute('SELECT * FROM software_bans'): - if software['reason']: - click.echo(f'- {software["name"]} ({software["reason"]})') + for row in conn.get_software_bans(): + if row.reason: + click.echo(f'- {row.name} ({row.reason})') else: - click.echo(f'- {software["name"]}') + click.echo(f'- {row.name}') @cli_software.command('ban') @@ -811,12 +812,12 @@ def cli_software_ban(ctx: click.Context, with ctx.obj.database.session() as conn: if name == 'RELAYS': - for software in RELAY_SOFTWARE: - if conn.get_software_ban(software): - click.echo(f'Relay already banned: {software}') + for item in RELAY_SOFTWARE: + if conn.get_software_ban(item): + click.echo(f'Relay already banned: {item}') continue - conn.put_software_ban(software, reason or 'relay', note) + conn.put_software_ban(item, reason or 'relay', note) click.echo('Banned all relay software') return @@ -893,11 +894,11 @@ def cli_software_update(ctx: click.Context, name: str, reason: str, note: str) - click.echo(f'Updated software ban: {name}') - if row['reason']: - click.echo(f'- {row["name"]} ({row["reason"]})') + if row.reason: + click.echo(f'- {row.name} ({row.reason})') else: - click.echo(f'- {row["name"]}') + click.echo(f'- {row.name}') @cli.group('whitelist') @@ -913,8 +914,8 @@ def cli_whitelist_list(ctx: click.Context) -> None: click.echo('Current whitelisted domains:') with ctx.obj.database.session() as conn: - for domain in conn.execute('SELECT * FROM whitelist'): - click.echo(f'- {domain["domain"]}') + for row in conn.get_domain_whitelist(): + click.echo(f'- {row.domain}') @cli_whitelist.command('add') @@ -953,23 +954,19 @@ def cli_whitelist_remove(ctx: click.Context, domain: str) -> None: @cli_whitelist.command('import') @click.pass_context def cli_whitelist_import(ctx: click.Context) -> None: - 'Add all current inboxes to the whitelist' + 'Add all current instances to the whitelist' with ctx.obj.database.session() as conn: - for inbox in conn.execute('SELECT * FROM inboxes').all(): - if conn.get_domain_whitelist(inbox['domain']): - click.echo(f'Domain already in whitelist: {inbox["domain"]}') + for row in conn.get_inboxes(): + if conn.get_domain_whitelist(row.domain) is not None: + click.echo(f'Domain already in whitelist: {row.domain}') continue - conn.put_domain_whitelist(inbox['domain']) + conn.put_domain_whitelist(row.domain) click.echo('Imported whitelist from inboxes') def main() -> None: - cli(prog_name='relay') - - -if __name__ == '__main__': - click.echo('Running relay.manage is depreciated. Run `activityrelay [command]` instead.') + cli(prog_name='activityrelay') diff --git a/relay/processors.py b/relay/processors.py index cd742ec..4e4d96f 100644 --- a/relay/processors.py +++ b/relay/processors.py @@ -34,7 +34,7 @@ async def handle_relay(view: ActorView, conn: Connection) -> None: logging.debug('>> relay: %s', message) for instance in conn.distill_inboxes(view.message): - view.app.push_message(instance["inbox"], message, instance) + view.app.push_message(instance.inbox, message, instance) view.cache.set('handle-relay', view.message.object_id, message.id, 'str') @@ -52,7 +52,7 @@ async def handle_forward(view: ActorView, conn: Connection) -> None: logging.debug('>> forward: %s', message) for instance in conn.distill_inboxes(view.message): - view.app.push_message(instance["inbox"], view.message, instance) + view.app.push_message(instance.inbox, view.message, instance) view.cache.set('handle-relay', view.message.id, message.id, 'str') @@ -177,7 +177,7 @@ async def handle_undo(view: ActorView, conn: Connection) -> None: return # prevent past unfollows from removing an instance - if view.instance['followid'] and view.instance['followid'] != view.message.object_id: + if view.instance.followid and view.instance.followid != view.message.object_id: return with conn.transaction(): @@ -221,18 +221,18 @@ async def run_processor(view: ActorView) -> None: with view.database.session() as conn: if view.instance: - if not view.instance['software']: - if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])): + if not view.instance.software: + if (nodeinfo := await view.client.fetch_nodeinfo(view.instance.domain)): with conn.transaction(): view.instance = conn.put_inbox( - domain = view.instance['domain'], + domain = view.instance.domain, software = nodeinfo.sw_name ) - if not view.instance['actor']: + if not view.instance.actor: with conn.transaction(): view.instance = conn.put_inbox( - domain = view.instance['domain'], + domain = view.instance.domain, actor = view.actor.id ) diff --git a/relay/views/activitypub.py b/relay/views/activitypub.py index f568d17..74b01c6 100644 --- a/relay/views/activitypub.py +++ b/relay/views/activitypub.py @@ -1,26 +1,22 @@ -from __future__ import annotations - import aputils import traceback -import typing + +from aiohttp.web import Request from .base import View, register_route from .. import logger as logging +from ..database import schema from ..misc import Message, Response from ..processors import run_processor -if typing.TYPE_CHECKING: - from aiohttp.web import Request - from bsql import Row - @register_route('/actor', '/inbox') class ActorView(View): signature: aputils.Signature message: Message actor: Message - instancce: Row + instance: schema.Instance signer: aputils.Signer @@ -47,7 +43,7 @@ class ActorView(View): return response with self.database.session() as conn: - self.instance = conn.get_inbox(self.actor.shared_inbox) + self.instance = conn.get_inbox(self.actor.shared_inbox) # type: ignore[assignment] # reject if actor is banned if conn.get_domain_ban(self.actor.domain): diff --git a/relay/views/api.py b/relay/views/api.py index 074dc04..3bdc822 100644 --- a/relay/views/api.py +++ b/relay/views/api.py @@ -90,10 +90,10 @@ class Login(View): token = conn.put_token(data['username']) - resp = Response.new({'token': token['code']}, ctype = 'json') + resp = Response.new({'token': token.code}, ctype = 'json') resp.set_cookie( 'user-token', - token['code'], + token.code, max_age = 60 * 60 * 24 * 365, domain = self.config.domain, path = '/', @@ -117,7 +117,7 @@ class RelayInfo(View): async def get(self, request: Request) -> Response: with self.database.session() as conn: config = conn.get_config_all() - inboxes = [row['domain'] for row in conn.get_inboxes()] + inboxes = [row.domain for row in conn.get_inboxes()] data = { 'domain': self.config.domain, @@ -188,7 +188,7 @@ class Config(View): class Inbox(View): async def get(self, request: Request) -> Response: with self.database.session() as conn: - data = conn.get_inboxes() + data = tuple(conn.get_inboxes()) return Response.new(data, ctype = 'json') @@ -202,7 +202,7 @@ class Inbox(View): data['domain'] = urlparse(data["actor"]).netloc with self.database.session() as conn: - if conn.get_inbox(data['domain']): + if conn.get_inbox(data['domain']) is not None: return Response.new_error(404, 'Instance already in database', 'json') data['domain'] = data['domain'].encode('idna').decode() @@ -225,7 +225,12 @@ class Inbox(View): except Exception: pass - row = conn.put_inbox(**data) # type: ignore[arg-type] + row = conn.put_inbox( + data['domain'], + actor = data.get('actor'), + software = data.get('software'), + followid = data.get('followid') + ) return Response.new(row, ctype = 'json') @@ -239,10 +244,15 @@ class Inbox(View): data['domain'] = data['domain'].encode('idna').decode() - if not (instance := conn.get_inbox(data['domain'])): + if (instance := conn.get_inbox(data['domain'])) is None: return Response.new_error(404, 'Instance with domain not found', 'json') - instance = conn.put_inbox(instance['domain'], **data) # type: ignore[arg-type] + instance = conn.put_inbox( + instance.domain, + actor = data.get('actor'), + software = data.get('software'), + followid = data.get('followid') + ) return Response.new(instance, ctype = 'json') @@ -268,7 +278,7 @@ class Inbox(View): class RequestView(View): async def get(self, request: Request) -> Response: with self.database.session() as conn: - instances = conn.get_requests() + instances = tuple(conn.get_requests()) return Response.new(instances, ctype = 'json') @@ -291,20 +301,20 @@ class RequestView(View): message = Message.new_response( host = self.config.domain, - actor = instance['actor'], - followid = instance['followid'], + actor = instance.actor, + followid = instance.followid, accept = data['accept'] ) - self.app.push_message(instance['inbox'], message, instance) + self.app.push_message(instance.inbox, message, instance) - if data['accept'] and instance['software'] != 'mastodon': + if data['accept'] and instance.software != 'mastodon': message = Message.new_follow( host = self.config.domain, - actor = instance['actor'] + actor = instance.actor ) - self.app.push_message(instance['inbox'], message, instance) + self.app.push_message(instance.inbox, message, instance) resp_message = {'message': 'Request accepted' if data['accept'] else 'Request denied'} return Response.new(resp_message, ctype = 'json') @@ -314,7 +324,7 @@ class RequestView(View): class DomainBan(View): async def get(self, request: Request) -> Response: with self.database.session() as conn: - bans = tuple(conn.execute('SELECT * FROM domain_bans').all()) + bans = tuple(conn.get_domain_bans()) return Response.new(bans, ctype = 'json') @@ -328,10 +338,14 @@ class DomainBan(View): data['domain'] = data['domain'].encode('idna').decode() with self.database.session() as conn: - if conn.get_domain_ban(data['domain']): + if conn.get_domain_ban(data['domain']) is not None: return Response.new_error(400, 'Domain already banned', 'json') - ban = conn.put_domain_ban(**data) + ban = conn.put_domain_ban( + domain = data['domain'], + reason = data.get('reason'), + note = data.get('note') + ) return Response.new(ban, ctype = 'json') @@ -343,15 +357,19 @@ class DomainBan(View): if isinstance(data, Response): return data - data['domain'] = data['domain'].encode('idna').decode() - - if not conn.get_domain_ban(data['domain']): - return Response.new_error(404, 'Domain not banned', 'json') - if not any([data.get('note'), data.get('reason')]): return Response.new_error(400, 'Must include note and/or reason parameters', 'json') - ban = conn.update_domain_ban(**data) + data['domain'] = data['domain'].encode('idna').decode() + + if conn.get_domain_ban(data['domain']) is None: + return Response.new_error(404, 'Domain not banned', 'json') + + ban = conn.update_domain_ban( + domain = data['domain'], + reason = data.get('reason'), + note = data.get('note') + ) return Response.new(ban, ctype = 'json') @@ -365,7 +383,7 @@ class DomainBan(View): data['domain'] = data['domain'].encode('idna').decode() - if not conn.get_domain_ban(data['domain']): + if conn.get_domain_ban(data['domain']) is None: return Response.new_error(404, 'Domain not banned', 'json') conn.del_domain_ban(data['domain']) @@ -377,7 +395,7 @@ class DomainBan(View): class SoftwareBan(View): async def get(self, request: Request) -> Response: with self.database.session() as conn: - bans = tuple(conn.execute('SELECT * FROM software_bans').all()) + bans = tuple(conn.get_software_bans()) return Response.new(bans, ctype = 'json') @@ -389,10 +407,14 @@ class SoftwareBan(View): return data with self.database.session() as conn: - if conn.get_software_ban(data['name']): + if conn.get_software_ban(data['name']) is not None: return Response.new_error(400, 'Domain already banned', 'json') - ban = conn.put_software_ban(**data) + ban = conn.put_software_ban( + name = data['name'], + reason = data.get('reason'), + note = data.get('note') + ) return Response.new(ban, ctype = 'json') @@ -403,14 +425,18 @@ class SoftwareBan(View): if isinstance(data, Response): return data + if not any([data.get('note'), data.get('reason')]): + return Response.new_error(400, 'Must include note and/or reason parameters', 'json') + with self.database.session() as conn: - if not conn.get_software_ban(data['name']): + if conn.get_software_ban(data['name']) is None: return Response.new_error(404, 'Software not banned', 'json') - if not any([data.get('note'), data.get('reason')]): - return Response.new_error(400, 'Must include note and/or reason parameters', 'json') - - ban = conn.update_software_ban(**data) + ban = conn.update_software_ban( + name = data['name'], + reason = data.get('reason'), + note = data.get('note') + ) return Response.new(ban, ctype = 'json') @@ -422,7 +448,7 @@ class SoftwareBan(View): return data with self.database.session() as conn: - if not conn.get_software_ban(data['name']): + if conn.get_software_ban(data['name']) is None: return Response.new_error(404, 'Software not banned', 'json') conn.del_software_ban(data['name']) @@ -436,7 +462,7 @@ class User(View): with self.database.session() as conn: items = [] - for row in conn.execute('SELECT * FROM users'): + for row in conn.get_users(): del row['hash'] items.append(row) @@ -450,12 +476,16 @@ class User(View): return data with self.database.session() as conn: - if conn.get_user(data['username']): + if conn.get_user(data['username']) is not None: return Response.new_error(404, 'User already exists', 'json') - user = conn.put_user(**data) - del user['hash'] + user = conn.put_user( + username = data['username'], + password = data['password'], + handle = data.get('handle') + ) + del user['hash'] return Response.new(user, ctype = 'json') @@ -466,9 +496,13 @@ class User(View): return data with self.database.session(True) as conn: - user = conn.put_user(**data) - del user['hash'] + user = conn.put_user( + username = data['username'], + password = data['password'], + handle = data.get('handle') + ) + del user['hash'] return Response.new(user, ctype = 'json') @@ -479,7 +513,7 @@ class User(View): return data with self.database.session(True) as conn: - if not conn.get_user(data['username']): + if conn.get_user(data['username']) is None: return Response.new_error(404, 'User does not exist', 'json') conn.del_user(data['username']) @@ -491,7 +525,7 @@ class User(View): class Whitelist(View): async def get(self, request: Request) -> Response: with self.database.session() as conn: - items = tuple(conn.execute('SELECT * FROM whitelist').all()) + items = tuple(conn.get_domains_whitelist()) return Response.new(items, ctype = 'json') @@ -502,13 +536,13 @@ class Whitelist(View): if isinstance(data, Response): return data - data['domain'] = data['domain'].encode('idna').decode() + domain = data['domain'].encode('idna').decode() with self.database.session() as conn: - if conn.get_domain_whitelist(data['domain']): + if conn.get_domain_whitelist(domain) is not None: return Response.new_error(400, 'Domain already added to whitelist', 'json') - item = conn.put_domain_whitelist(**data) + item = conn.put_domain_whitelist(domain) return Response.new(item, ctype = 'json') @@ -519,12 +553,12 @@ class Whitelist(View): if isinstance(data, Response): return data - data['domain'] = data['domain'].encode('idna').decode() + domain = data['domain'].encode('idna').decode() with self.database.session() as conn: - if not conn.get_domain_whitelist(data['domain']): + if conn.get_domain_whitelist(domain) is None: return Response.new_error(404, 'Domain not in whitelist', 'json') - conn.del_domain_whitelist(data['domain']) + conn.del_domain_whitelist(domain) return Response.new({'message': 'Removed domain from whitelist'}, ctype = 'json') diff --git a/relay/views/frontend.py b/relay/views/frontend.py index 5ec16fc..cf6b338 100644 --- a/relay/views/frontend.py +++ b/relay/views/frontend.py @@ -202,7 +202,7 @@ class AdminConfig(View): 'message': message, 'desc': { "name": "Name of the relay to be displayed in the header of the pages and in " + - "the actor endpoint.", + "the actor endpoint.", # noqa: E131 "note": "Description of the relay to be displayed on the front page and as the " + "bio in the actor endpoint.", "theme": "Color theme to use on the web pages.", diff --git a/relay/workers.py b/relay/workers.py index 8d88ad7..4b57409 100644 --- a/relay/workers.py +++ b/relay/workers.py @@ -4,7 +4,6 @@ import typing from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError from asyncio.exceptions import TimeoutError as AsyncTimeoutError -from bsql import Row from dataclasses import dataclass from multiprocessing import Event, Process, Queue, Value from multiprocessing.synchronize import Event as EventType @@ -13,6 +12,7 @@ from queue import Empty, Queue as QueueType from urllib.parse import urlparse from . import application, logger as logging +from .database.schema import Instance from .http_client import HttpClient from .misc import IS_WINDOWS, Message, get_app @@ -29,7 +29,7 @@ class QueueItem: class PostItem(QueueItem): inbox: str message: Message - instance: Row | None + instance: Instance | None @property def domain(self) -> str: @@ -122,7 +122,7 @@ class PushWorkers(list[PushWorker]): self.queue.put(item) - def push_message(self, inbox: str, message: Message, instance: Row) -> None: + def push_message(self, inbox: str, message: Message, instance: Instance) -> None: self.queue.put(PostItem(inbox, message, instance)) From 5e962be057d3f47a7793eb9903844ffb8ce2034c Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Tue, 25 Jun 2024 16:25:33 -0400 Subject: [PATCH 07/21] merge all javascript files --- relay/frontend/base.haml | 2 +- relay/frontend/page/admin-config.haml | 5 +- relay/frontend/page/admin-domain_bans.haml | 3 - relay/frontend/page/admin-instances.haml | 3 - relay/frontend/page/admin-software_bans.haml | 3 - relay/frontend/page/admin-users.haml | 3 - relay/frontend/page/admin-whitelist.haml | 3 - relay/frontend/page/home.haml | 1 + relay/frontend/page/login.haml | 3 - relay/frontend/static/api.js | 131 ---------------- relay/frontend/static/config.js | 52 ------- relay/frontend/static/domain_ban.js | 131 ---------------- relay/frontend/static/instance.js | 153 ------------------- relay/frontend/static/login.js | 42 ----- relay/frontend/static/software_ban.js | 138 ----------------- relay/frontend/static/user.js | 93 ----------- relay/frontend/static/whitelist.js | 70 --------- 17 files changed, 3 insertions(+), 833 deletions(-) delete mode 100644 relay/frontend/static/api.js delete mode 100644 relay/frontend/static/config.js delete mode 100644 relay/frontend/static/domain_ban.js delete mode 100644 relay/frontend/static/instance.js delete mode 100644 relay/frontend/static/login.js delete mode 100644 relay/frontend/static/software_ban.js delete mode 100644 relay/frontend/static/user.js delete mode 100644 relay/frontend/static/whitelist.js diff --git a/relay/frontend/base.haml b/relay/frontend/base.haml index d7551b8..d3d8bb6 100644 --- a/relay/frontend/base.haml +++ b/relay/frontend/base.haml @@ -15,7 +15,7 @@ %link(rel="stylesheet" type="text/css" href="/static/style.css?{{version}}" nonce="{{view.request['hash']}}") %link(rel="stylesheet" type="text/css" href="/static/bootstrap-icons.css?{{version}}" nonce="{{view.request['hash']}}") %link(rel="manifest" href="/manifest.json?{{version}}") - %script(type="application/javascript" src="/static/api.js?{{version}}" nonce="{{view.request['hash']}}" defer) + %script(type="application/javascript" src="/static/functions.js?{{version}}" nonce="{{view.request['hash']}}" defer) -block head %body diff --git a/relay/frontend/page/admin-config.haml b/relay/frontend/page/admin-config.haml index 8dc0ad9..226c052 100644 --- a/relay/frontend/page/admin-config.haml +++ b/relay/frontend/page/admin-config.haml @@ -1,10 +1,7 @@ -extends "base.haml" -set page="Config" - --block head - %script(type="application/javascript" src="/static/config.js?{{version}}" nonce="{{view.request['hash']}}" defer) - -import "functions.haml" as func + -block content %fieldset.section %legend << Config diff --git a/relay/frontend/page/admin-domain_bans.haml b/relay/frontend/page/admin-domain_bans.haml index 66cec3d..8aa6728 100644 --- a/relay/frontend/page/admin-domain_bans.haml +++ b/relay/frontend/page/admin-domain_bans.haml @@ -1,9 +1,6 @@ -extends "base.haml" -set page="Domain Bans" --block head - %script(type="application/javascript" src="/static/domain_ban.js?{{version}}" nonce="{{view.request['hash']}}" defer) - -block content %details.section %summary << Ban Domain diff --git a/relay/frontend/page/admin-instances.haml b/relay/frontend/page/admin-instances.haml index 1490fdd..61e08a0 100644 --- a/relay/frontend/page/admin-instances.haml +++ b/relay/frontend/page/admin-instances.haml @@ -1,9 +1,6 @@ -extends "base.haml" -set page="Instances" --block head - %script(type="application/javascript" src="/static/instance.js?{{version}}" nonce="{{view.request['hash']}}" defer) - -block content %details.section %summary << Add Instance diff --git a/relay/frontend/page/admin-software_bans.haml b/relay/frontend/page/admin-software_bans.haml index 3bc4648..faaa57e 100644 --- a/relay/frontend/page/admin-software_bans.haml +++ b/relay/frontend/page/admin-software_bans.haml @@ -1,9 +1,6 @@ -extends "base.haml" -set page="Software Bans" --block head - %script(type="application/javascript" src="/static/software_ban.js?{{version}}" nonce="{{view.request['hash']}}" defer) - -block content %details.section %summary << Ban Software diff --git a/relay/frontend/page/admin-users.haml b/relay/frontend/page/admin-users.haml index caa9dc2..d6715c9 100644 --- a/relay/frontend/page/admin-users.haml +++ b/relay/frontend/page/admin-users.haml @@ -1,9 +1,6 @@ -extends "base.haml" -set page="Users" --block head - %script(type="application/javascript" src="/static/user.js?{{version}}" nonce="{{view.request['hash']}}" defer) - -block content %details.section %summary << Add User diff --git a/relay/frontend/page/admin-whitelist.haml b/relay/frontend/page/admin-whitelist.haml index 0a300dd..2fa3b99 100644 --- a/relay/frontend/page/admin-whitelist.haml +++ b/relay/frontend/page/admin-whitelist.haml @@ -1,9 +1,6 @@ -extends "base.haml" -set page="Whitelist" --block head - %script(type="application/javascript" src="/static/whitelist.js?{{version}}" nonce="{{view.request['hash']}}" defer) - -block content %details.section %summary << Add Domain diff --git a/relay/frontend/page/home.haml b/relay/frontend/page/home.haml index fa883d6..7db7551 100644 --- a/relay/frontend/page/home.haml +++ b/relay/frontend/page/home.haml @@ -1,5 +1,6 @@ -extends "base.haml" -set page = "Home" + -block content -if config.note .section diff --git a/relay/frontend/page/login.haml b/relay/frontend/page/login.haml index a56973a..c32160f 100644 --- a/relay/frontend/page/login.haml +++ b/relay/frontend/page/login.haml @@ -1,9 +1,6 @@ -extends "base.haml" -set page="Login" --block head - %script(type="application/javascript" src="/static/login.js?{{version}}" nonce="{{view.request['hash']}}" defer) - -block content %fieldset.section %legend << Login diff --git a/relay/frontend/static/api.js b/relay/frontend/static/api.js deleted file mode 100644 index 6aaefd7..0000000 --- a/relay/frontend/static/api.js +++ /dev/null @@ -1,131 +0,0 @@ -// toast notifications - -const notifications = document.querySelector("#notifications") - - -function remove_toast(toast) { - toast.classList.add("hide"); - - if (toast.timeoutId) { - clearTimeout(toast.timeoutId); - } - - setTimeout(() => toast.remove(), 300); -} - -function toast(text, type="error", timeout=5) { - const toast = document.createElement("li"); - toast.className = `section ${type}` - toast.innerHTML = `${text}✖` - - toast.querySelector("a").addEventListener("click", async (event) => { - event.preventDefault(); - await remove_toast(toast); - }); - - notifications.appendChild(toast); - toast.timeoutId = setTimeout(() => remove_toast(toast), timeout * 1000); -} - - -// menu - -const body = document.getElementById("container") -const menu = document.getElementById("menu"); -const menu_open = document.querySelector("#menu-open i"); -const menu_close = document.getElementById("menu-close"); - - -function toggle_menu() { - let new_value = menu.attributes.visible.nodeValue === "true" ? "false" : "true"; - menu.attributes.visible.nodeValue = new_value; -} - - -menu_open.addEventListener("click", toggle_menu); -menu_close.addEventListener("click", toggle_menu); - -body.addEventListener("click", (event) => { - if (event.target === menu_open) { - return; - } - - menu.attributes.visible.nodeValue = "false"; -}); - -for (const elem of document.querySelectorAll("#menu-open div")) { - elem.addEventListener("click", toggle_menu); -} - - -// misc - -function get_date_string(date) { - var year = date.getUTCFullYear().toString(); - var month = (date.getUTCMonth() + 1).toString().padStart(2, "0"); - var day = date.getUTCDate().toString().padStart(2, "0"); - - return `${year}-${month}-${day}`; -} - - -function append_table_row(table, row_name, row) { - var table_row = table.insertRow(-1); - table_row.id = row_name; - - index = 0; - - for (var prop in row) { - if (Object.prototype.hasOwnProperty.call(row, prop)) { - var cell = table_row.insertCell(index); - cell.className = prop; - cell.innerHTML = row[prop]; - - index += 1; - } - } - - return table_row; -} - - -async function request(method, path, body = null) { - var headers = { - "Accept": "application/json" - } - - if (body !== null) { - headers["Content-Type"] = "application/json" - body = JSON.stringify(body) - } - - const response = await fetch("/api/" + path, { - method: method, - mode: "cors", - cache: "no-store", - redirect: "follow", - body: body, - headers: headers - }); - - const message = await response.json(); - - if (Object.hasOwn(message, "error")) { - throw new Error(message.error); - } - - if (Array.isArray(message)) { - message.forEach((msg) => { - if (Object.hasOwn(msg, "created")) { - msg.created = new Date(msg.created); - } - }); - - } else { - if (Object.hasOwn(message, "created")) { - message.created = new Date(message.created); - } - } - - return message; -} diff --git a/relay/frontend/static/config.js b/relay/frontend/static/config.js deleted file mode 100644 index 612f4f3..0000000 --- a/relay/frontend/static/config.js +++ /dev/null @@ -1,52 +0,0 @@ -const elems = [ - document.querySelector("#name"), - document.querySelector("#note"), - document.querySelector("#theme"), - document.querySelector("#log-level"), - document.querySelector("#whitelist-enabled"), - document.querySelector("#approval-required") -] - - -async function handle_config_change(event) { - params = { - key: event.target.id, - value: event.target.type === "checkbox" ? event.target.checked : event.target.value - } - - try { - await request("POST", "v1/config", params); - - } catch (error) { - toast(error); - return; - } - - if (params.key === "name") { - document.querySelector("#header .title").innerHTML = params.value; - document.querySelector("title").innerHTML = params.value; - } - - if (params.key === "theme") { - document.querySelector("link.theme").href = `/theme/${params.value}.css`; - } - - toast("Updated config", "message"); -} - - -document.querySelector("#name").addEventListener("keydown", async (event) => { - if (event.which === 13) { - await handle_config_change(event); - } -}); - -document.querySelector("#note").addEventListener("keydown", async (event) => { - if (event.which === 13 && event.ctrlKey) { - await handle_config_change(event); - } -}); - -for (const elem of elems) { - elem.addEventListener("change", handle_config_change); -} diff --git a/relay/frontend/static/domain_ban.js b/relay/frontend/static/domain_ban.js deleted file mode 100644 index cdffbd3..0000000 --- a/relay/frontend/static/domain_ban.js +++ /dev/null @@ -1,131 +0,0 @@ -function create_ban_object(domain, reason, note) { - var text = '
\n'; - text += `${domain}\n`; - text += '
\n'; - text += `\n`; - text += `\n`; - text += `\n`; - text += `\n`; - text += ``; - text += '
'; - - return text; -} - - -function add_row_listeners(row) { - row.querySelector(".update-ban").addEventListener("click", async (event) => { - await update_ban(row.id); - }); - - row.querySelector(".remove a").addEventListener("click", async (event) => { - event.preventDefault(); - await unban(row.id); - }); -} - - -async function ban() { - var table = document.querySelector("table"); - var elems = { - domain: document.getElementById("new-domain"), - reason: document.getElementById("new-reason"), - note: document.getElementById("new-note") - } - - var values = { - domain: elems.domain.value.trim(), - reason: elems.reason.value.trim(), - note: elems.note.value.trim() - } - - if (values.domain === "") { - toast("Domain is required"); - return; - } - - try { - var ban = await request("POST", "v1/domain_ban", values); - - } catch (err) { - toast(err); - return - } - - var row = append_table_row(document.querySelector("table"), ban.domain, { - domain: create_ban_object(ban.domain, ban.reason, ban.note), - date: get_date_string(ban.created), - remove: `
` - }); - - add_row_listeners(row); - - elems.domain.value = null; - elems.reason.value = null; - elems.note.value = null; - - document.querySelector("details.section").open = false; - toast("Banned domain", "message"); -} - - -async function update_ban(domain) { - var row = document.getElementById(domain); - - var elems = { - "reason": row.querySelector("textarea.reason"), - "note": row.querySelector("textarea.note") - } - - var values = { - "domain": domain, - "reason": elems.reason.value, - "note": elems.note.value - } - - try { - await request("PATCH", "v1/domain_ban", values) - - } catch (error) { - toast(error); - return; - } - - row.querySelector("details").open = false; - toast("Updated baned domain", "message"); -} - - -async function unban(domain) { - try { - await request("DELETE", "v1/domain_ban", {"domain": domain}); - - } catch (error) { - toast(error); - return; - } - - document.getElementById(domain).remove(); - toast("Unbanned domain", "message"); -} - - -document.querySelector("#new-ban").addEventListener("click", async (event) => { - await ban(); -}); - -for (var elem of document.querySelectorAll("#add-item input")) { - elem.addEventListener("keydown", async (event) => { - if (event.which === 13) { - await ban(); - } - }); -} - -for (var row of document.querySelector("fieldset.section table").rows) { - if (!row.querySelector(".update-ban")) { - continue; - } - - add_row_listeners(row); -} diff --git a/relay/frontend/static/instance.js b/relay/frontend/static/instance.js deleted file mode 100644 index 9519ebc..0000000 --- a/relay/frontend/static/instance.js +++ /dev/null @@ -1,153 +0,0 @@ -function add_instance_listeners(row) { - row.querySelector(".remove a").addEventListener("click", async (event) => { - event.preventDefault(); - await del_instance(row.id); - }); -} - - -function add_request_listeners(row) { - row.querySelector(".approve a").addEventListener("click", async (event) => { - event.preventDefault(); - await req_response(row.id, true); - }); - - row.querySelector(".deny a").addEventListener("click", async (event) => { - event.preventDefault(); - await req_response(row.id, false); - }); -} - - -async function add_instance() { - var elems = { - actor: document.getElementById("new-actor"), - inbox: document.getElementById("new-inbox"), - followid: document.getElementById("new-followid"), - software: document.getElementById("new-software") - } - - var values = { - actor: elems.actor.value.trim(), - inbox: elems.inbox.value.trim(), - followid: elems.followid.value.trim(), - software: elems.software.value.trim() - } - - if (values.actor === "") { - toast("Actor is required"); - return; - } - - try { - var instance = await request("POST", "v1/instance", values); - - } catch (err) { - toast(err); - return - } - - row = append_table_row(document.getElementById("instances"), instance.domain, { - domain: `${instance.domain}`, - software: instance.software, - date: get_date_string(instance.created), - remove: `` - }); - - add_instance_listeners(row); - - elems.actor.value = null; - elems.inbox.value = null; - elems.followid.value = null; - elems.software.value = null; - - document.querySelector("details.section").open = false; - toast("Added instance", "message"); -} - - -async function del_instance(domain) { - try { - await request("DELETE", "v1/instance", {"domain": domain}); - - } catch (error) { - toast(error); - return; - } - - document.getElementById(domain).remove(); -} - - -async function req_response(domain, accept) { - params = { - "domain": domain, - "accept": accept - } - - try { - await request("POST", "v1/request", params); - - } catch (error) { - toast(error); - return; - } - - document.getElementById(domain).remove(); - - if (document.getElementById("requests").rows.length < 2) { - document.querySelector("fieldset.requests").remove() - } - - if (!accept) { - toast("Denied instance request", "message"); - return; - } - - instances = await request("GET", `v1/instance`, null); - instances.forEach((instance) => { - if (instance.domain === domain) { - row = append_table_row(document.getElementById("instances"), instance.domain, { - domain: `${instance.domain}`, - software: instance.software, - date: get_date_string(instance.created), - remove: `` - }); - - add_instance_listeners(row); - } - }); - - toast("Accepted instance request", "message"); -} - - -document.querySelector("#add-instance").addEventListener("click", async (event) => { - await add_instance(); -}) - -for (var elem of document.querySelectorAll("#add-item input")) { - elem.addEventListener("keydown", async (event) => { - if (event.which === 13) { - await add_instance(); - } - }); -} - -for (var row of document.querySelector("#instances").rows) { - if (!row.querySelector(".remove a")) { - continue; - } - - add_instance_listeners(row); -} - -if (document.querySelector("#requests")) { - for (var row of document.querySelector("#requests").rows) { - if (!row.querySelector(".approve a")) { - continue; - } - - add_request_listeners(row); - } -} diff --git a/relay/frontend/static/login.js b/relay/frontend/static/login.js deleted file mode 100644 index c61a7f4..0000000 --- a/relay/frontend/static/login.js +++ /dev/null @@ -1,42 +0,0 @@ -const fields = { - username: document.querySelector("#username"), - password: document.querySelector("#password") -} - -async function login(event) { - const values = { - username: fields.username.value.trim(), - password: fields.password.value.trim() - } - - if (values.username === "" | values.password === "") { - toast("Username and/or password field is blank"); - return; - } - - try { - await request("POST", "v1/token", values); - - } catch (error) { - toast(error); - return; - } - - document.location = "/"; -} - - -document.querySelector("#username").addEventListener("keydown", async (event) => { - if (event.which === 13) { - fields.password.focus(); - fields.password.select(); - } -}); - -document.querySelector("#password").addEventListener("keydown", async (event) => { - if (event.which === 13) { - await login(event); - } -}); - -document.querySelector(".submit").addEventListener("click", login); diff --git a/relay/frontend/static/software_ban.js b/relay/frontend/static/software_ban.js deleted file mode 100644 index bb54dbe..0000000 --- a/relay/frontend/static/software_ban.js +++ /dev/null @@ -1,138 +0,0 @@ -function create_ban_object(name, reason, note) { - var text = '
\n'; - text += `${name}\n`; - text += '
\n'; - text += `\n`; - text += `\n`; - text += `\n`; - text += `\n`; - text += ``; - text += '
'; - - return text; -} - - -function add_row_listeners(row) { - row.querySelector(".update-ban").addEventListener("click", async (event) => { - await update_ban(row.id); - }); - - row.querySelector(".remove a").addEventListener("click", async (event) => { - event.preventDefault(); - await unban(row.id); - }); -} - - -async function ban() { - var elems = { - name: document.getElementById("new-name"), - reason: document.getElementById("new-reason"), - note: document.getElementById("new-note") - } - - var values = { - name: elems.name.value.trim(), - reason: elems.reason.value, - note: elems.note.value - } - - if (values.name === "") { - toast("Domain is required"); - return; - } - - try { - var ban = await request("POST", "v1/software_ban", values); - - } catch (err) { - toast(err); - return - } - - var row = append_table_row(document.getElementById("bans"), ban.name, { - name: create_ban_object(ban.name, ban.reason, ban.note), - date: get_date_string(ban.created), - remove: `` - }); - - add_row_listeners(row); - - elems.name.value = null; - elems.reason.value = null; - elems.note.value = null; - - document.querySelector("details.section").open = false; - toast("Banned software", "message"); -} - - -async function update_ban(name) { - var row = document.getElementById(name); - - var elems = { - "reason": row.querySelector("textarea.reason"), - "note": row.querySelector("textarea.note") - } - - var values = { - "name": name, - "reason": elems.reason.value, - "note": elems.note.value - } - - try { - await request("PATCH", "v1/software_ban", values) - - } catch (error) { - toast(error); - return; - } - - row.querySelector("details").open = false; - toast("Updated software ban", "message"); -} - - -async function unban(name) { - try { - await request("DELETE", "v1/software_ban", {"name": name}); - - } catch (error) { - toast(error); - return; - } - - document.getElementById(name).remove(); - toast("Unbanned software", "message"); -} - - -document.querySelector("#new-ban").addEventListener("click", async (event) => { - await ban(); -}); - -for (var elem of document.querySelectorAll("#add-item input")) { - elem.addEventListener("keydown", async (event) => { - if (event.which === 13) { - await ban(); - } - }); -} - -for (var elem of document.querySelectorAll("#add-item textarea")) { - elem.addEventListener("keydown", async (event) => { - if (event.which === 13 && event.ctrlKey) { - await ban(); - } - }); -} - -for (var row of document.querySelector("#bans").rows) { - if (!row.querySelector(".update-ban")) { - continue; - } - - add_row_listeners(row); -} diff --git a/relay/frontend/static/user.js b/relay/frontend/static/user.js deleted file mode 100644 index 5bc16ec..0000000 --- a/relay/frontend/static/user.js +++ /dev/null @@ -1,93 +0,0 @@ -function add_row_listeners(row) { - row.querySelector(".remove a").addEventListener("click", async (event) => { - event.preventDefault(); - await del_user(row.id); - }); -} - - -async function add_user() { - var elems = { - username: document.getElementById("new-username"), - password: document.getElementById("new-password"), - password2: document.getElementById("new-password2"), - handle: document.getElementById("new-handle") - } - - var values = { - username: elems.username.value.trim(), - password: elems.password.value.trim(), - password2: elems.password2.value.trim(), - handle: elems.handle.value.trim() - } - - if (values.username === "" | values.password === "" | values.password2 === "") { - toast("Username, password, and password2 are required"); - return; - } - - if (values.password !== values.password2) { - toast("Passwords do not match"); - return; - } - - try { - var user = await request("POST", "v1/user", values); - - } catch (err) { - toast(err); - return - } - - var row = append_table_row(document.querySelector("fieldset.section table"), user.username, { - domain: user.username, - handle: user.handle ? self.handle : "n/a", - date: get_date_string(user.created), - remove: `` - }); - - add_row_listeners(row); - - elems.username.value = null; - elems.password.value = null; - elems.password2.value = null; - elems.handle.value = null; - - document.querySelector("details.section").open = false; - toast("Created user", "message"); -} - - -async function del_user(username) { - try { - await request("DELETE", "v1/user", {"username": username}); - - } catch (error) { - toast(error); - return; - } - - document.getElementById(username).remove(); - toast("Deleted user", "message"); -} - - -document.querySelector("#new-user").addEventListener("click", async (event) => { - await add_user(); -}); - -for (var elem of document.querySelectorAll("#add-item input")) { - elem.addEventListener("keydown", async (event) => { - if (event.which === 13) { - await add_user(); - } - }); -} - -for (var row of document.querySelector("#users").rows) { - if (!row.querySelector(".remove a")) { - continue; - } - - add_row_listeners(row); -} diff --git a/relay/frontend/static/whitelist.js b/relay/frontend/static/whitelist.js deleted file mode 100644 index c2b31e4..0000000 --- a/relay/frontend/static/whitelist.js +++ /dev/null @@ -1,70 +0,0 @@ -function add_row_listeners(row) { - row.querySelector(".remove a").addEventListener("click", async (event) => { - event.preventDefault(); - await del_whitelist(row.id); - }); -} - - -async function add_whitelist() { - var domain_elem = document.getElementById("new-domain"); - var domain = domain_elem.value.trim(); - - if (domain === "") { - toast("Domain is required"); - return; - } - - try { - var item = await request("POST", "v1/whitelist", {"domain": domain}); - - } catch (err) { - toast(err); - return; - } - - var row = append_table_row(document.getElementById("whitelist"), item.domain, { - domain: item.domain, - date: get_date_string(item.created), - remove: `` - }); - - add_row_listeners(row); - - domain_elem.value = null; - document.querySelector("details.section").open = false; - toast("Added domain", "message"); -} - - -async function del_whitelist(domain) { - try { - await request("DELETE", "v1/whitelist", {"domain": domain}); - - } catch (error) { - toast(error); - return; - } - - document.getElementById(domain).remove(); - toast("Removed domain", "message"); -} - - -document.querySelector("#new-item").addEventListener("click", async (event) => { - await add_whitelist(); -}); - -document.querySelector("#add-item").addEventListener("keydown", async (event) => { - if (event.which === 13) { - await add_whitelist(); - } -}); - -for (var row of document.querySelector("fieldset.section table").rows) { - if (!row.querySelector(".remove a")) { - continue; - } - - add_row_listeners(row); -} From 1d72f2a25420c9d6bc92cecc0652b2f46af827bd Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Tue, 25 Jun 2024 16:29:58 -0400 Subject: [PATCH 08/21] fix ValueError when adding new instance via api --- relay/views/api.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/relay/views/api.py b/relay/views/api.py index 3bdc822..7511851 100644 --- a/relay/views/api.py +++ b/relay/views/api.py @@ -226,8 +226,9 @@ class Inbox(View): pass row = conn.put_inbox( - data['domain'], - actor = data.get('actor'), + domain = data['domain'], + actor = data['actor'], + inbox = data.get('inbox'), software = data.get('software'), followid = data.get('followid') ) From 5765753b59b70ac8e1625cd9d771e49059af2adb Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Tue, 25 Jun 2024 16:53:38 -0400 Subject: [PATCH 09/21] modify dependencies --- pyproject.toml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a3c9410..2207bf7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,19 +16,19 @@ classifiers = [ "Programming Language :: Python :: 3.12", ] dependencies = [ - "activitypub-utils == 0.3.1", + "activitypub-utils >= 0.3.1, < 0.4.0", "aiohttp >= 3.9.5", "aiohttp-swagger[performance] == 1.0.16", "argon2-cffi == 23.1.0", - "barkshark-lib >= 0.2.0-rc1", - "barkshark-sql == 0.1.4-1", - "click >= 8.1.2", + "barkshark-lib >= 0.1.4, < 0.2.0", + "barkshark-sql >= 0.2.0-rc1, < 0.3.0", + "click == 8.1.2", "hiredis == 2.3.2", "idna == 3.4", "jinja2-haml == 0.3.5", "markdown == 3.6", "platformdirs == 4.2.2", - "pyyaml >= 6.0", + "pyyaml == 6.0", "redis == 5.0.5", "importlib-resources == 6.4.0; python_version < '3.9'" ] @@ -53,7 +53,7 @@ dev = [ "mypy == 1.10.0", "pyinstaller == 6.8.0", "watchdog == 4.0.1", - "typing-extensions >= 4.12.2; python_version < '3.11.0'" + "typing-extensions == 4.12.2; python_version < '3.11.0'" ] [tool.setuptools] From 5217516c8a5c6b00711a4784eea4caf26801de4d Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Tue, 25 Jun 2024 16:54:04 -0400 Subject: [PATCH 10/21] move Self imports to typing block --- relay/config.py | 13 ++++++++----- relay/database/config.py | 11 ++++++----- relay/logger.py | 13 ++++++++----- relay/misc.py | 12 ++++++------ relay/views/base.py | 8 ++++---- 5 files changed, 32 insertions(+), 25 deletions(-) diff --git a/relay/config.py b/relay/config.py index 7e86ef7..eccc1ab 100644 --- a/relay/config.py +++ b/relay/config.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import getpass import os import platform @@ -6,15 +8,16 @@ import yaml from dataclasses import asdict, dataclass, fields from pathlib import Path from platformdirs import user_config_dir -from typing import Any +from typing import TYPE_CHECKING, Any from .misc import IS_DOCKER -try: - from typing import Self +if TYPE_CHECKING: + try: + from typing import Self -except ImportError: - from typing_extensions import Self + except ImportError: + from typing_extensions import Self if platform.system() == 'Windows': diff --git a/relay/database/config.py b/relay/database/config.py index 6effbb9..2be3ecc 100644 --- a/relay/database/config.py +++ b/relay/database/config.py @@ -5,16 +5,17 @@ from __future__ import annotations from bsql import Row from collections.abc import Callable, Sequence from dataclasses import Field, asdict, dataclass, fields -from typing import Any +from typing import TYPE_CHECKING, Any from .. import logger as logging from ..misc import boolean -try: - from typing import Self +if TYPE_CHECKING: + try: + from typing import Self -except ImportError: - from typing_extensions import Self + except ImportError: + from typing_extensions import Self THEMES = { diff --git a/relay/logger.py b/relay/logger.py index f1a1bd7..f4ef1f7 100644 --- a/relay/logger.py +++ b/relay/logger.py @@ -1,15 +1,18 @@ +from __future__ import annotations + import logging import os from enum import IntEnum from pathlib import Path -from typing import Any, Protocol +from typing import TYPE_CHECKING, Any, Protocol -try: - from typing import Self +if TYPE_CHECKING: + try: + from typing import Self -except ImportError: - from typing_extensions import Self + except ImportError: + from typing_extensions import Self class LoggingMethod(Protocol): diff --git a/relay/misc.py b/relay/misc.py index 9e8f035..37764e7 100644 --- a/relay/misc.py +++ b/relay/misc.py @@ -19,15 +19,15 @@ try: except ImportError: from importlib_resources import files as pkgfiles # type: ignore -try: - from typing import Self - -except ImportError: - from typing_extensions import Self - if TYPE_CHECKING: from .application import Application + try: + from typing import Self + + except ImportError: + from typing_extensions import Self + T = TypeVar('T') ResponseType = TypedDict('ResponseType', { diff --git a/relay/views/base.py b/relay/views/base.py index 350016c..64f792e 100644 --- a/relay/views/base.py +++ b/relay/views/base.py @@ -21,11 +21,11 @@ if TYPE_CHECKING: from ..application import Application from ..template import Template -try: - from typing import Self + try: + from typing import Self -except ImportError: - from typing_extensions import Self + except ImportError: + from typing_extensions import Self HandlerCallback = Callable[[Request], Awaitable[Response]] From e8b3a210a9c3447aa22b117c8038fcd12da8a80d Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Tue, 2 Jul 2024 23:51:37 -0400 Subject: [PATCH 11/21] forgot to add functions.js --- relay/frontend/static/functions.js | 862 +++++++++++++++++++++++++++++ 1 file changed, 862 insertions(+) create mode 100644 relay/frontend/static/functions.js diff --git a/relay/frontend/static/functions.js b/relay/frontend/static/functions.js new file mode 100644 index 0000000..b0e4db5 --- /dev/null +++ b/relay/frontend/static/functions.js @@ -0,0 +1,862 @@ +// toast notifications + +const notifications = document.querySelector("#notifications") + + +function remove_toast(toast) { + toast.classList.add("hide"); + + if (toast.timeoutId) { + clearTimeout(toast.timeoutId); + } + + setTimeout(() => toast.remove(), 300); +} + +function toast(text, type="error", timeout=5) { + const toast = document.createElement("li"); + toast.className = `section ${type}` + toast.innerHTML = `${text}✖` + + toast.querySelector("a").addEventListener("click", async (event) => { + event.preventDefault(); + await remove_toast(toast); + }); + + notifications.appendChild(toast); + toast.timeoutId = setTimeout(() => remove_toast(toast), timeout * 1000); +} + + +// menu + +const body = document.getElementById("container") +const menu = document.getElementById("menu"); +const menu_open = document.querySelector("#menu-open i"); +const menu_close = document.getElementById("menu-close"); + + +function toggle_menu() { + let new_value = menu.attributes.visible.nodeValue === "true" ? "false" : "true"; + menu.attributes.visible.nodeValue = new_value; +} + + +menu_open.addEventListener("click", toggle_menu); +menu_close.addEventListener("click", toggle_menu); + +body.addEventListener("click", (event) => { + if (event.target === menu_open) { + return; + } + + menu.attributes.visible.nodeValue = "false"; +}); + +for (const elem of document.querySelectorAll("#menu-open div")) { + elem.addEventListener("click", toggle_menu); +} + + +// misc + +function get_date_string(date) { + var year = date.getUTCFullYear().toString(); + var month = (date.getUTCMonth() + 1).toString().padStart(2, "0"); + var day = date.getUTCDate().toString().padStart(2, "0"); + + return `${year}-${month}-${day}`; +} + + +function append_table_row(table, row_name, row) { + var table_row = table.insertRow(-1); + table_row.id = row_name; + + index = 0; + + for (var prop in row) { + if (Object.prototype.hasOwnProperty.call(row, prop)) { + var cell = table_row.insertCell(index); + cell.className = prop; + cell.innerHTML = row[prop]; + + index += 1; + } + } + + return table_row; +} + + +async function request(method, path, body = null) { + var headers = { + "Accept": "application/json" + } + + if (body !== null) { + headers["Content-Type"] = "application/json" + body = JSON.stringify(body) + } + + const response = await fetch("/api/" + path, { + method: method, + mode: "cors", + cache: "no-store", + redirect: "follow", + body: body, + headers: headers + }); + + const message = await response.json(); + + if (Object.hasOwn(message, "error")) { + throw new Error(message.error); + } + + if (Array.isArray(message)) { + message.forEach((msg) => { + if (Object.hasOwn(msg, "created")) { + msg.created = new Date(msg.created); + } + }); + + } else { + if (Object.hasOwn(message, "created")) { + message.created = new Date(message.created); + } + } + + return message; +} + +// page functions + +function page_config() { + const elems = [ + document.querySelector("#name"), + document.querySelector("#note"), + document.querySelector("#theme"), + document.querySelector("#log-level"), + document.querySelector("#whitelist-enabled"), + document.querySelector("#approval-required") + ] + + + async function handle_config_change(event) { + params = { + key: event.target.id, + value: event.target.type === "checkbox" ? event.target.checked : event.target.value + } + + try { + await request("POST", "v1/config", params); + + } catch (error) { + toast(error); + return; + } + + if (params.key === "name") { + document.querySelector("#header .title").innerHTML = params.value; + document.querySelector("title").innerHTML = params.value; + } + + if (params.key === "theme") { + document.querySelector("link.theme").href = `/theme/${params.value}.css`; + } + + toast("Updated config", "message"); + } + + + document.querySelector("#name").addEventListener("keydown", async (event) => { + if (event.which === 13) { + await handle_config_change(event); + } + }); + + document.querySelector("#note").addEventListener("keydown", async (event) => { + if (event.which === 13 && event.ctrlKey) { + await handle_config_change(event); + } + }); + + for (const elem of elems) { + elem.addEventListener("change", handle_config_change); + } +} + + +function page_domain_ban() { + function create_ban_object(domain, reason, note) { + var text = '
\n'; + text += `${domain}\n`; + text += '
\n'; + text += `\n`; + text += `\n`; + text += `\n`; + text += `\n`; + text += ``; + text += '
'; + + return text; + } + + + function add_row_listeners(row) { + row.querySelector(".update-ban").addEventListener("click", async (event) => { + await update_ban(row.id); + }); + + row.querySelector(".remove a").addEventListener("click", async (event) => { + event.preventDefault(); + await unban(row.id); + }); + } + + + async function ban() { + var table = document.querySelector("table"); + var elems = { + domain: document.getElementById("new-domain"), + reason: document.getElementById("new-reason"), + note: document.getElementById("new-note") + } + + var values = { + domain: elems.domain.value.trim(), + reason: elems.reason.value.trim(), + note: elems.note.value.trim() + } + + if (values.domain === "") { + toast("Domain is required"); + return; + } + + try { + var ban = await request("POST", "v1/domain_ban", values); + + } catch (err) { + toast(err); + return + } + + var row = append_table_row(document.querySelector("table"), ban.domain, { + domain: create_ban_object(ban.domain, ban.reason, ban.note), + date: get_date_string(ban.created), + remove: `
` + }); + + add_row_listeners(row); + + elems.domain.value = null; + elems.reason.value = null; + elems.note.value = null; + + document.querySelector("details.section").open = false; + toast("Banned domain", "message"); + } + + + async function update_ban(domain) { + var row = document.getElementById(domain); + + var elems = { + "reason": row.querySelector("textarea.reason"), + "note": row.querySelector("textarea.note") + } + + var values = { + "domain": domain, + "reason": elems.reason.value, + "note": elems.note.value + } + + try { + await request("PATCH", "v1/domain_ban", values) + + } catch (error) { + toast(error); + return; + } + + row.querySelector("details").open = false; + toast("Updated baned domain", "message"); + } + + + async function unban(domain) { + try { + await request("DELETE", "v1/domain_ban", {"domain": domain}); + + } catch (error) { + toast(error); + return; + } + + document.getElementById(domain).remove(); + toast("Unbanned domain", "message"); + } + + + document.querySelector("#new-ban").addEventListener("click", async (event) => { + await ban(); + }); + + for (var elem of document.querySelectorAll("#add-item input")) { + elem.addEventListener("keydown", async (event) => { + if (event.which === 13) { + await ban(); + } + }); + } + + for (var row of document.querySelector("fieldset.section table").rows) { + if (!row.querySelector(".update-ban")) { + continue; + } + + add_row_listeners(row); + } +} + + +function page_instance() { + function add_instance_listeners(row) { + row.querySelector(".remove a").addEventListener("click", async (event) => { + event.preventDefault(); + await del_instance(row.id); + }); + } + + + function add_request_listeners(row) { + row.querySelector(".approve a").addEventListener("click", async (event) => { + event.preventDefault(); + await req_response(row.id, true); + }); + + row.querySelector(".deny a").addEventListener("click", async (event) => { + event.preventDefault(); + await req_response(row.id, false); + }); + } + + + async function add_instance() { + var elems = { + actor: document.getElementById("new-actor"), + inbox: document.getElementById("new-inbox"), + followid: document.getElementById("new-followid"), + software: document.getElementById("new-software") + } + + var values = { + actor: elems.actor.value.trim(), + inbox: elems.inbox.value.trim(), + followid: elems.followid.value.trim(), + software: elems.software.value.trim() + } + + if (values.actor === "") { + toast("Actor is required"); + return; + } + + try { + var instance = await request("POST", "v1/instance", values); + + } catch (err) { + toast(err); + return + } + + row = append_table_row(document.getElementById("instances"), instance.domain, { + domain: `${instance.domain}`, + software: instance.software, + date: get_date_string(instance.created), + remove: `` + }); + + add_instance_listeners(row); + + elems.actor.value = null; + elems.inbox.value = null; + elems.followid.value = null; + elems.software.value = null; + + document.querySelector("details.section").open = false; + toast("Added instance", "message"); + } + + + async function del_instance(domain) { + try { + await request("DELETE", "v1/instance", {"domain": domain}); + + } catch (error) { + toast(error); + return; + } + + document.getElementById(domain).remove(); + } + + + async function req_response(domain, accept) { + params = { + "domain": domain, + "accept": accept + } + + try { + await request("POST", "v1/request", params); + + } catch (error) { + toast(error); + return; + } + + document.getElementById(domain).remove(); + + if (document.getElementById("requests").rows.length < 2) { + document.querySelector("fieldset.requests").remove() + } + + if (!accept) { + toast("Denied instance request", "message"); + return; + } + + instances = await request("GET", `v1/instance`, null); + instances.forEach((instance) => { + if (instance.domain === domain) { + row = append_table_row(document.getElementById("instances"), instance.domain, { + domain: `${instance.domain}`, + software: instance.software, + date: get_date_string(instance.created), + remove: `` + }); + + add_instance_listeners(row); + } + }); + + toast("Accepted instance request", "message"); + } + + + document.querySelector("#add-instance").addEventListener("click", async (event) => { + await add_instance(); + }) + + for (var elem of document.querySelectorAll("#add-item input")) { + elem.addEventListener("keydown", async (event) => { + if (event.which === 13) { + await add_instance(); + } + }); + } + + for (var row of document.querySelector("#instances").rows) { + if (!row.querySelector(".remove a")) { + continue; + } + + add_instance_listeners(row); + } + + if (document.querySelector("#requests")) { + for (var row of document.querySelector("#requests").rows) { + if (!row.querySelector(".approve a")) { + continue; + } + + add_request_listeners(row); + } + } +} + + +function page_login() { + const fields = { + username: document.querySelector("#username"), + password: document.querySelector("#password") + } + + async function login(event) { + const values = { + username: fields.username.value.trim(), + password: fields.password.value.trim() + } + + if (values.username === "" | values.password === "") { + toast("Username and/or password field is blank"); + return; + } + + try { + await request("POST", "v1/token", values); + + } catch (error) { + toast(error); + return; + } + + document.location = "/"; + } + + + document.querySelector("#username").addEventListener("keydown", async (event) => { + if (event.which === 13) { + fields.password.focus(); + fields.password.select(); + } + }); + + document.querySelector("#password").addEventListener("keydown", async (event) => { + if (event.which === 13) { + await login(event); + } + }); + + document.querySelector(".submit").addEventListener("click", login); +} + + +function page_software_ban() { + function create_ban_object(name, reason, note) { + var text = '
\n'; + text += `${name}\n`; + text += '
\n'; + text += `\n`; + text += `\n`; + text += `\n`; + text += `\n`; + text += ``; + text += '
'; + + return text; + } + + + function add_row_listeners(row) { + row.querySelector(".update-ban").addEventListener("click", async (event) => { + await update_ban(row.id); + }); + + row.querySelector(".remove a").addEventListener("click", async (event) => { + event.preventDefault(); + await unban(row.id); + }); + } + + + async function ban() { + var elems = { + name: document.getElementById("new-name"), + reason: document.getElementById("new-reason"), + note: document.getElementById("new-note") + } + + var values = { + name: elems.name.value.trim(), + reason: elems.reason.value, + note: elems.note.value + } + + if (values.name === "") { + toast("Domain is required"); + return; + } + + try { + var ban = await request("POST", "v1/software_ban", values); + + } catch (err) { + toast(err); + return + } + + var row = append_table_row(document.getElementById("bans"), ban.name, { + name: create_ban_object(ban.name, ban.reason, ban.note), + date: get_date_string(ban.created), + remove: `` + }); + + add_row_listeners(row); + + elems.name.value = null; + elems.reason.value = null; + elems.note.value = null; + + document.querySelector("details.section").open = false; + toast("Banned software", "message"); + } + + + async function update_ban(name) { + var row = document.getElementById(name); + + var elems = { + "reason": row.querySelector("textarea.reason"), + "note": row.querySelector("textarea.note") + } + + var values = { + "name": name, + "reason": elems.reason.value, + "note": elems.note.value + } + + try { + await request("PATCH", "v1/software_ban", values) + + } catch (error) { + toast(error); + return; + } + + row.querySelector("details").open = false; + toast("Updated software ban", "message"); + } + + + async function unban(name) { + try { + await request("DELETE", "v1/software_ban", {"name": name}); + + } catch (error) { + toast(error); + return; + } + + document.getElementById(name).remove(); + toast("Unbanned software", "message"); + } + + + document.querySelector("#new-ban").addEventListener("click", async (event) => { + await ban(); + }); + + for (var elem of document.querySelectorAll("#add-item input")) { + elem.addEventListener("keydown", async (event) => { + if (event.which === 13) { + await ban(); + } + }); + } + + for (var elem of document.querySelectorAll("#add-item textarea")) { + elem.addEventListener("keydown", async (event) => { + if (event.which === 13 && event.ctrlKey) { + await ban(); + } + }); + } + + for (var row of document.querySelector("#bans").rows) { + if (!row.querySelector(".update-ban")) { + continue; + } + + add_row_listeners(row); + } +} + + +function page_user() { + function add_row_listeners(row) { + row.querySelector(".remove a").addEventListener("click", async (event) => { + event.preventDefault(); + await del_user(row.id); + }); + } + + + async function add_user() { + var elems = { + username: document.getElementById("new-username"), + password: document.getElementById("new-password"), + password2: document.getElementById("new-password2"), + handle: document.getElementById("new-handle") + } + + var values = { + username: elems.username.value.trim(), + password: elems.password.value.trim(), + password2: elems.password2.value.trim(), + handle: elems.handle.value.trim() + } + + if (values.username === "" | values.password === "" | values.password2 === "") { + toast("Username, password, and password2 are required"); + return; + } + + if (values.password !== values.password2) { + toast("Passwords do not match"); + return; + } + + try { + var user = await request("POST", "v1/user", values); + + } catch (err) { + toast(err); + return + } + + var row = append_table_row(document.querySelector("fieldset.section table"), user.username, { + domain: user.username, + handle: user.handle ? self.handle : "n/a", + date: get_date_string(user.created), + remove: `` + }); + + add_row_listeners(row); + + elems.username.value = null; + elems.password.value = null; + elems.password2.value = null; + elems.handle.value = null; + + document.querySelector("details.section").open = false; + toast("Created user", "message"); + } + + + async function del_user(username) { + try { + await request("DELETE", "v1/user", {"username": username}); + + } catch (error) { + toast(error); + return; + } + + document.getElementById(username).remove(); + toast("Deleted user", "message"); + } + + + document.querySelector("#new-user").addEventListener("click", async (event) => { + await add_user(); + }); + + for (var elem of document.querySelectorAll("#add-item input")) { + elem.addEventListener("keydown", async (event) => { + if (event.which === 13) { + await add_user(); + } + }); + } + + for (var row of document.querySelector("#users").rows) { + if (!row.querySelector(".remove a")) { + continue; + } + + add_row_listeners(row); + } +} + + +function page_whitelist() { + function add_row_listeners(row) { + row.querySelector(".remove a").addEventListener("click", async (event) => { + event.preventDefault(); + await del_whitelist(row.id); + }); + } + + + async function add_whitelist() { + var domain_elem = document.getElementById("new-domain"); + var domain = domain_elem.value.trim(); + + if (domain === "") { + toast("Domain is required"); + return; + } + + try { + var item = await request("POST", "v1/whitelist", {"domain": domain}); + + } catch (err) { + toast(err); + return; + } + + var row = append_table_row(document.getElementById("whitelist"), item.domain, { + domain: item.domain, + date: get_date_string(item.created), + remove: `` + }); + + add_row_listeners(row); + + domain_elem.value = null; + document.querySelector("details.section").open = false; + toast("Added domain", "message"); + } + + + async function del_whitelist(domain) { + try { + await request("DELETE", "v1/whitelist", {"domain": domain}); + + } catch (error) { + toast(error); + return; + } + + document.getElementById(domain).remove(); + toast("Removed domain", "message"); + } + + + document.querySelector("#new-item").addEventListener("click", async (event) => { + await add_whitelist(); + }); + + document.querySelector("#add-item").addEventListener("keydown", async (event) => { + if (event.which === 13) { + await add_whitelist(); + } + }); + + for (var row of document.querySelector("fieldset.section table").rows) { + if (!row.querySelector(".remove a")) { + continue; + } + + add_row_listeners(row); + } +} + + +if (location.pathname.startsWith("/admin/config")) { + page_config(); + +} else if (location.pathname.startsWith("/admin/domain_bans")) { + page_domain_ban(); + +} else if (location.pathname.startsWith("/admin/instances")) { + page_instance(); + +} else if (location.pathname.startsWith("/admin/login")) { + page_login(); + +} else if (location.pathname.startsWith("/admin/software_bans")) { + page_software_ban(); + +} else if (location.pathname.startsWith("/admin/users")) { + page_user(); + +} else if (location.pathname.startsWith("/admin/whitelist")) { + page_whitelist(); +} From b22b5bbefaa1b6cf13deaeb65396b135dc3fb192 Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Wed, 3 Jul 2024 00:59:59 -0400 Subject: [PATCH 12/21] ensure the relay can run on python >= 3.10 --- dev.py | 49 ++++++++++++++++---------- pyproject.toml | 24 ++++++------- relay/cache.py | 83 +++++++++++++++++++++++++++++++-------------- relay/config.py | 8 ++--- relay/logger.py | 6 +--- relay/misc.py | 14 ++------ relay/template.py | 5 ++- relay/views/api.py | 3 +- relay/views/base.py | 8 +---- relay/workers.py | 18 +++++----- 10 files changed, 119 insertions(+), 99 deletions(-) diff --git a/dev.py b/dev.py index 38499d8..114073f 100755 --- a/dev.py +++ b/dev.py @@ -1,25 +1,38 @@ #!/usr/bin/env python3 -import click import platform import shutil import subprocess import sys import time -import tomllib from datetime import datetime, timedelta +from importlib.util import find_spec from pathlib import Path -from relay import __version__, logger as logging from tempfile import TemporaryDirectory from typing import Any, Sequence try: - from watchdog.observers import Observer - from watchdog.events import FileSystemEvent, PatternMatchingEventHandler + import tomllib except ImportError: - class PatternMatchingEventHandler: # type: ignore - pass + if find_spec("toml") is None: + subprocess.run([sys.executable, "-m", "pip", "install", "toml"]) + + import toml as tomllib # type: ignore[no-redef] + +if None in [find_spec("click"), find_spec("watchdog")]: + CMD = [sys.executable, "-m", "pip", "install", "click >= 8.1.0", "watchdog >= 4.0.0"] + PROC = subprocess.run(CMD, check = False) + + if PROC.returncode != 0: + sys.exit() + + print("Successfully installed dependencies") + +import click + +from watchdog.observers import Observer +from watchdog.events import FileSystemEvent, PatternMatchingEventHandler REPO = Path(__file__).parent @@ -37,13 +50,11 @@ def cli() -> None: @cli.command('install') @click.option('--no-dev', '-d', is_flag = True, help = 'Do not install development dependencies') def cli_install(no_dev: bool) -> None: - with open('pyproject.toml', 'rb') as fd: - data = tomllib.load(fd) + with open('pyproject.toml', 'r', encoding = 'utf-8') as fd: + data = tomllib.loads(fd.read()) deps = data['project']['dependencies'] - - if not no_dev: - deps.extend(data['project']['optional-dependencies']['dev']) + deps.extend(data['project']['optional-dependencies']['dev']) subprocess.run([sys.executable, '-m', 'pip', 'install', '-U', *deps], check = False) @@ -60,7 +71,7 @@ def cli_lint(path: Path, watch: bool) -> None: return flake8 = [sys.executable, '-m', 'flake8', "dev.py", str(path)] - mypy = [sys.executable, '-m', 'mypy', "dev.py", str(path)] + mypy = [sys.executable, '-m', 'mypy', '--python-version', '3.12', 'dev.py', str(path)] click.echo('----- flake8 -----') subprocess.run(flake8) @@ -89,6 +100,8 @@ def cli_clean() -> None: @cli.command('build') def cli_build() -> None: + from relay import __version__ + with TemporaryDirectory() as tmp: arch = 'amd64' if sys.maxsize >= 2**32 else 'i386' cmd = [ @@ -171,7 +184,7 @@ class WatchHandler(PatternMatchingEventHandler): if proc.poll() is not None: continue - logging.info(f'Terminating process {proc.pid}') + print(f'Terminating process {proc.pid}') proc.terminate() sec = 0.0 @@ -180,11 +193,11 @@ class WatchHandler(PatternMatchingEventHandler): sec += 0.1 if sec >= 5: - logging.error('Failed to terminate. Killing process...') + print('Failed to terminate. Killing process...') proc.kill() break - logging.info('Process terminated') + print('Process terminated') def run_procs(self, restart: bool = False) -> None: @@ -200,13 +213,13 @@ class WatchHandler(PatternMatchingEventHandler): self.procs = [] for cmd in self.commands: - logging.info('Running command: %s', ' '.join(cmd)) + print('Running command:', ' '.join(cmd)) subprocess.run(cmd) else: self.procs = list(subprocess.Popen(cmd) for cmd in self.commands) pids = (str(proc.pid) for proc in self.procs) - logging.info('Started processes with PIDs: %s', ', '.join(pids)) + print('Started processes with PIDs:', ', '.join(pids)) def on_any_event(self, event: FileSystemEvent) -> None: diff --git a/pyproject.toml b/pyproject.toml index 2207bf7..b1249d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,30 +9,27 @@ license = {text = "AGPLv3"} classifiers = [ "Environment :: Console", "License :: OSI Approved :: GNU Affero General Public License v3", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.12" ] dependencies = [ - "activitypub-utils >= 0.3.1, < 0.4.0", + "activitypub-utils >= 0.3.1.post1, < 0.4.0", "aiohttp >= 3.9.5", "aiohttp-swagger[performance] == 1.0.16", "argon2-cffi == 23.1.0", - "barkshark-lib >= 0.1.4, < 0.2.0", - "barkshark-sql >= 0.2.0-rc1, < 0.3.0", + "barkshark-lib >= 0.1.5rc1, < 0.2.0", + "barkshark-sql >= 0.2.0rc2, < 0.3.0", "click == 8.1.2", "hiredis == 2.3.2", "idna == 3.4", "jinja2-haml == 0.3.5", "markdown == 3.6", "platformdirs == 4.2.2", - "pyyaml == 6.0", - "redis == 5.0.5", - "importlib-resources == 6.4.0; python_version < '3.9'" + "pyyaml == 6.0.1", + "redis == 5.0.7" ] -requires-python = ">=3.8" +requires-python = ">=3.10" dynamic = ["version"] [project.readme] @@ -49,11 +46,10 @@ activityrelay = "relay.manage:main" [project.optional-dependencies] dev = [ - "flake8 == 7.0.0", - "mypy == 1.10.0", + "flake8 == 7.1.0", + "mypy == 1.10.1", "pyinstaller == 6.8.0", - "watchdog == 4.0.1", - "typing-extensions == 4.12.2; python_version < '3.11.0'" + "watchdog == 4.0.1" ] [tool.setuptools] diff --git a/relay/cache.py b/relay/cache.py index da87cc5..1bf20d7 100644 --- a/relay/cache.py +++ b/relay/cache.py @@ -4,12 +4,13 @@ import json import os from abc import ABC, abstractmethod +from blib import Date from bsql import Database, Row from collections.abc import Callable, Iterator from dataclasses import asdict, dataclass -from datetime import datetime, timedelta, timezone +from datetime import timedelta from redis import Redis -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypedDict from .database import Connection, get_database from .misc import Message, boolean @@ -31,6 +32,14 @@ CONVERTERS: dict[str, tuple[SerializerCallback, DeserializerCallback]] = { } +class RedisConnectType(TypedDict): + client_name: str + decode_responses: bool + username: str | None + password: str | None + db: int + + def get_cache(app: Application) -> Cache: return BACKENDS[app.config.ca_type](app) @@ -57,12 +66,11 @@ class Item: key: str value: Any value_type: str - updated: datetime + updated: Date def __post_init__(self) -> None: - if isinstance(self.updated, str): # type: ignore[unreachable] - self.updated = datetime.fromisoformat(self.updated) # type: ignore[unreachable] + self.updated = Date.parse(self.updated) @classmethod @@ -70,14 +78,11 @@ class Item: data = cls(*args) data.value = deserialize_value(data.value, data.value_type) - if not isinstance(data.updated, datetime): - data.updated = datetime.fromtimestamp(data.updated, tz = timezone.utc) # type: ignore - return data def older_than(self, hours: int) -> bool: - delta = datetime.now(tz = timezone.utc) - self.updated + delta = Date.new_utc() - self.updated return (delta.total_seconds()) > hours * 3600 @@ -206,7 +211,7 @@ class SqlCache(Cache): 'key': key, 'value': serialize_value(value, value_type), 'type': value_type, - 'date': datetime.now(tz = timezone.utc) + 'date': Date.new_utc() } with self._db.session(True) as conn: @@ -236,7 +241,7 @@ class SqlCache(Cache): if self._db is None: raise RuntimeError("Database has not been setup") - limit = datetime.now(tz = timezone.utc) - timedelta(days = days) + limit = Date.new_utc() - timedelta(days = days) params = {"limit": limit.timestamp()} with self._db.session(True) as conn: @@ -280,7 +285,7 @@ class RedisCache(Cache): def __init__(self, app: Application): Cache.__init__(self, app) - self._rd: Redis = None # type: ignore + self._rd: Redis | None = None @property @@ -293,28 +298,38 @@ class RedisCache(Cache): def get(self, namespace: str, key: str) -> Item: + if self._rd is None: + raise ConnectionError("Not connected") + key_name = self.get_key_name(namespace, key) if not (raw_value := self._rd.get(key_name)): raise KeyError(f'{namespace}:{key}') - value_type, updated, value = raw_value.split(':', 2) # type: ignore + value_type, updated, value = raw_value.split(':', 2) # type: ignore[union-attr] + return Item.from_data( namespace, key, value, value_type, - datetime.fromtimestamp(float(updated), tz = timezone.utc) + Date.parse(float(updated)) ) def get_keys(self, namespace: str) -> Iterator[str]: + if self._rd is None: + raise ConnectionError("Not connected") + for key in self._rd.scan_iter(self.get_key_name(namespace, '*')): *_, key_name = key.split(':', 2) yield key_name def get_namespaces(self) -> Iterator[str]: + if self._rd is None: + raise ConnectionError("Not connected") + namespaces = [] for key in self._rd.scan_iter(f'{self.prefix}:*'): @@ -326,7 +341,10 @@ class RedisCache(Cache): def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> Item: - date = datetime.now(tz = timezone.utc).timestamp() + if self._rd is None: + raise ConnectionError("Not connected") + + date = Date.new_utc().timestamp() value = serialize_value(value, value_type) self._rd.set( @@ -338,11 +356,17 @@ class RedisCache(Cache): def delete(self, namespace: str, key: str) -> None: + if self._rd is None: + raise ConnectionError("Not connected") + self._rd.delete(self.get_key_name(namespace, key)) def delete_old(self, days: int = 14) -> None: - limit = datetime.now(tz = timezone.utc) - timedelta(days = days) + if self._rd is None: + raise ConnectionError("Not connected") + + limit = Date.new_utc() - timedelta(days = days) for full_key in self._rd.scan_iter(f'{self.prefix}:*'): _, namespace, key = full_key.split(':', 2) @@ -353,14 +377,17 @@ class RedisCache(Cache): def clear(self) -> None: + if self._rd is None: + raise ConnectionError("Not connected") + self._rd.delete(f"{self.prefix}:*") def setup(self) -> None: - if self._rd: + if self._rd is not None: return - options = { + options: RedisConnectType = { 'client_name': f'ActivityRelay_{self.app.config.domain}', 'decode_responses': True, 'username': self.app.config.rd_user, @@ -369,18 +396,22 @@ class RedisCache(Cache): } if os.path.exists(self.app.config.rd_host): - options['unix_socket_path'] = self.app.config.rd_host + self._rd = Redis( + unix_socket_path = self.app.config.rd_host, + **options + ) + return - else: - options['host'] = self.app.config.rd_host - options['port'] = self.app.config.rd_port - - self._rd = Redis(**options) # type: ignore + self._rd = Redis( + host = self.app.config.rd_host, + port = self.app.config.rd_port, + **options + ) def close(self) -> None: if not self._rd: return - self._rd.close() # type: ignore - self._rd = None # type: ignore + self._rd.close() # type: ignore[no-untyped-call] + self._rd = None diff --git a/relay/config.py b/relay/config.py index eccc1ab..5805f51 100644 --- a/relay/config.py +++ b/relay/config.py @@ -13,11 +13,7 @@ from typing import TYPE_CHECKING, Any from .misc import IS_DOCKER if TYPE_CHECKING: - try: - from typing import Self - - except ImportError: - from typing_extensions import Self + from typing import Self if platform.system() == 'Windows': @@ -84,7 +80,7 @@ class Config: def DEFAULT(cls: type[Self], key: str) -> str | int | None: for field in fields(cls): if field.name == key: - return field.default # type: ignore + return field.default # type: ignore[return-value] raise KeyError(key) diff --git a/relay/logger.py b/relay/logger.py index f4ef1f7..7caac9f 100644 --- a/relay/logger.py +++ b/relay/logger.py @@ -8,11 +8,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Protocol if TYPE_CHECKING: - try: - from typing import Self - - except ImportError: - from typing_extensions import Self + from typing import Self class LoggingMethod(Protocol): diff --git a/relay/misc.py b/relay/misc.py index 37764e7..6995bc4 100644 --- a/relay/misc.py +++ b/relay/misc.py @@ -9,25 +9,15 @@ import socket from aiohttp.web import Response as AiohttpResponse from collections.abc import Sequence from datetime import datetime +from importlib.resources import files as pkgfiles from pathlib import Path from typing import TYPE_CHECKING, Any, TypedDict, TypeVar from uuid import uuid4 -try: - from importlib.resources import files as pkgfiles - -except ImportError: - from importlib_resources import files as pkgfiles # type: ignore - if TYPE_CHECKING: + from typing import Self from .application import Application - try: - from typing import Self - - except ImportError: - from typing_extensions import Self - T = TypeVar('T') ResponseType = TypedDict('ResponseType', { diff --git a/relay/template.py b/relay/template.py index ef25f92..7e3f657 100644 --- a/relay/template.py +++ b/relay/template.py @@ -20,6 +20,9 @@ if TYPE_CHECKING: class Template(Environment): + _render_markdown: Callable[[str], str] + + def __init__(self, app: Application): Environment.__init__(self, autoescape = True, @@ -56,7 +59,7 @@ class Template(Environment): def render_markdown(self, text: str) -> str: - return self._render_markdown(text) # type: ignore + return self._render_markdown(text) class MarkdownExtension(Extension): diff --git a/relay/views/api.py b/relay/views/api.py index 7511851..73b6a16 100644 --- a/relay/views/api.py +++ b/relay/views/api.py @@ -13,7 +13,8 @@ from ..database import ConfigData from ..misc import Message, Response, boolean, get_app -ALLOWED_HEADERS = { +DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob' +ALLOWED_HEADERS: set[str] = { 'accept', 'authorization', 'content-type' diff --git a/relay/views/base.py b/relay/views/base.py index 64f792e..e102896 100644 --- a/relay/views/base.py +++ b/relay/views/base.py @@ -18,18 +18,12 @@ from ..http_client import HttpClient from ..misc import Response, get_app if TYPE_CHECKING: + from typing import Self from ..application import Application from ..template import Template - try: - from typing import Self - - except ImportError: - from typing_extensions import Self HandlerCallback = Callable[[Request], Awaitable[Response]] - - VIEWS: list[tuple[str, type[View]]] = [] diff --git a/relay/workers.py b/relay/workers.py index 4b57409..9fee7b1 100644 --- a/relay/workers.py +++ b/relay/workers.py @@ -1,14 +1,17 @@ +from __future__ import annotations + import asyncio import traceback -import typing from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError from asyncio.exceptions import TimeoutError as AsyncTimeoutError from dataclasses import dataclass from multiprocessing import Event, Process, Queue, Value +from multiprocessing.queues import Queue as QueueType +from multiprocessing.sharedctypes import Synchronized from multiprocessing.synchronize import Event as EventType from pathlib import Path -from queue import Empty, Queue as QueueType +from queue import Empty from urllib.parse import urlparse from . import application, logger as logging @@ -16,9 +19,6 @@ from .database.schema import Instance from .http_client import HttpClient from .misc import IS_WINDOWS, Message, get_app -if typing.TYPE_CHECKING: - from .multiprocessing.synchronize import Syncronized - @dataclass class QueueItem: @@ -40,13 +40,13 @@ class PushWorker(Process): client: HttpClient - def __init__(self, queue: QueueType[QueueItem], log_level: "Syncronized[str]") -> None: + def __init__(self, queue: QueueType[QueueItem], log_level: Synchronized[int]) -> None: Process.__init__(self) self.queue: QueueType[QueueItem] = queue self.shutdown: EventType = Event() self.path: Path = get_app().config.path - self.log_level: "Syncronized[str]" = log_level + self.log_level: Synchronized[int] = log_level self._log_level_changed: EventType = Event() @@ -113,8 +113,8 @@ class PushWorker(Process): class PushWorkers(list[PushWorker]): def __init__(self, count: int) -> None: - self.queue: QueueType[QueueItem] = Queue() # type: ignore[assignment] - self._log_level: "Syncronized[str]" = Value("i", logging.get_level()) + self.queue: QueueType[QueueItem] = Queue() + self._log_level: Synchronized[int] = Value("i", logging.get_level()) self._count: int = count From f98ca54ab7c80e0a5ca0a96ae158f4c1258dd402 Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Thu, 4 Jul 2024 20:36:04 -0400 Subject: [PATCH 13/21] various changes * Add oauth login support * Add `HttpError` class * Add custom error handling * Use `blib.Date` class for (de)serializing db timestamp values * Add `db-maintenance` command * Rework middleware route checking * Fix fetching post data in api endpoints --- relay/application.py | 69 ++++- relay/data/statements.sql | 32 +++ relay/database/config.py | 12 +- relay/database/connection.py | 133 +++++++++- relay/database/schema.py | 109 +++++++- relay/frontend/base.haml | 18 +- relay/frontend/page/authorize_new.haml | 31 +++ relay/frontend/page/authorize_show.haml | 18 ++ relay/frontend/page/error.haml | 7 + relay/frontend/page/login.haml | 2 + relay/frontend/static/functions.js | 18 +- relay/frontend/static/style.css | 38 +++ relay/manage.py | 25 +- relay/misc.py | 37 ++- relay/template.py | 6 +- relay/views/activitypub.py | 25 +- relay/views/api.py | 318 +++++++++++++++--------- relay/views/base.py | 22 +- relay/views/frontend.py | 59 ++--- 19 files changed, 748 insertions(+), 231 deletions(-) create mode 100644 relay/frontend/page/authorize_new.haml create mode 100644 relay/frontend/page/authorize_show.haml create mode 100644 relay/frontend/page/error.haml diff --git a/relay/application.py b/relay/application.py index d852f29..6ab481b 100644 --- a/relay/application.py +++ b/relay/application.py @@ -4,11 +4,14 @@ import asyncio import multiprocessing import signal import time +import traceback +from Crypto.Random import get_random_bytes from aiohttp import web -from aiohttp.web import StaticResource +from aiohttp.web import HTTPException, StaticResource from aiohttp_swagger import setup_swagger from aputils.signer import Signer +from base64 import b64encode from bsql import Database from collections.abc import Awaitable, Callable from datetime import datetime, timedelta @@ -23,7 +26,8 @@ from .config import Config from .database import Connection, get_database from .database.schema import Instance from .http_client import HttpClient -from .misc import Message, Response, check_open_port, get_resource +from .misc import HttpError, Message, Response, check_open_port, get_resource +from .misc import JSON_PATHS, TOKEN_PATHS from .template import Template from .views import VIEWS from .views.api import handle_api_path @@ -53,9 +57,9 @@ class Application(web.Application): def __init__(self, cfgpath: Path | None, dev: bool = False): web.Application.__init__(self, middlewares = [ - handle_api_path, # type: ignore[list-item] + handle_response_headers, # type: ignore[list-item] handle_frontend_path, # type: ignore[list-item] - handle_response_headers # type: ignore[list-item] + handle_api_path # type: ignore[list-item] ] ) @@ -282,19 +286,70 @@ class CacheCleanupThread(Thread): self.running.clear() +def format_error(request: web.Request, error: HttpError) -> Response: + app: Application = request.app # type: ignore[assignment] + + if request.path.startswith(JSON_PATHS) or 'json' in request.headers.get('accept', ''): + return Response.new({'error': error.body}, error.status, ctype = 'json') + + else: + body = app.template.render('page/error.haml', request, e = error) + return Response.new(body, error.status, ctype = 'html') + + @web.middleware async def handle_response_headers( request: web.Request, handler: Callable[[web.Request], Awaitable[Response]]) -> Response: - resp = await handler(request) + request['hash'] = b64encode(get_random_bytes(16)).decode('ascii') + request['token'] = None + request['user'] = None + + app: Application = request.app # type: ignore[assignment] + + if request.path == "/" or request.path.startswith(TOKEN_PATHS): + with app.database.session() as conn: + if (token := request.headers.get('Authorization')) is not None: + token = token.replace('Bearer', '').strip() + + request['token'] = conn.get_app_by_token(token) + request['user'] = conn.get_user_by_app_token(token) + + elif (token := request.cookies.get('user-token')) is not None: + request['token'] = conn.get_token(token) + request['user'] = conn.get_user_by_token(token) + + try: + resp = await handler(request) + + except HttpError as e: + resp = format_error(request, e) + + except HTTPException as ae: + if ae.status == 404: + try: + text = (ae.text or "").split(":")[1].strip() + + except IndexError: + text = ae.text or "" + + resp = format_error(request, HttpError(ae.status, text)) + + else: + raise + + except Exception as e: + resp = format_error(request, HttpError(500, f'{type(e).__name__}: {str(e)}')) + traceback.print_exc() + resp.headers['Server'] = 'ActivityRelay' # Still have to figure out how csp headers work - if resp.content_type == 'text/html' and not request.path.startswith("/api"): + if resp.content_type == 'text/html': resp.headers['Content-Security-Policy'] = get_csp(request) - if not request.app['dev'] and request.path.endswith(('.css', '.js')): + if not request.app['dev'] and request.path.endswith(('.css', '.js', '.woff2')): # cache for 2 weeks resp.headers['Cache-Control'] = 'public,max-age=1209600,immutable' diff --git a/relay/data/statements.sql b/relay/data/statements.sql index f06d4b5..e8694ae 100644 --- a/relay/data/statements.sql +++ b/relay/data/statements.sql @@ -56,6 +56,14 @@ WHERE username = ( ); +-- name: get-user-by-app-token +SELECT * FROM users +WHERE username = ( + SELECT user FROM app + WHERE code = :code +); + + -- name: put-user INSERT INTO users (username, hash, handle, created) VALUES (:username, :hash, :handle, :created) @@ -67,6 +75,30 @@ DELETE FROM users WHERE username = :value or handle = :value; +-- name: get-app +SELECT * FROM app +WHERE client_id = :id and client_secret = :secret; + + +-- name: get-app-token +SELECT * FROM app +WHERE client_id = :id and client_secret = :secret and token = :token; + + +-- name: get-app-by-token +SELECT * FROM app +WHERE token = :token; + +-- name: del-app +DELETE FROM users +WHERE client_id = :id and client_secret = :secret; + + +-- name: del-app-token +DELETE FROM users +WHERE client_id = :id and client_secret = :secret and token = :token; + + -- name: get-token SELECT * FROM tokens WHERE code = :code; diff --git a/relay/database/config.py b/relay/database/config.py index 2be3ecc..3f3c7e0 100644 --- a/relay/database/config.py +++ b/relay/database/config.py @@ -11,11 +11,7 @@ from .. import logger as logging from ..misc import boolean if TYPE_CHECKING: - try: - from typing import Self - - except ImportError: - from typing_extensions import Self + from typing import Self THEMES = { @@ -77,7 +73,7 @@ CONFIG_CONVERT: dict[str, tuple[Callable[[Any], str], Callable[[str], Any]]] = { @dataclass() class ConfigData: - schema_version: int = 20240310 + schema_version: int = 20240625 private_key: str = '' approval_required: bool = False log_level: logging.LogLevel = logging.LogLevel.INFO @@ -115,11 +111,11 @@ class ConfigData: @classmethod def DEFAULT(cls: type[Self], key: str) -> str | int | bool: - return cls.FIELD(key.replace('-', '_')).default # type: ignore + return cls.FIELD(key.replace('-', '_')).default # type: ignore[return-value] @classmethod - def FIELD(cls: type[Self], key: str) -> Field[Any]: + def FIELD(cls: type[Self], key: str) -> Field[str | int | bool]: for field in fields(cls): if field.name == key.replace('-', '_'): return field diff --git a/relay/database/connection.py b/relay/database/connection.py index 006a907..3c973b8 100644 --- a/relay/database/connection.py +++ b/relay/database/connection.py @@ -1,6 +1,9 @@ from __future__ import annotations +import secrets + from argon2 import PasswordHasher +from blib import Date from bsql import Connection as SqlConnection, Row, Update from collections.abc import Iterator from datetime import datetime, timezone @@ -49,6 +52,40 @@ class Connection(SqlConnection): yield instance + def fix_timestamps(self) -> None: + for app in self.select('apps').all(schema.App): + data = {'created': app.created.timestamp(), 'accessed': app.accessed.timestamp()} + self.update('apps', data, client_id = app.client_id) + + for item in self.select('cache'): + data = {'updated': Date.parse(item['updated']).timestamp()} + self.update('cache', data, id = item['id']) + + for dban in self.select('domain_bans').all(schema.DomainBan): + data = {'created': dban.created.timestamp()} + self.update('domain_bans', data, domain = dban.domain) + + for instance in self.select('inboxes').all(schema.Instance): + data = {'created': instance.created.timestamp()} + self.update('inboxes', data, domain = instance.domain) + + for sban in self.select('software_bans').all(schema.SoftwareBan): + data = {'created': sban.created.timestamp()} + self.update('software_bans', data, name = sban.name) + + for token in self.select('tokens').all(schema.Token): + data = {'created': token.created.timestamp(), 'accessed': token.accessed.timestamp()} + self.update('tokens', data, code = token.code) + + for user in self.select('users').all(schema.User): + data = {'created': user.created.timestamp()} + self.update('users', data, username = user.username) + + for wlist in self.select('whitelist').all(schema.Whitelist): + data = {'created': wlist.created.timestamp()} + self.update('whitelist', data, domain = wlist.domain) + + def get_config(self, key: str) -> Any: key = key.replace('_', '-') @@ -198,6 +235,11 @@ class Connection(SqlConnection): return cur.one(schema.User) + def get_user_by_app_token(self, code: str) -> schema.User | None: + with self.run('get-user-by-app-token', {'code': code}) as cur: + return cur.one(schema.User) + + def get_users(self) -> Iterator[schema.User]: return self.execute("SELECT * FROM users").all(schema.User) @@ -249,13 +291,102 @@ class Connection(SqlConnection): pass + def get_app(self, + client_id: str, + client_secret: str, + token: str | None = None) -> schema.App | None: + + params = { + 'id': client_id, + 'secret': client_secret + } + + if token is not None: + command = 'get-app-with-token' + params['token'] = token + + else: + command = 'get-app' + + with self.run(command, params) as cur: + return cur.one(schema.App) + + + def get_app_by_token(self, token: str) -> schema.App | None: + with self.run('get-app-by-token', {'token': token}) as cur: + return cur.one(schema.App) + + + def put_app(self, name: str, redirect_uri: str, website: str | None = None) -> schema.App: + params = { + 'name': name, + 'redirect_uri': redirect_uri, + 'website': website, + 'client_id': secrets.token_hex(20), + 'client_secret': secrets.token_hex(20), + 'created': Date.new_utc().timestamp(), + 'accessed': Date.new_utc().timestamp() + } + + with self.insert('app', params) as cur: + if (row := cur.one(schema.App)) is None: + raise RuntimeError(f'Failed to insert app: {name}') + + return row + + + def update_app(self, app: schema.App, user: schema.User | None, set_auth: bool) -> schema.App: + data: dict[str, str | None] = {} + + if user is not None: + data['user'] = user.username + + if set_auth: + data['auth_code'] = secrets.token_hex(20) + + else: + data['token'] = secrets.token_hex(20) + data['auth_code'] = None + + params = { + 'client_id': app.client_id, + 'client_secret': app.client_secret + } + + with self.update('app', data, **params) as cur: # type: ignore[arg-type] + if (row := cur.one(schema.App)) is None: + raise RuntimeError('Failed to update row') + + return row + + + def del_app(self, client_id: str, client_secret: str, token: str | None = None) -> bool: + params = { + 'id': client_id, + 'secret': client_secret + } + + if token is not None: + command = 'del-app-token' + params['token'] = token + + else: + command = 'del-app' + + with self.run(command, params) as cur: + if cur.row_count > 1: + raise RuntimeError('More than 1 row was deleted') + + return cur.row_count == 0 + + def get_token(self, code: str) -> schema.Token | None: with self.run('get-token', {'code': code}) as cur: return cur.one(schema.Token) def get_tokens(self, username: str | None = None) -> Iterator[schema.Token]: - if username is not None: + if username is None: return self.select('tokens').all(schema.Token) return self.select('tokens', username = username).all(schema.Token) diff --git a/relay/database/schema.py b/relay/database/schema.py index 1fd7003..660e527 100644 --- a/relay/database/schema.py +++ b/relay/database/schema.py @@ -1,14 +1,14 @@ from __future__ import annotations -import typing - +from blib import Date from bsql import Column, Row, Tables from collections.abc import Callable -from datetime import datetime +from copy import deepcopy +from typing import TYPE_CHECKING, Any from .config import ConfigData -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from .connection import Connection @@ -16,6 +16,16 @@ VERSIONS: dict[int, Callable[[Connection], None]] = {} TABLES = Tables() +def deserialize_timestamp(value: Any) -> Date: + try: + return Date.parse(value) + + except ValueError: + pass + + return Date.fromisoformat(value) + + @TABLES.add_row class Config(Row): key: Column[str] = Column('key', 'text', primary_key = True, unique = True, nullable = False) @@ -27,62 +37,125 @@ class Config(Row): class Instance(Row): table_name: str = 'inboxes' + domain: Column[str] = Column( 'domain', 'text', primary_key = True, unique = True, nullable = False) actor: Column[str] = Column('actor', 'text', unique = True) inbox: Column[str] = Column('inbox', 'text', unique = True, nullable = False) followid: Column[str] = Column('followid', 'text') software: Column[str] = Column('software', 'text') - accepted: Column[datetime] = Column('accepted', 'boolean') - created: Column[datetime] = Column('created', 'timestamp', nullable = False) + accepted: Column[Date] = Column('accepted', 'boolean') + created: Column[Date] = Column( + 'created', 'timestamp', nullable = False, + deserializer = deserialize_timestamp, serializer = Date.timestamp + ) @TABLES.add_row class Whitelist(Row): domain: Column[str] = Column( 'domain', 'text', primary_key = True, unique = True, nullable = True) - created: Column[datetime] = Column('created', 'timestamp') + created: Column[Date] = Column( + 'created', 'timestamp', nullable = False, + deserializer = deserialize_timestamp, serializer = Date.timestamp + ) @TABLES.add_row class DomainBan(Row): table_name: str = 'domain_bans' + domain: Column[str] = Column( 'domain', 'text', primary_key = True, unique = True, nullable = True) reason: Column[str] = Column('reason', 'text') note: Column[str] = Column('note', 'text') - created: Column[datetime] = Column('created', 'timestamp') + created: Column[Date] = Column( + 'created', 'timestamp', nullable = False, + deserializer = deserialize_timestamp, serializer = Date.timestamp + ) @TABLES.add_row class SoftwareBan(Row): table_name: str = 'software_bans' + name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True) reason: Column[str] = Column('reason', 'text') note: Column[str] = Column('note', 'text') - created: Column[datetime] = Column('created', 'timestamp') + created: Column[Date] = Column( + 'created', 'timestamp', nullable = False, + deserializer = deserialize_timestamp, serializer = Date.timestamp + ) @TABLES.add_row class User(Row): table_name: str = 'users' + username: Column[str] = Column( 'username', 'text', primary_key = True, unique = True, nullable = False) hash: Column[str] = Column('hash', 'text', nullable = False) handle: Column[str] = Column('handle', 'text') - created: Column[datetime] = Column('created', 'timestamp') + created: Column[Date] = Column( + 'created', 'timestamp', nullable = False, + deserializer = deserialize_timestamp, serializer = Date.timestamp + ) @TABLES.add_row class Token(Row): table_name: str = 'tokens' + code: Column[str] = Column('code', 'text', primary_key = True, unique = True, nullable = False) user: Column[str] = Column('user', 'text', nullable = False) - created: Column[datetime] = Column('created', 'timestamp') + created: Column[Date] = Column( + 'created', 'timestamp', nullable = False, + deserializer = deserialize_timestamp, serializer = Date.timestamp + ) + accessed: Column[Date] = Column( + 'accessed', 'timestamp', nullable = False, + deserializer = deserialize_timestamp, serializer = Date.timestamp + ) + + +@TABLES.add_row +class App(Row): + table_name: str = 'apps' + + + client_id: Column[str] = Column( + 'client_id', 'text', primary_key = True, unique = True, nullable = False) + client_secret: Column[str] = Column('client_secret', 'text', nullable = False) + name: Column[str] = Column('name', 'text') + website: Column[str] = Column('website', 'text') + redirect_uri: Column[str] = Column('redirect_uri', 'text', nullable = False) + token: Column[str | None] = Column('token', 'text') + auth_code: Column[str | None] = Column('auth_code', 'text') + user: Column[str | None] = Column('user', 'text') + created: Column[Date] = Column( + 'created', 'timestamp', nullable = False, + deserializer = deserialize_timestamp, serializer = Date.timestamp + ) + accessed: Column[Date] = Column( + 'accessed', 'timestamp', nullable = False, + deserializer = deserialize_timestamp, serializer = Date.timestamp + ) + + + def get_api_data(self, include_token: bool = False) -> dict[str, Any]: + data = deepcopy(self) + data.pop('auth_code') + data.pop('created') + data.pop('accessed') + + if not include_token: + data.pop('token') + + return data def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]: @@ -103,5 +176,15 @@ def migrate_20240206(conn: Connection) -> None: @migration def migrate_20240310(conn: Connection) -> None: - conn.execute("ALTER TABLE inboxes ADD COLUMN accepted BOOLEAN") - conn.execute("UPDATE inboxes SET accepted = 1") + conn.execute('ALTER TABLE "inboxes" ADD COLUMN "accepted" BOOLEAN') + conn.execute('UPDATE "inboxes" SET accepted = 1') + + +@migration +def migrate_20240625(conn: Connection) -> None: + conn.execute('ALTER TABLE "tokens" ADD "accessed" timestamp') + + for token in conn.get_tokens(): + conn.update('tokens', {'accessed': token.created}, code = token.code).one() + + conn.create_tables() diff --git a/relay/frontend/base.haml b/relay/frontend/base.haml index d3d8bb6..dd1e3e2 100644 --- a/relay/frontend/base.haml +++ b/relay/frontend/base.haml @@ -1,5 +1,5 @@ -macro menu_item(name, path) - -if view.request.path == path or (path != "/" and view.request.path.startswith(path)) + -if request.path == path or (path != "/" and request.path.startswith(path)) %a.button(href="{{path}}" active="true") -> =name -else @@ -10,12 +10,12 @@ %head %title << {{config.name}}: {{page}} %meta(charset="UTF-8") - %meta(name="viewport" content="width=device-width, initial-scale=1") - %link(rel="stylesheet" type="text/css" href="/theme/{{config.theme}}.css?{{version}}" nonce="{{view.request['hash']}}" class="theme") - %link(rel="stylesheet" type="text/css" href="/static/style.css?{{version}}" nonce="{{view.request['hash']}}") - %link(rel="stylesheet" type="text/css" href="/static/bootstrap-icons.css?{{version}}" nonce="{{view.request['hash']}}") + %meta(name="ort" content="width=device-width, initial-scale=1") + %link(rel="stylesheet" type="text/css" href="/theme/{{config.theme}}.css?{{version}}" nonce="{{request['hash']}}" class="theme") + %link(rel="stylesheet" type="text/css" href="/static/style.css?{{version}}" nonce="{{request['hash']}}") + %link(rel="stylesheet" type="text/css" href="/static/bootstrap-icons.css?{{version}}" nonce="{{request['hash']}}") %link(rel="manifest" href="/manifest.json?{{version}}") - %script(type="application/javascript" src="/static/functions.js?{{version}}" nonce="{{view.request['hash']}}" defer) + %script(type="application/javascript" src="/static/functions.js?{{version}}" nonce="{{request['hash']}}" defer) -block head %body @@ -26,7 +26,7 @@ {{menu_item("Home", "/")}} - -if view.request["user"] + -if request["user"] {{menu_item("Instances", "/admin/instances")}} {{menu_item("Whitelist", "/admin/whitelist")}} {{menu_item("Domain Bans", "/admin/domain_bans")}} @@ -61,11 +61,11 @@ #footer.section .col1 - -if not view.request["user"] + -if not request["user"] %a(href="/login") << Login -else - =view.request["user"]["username"] + =request["user"]["username"] ( %a(href="/logout") << Logout ) diff --git a/relay/frontend/page/authorize_new.haml b/relay/frontend/page/authorize_new.haml new file mode 100644 index 0000000..4f07df3 --- /dev/null +++ b/relay/frontend/page/authorize_new.haml @@ -0,0 +1,31 @@ +-extends "base.haml" +-set page="App Authorization" + +-block content + %fieldset.section + %legend << App Authorization + + -if application.website + #title << Application "{{application.name}}" wants full API access + + -else + #title << Application "{{application.name}}" wants full API access + + #buttons + .spacer + + %form(action="/oauth/authorize" method="POST") + %input(type="hidden" name="client_id" value="{{application.client_id}}") + %input(type="hidden" name="client_secret" value="{{application.client_secret}}") + %input(type="hidden" name="redirect_uri" value="{{application.redirect_uri}}") + %input(type="hidden" name="response" value="true") + %input.button(type="submit" value="Allow") + + %form(action="/oauth/authorize" method="POST") + %input(type="hidden" name="client_id" value="{{application.client_id}}") + %input(type="hidden" name="client_secret" value="{{application.client_secret}}") + %input(type="hidden" name="redirect_uri" value="{{application.redirect_uri}}") + %input(type="hidden" name="response" value="false") + %input.button(type="submit" value="Deny") + + .spacer diff --git a/relay/frontend/page/authorize_show.haml b/relay/frontend/page/authorize_show.haml new file mode 100644 index 0000000..19cde40 --- /dev/null +++ b/relay/frontend/page/authorize_show.haml @@ -0,0 +1,18 @@ +-extends "base.haml" +-set page="App Authorization" + +-block content + %fieldset.section + %legend << App Authorization Code + + -if application.website + %p + Copy the following code into + %a(href="{{application.website}}" target="_main") -> %code -> =application.name + + -else + %p + Copy the following code info + %code -> =application.name + + %pre#code -> =application.auth_code diff --git a/relay/frontend/page/error.haml b/relay/frontend/page/error.haml new file mode 100644 index 0000000..4d4bf95 --- /dev/null +++ b/relay/frontend/page/error.haml @@ -0,0 +1,7 @@ +-extends "base.haml" +-set page="Error" + +-block content + .section.error + .title << HTTP Error {{e.status}} + .body -> =e.body diff --git a/relay/frontend/page/login.haml b/relay/frontend/page/login.haml index c32160f..4f29746 100644 --- a/relay/frontend/page/login.haml +++ b/relay/frontend/page/login.haml @@ -12,4 +12,6 @@ %label(for="password") << Password %input(id="password" name="password" placeholder="Password" type="password") + + %input#redir(type="hidden" name="redir" value="{{redir}}") %input.submit(type="button" value="Login") diff --git a/relay/frontend/static/functions.js b/relay/frontend/static/functions.js index b0e4db5..3063223 100644 --- a/relay/frontend/static/functions.js +++ b/relay/frontend/static/functions.js @@ -483,13 +483,15 @@ function page_instance() { function page_login() { const fields = { username: document.querySelector("#username"), - password: document.querySelector("#password") - } + password: document.querySelector("#password"), + redir: document.querySelector("#redir") + }; async function login(event) { const values = { username: fields.username.value.trim(), - password: fields.password.value.trim() + password: fields.password.value.trim(), + redir: fields.redir.value.trim() } if (values.username === "" | values.password === "") { @@ -498,14 +500,14 @@ function page_login() { } try { - await request("POST", "v1/token", values); + await request("POST", "v1/login", values); } catch (error) { toast(error); return; } - document.location = "/"; + document.location = values.redir; } @@ -848,9 +850,6 @@ if (location.pathname.startsWith("/admin/config")) { } else if (location.pathname.startsWith("/admin/instances")) { page_instance(); -} else if (location.pathname.startsWith("/admin/login")) { - page_login(); - } else if (location.pathname.startsWith("/admin/software_bans")) { page_software_ban(); @@ -859,4 +858,7 @@ if (location.pathname.startsWith("/admin/config")) { } else if (location.pathname.startsWith("/admin/whitelist")) { page_whitelist(); + +} else if (location.pathname.startsWith("/login")) { + page_login(); } diff --git a/relay/frontend/static/style.css b/relay/frontend/static/style.css index f0d72f5..c9bcd43 100644 --- a/relay/frontend/static/style.css +++ b/relay/frontend/static/style.css @@ -338,6 +338,44 @@ textarea { } +/* error */ +#content.page-error { + text-align: center; +} + +#content.page-error .title { + font-size: 24px; + font-weight: bold; +} + + +/* auth */ +#content.page-app_authorization { + text-align: center; +} + +#content.page-app_authorization #code { + background: var(--background); + border: 1px solid var(--border); + font-size: 18px; + margin: 0 auto; + width: max-content; + padding: 5px; +} + +#content.page-app_authorization #title { + font-size: 24px; +} + +#content.page-app_authorization #buttons { + display: grid; + grid-template-columns: auto max-content max-content auto; + grid-gap: var(--spacing); + justify-items: center; + margin: var(--spacing) 0; +} + + @keyframes show_toast { 0% { transform: translateX(100%); diff --git a/relay/manage.py b/relay/manage.py index 81f546e..5ae8238 100644 --- a/relay/manage.py +++ b/relay/manage.py @@ -212,6 +212,21 @@ def cli_run(ctx: click.Context, dev: bool = False) -> None: os._exit(0) +@cli.command('db-maintenance') +@click.option('--fix-timestamps', '-t', is_flag = True, + help = 'Make sure timestamps in the database are float values') +@click.pass_context +def cli_db_maintenance(ctx: click.Context, fix_timestamps: bool) -> None: + 'Perform maintenance tasks on the database' + + if fix_timestamps: + with ctx.obj.database.session(True) as conn: + conn.fix_timestamps() + + with ctx.obj.database.session(False) as conn: + with conn.execute("VACUUM"): + pass + @cli.command('convert') @click.option('--old-config', '-o', help = 'Path to the config file to convert from') @@ -239,18 +254,18 @@ def cli_convert(ctx: click.Context, old_config: str) -> None: ctx.obj.config.set('domain', config['host']) ctx.obj.config.save() + # fix: mypy complains about the types returned by click.progressbar when updating click to 8.1.7 with get_database(ctx.obj.config) as db: with db.session(True) as conn: conn.put_config('private-key', database['private-key']) conn.put_config('note', config['note']) conn.put_config('whitelist-enabled', config['whitelist_enabled']) - with click.progressbar( # type: ignore + with click.progressbar( database['relay-list'].values(), label = 'Inboxes'.ljust(15), width = 0 ) as inboxes: - for inbox in inboxes: if inbox['software'] in {'akkoma', 'pleroma'}: actor = f'https://{inbox["domain"]}/relay' @@ -269,7 +284,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None: software = inbox['software'] ) - with click.progressbar( # type: ignore + with click.progressbar( config['blocked_software'], label = 'Banned software'.ljust(15), width = 0 @@ -281,7 +296,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None: reason = 'relay' if software in RELAY_SOFTWARE else None ) - with click.progressbar( # type: ignore + with click.progressbar( config['blocked_instances'], label = 'Banned domains'.ljust(15), width = 0 @@ -290,7 +305,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None: for domain in banned_software: conn.put_domain_ban(domain) - with click.progressbar( # type: ignore + with click.progressbar( config['whitelist'], label = 'Whitelist'.ljust(15), width = 0 diff --git a/relay/misc.py b/relay/misc.py index 6995bc4..b27c89a 100644 --- a/relay/misc.py +++ b/relay/misc.py @@ -62,6 +62,28 @@ SOFTWARE = ( 'gotosocial' ) +JSON_PATHS: tuple[str, ...] = ( + '/api/v1', + '/actor', + '/inbox', + '/outbox', + '/following', + '/followers', + '/.well-known', + '/nodeinfo', + '/oauth/token', + '/oauth/revoke' +) + +TOKEN_PATHS: tuple[str, ...] = ( + '/api', + '/login', + '/logout', + '/oauth/authorize', + '/oauth/revoke', + '/admin' +) + def boolean(value: Any) -> bool: if isinstance(value, str): @@ -113,6 +135,17 @@ def get_resource(path: str) -> Path: return Path(str(pkgfiles('relay'))).joinpath(path) +class HttpError(Exception): + def __init__(self, + status: int, + body: str) -> None: + + self.body: str = body + self.status: int = status + + Exception.__init__(self, f"HTTP Error {status}: {body}") + + class JsonEncoder(json.JSONEncoder): def default(self, o: Any) -> str: if isinstance(o, datetime): @@ -242,9 +275,9 @@ class Response(AiohttpResponse): @classmethod - def new_redir(cls: type[Self], path: str) -> Self: + def new_redir(cls: type[Self], path: str, status: int = 307) -> Self: body = f'Redirect to {path}' - return cls.new(body, 302, {'Location': path}) + return cls.new(body, status, {'Location': path}, ctype = 'html') @property diff --git a/relay/template.py b/relay/template.py index 7e3f657..3ee2855 100644 --- a/relay/template.py +++ b/relay/template.py @@ -2,6 +2,7 @@ from __future__ import annotations import textwrap +from aiohttp.web import Request from collections.abc import Callable from hamlish_jinja import HamlishExtension from jinja2 import Environment, FileSystemLoader @@ -13,7 +14,6 @@ from typing import TYPE_CHECKING, Any from . import __version__ from .misc import get_resource -from .views.base import View if TYPE_CHECKING: from .application import Application @@ -43,12 +43,12 @@ class Template(Environment): self.hamlish_mode = 'indented' - def render(self, path: str, view: View | None = None, **context: Any) -> str: + def render(self, path: str, request: Request, **context: Any) -> str: with self.app.database.session(False) as conn: config = conn.get_config_all() new_context = { - 'view': view, + 'request': request, 'domain': self.app.config.domain, 'version': __version__, 'config': config, diff --git a/relay/views/activitypub.py b/relay/views/activitypub.py index 74b01c6..aa672f2 100644 --- a/relay/views/activitypub.py +++ b/relay/views/activitypub.py @@ -7,7 +7,7 @@ from .base import View, register_route from .. import logger as logging from ..database import schema -from ..misc import Message, Response +from ..misc import HttpError, Message, Response from ..processors import run_processor @@ -39,8 +39,7 @@ class ActorView(View): async def post(self, request: Request) -> Response: - if response := await self.get_post_data(): - return response + await self.get_post_data() with self.database.session() as conn: self.instance = conn.get_inbox(self.actor.shared_inbox) # type: ignore[assignment] @@ -65,13 +64,13 @@ class ActorView(View): return Response.new(status = 202) - async def get_post_data(self) -> Response | None: + async def get_post_data(self) -> None: try: self.signature = aputils.Signature.parse(self.request.headers['signature']) except KeyError: logging.verbose('Missing signature header') - return Response.new_error(400, 'missing signature header', 'json') + raise HttpError(400, 'missing signature header') try: message: Message | None = await self.request.json(loads = Message.parse) @@ -79,17 +78,17 @@ class ActorView(View): except Exception: traceback.print_exc() logging.verbose('Failed to parse inbox message') - return Response.new_error(400, 'failed to parse message', 'json') + raise HttpError(400, 'failed to parse message') if message is None: logging.verbose('empty message') - return Response.new_error(400, 'missing message', 'json') + raise HttpError(400, 'missing message') self.message = message if 'actor' not in self.message: logging.verbose('actor not in message') - return Response.new_error(400, 'no actor in message', 'json') + raise HttpError(400, 'no actor in message') try: self.actor = await self.client.get(self.signature.keyid, True, Message) @@ -98,26 +97,24 @@ class ActorView(View): # ld signatures aren't handled atm, so just ignore it if self.message.type == 'Delete': logging.verbose('Instance sent a delete which cannot be handled') - return Response.new(status=202) + raise HttpError(202, '') logging.verbose(f'Failed to fetch actor: {self.signature.keyid}') - return Response.new_error(400, 'failed to fetch actor', 'json') + raise HttpError(400, 'failed to fetch actor') try: self.signer = self.actor.signer except KeyError: logging.verbose('Actor missing public key: %s', self.signature.keyid) - return Response.new_error(400, 'actor missing public key', 'json') + raise HttpError(400, 'actor missing public key') try: await self.signer.validate_request_async(self.request) except aputils.SignatureFailureError as e: logging.verbose('signature validation failed for "%s": %s', self.actor.id, e) - return Response.new_error(401, str(e), 'json') - - return None + raise HttpError(401, str(e)) @register_route('/outbox') diff --git a/relay/views/api.py b/relay/views/api.py index 73b6a16..76cd1e5 100644 --- a/relay/views/api.py +++ b/relay/views/api.py @@ -1,16 +1,17 @@ +import secrets import traceback from aiohttp.web import Request, middleware from argon2.exceptions import VerifyMismatchError +from blib import convert_to_boolean from collections.abc import Awaitable, Callable, Sequence -from typing import Any from urllib.parse import urlparse from .base import View, register_route from .. import __version__ -from ..database import ConfigData -from ..misc import Message, Response, boolean, get_app +from ..database import ConfigData, schema +from ..misc import HttpError, Message, Response, boolean DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob' @@ -22,6 +23,8 @@ ALLOWED_HEADERS: set[str] = { PUBLIC_API_PATHS: Sequence[tuple[str, str]] = ( ('GET', '/api/v1/relay'), + ('POST', '/api/v1/app'), + ('POST', '/api/v1/login'), ('POST', '/api/v1/token') ) @@ -37,57 +40,174 @@ def check_api_path(method: str, path: str) -> bool: async def handle_api_path( request: Request, handler: Callable[[Request], Awaitable[Response]]) -> Response: - try: - if (token := request.cookies.get('user-token')): - request['token'] = token - else: - request['token'] = request.headers['Authorization'].replace('Bearer', '').strip() - - with get_app().database.session() as conn: - request['user'] = conn.get_user_by_token(request['token']) - - except (KeyError, ValueError): - request['token'] = None - request['user'] = None + if not request.path.startswith('/api'): + return await handler(request) if request.method != "OPTIONS" and check_api_path(request.method, request.path): - if not request['token']: - return Response.new_error(401, 'Missing token', 'json') + if request['token'] is None: + raise HttpError(401, 'Missing token') - if not request['user']: - return Response.new_error(401, 'Invalid token', 'json') + if request['user'] is None: + raise HttpError(401, 'Invalid token') response = await handler(request) - - if request.path.startswith('/api'): - response.headers['Access-Control-Allow-Origin'] = '*' - response.headers['Access-Control-Allow-Headers'] = ', '.join(ALLOWED_HEADERS) + response.headers['Access-Control-Allow-Origin'] = '*' + response.headers['Access-Control-Allow-Headers'] = ', '.join(ALLOWED_HEADERS) return response -@register_route('/api/v1/token') -class Login(View): +@register_route('/oauth/authorize') +class OauthAuthorize(View): async def get(self, request: Request) -> Response: - return Response.new({'message': 'Token valid'}, ctype = 'json') + data = await self.get_api_data(['response_type', 'client_id', 'redirect_uri'], []) + + if data['response_type'] != 'code': + raise HttpError(400, 'Response type is not "code"') + + with self.database.session(True) as conn: + with conn.select('app', client_id = data['client_id']) as cur: + if (app := cur.one(schema.App)) is None: + raise HttpError(404, 'Could not find app') + + if app.token is not None or app.auth_code is not None: + context = {'application': app} + html = self.template.render( + 'page/authorize_show.haml', self.request, **context + ) + + return Response.new(html, ctype = 'html') + + if data['redirect_uri'] != app.redirect_uri: + raise HttpError(400, 'redirect_uri does not match application') + + context = {'application': app} + html = self.template.render('page/authorize_new.haml', self.request, **context) + return Response.new(html, ctype = 'html') async def post(self, request: Request) -> Response: - data = await self.get_api_data(['username', 'password'], []) + data = await self.get_api_data( + ['client_id', 'client_secret', 'redirect_uri', 'response'], [] + ) - if isinstance(data, Response): - return data + with self.database.session(True) as conn: + if (app := conn.get_app(data['client_id'], data['client_secret'])) is None: + return Response.new_error(404, 'Could not find app', 'json') + + if convert_to_boolean(data['response']): + if app.auth_code is None: + app = conn.update_app(app, request['user'], True) + + if app.redirect_uri == DEFAULT_REDIRECT: + context = {'application': app} + html = self.template.render( + 'page/authorize_show.haml', self.request, **context + ) + + return Response.new(html, ctype = 'html') + + return Response.new_redir(f'{app.redirect_uri}?code={app.auth_code}') + + if not conn.del_app(app.client_id, app.client_secret): + raise HttpError(404, 'App not found') + + return Response.new_redir('/') + + +@register_route('/oauth/token') +class OauthToken(View): + async def post(self, request: Request) -> Response: + data = await self.get_api_data( + ['grant_type', 'code', 'client_id', 'client_secret', 'redirect_uri'], [] + ) + + if data['grant_type'] != 'authorization_code': + raise HttpError(400, 'Invalid grant type') + + with self.database.session(True) as conn: + if (app := conn.get_app(data['client_id'], data['client_secret'])) is None: + raise HttpError(404, 'Application not found') + + if app.auth_code != data['code']: + raise HttpError(400, 'Invalid authentication code') + + if app.redirect_uri != data['redirect_uri']: + raise HttpError(400, 'Invalid redirect uri') + + app = conn.update_app(app, request['user'], False) + + return Response.new(app.get_api_data(True), ctype = 'json') + + +@register_route('/oauth/revoke') +class OauthRevoke(View): + async def post(self, request: Request) -> Response: + data = await self.get_api_data(['client_id', 'client_secret', 'token'], []) + + with self.database.session(True) as conn: + if (app := conn.get_app(**data)) is None: + raise HttpError(404, 'Could not find token') + + if app.user != request['token'].username: + raise HttpError(403, 'Invalid token') + + if not conn.del_app(**data): + raise HttpError(400, 'Token not removed') + + return Response.new({'msg': 'Token deleted'}, ctype = 'json') + + +@register_route('/api/v1/app') +class App(View): + async def get(self, request: Request) -> Response: + data = await self.get_api_data(['client_id', 'client_secret'], []) + + with self.database.session(False) as conn: + if (app := conn.get_app(data['client_id'], data['client_secret'])) is None: + raise HttpError(404, 'Application cannot be found') + + return Response.new(app.get_api_data(), ctype = 'json') + + + async def post(self, request: Request) -> Response: + data = await self.get_api_data(['name', 'redirect_uri'], ['website']) + + with self.database.session(True) as conn: + app = conn.put_app( + name = data['name'], + redirect_uri = data['redirect_uri'], + website = data.get('website') + ) + + return Response.new(app.get_api_data(), ctype = 'json') + + + async def delete(self, request: Request) -> Response: + data = await self.get_api_data(['client_id', 'client_secret'], []) + + with self.database.session(True) as conn: + if not conn.del_app(data['client_id'], data['client_secret'], request['token'].code): + raise HttpError(400, 'Token not removed') + + return Response.new({'msg': 'Token deleted'}, ctype = 'json') + + +@register_route('/api/v1/login') +class Login(View): + async def post(self, request: Request) -> Response: + data = await self.get_api_data(['username', 'password'], []) with self.database.session(True) as conn: if not (user := conn.get_user(data['username'])): - return Response.new_error(401, 'User not found', 'json') + raise HttpError(401, 'User not found') try: conn.hasher.verify(user['hash'], data['password']) except VerifyMismatchError: - return Response.new_error(401, 'Invalid password', 'json') + raise HttpError(401, 'Invalid password') token = conn.put_token(data['username']) @@ -106,11 +226,36 @@ class Login(View): return resp - async def delete(self, request: Request) -> Response: - with self.database.session() as conn: - conn.del_token(request['token']) - return Response.new({'message': 'Token revoked'}, ctype = 'json') + async def post2(self, request: Request) -> Response: + data = await self.get_api_data(['username', 'password'], []) + + with self.database.session(True) as conn: + if not (user := conn.get_user(data['username'])): + raise HttpError(401, 'User not found') + + try: + conn.hasher.verify(user['hash'], data['password']) + + except VerifyMismatchError: + raise HttpError(401, 'Invalid password') + + app = conn.put_app( + data['app_name'], + DEFAULT_REDIRECT, + data.get('website') + ) + + params = { + 'code': secrets.token_hex(20), + 'user': user.username + } + + with conn.update('app', params, client_id = app.client_id) as cur: + if (row := cur.one(schema.App)) is None: + raise HttpError(500, 'Failed to create app') + + return Response.new(row.get_api_data(True), ctype = 'json') @register_route('/api/v1/relay') @@ -155,14 +300,10 @@ class Config(View): async def post(self, request: Request) -> Response: data = await self.get_api_data(['key', 'value'], []) - - if isinstance(data, Response): - return data - data['key'] = data['key'].replace('-', '_') if data['key'] not in ConfigData.USER_KEYS(): - return Response.new_error(400, 'Invalid key', 'json') + raise HttpError(400, 'Invalid key') with self.database.session() as conn: conn.put_config(data['key'], data['value']) @@ -173,11 +314,8 @@ class Config(View): async def delete(self, request: Request) -> Response: data = await self.get_api_data(['key'], []) - if isinstance(data, Response): - return data - if data['key'] not in ConfigData.USER_KEYS(): - return Response.new_error(400, 'Invalid key', 'json') + raise HttpError(400, 'Invalid key') with self.database.session() as conn: conn.put_config(data['key'], ConfigData.DEFAULT(data['key'])) @@ -196,15 +334,11 @@ class Inbox(View): async def post(self, request: Request) -> Response: data = await self.get_api_data(['actor'], ['inbox', 'software', 'followid']) - - if isinstance(data, Response): - return data - data['domain'] = urlparse(data["actor"]).netloc with self.database.session() as conn: if conn.get_inbox(data['domain']) is not None: - return Response.new_error(404, 'Instance already in database', 'json') + raise HttpError(404, 'Instance already in database') data['domain'] = data['domain'].encode('idna').decode() @@ -214,7 +348,7 @@ class Inbox(View): except Exception: traceback.print_exc() - return Response.new_error(500, 'Failed to fetch actor', 'json') + raise HttpError(500, 'Failed to fetch actor') data['inbox'] = actor_data.shared_inbox @@ -240,14 +374,10 @@ class Inbox(View): async def patch(self, request: Request) -> Response: with self.database.session() as conn: data = await self.get_api_data(['domain'], ['actor', 'software', 'followid']) - - if isinstance(data, Response): - return data - data['domain'] = data['domain'].encode('idna').decode() if (instance := conn.get_inbox(data['domain'])) is None: - return Response.new_error(404, 'Instance with domain not found', 'json') + raise HttpError(404, 'Instance with domain not found') instance = conn.put_inbox( instance.domain, @@ -262,14 +392,10 @@ class Inbox(View): async def delete(self, request: Request) -> Response: with self.database.session() as conn: data = await self.get_api_data(['domain'], []) - - if isinstance(data, Response): - return data - data['domain'] = data['domain'].encode('idna').decode() if not conn.get_inbox(data['domain']): - return Response.new_error(404, 'Instance with domain not found', 'json') + raise HttpError(404, 'Instance with domain not found') conn.del_inbox(data['domain']) @@ -286,26 +412,21 @@ class RequestView(View): async def post(self, request: Request) -> Response: - data: dict[str, Any] | Response = await self.get_api_data(['domain', 'accept'], []) - - if isinstance(data, Response): - return data - - data['accept'] = boolean(data['accept']) + data = await self.get_api_data(['domain', 'accept'], []) data['domain'] = data['domain'].encode('idna').decode() try: with self.database.session(True) as conn: - instance = conn.put_request_response(data['domain'], data['accept']) + instance = conn.put_request_response(data['domain'], boolean(data['accept'])) except KeyError: - return Response.new_error(404, 'Request not found', 'json') + raise HttpError(404, 'Request not found') message = Message.new_response( host = self.config.domain, actor = instance.actor, followid = instance.followid, - accept = data['accept'] + accept = boolean(data['accept']) ) self.app.push_message(instance.inbox, message, instance) @@ -333,15 +454,11 @@ class DomainBan(View): async def post(self, request: Request) -> Response: data = await self.get_api_data(['domain'], ['note', 'reason']) - - if isinstance(data, Response): - return data - data['domain'] = data['domain'].encode('idna').decode() with self.database.session() as conn: if conn.get_domain_ban(data['domain']) is not None: - return Response.new_error(400, 'Domain already banned', 'json') + raise HttpError(400, 'Domain already banned') ban = conn.put_domain_ban( domain = data['domain'], @@ -356,16 +473,13 @@ class DomainBan(View): with self.database.session() as conn: data = await self.get_api_data(['domain'], ['note', 'reason']) - if isinstance(data, Response): - return data - if not any([data.get('note'), data.get('reason')]): - return Response.new_error(400, 'Must include note and/or reason parameters', 'json') + raise HttpError(400, 'Must include note and/or reason parameters') data['domain'] = data['domain'].encode('idna').decode() if conn.get_domain_ban(data['domain']) is None: - return Response.new_error(404, 'Domain not banned', 'json') + raise HttpError(404, 'Domain not banned') ban = conn.update_domain_ban( domain = data['domain'], @@ -379,14 +493,10 @@ class DomainBan(View): async def delete(self, request: Request) -> Response: with self.database.session() as conn: data = await self.get_api_data(['domain'], []) - - if isinstance(data, Response): - return data - data['domain'] = data['domain'].encode('idna').decode() if conn.get_domain_ban(data['domain']) is None: - return Response.new_error(404, 'Domain not banned', 'json') + raise HttpError(404, 'Domain not banned') conn.del_domain_ban(data['domain']) @@ -405,12 +515,9 @@ class SoftwareBan(View): async def post(self, request: Request) -> Response: data = await self.get_api_data(['name'], ['note', 'reason']) - if isinstance(data, Response): - return data - with self.database.session() as conn: if conn.get_software_ban(data['name']) is not None: - return Response.new_error(400, 'Domain already banned', 'json') + raise HttpError(400, 'Domain already banned') ban = conn.put_software_ban( name = data['name'], @@ -424,15 +531,12 @@ class SoftwareBan(View): async def patch(self, request: Request) -> Response: data = await self.get_api_data(['name'], ['note', 'reason']) - if isinstance(data, Response): - return data - if not any([data.get('note'), data.get('reason')]): - return Response.new_error(400, 'Must include note and/or reason parameters', 'json') + raise HttpError(400, 'Must include note and/or reason parameters') with self.database.session() as conn: if conn.get_software_ban(data['name']) is None: - return Response.new_error(404, 'Software not banned', 'json') + raise HttpError(404, 'Software not banned') ban = conn.update_software_ban( name = data['name'], @@ -446,12 +550,9 @@ class SoftwareBan(View): async def delete(self, request: Request) -> Response: data = await self.get_api_data(['name'], []) - if isinstance(data, Response): - return data - with self.database.session() as conn: if conn.get_software_ban(data['name']) is None: - return Response.new_error(404, 'Software not banned', 'json') + raise HttpError(404, 'Software not banned') conn.del_software_ban(data['name']) @@ -474,12 +575,9 @@ class User(View): async def post(self, request: Request) -> Response: data = await self.get_api_data(['username', 'password'], ['handle']) - if isinstance(data, Response): - return data - with self.database.session() as conn: if conn.get_user(data['username']) is not None: - return Response.new_error(404, 'User already exists', 'json') + raise HttpError(404, 'User already exists') user = conn.put_user( username = data['username'], @@ -494,9 +592,6 @@ class User(View): async def patch(self, request: Request) -> Response: data = await self.get_api_data(['username'], ['password', 'handle']) - if isinstance(data, Response): - return data - with self.database.session(True) as conn: user = conn.put_user( username = data['username'], @@ -511,12 +606,9 @@ class User(View): async def delete(self, request: Request) -> Response: data = await self.get_api_data(['username'], []) - if isinstance(data, Response): - return data - with self.database.session(True) as conn: if conn.get_user(data['username']) is None: - return Response.new_error(404, 'User does not exist', 'json') + raise HttpError(404, 'User does not exist') conn.del_user(data['username']) @@ -535,14 +627,11 @@ class Whitelist(View): async def post(self, request: Request) -> Response: data = await self.get_api_data(['domain'], []) - if isinstance(data, Response): - return data - domain = data['domain'].encode('idna').decode() with self.database.session() as conn: if conn.get_domain_whitelist(domain) is not None: - return Response.new_error(400, 'Domain already added to whitelist', 'json') + raise HttpError(400, 'Domain already added to whitelist') item = conn.put_domain_whitelist(domain) @@ -552,14 +641,11 @@ class Whitelist(View): async def delete(self, request: Request) -> Response: data = await self.get_api_data(['domain'], []) - if isinstance(data, Response): - return data - domain = data['domain'].encode('idna').decode() with self.database.session() as conn: if conn.get_domain_whitelist(domain) is None: - return Response.new_error(404, 'Domain not in whitelist', 'json') + raise HttpError(404, 'Domain not in whitelist') conn.del_domain_whitelist(domain) diff --git a/relay/views/base.py b/relay/views/base.py index e102896..1b2d405 100644 --- a/relay/views/base.py +++ b/relay/views/base.py @@ -1,10 +1,8 @@ from __future__ import annotations -from Crypto.Random import get_random_bytes from aiohttp.abc import AbstractView from aiohttp.hdrs import METH_ALL as METHODS -from aiohttp.web import HTTPMethodNotAllowed, Request -from base64 import b64encode +from aiohttp.web import Request from bsql import Database from collections.abc import Awaitable, Callable, Generator, Sequence, Mapping from functools import cached_property @@ -15,7 +13,7 @@ from ..cache import Cache from ..config import Config from ..database import Connection from ..http_client import HttpClient -from ..misc import Response, get_app +from ..misc import HttpError, Response, get_app if TYPE_CHECKING: from typing import Self @@ -43,10 +41,10 @@ def register_route(*paths: str) -> Callable[[type[View]], type[View]]: class View(AbstractView): def __await__(self) -> Generator[Any, None, Response]: if self.request.method not in METHODS: - raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) + raise HttpError(405, f'"{self.request.method}" method not allowed') if not (handler := self.handlers.get(self.request.method)): - raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) + raise HttpError(405, f'"{self.request.method}" method not allowed') return self._run_handler(handler).__await__() @@ -58,7 +56,6 @@ class View(AbstractView): async def _run_handler(self, handler: HandlerCallback, **kwargs: Any) -> Response: - self.request['hash'] = b64encode(get_random_bytes(16)).decode('ascii') return await handler(self.request, **self.request.match_info, **kwargs) @@ -117,17 +114,18 @@ class View(AbstractView): async def get_api_data(self, required: list[str], - optional: list[str]) -> dict[str, str] | Response: + optional: list[str]) -> dict[str, str]: - if self.request.content_type in {'x-www-form-urlencoded', 'multipart/form-data'}: + if self.request.content_type in {'application/x-www-form-urlencoded', 'multipart/form-data'}: post_data = convert_data(await self.request.post()) + # post_data = {key: value for key, value in parse_qsl(await self.request.text())} elif self.request.content_type == 'application/json': try: post_data = convert_data(await self.request.json()) except JSONDecodeError: - return Response.new_error(400, 'Invalid JSON data', 'json') + raise HttpError(400, 'Invalid JSON data') else: post_data = convert_data(self.request.query) @@ -139,9 +137,9 @@ class View(AbstractView): data[key] = post_data[key] except KeyError as e: - return Response.new_error(400, f'Missing {str(e)} pararmeter', 'json') + raise HttpError(400, f'Missing {str(e)} pararmeter') for key in optional: - data[key] = post_data.get(key, '') + data[key] = post_data.get(key) # type: ignore[assignment] return data diff --git a/relay/views/frontend.py b/relay/views/frontend.py index cf6b338..a383d20 100644 --- a/relay/views/frontend.py +++ b/relay/views/frontend.py @@ -1,18 +1,13 @@ from aiohttp import web from collections.abc import Awaitable, Callable from typing import Any +from urllib.parse import unquote from .base import View, register_route from ..database import THEMES from ..logger import LogLevel -from ..misc import Response, get_app - - -UNAUTH_ROUTES = { - '/', - '/login' -} +from ..misc import TOKEN_PATHS, Response @web.middleware @@ -20,28 +15,25 @@ async def handle_frontend_path( request: web.Request, handler: Callable[[web.Request], Awaitable[Response]]) -> Response: - app = get_app() + if request['user'] is not None and request.path == '/login': + return Response.new_redir('/') - if request.path in UNAUTH_ROUTES or request.path.startswith('/admin'): - request['token'] = request.cookies.get('user-token') - request['user'] = None + if request.path.startswith(TOKEN_PATHS) and request['user'] is None: + if request.path == '/logout': + return Response.new_redir('/') - if request['token']: - with app.database.session(False) as conn: - request['user'] = conn.get_user_by_token(request['token']) + response = Response.new_redir(f'/login?redir={request.path}') - if request['user'] and request.path == '/login': - return Response.new('', 302, {'Location': '/'}) - - if not request['user'] and request.path.startswith('/admin'): - response = Response.new('', 302, {'Location': f'/login?redir={request.path}'}) + if request['token'] is not None: response.del_cookie('user-token') - return response + + return response response = await handler(request) - if not request.path.startswith('/api') and not request['user'] and request['token']: - response.del_cookie('user-token') + if not request.path.startswith('/api'): + if request['user'] is None and request['token'] is not None: + response.del_cookie('user-token') return response @@ -54,14 +46,15 @@ class HomeView(View): 'instances': tuple(conn.get_inboxes()) } - data = self.template.render('page/home.haml', self, **context) + data = self.template.render('page/home.haml', self.request, **context) return Response.new(data, ctype='html') @register_route('/login') class Login(View): async def get(self, request: web.Request) -> Response: - data = self.template.render('page/login.haml', self) + redir = unquote(request.query.get('redir', '/')) + data = self.template.render('page/login.haml', self.request, redir = redir) return Response.new(data, ctype = 'html') @@ -69,7 +62,7 @@ class Login(View): class Logout(View): async def get(self, request: web.Request) -> Response: with self.database.session(True) as conn: - conn.del_token(request['token']) + conn.del_token(request['token'].code) resp = Response.new_redir('/') resp.del_cookie('user-token', domain = self.config.domain, path = '/') @@ -79,7 +72,7 @@ class Logout(View): @register_route('/admin') class Admin(View): async def get(self, request: web.Request) -> Response: - return Response.new('', 302, {'Location': '/admin/instances'}) + return Response.new_redir(f'/login?redir={request.path}', 301) @register_route('/admin/instances') @@ -101,7 +94,7 @@ class AdminInstances(View): if message: context['message'] = message - data = self.template.render('page/admin-instances.haml', self, **context) + data = self.template.render('page/admin-instances.haml', self.request, **context) return Response.new(data, ctype = 'html') @@ -123,7 +116,7 @@ class AdminWhitelist(View): if message: context['message'] = message - data = self.template.render('page/admin-whitelist.haml', self, **context) + data = self.template.render('page/admin-whitelist.haml', self.request, **context) return Response.new(data, ctype = 'html') @@ -145,7 +138,7 @@ class AdminDomainBans(View): if message: context['message'] = message - data = self.template.render('page/admin-domain_bans.haml', self, **context) + data = self.template.render('page/admin-domain_bans.haml', self.request, **context) return Response.new(data, ctype = 'html') @@ -167,7 +160,7 @@ class AdminSoftwareBans(View): if message: context['message'] = message - data = self.template.render('page/admin-software_bans.haml', self, **context) + data = self.template.render('page/admin-software_bans.haml', self.request, **context) return Response.new(data, ctype = 'html') @@ -189,7 +182,7 @@ class AdminUsers(View): if message: context['message'] = message - data = self.template.render('page/admin-users.haml', self, **context) + data = self.template.render('page/admin-users.haml', self.request, **context) return Response.new(data, ctype = 'html') @@ -213,7 +206,7 @@ class AdminConfig(View): } } - data = self.template.render('page/admin-config.haml', self, **context) + data = self.template.render('page/admin-config.haml', self.request, **context) return Response.new(data, ctype = 'html') @@ -251,5 +244,5 @@ class ThemeCss(View): except KeyError: return Response.new('Invalid theme', 404) - data = self.template.render('variables.css', self, **context) + data = self.template.render('variables.css', self.request, **context) return Response.new(data, ctype = 'css') From 773922e2630ad355340f5497ac5a86f899817985 Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Thu, 4 Jul 2024 22:00:54 -0400 Subject: [PATCH 14/21] remove tokens table and fix auth handling --- relay/application.py | 18 ++++++---- relay/data/statements.sql | 41 ++++----------------- relay/database/connection.py | 70 +++++++++++++----------------------- relay/database/schema.py | 30 +++------------- relay/misc.py | 7 ++-- relay/views/api.py | 39 ++------------------ relay/views/frontend.py | 4 +-- 7 files changed, 55 insertions(+), 154 deletions(-) diff --git a/relay/application.py b/relay/application.py index 6ab481b..5ee1a73 100644 --- a/relay/application.py +++ b/relay/application.py @@ -310,15 +310,21 @@ async def handle_response_headers( if request.path == "/" or request.path.startswith(TOKEN_PATHS): with app.database.session() as conn: - if (token := request.headers.get('Authorization')) is not None: - token = token.replace('Bearer', '').strip() + tokens = ( + request.headers.get('Authorization', '').replace('Bearer', '').strip(), + request.cookies.get('user-token') + ) + + for token in tokens: + if not token: + continue request['token'] = conn.get_app_by_token(token) - request['user'] = conn.get_user_by_app_token(token) - elif (token := request.cookies.get('user-token')) is not None: - request['token'] = conn.get_token(token) - request['user'] = conn.get_user_by_token(token) + if request['token'] is not None: + request['user'] = conn.get_user(request['token'].user) + + break try: resp = await handler(request) diff --git a/relay/data/statements.sql b/relay/data/statements.sql index e8694ae..0097252 100644 --- a/relay/data/statements.sql +++ b/relay/data/statements.sql @@ -50,17 +50,9 @@ WHERE username = :value or handle = :value; -- name: get-user-by-token SELECT * FROM users -WHERE username = ( - SELECT user FROM tokens - WHERE code = :code -); - - --- name: get-user-by-app-token -SELECT * FROM users WHERE username = ( SELECT user FROM app - WHERE code = :code + WHERE token = :token ); @@ -80,46 +72,25 @@ SELECT * FROM app WHERE client_id = :id and client_secret = :secret; --- name: get-app-token +-- name: get-app-with-token SELECT * FROM app WHERE client_id = :id and client_secret = :secret and token = :token; -- name: get-app-by-token -SELECT * FROM app +SELECT * FROM apps WHERE token = :token; -- name: del-app -DELETE FROM users +DELETE FROM apps WHERE client_id = :id and client_secret = :secret; --- name: del-app-token -DELETE FROM users +-- name: del-app-with-token +DELETE FROM apps WHERE client_id = :id and client_secret = :secret and token = :token; --- name: get-token -SELECT * FROM tokens -WHERE code = :code; - - --- name: put-token -INSERT INTO tokens (code, user, created) -VALUES (:code, :user, :created) -RETURNING *; - - --- name: del-token -DELETE FROM tokens -WHERE code = :code; - - --- name: del-token-user -DELETE FROM tokens -WHERE user = :username; - - -- name: get-software-ban SELECT * FROM software_bans WHERE name = :name; diff --git a/relay/database/connection.py b/relay/database/connection.py index 3c973b8..603e63a 100644 --- a/relay/database/connection.py +++ b/relay/database/connection.py @@ -9,7 +9,6 @@ from collections.abc import Iterator from datetime import datetime, timezone from typing import TYPE_CHECKING, Any from urllib.parse import urlparse -from uuid import uuid4 from . import schema from .config import ( @@ -73,10 +72,6 @@ class Connection(SqlConnection): data = {'created': sban.created.timestamp()} self.update('software_bans', data, name = sban.name) - for token in self.select('tokens').all(schema.Token): - data = {'created': token.created.timestamp(), 'accessed': token.accessed.timestamp()} - self.update('tokens', data, code = token.code) - for user in self.select('users').all(schema.User): data = {'created': user.created.timestamp()} self.update('users', data, username = user.username) @@ -230,13 +225,8 @@ class Connection(SqlConnection): return cur.one(schema.User) - def get_user_by_token(self, code: str) -> schema.User | None: - with self.run('get-user-by-token', {'code': code}) as cur: - return cur.one(schema.User) - - - def get_user_by_app_token(self, code: str) -> schema.User | None: - with self.run('get-user-by-app-token', {'code': code}) as cur: + def get_user_by_token(self, token: str) -> schema.User | None: + with self.run('get-user-by-token', {'token': token}) as cur: return cur.one(schema.User) @@ -328,13 +318,34 @@ class Connection(SqlConnection): 'accessed': Date.new_utc().timestamp() } - with self.insert('app', params) as cur: + with self.insert('apps', params) as cur: if (row := cur.one(schema.App)) is None: raise RuntimeError(f'Failed to insert app: {name}') return row + def put_app_login(self, user: schema.User) -> schema.App: + params = { + 'name': 'Web', + 'redirect_uri': 'urn:ietf:wg:oauth:2.0:oob', + 'website': None, + 'user': user.username, + 'client_id': secrets.token_hex(20), + 'client_secret': secrets.token_hex(20), + 'auth_code': None, + 'token': secrets.token_hex(20), + 'created': Date.new_utc().timestamp(), + 'accessed': Date.new_utc().timestamp() + } + + with self.insert('apps', params) as cur: + if (row := cur.one(schema.App)) is None: + raise RuntimeError(f'Failed to create app for "{user.username}"') + + return row + + def update_app(self, app: schema.App, user: schema.User | None, set_auth: bool) -> schema.App: data: dict[str, str | None] = {} @@ -367,7 +378,7 @@ class Connection(SqlConnection): } if token is not None: - command = 'del-app-token' + command = 'del-app-with-token' params['token'] = token else: @@ -380,37 +391,6 @@ class Connection(SqlConnection): return cur.row_count == 0 - def get_token(self, code: str) -> schema.Token | None: - with self.run('get-token', {'code': code}) as cur: - return cur.one(schema.Token) - - - def get_tokens(self, username: str | None = None) -> Iterator[schema.Token]: - if username is None: - return self.select('tokens').all(schema.Token) - - return self.select('tokens', username = username).all(schema.Token) - - - def put_token(self, username: str) -> schema.Token: - data = { - 'code': uuid4().hex, - 'user': username, - 'created': datetime.now(tz = timezone.utc) - } - - with self.run('put-token', data) as cur: - if (row := cur.one(schema.Token)) is None: - raise RuntimeError(f"Failed to insert token for user: {username}") - - return row - - - def del_token(self, code: str) -> None: - with self.run('del-token', {'code': code}): - pass - - def get_domain_ban(self, domain: str) -> schema.DomainBan | None: if domain.startswith('http'): domain = urlparse(domain).netloc diff --git a/relay/database/schema.py b/relay/database/schema.py index 660e527..a6016bb 100644 --- a/relay/database/schema.py +++ b/relay/database/schema.py @@ -105,23 +105,6 @@ class User(Row): ) -@TABLES.add_row -class Token(Row): - table_name: str = 'tokens' - - - code: Column[str] = Column('code', 'text', primary_key = True, unique = True, nullable = False) - user: Column[str] = Column('user', 'text', nullable = False) - created: Column[Date] = Column( - 'created', 'timestamp', nullable = False, - deserializer = deserialize_timestamp, serializer = Date.timestamp - ) - accessed: Column[Date] = Column( - 'accessed', 'timestamp', nullable = False, - deserializer = deserialize_timestamp, serializer = Date.timestamp - ) - - @TABLES.add_row class App(Row): table_name: str = 'apps' @@ -148,9 +131,8 @@ class App(Row): def get_api_data(self, include_token: bool = False) -> dict[str, Any]: data = deepcopy(self) + data.pop('user') data.pop('auth_code') - data.pop('created') - data.pop('accessed') if not include_token: data.pop('token') @@ -176,15 +158,11 @@ def migrate_20240206(conn: Connection) -> None: @migration def migrate_20240310(conn: Connection) -> None: - conn.execute('ALTER TABLE "inboxes" ADD COLUMN "accepted" BOOLEAN') - conn.execute('UPDATE "inboxes" SET accepted = 1') + conn.execute('ALTER TABLE "inboxes" ADD COLUMN "accepted" BOOLEAN').close() + conn.execute('UPDATE "inboxes" SET accepted = 1').close() @migration def migrate_20240625(conn: Connection) -> None: - conn.execute('ALTER TABLE "tokens" ADD "accessed" timestamp') - - for token in conn.get_tokens(): - conn.update('tokens', {'accessed': token.created}, code = token.code).one() - conn.create_tables() + conn.execute('DROP TABLE tokens').close() diff --git a/relay/misc.py b/relay/misc.py index b27c89a..cb35339 100644 --- a/relay/misc.py +++ b/relay/misc.py @@ -76,12 +76,11 @@ JSON_PATHS: tuple[str, ...] = ( ) TOKEN_PATHS: tuple[str, ...] = ( - '/api', - '/login', '/logout', + '/admin', + '/api', '/oauth/authorize', - '/oauth/revoke', - '/admin' + '/oauth/revoke' ) diff --git a/relay/views/api.py b/relay/views/api.py index 76cd1e5..f8fe828 100644 --- a/relay/views/api.py +++ b/relay/views/api.py @@ -1,4 +1,3 @@ -import secrets import traceback from aiohttp.web import Request, middleware @@ -209,12 +208,12 @@ class Login(View): except VerifyMismatchError: raise HttpError(401, 'Invalid password') - token = conn.put_token(data['username']) + app = conn.put_app_login(user) - resp = Response.new({'token': token.code}, ctype = 'json') + resp = Response.new({'token': app.token}, ctype = 'json') resp.set_cookie( 'user-token', - token.code, + app.token, # type: ignore[arg-type] max_age = 60 * 60 * 24 * 365, domain = self.config.domain, path = '/', @@ -226,38 +225,6 @@ class Login(View): return resp - - async def post2(self, request: Request) -> Response: - data = await self.get_api_data(['username', 'password'], []) - - with self.database.session(True) as conn: - if not (user := conn.get_user(data['username'])): - raise HttpError(401, 'User not found') - - try: - conn.hasher.verify(user['hash'], data['password']) - - except VerifyMismatchError: - raise HttpError(401, 'Invalid password') - - app = conn.put_app( - data['app_name'], - DEFAULT_REDIRECT, - data.get('website') - ) - - params = { - 'code': secrets.token_hex(20), - 'user': user.username - } - - with conn.update('app', params, client_id = app.client_id) as cur: - if (row := cur.one(schema.App)) is None: - raise HttpError(500, 'Failed to create app') - - return Response.new(row.get_api_data(True), ctype = 'json') - - @register_route('/api/v1/relay') class RelayInfo(View): async def get(self, request: Request) -> Response: diff --git a/relay/views/frontend.py b/relay/views/frontend.py index a383d20..b6dba7b 100644 --- a/relay/views/frontend.py +++ b/relay/views/frontend.py @@ -18,7 +18,7 @@ async def handle_frontend_path( if request['user'] is not None and request.path == '/login': return Response.new_redir('/') - if request.path.startswith(TOKEN_PATHS) and request['user'] is None: + if request.path.startswith(TOKEN_PATHS[:2]) and request['user'] is None: if request.path == '/logout': return Response.new_redir('/') @@ -62,7 +62,7 @@ class Login(View): class Logout(View): async def get(self, request: web.Request) -> Response: with self.database.session(True) as conn: - conn.del_token(request['token'].code) + conn.del_app(request['token'].client_id, request['token'].client_secret) resp = Response.new_redir('/') resp.del_cookie('user-token', domain = self.config.domain, path = '/') From 9f5df5f95c9f84ca01ccda5dc7ee38f482eff37a Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Thu, 4 Jul 2024 22:58:50 -0400 Subject: [PATCH 15/21] update swagger docs --- relay/application.py | 2 +- relay/data/statements.sql | 6 +- relay/data/swagger.yaml | 241 +++++++++++++++++++++++++++++------ relay/database/connection.py | 2 +- relay/views/api.py | 25 ++-- 5 files changed, 220 insertions(+), 56 deletions(-) diff --git a/relay/application.py b/relay/application.py index 5ee1a73..38c39a6 100644 --- a/relay/application.py +++ b/relay/application.py @@ -352,7 +352,7 @@ async def handle_response_headers( resp.headers['Server'] = 'ActivityRelay' # Still have to figure out how csp headers work - if resp.content_type == 'text/html': + if resp.content_type == 'text/html' and not request.path.startswith("/api"): resp.headers['Content-Security-Policy'] = get_csp(request) if not request.app['dev'] and request.path.endswith(('.css', '.js', '.woff2')): diff --git a/relay/data/statements.sql b/relay/data/statements.sql index 0097252..dde6a29 100644 --- a/relay/data/statements.sql +++ b/relay/data/statements.sql @@ -51,7 +51,7 @@ WHERE username = :value or handle = :value; -- name: get-user-by-token SELECT * FROM users WHERE username = ( - SELECT user FROM app + SELECT user FROM apps WHERE token = :token ); @@ -68,12 +68,12 @@ WHERE username = :value or handle = :value; -- name: get-app -SELECT * FROM app +SELECT * FROM apps WHERE client_id = :id and client_secret = :secret; -- name: get-app-with-token -SELECT * FROM app +SELECT * FROM apps WHERE client_id = :id and client_secret = :secret and token = :token; diff --git a/relay/data/swagger.yaml b/relay/data/swagger.yaml index a2a51dc..ac7b728 100644 --- a/relay/data/swagger.yaml +++ b/relay/data/swagger.yaml @@ -18,10 +18,12 @@ securityDefinitions: in: cookie name: user-token Bearer: - type: apiKey + type: oauth2 name: Authorization in: header - description: "Enter the token with the `Bearer ` prefix" + flow: accessCode + authorizationUrl: /oauth/authorize + tokenUrl: /oauth/token paths: /: @@ -35,6 +37,161 @@ paths: schema: $ref: "#/definitions/Error" + /oauth/authorize: + get: + tags: + - OAuth + description: Get an authorization code + parameters: + - in: query + name: response-type + required: true + type: string + - in: query + name: client_id + required: true + type: string + - in: query + name: redirect_uri + required: true + type: string + + /oauth/token: + post: + tags: + - OAuth + description: Get a token for an authorized app + parameters: + - in: formData + name: grant_type + required: true + type: string + - in: formData + name: code + required: true + type: string + - in: formData + name: client_id + required: true + type: string + - in: formData + name: client_secret + required: true + type: string + - in: formData + name: redirect_uri + required: true + type: string + consumes: + - application/x-www-form-urlencoded + - application/json + - multipart/form-data + produces: + - application/json + responses: + "200": + description: Application + schema: + $ref: "#/definitions/Application" + + /oauth/revoke: + post: + tags: + - OAuth + description: Get a token for an authorized app + parameters: + - in: formData + name: client_id + required: true + type: string + - in: formData + name: client_secret + required: true + type: string + - in: formData + name: token + required: true + type: string + consumes: + - application/json + - multipart/form-data + - application/x-www-form-urlencoded + produces: + - application/json + responses: + "200": + description: Message confirming application deletion + schema: + $ref: "#/definitions/Message" + + /v1/app: + get: + tags: + - Applications + description: Verify the token is valid + produces: + - application/json + responses: + "200": + description: Application with the associated token + schema: + $ref: "#/definitions/Application" + + post: + tags: + - Applications + description: Create a new application + parameters: + - in: query + name: name + required: true + type: string + - in: query + name: redirect_uri + required: true + type: string + - in: query + name: website + required: false + type: string + format: url + consumes: + - application/json + - multipart/form-data + - application/x-www-form-urlencoded + produces: + - application/json + responses: + "200": + description: Newly created application + schema: + $ref: "#/definitions/Application" + + delete: + tags: + - Applications + description: Deletes an application + parameters: + - in: formData + name: client_id + required: true + type: string + - in: formData + name: client_secret + required: true + type: string + consumes: + - application/json + - multipart/form-data + - application/x-www-form-urlencoded + produces: + - application/json + responses: + "200": + description: Confirmation of application deletion + schema: + $ref: "#/definitions/Message" + /v1/relay: get: tags: @@ -48,23 +205,11 @@ paths: schema: $ref: "#/definitions/Info" - /v1/token: - get: - tags: - - Token - description: Verify API token - produces: - - application/json - responses: - "200": - description: Valid token - schema: - $ref: "#/definitions/Message" - + /v1/login: post: tags: - - Token - description: Get a new token + - Login + description: Login with a username and password parameters: - in: formData name: username @@ -74,7 +219,6 @@ paths: name: password required: true type: string - format: password consumes: - application/json - multipart/form-data @@ -83,22 +227,9 @@ paths: - application/json responses: "200": - description: Created token + description: A new Application schema: - $ref: "#/definitions/Token" - - - delete: - tags: - - Token - description: Revoke a token - produces: - - application/json - responses: - "200": - description: Revoked token - schema: - $ref: "#/definitions/Message" + $ref: "#/definitions/Application" /v1/config: get: @@ -731,9 +862,43 @@ definitions: description: Human-readable message text type: string + Application: + type: object + properties: + client_id: + description: Identifier for the application + type: string + client_secret: + description: Secret string for the application + type: string + name: + description: Human-readable name of the application + type: string + website: + description: Website for the application + type: string + format: url + redirect_uri: + description: URL to redirect to when authorizing an app + type: string + token: + description: String to use in the Authorization header for client requests + type: string + created: + description: Date the application was created + type: string + format: date-time + accessed: + description: Date the application was last used + type: string + format: date-time + Config: type: object properties: + approval-required: + description: Require instances to be approved when following + type: bool log-level: description: Maximum level of log messages to print to the console type: string @@ -743,6 +908,9 @@ definitions: note: description: Blurb to display on the home page type: string + theme: + description: Name of the color scheme to use for the frontend + type: string whitelist-enabled: description: Only allow specific instances to join the relay when enabled type: boolean @@ -843,13 +1011,6 @@ definitions: type: string format: date-time - Token: - type: object - properties: - token: - description: Character string used for authenticating with the api - type: string - User: type: object properties: diff --git a/relay/database/connection.py b/relay/database/connection.py index 603e63a..e18278a 100644 --- a/relay/database/connection.py +++ b/relay/database/connection.py @@ -364,7 +364,7 @@ class Connection(SqlConnection): 'client_secret': app.client_secret } - with self.update('app', data, **params) as cur: # type: ignore[arg-type] + with self.update('apps', data, **params) as cur: # type: ignore[arg-type] if (row := cur.one(schema.App)) is None: raise RuntimeError('Failed to update row') diff --git a/relay/views/api.py b/relay/views/api.py index f8fe828..b1b820b 100644 --- a/relay/views/api.py +++ b/relay/views/api.py @@ -40,7 +40,7 @@ async def handle_api_path( request: Request, handler: Callable[[Request], Awaitable[Response]]) -> Response: - if not request.path.startswith('/api'): + if not request.path.startswith('/api') or request.path == '/api/doc': return await handler(request) if request.method != "OPTIONS" and check_api_path(request.method, request.path): @@ -58,6 +58,7 @@ async def handle_api_path( @register_route('/oauth/authorize') +@register_route('/api/oauth/authorize') class OauthAuthorize(View): async def get(self, request: Request) -> Response: data = await self.get_api_data(['response_type', 'client_id', 'redirect_uri'], []) @@ -66,11 +67,14 @@ class OauthAuthorize(View): raise HttpError(400, 'Response type is not "code"') with self.database.session(True) as conn: - with conn.select('app', client_id = data['client_id']) as cur: + with conn.select('apps', client_id = data['client_id']) as cur: if (app := cur.one(schema.App)) is None: raise HttpError(404, 'Could not find app') - if app.token is not None or app.auth_code is not None: + if app.token is not None: + raise HttpError(400, 'Application has already been authorized') + + if app.auth_code is not None: context = {'application': app} html = self.template.render( 'page/authorize_show.haml', self.request, **context @@ -96,6 +100,9 @@ class OauthAuthorize(View): return Response.new_error(404, 'Could not find app', 'json') if convert_to_boolean(data['response']): + if app.token is not None: + raise HttpError(400, 'Application has already been authorized') + if app.auth_code is None: app = conn.update_app(app, request['user'], True) @@ -116,6 +123,7 @@ class OauthAuthorize(View): @register_route('/oauth/token') +@register_route('/api/oauth/token') class OauthToken(View): async def post(self, request: Request) -> Response: data = await self.get_api_data( @@ -141,6 +149,7 @@ class OauthToken(View): @register_route('/oauth/revoke') +@register_route('/api/oauth/revoke') class OauthRevoke(View): async def post(self, request: Request) -> Response: data = await self.get_api_data(['client_id', 'client_secret', 'token'], []) @@ -161,13 +170,7 @@ class OauthRevoke(View): @register_route('/api/v1/app') class App(View): async def get(self, request: Request) -> Response: - data = await self.get_api_data(['client_id', 'client_secret'], []) - - with self.database.session(False) as conn: - if (app := conn.get_app(data['client_id'], data['client_secret'])) is None: - raise HttpError(404, 'Application cannot be found') - - return Response.new(app.get_api_data(), ctype = 'json') + return Response.new(request['token'].get_api_data(), ctype = 'json') async def post(self, request: Request) -> Response: @@ -210,7 +213,7 @@ class Login(View): app = conn.put_app_login(user) - resp = Response.new({'token': app.token}, ctype = 'json') + resp = Response.new(app.get_api_data(), ctype = 'json') resp.set_cookie( 'user-token', app.token, # type: ignore[arg-type] From baae0b46acc414e8e41ea6dcffd4b0aa1d7765bc Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Thu, 22 Aug 2024 23:25:43 -0400 Subject: [PATCH 16/21] minor styling fixes * don't allow line wrapping in tables * replace legend for follow required message with div * set correct font family --- relay/frontend/page/home.haml | 4 +--- relay/frontend/static/style.css | 8 ++++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/relay/frontend/page/home.haml b/relay/frontend/page/home.haml index 7db7551..1de9b14 100644 --- a/relay/frontend/page/home.haml +++ b/relay/frontend/page/home.haml @@ -15,9 +15,7 @@ %a(href="https://{{domain}}/actor") << https://{{domain}}/actor -if config.approval_required - %fieldset.section.message - %legend << Require Approval - + %div.section.message Follow requests require approval. You will need to wait for an admin to accept or deny your request. diff --git a/relay/frontend/static/style.css b/relay/frontend/static/style.css index c9bcd43..ac4eaf5 100644 --- a/relay/frontend/static/style.css +++ b/relay/frontend/static/style.css @@ -12,7 +12,7 @@ body { color: var(--text); background-color: #222; margin: var(--spacing); - font-family: sans serif; + font-family: sans-serif; } details *:nth-child(2) { @@ -88,6 +88,7 @@ tbody tr:last-child td:last-child { table td { padding: 5px; + white-space: nowrap; } table thead td { @@ -282,8 +283,11 @@ textarea { width: 100%; } -.data-table .date { +.data-table td:not(:first-child) { width: max-content; +} + +.data-table .date { text-align: right; } From de190fcdd3b81159795a7f5d8e917390663c4cf6 Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Thu, 22 Aug 2024 23:40:39 -0400 Subject: [PATCH 17/21] update dev dependencies and barkshark-lib --- pyproject.toml | 8 ++++---- relay/http_client.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b1249d5..05a9b15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ dependencies = [ "aiohttp >= 3.9.5", "aiohttp-swagger[performance] == 1.0.16", "argon2-cffi == 23.1.0", - "barkshark-lib >= 0.1.5rc1, < 0.2.0", + "barkshark-lib >= 0.1.6, < 0.2.0", "barkshark-sql >= 0.2.0rc2, < 0.3.0", "click == 8.1.2", "hiredis == 2.3.2", @@ -47,9 +47,9 @@ activityrelay = "relay.manage:main" [project.optional-dependencies] dev = [ "flake8 == 7.1.0", - "mypy == 1.10.1", - "pyinstaller == 6.8.0", - "watchdog == 4.0.1" + "mypy == 1.11.1", + "pyinstaller == 6.10.0", + "watchdog == 4.0.2" ] [tool.setuptools] diff --git a/relay/http_client.py b/relay/http_client.py index 05a6565..e5b58f6 100644 --- a/relay/http_client.py +++ b/relay/http_client.py @@ -146,7 +146,7 @@ class HttpClient: @overload - async def get(self, # type: ignore[overload-overlap] + async def get(self, url: str, sign_headers: bool, cls: None = None, From 1516f27b76847cb2c6ff3b647db4a499cfa71a7d Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Fri, 23 Aug 2024 00:35:00 -0400 Subject: [PATCH 18/21] properly convert Date objects for sqlite databases --- relay/database/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/relay/database/__init__.py b/relay/database/__init__.py index 545f822..03198ab 100644 --- a/relay/database/__init__.py +++ b/relay/database/__init__.py @@ -1,3 +1,6 @@ +import sqlite3 + +from blib import Date from bsql import Database from .config import THEMES, ConfigData @@ -9,6 +12,9 @@ from ..config import Config from ..misc import get_resource +sqlite3.register_adapter(Date, Date.timestamp) + + def get_database(config: Config, migrate: bool = True) -> Database[Connection]: options = { 'connection_class': Connection, From 4203355d7df7ba87361951ef1ff4b6412d2c9860 Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Fri, 23 Aug 2024 01:13:32 -0400 Subject: [PATCH 19/21] better error message when trying to set invalid config key --- relay/manage.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/relay/manage.py b/relay/manage.py index 5ae8238..d2525a0 100644 --- a/relay/manage.py +++ b/relay/manage.py @@ -362,10 +362,15 @@ def cli_config_list(ctx: click.Context) -> None: def cli_config_set(ctx: click.Context, key: str, value: Any) -> None: 'Set a config value' - with ctx.obj.database.session() as conn: - new_value = conn.put_config(key, value) + try: + with ctx.obj.database.session() as conn: + new_value = conn.put_config(key, value) - print(f'{key}: {repr(new_value)}') + except: + click.echo('Invalid config name:', key) + return + + click.echo(f'{key}: {repr(new_value)}') @cli.group('user') From c9598ff273e2abc3b53a2dccda44de005c602848 Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Fri, 23 Aug 2024 01:13:55 -0400 Subject: [PATCH 20/21] update barkshark-lib and activitypub-utils --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 05a9b15..95a17a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,11 +14,11 @@ classifiers = [ "Programming Language :: Python :: 3.12" ] dependencies = [ - "activitypub-utils >= 0.3.1.post1, < 0.4.0", + "activitypub-utils >= 0.3.2, < 0.4", "aiohttp >= 3.9.5", "aiohttp-swagger[performance] == 1.0.16", "argon2-cffi == 23.1.0", - "barkshark-lib >= 0.1.6, < 0.2.0", + "barkshark-lib >= 0.2.2.post2, < 0.3.0", "barkshark-sql >= 0.2.0rc2, < 0.3.0", "click == 8.1.2", "hiredis == 2.3.2", From 85825f6de1b1edba19a41122a4104dc99f7e43ce Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Fri, 23 Aug 2024 01:37:15 -0400 Subject: [PATCH 21/21] don't use cached nodeinfo data when instance follows --- relay/http_client.py | 8 +++----- relay/processors.py | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/relay/http_client.py b/relay/http_client.py index e5b58f6..26f2ba8 100644 --- a/relay/http_client.py +++ b/relay/http_client.py @@ -230,12 +230,10 @@ class HttpClient: return - async def fetch_nodeinfo(self, domain: str) -> Nodeinfo: + async def fetch_nodeinfo(self, domain: str, force: bool = False) -> Nodeinfo: nodeinfo_url = None wk_nodeinfo = await self.get( - f'https://{domain}/.well-known/nodeinfo', - False, - WellKnownNodeinfo + f'https://{domain}/.well-known/nodeinfo', False, WellKnownNodeinfo, force ) for version in ('20', '21'): @@ -248,7 +246,7 @@ class HttpClient: if nodeinfo_url is None: raise ValueError(f'Failed to fetch nodeinfo url for {domain}') - return await self.get(nodeinfo_url, False, Nodeinfo) + return await self.get(nodeinfo_url, False, Nodeinfo, force) async def get(*args: Any, **kwargs: Any) -> Any: diff --git a/relay/processors.py b/relay/processors.py index 4e4d96f..57e9222 100644 --- a/relay/processors.py +++ b/relay/processors.py @@ -58,7 +58,7 @@ async def handle_forward(view: ActorView, conn: Connection) -> None: async def handle_follow(view: ActorView, conn: Connection) -> None: - nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain) + nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain, force = True) software = nodeinfo.sw_name if nodeinfo else None config = conn.get_config_all()