Merge branch 'dev' into 'main'

version 0.3.3

See merge request pleroma/relay!59
This commit is contained in:
Izalia Mae 2024-09-21 03:21:13 +00:00
commit 49e77ee730
47 changed files with 2545 additions and 1646 deletions

47
dev.py
View file

@ -1,25 +1,38 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import click
import platform import platform
import shutil import shutil
import subprocess import subprocess
import sys import sys
import time import time
import tomllib
from datetime import datetime, timedelta from datetime import datetime, timedelta
from importlib.util import find_spec
from pathlib import Path from pathlib import Path
from relay import __version__, logger as logging
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Any, Sequence from typing import Any, Sequence
try: try:
from watchdog.observers import Observer import tomllib
from watchdog.events import FileSystemEvent, PatternMatchingEventHandler
except ImportError: except ImportError:
class PatternMatchingEventHandler: # type: ignore if find_spec("toml") is None:
pass subprocess.run([sys.executable, "-m", "pip", "install", "toml"])
import toml as tomllib # type: ignore[no-redef]
if None in [find_spec("click"), find_spec("watchdog")]:
CMD = [sys.executable, "-m", "pip", "install", "click >= 8.1.0", "watchdog >= 4.0.0"]
PROC = subprocess.run(CMD, check = False)
if PROC.returncode != 0:
sys.exit()
print("Successfully installed dependencies")
import click
from watchdog.observers import Observer
from watchdog.events import FileSystemEvent, PatternMatchingEventHandler
REPO = Path(__file__).parent REPO = Path(__file__).parent
@ -37,12 +50,10 @@ def cli() -> None:
@cli.command('install') @cli.command('install')
@click.option('--no-dev', '-d', is_flag = True, help = 'Do not install development dependencies') @click.option('--no-dev', '-d', is_flag = True, help = 'Do not install development dependencies')
def cli_install(no_dev: bool) -> None: def cli_install(no_dev: bool) -> None:
with open('pyproject.toml', 'rb') as fd: with open('pyproject.toml', 'r', encoding = 'utf-8') as fd:
data = tomllib.load(fd) data = tomllib.loads(fd.read())
deps = data['project']['dependencies'] deps = data['project']['dependencies']
if not no_dev:
deps.extend(data['project']['optional-dependencies']['dev']) deps.extend(data['project']['optional-dependencies']['dev'])
subprocess.run([sys.executable, '-m', 'pip', 'install', '-U', *deps], check = False) subprocess.run([sys.executable, '-m', 'pip', 'install', '-U', *deps], check = False)
@ -60,7 +71,7 @@ def cli_lint(path: Path, watch: bool) -> None:
return return
flake8 = [sys.executable, '-m', 'flake8', "dev.py", str(path)] flake8 = [sys.executable, '-m', 'flake8', "dev.py", str(path)]
mypy = [sys.executable, '-m', 'mypy', "dev.py", str(path)] mypy = [sys.executable, '-m', 'mypy', '--python-version', '3.12', 'dev.py', str(path)]
click.echo('----- flake8 -----') click.echo('----- flake8 -----')
subprocess.run(flake8) subprocess.run(flake8)
@ -89,6 +100,8 @@ def cli_clean() -> None:
@cli.command('build') @cli.command('build')
def cli_build() -> None: def cli_build() -> None:
from relay import __version__
with TemporaryDirectory() as tmp: with TemporaryDirectory() as tmp:
arch = 'amd64' if sys.maxsize >= 2**32 else 'i386' arch = 'amd64' if sys.maxsize >= 2**32 else 'i386'
cmd = [ cmd = [
@ -171,7 +184,7 @@ class WatchHandler(PatternMatchingEventHandler):
if proc.poll() is not None: if proc.poll() is not None:
continue continue
logging.info(f'Terminating process {proc.pid}') print(f'Terminating process {proc.pid}')
proc.terminate() proc.terminate()
sec = 0.0 sec = 0.0
@ -180,11 +193,11 @@ class WatchHandler(PatternMatchingEventHandler):
sec += 0.1 sec += 0.1
if sec >= 5: if sec >= 5:
logging.error('Failed to terminate. Killing process...') print('Failed to terminate. Killing process...')
proc.kill() proc.kill()
break break
logging.info('Process terminated') print('Process terminated')
def run_procs(self, restart: bool = False) -> None: def run_procs(self, restart: bool = False) -> None:
@ -200,13 +213,13 @@ class WatchHandler(PatternMatchingEventHandler):
self.procs = [] self.procs = []
for cmd in self.commands: for cmd in self.commands:
logging.info('Running command: %s', ' '.join(cmd)) print('Running command:', ' '.join(cmd))
subprocess.run(cmd) subprocess.run(cmd)
else: else:
self.procs = list(subprocess.Popen(cmd) for cmd in self.commands) self.procs = list(subprocess.Popen(cmd) for cmd in self.commands)
pids = (str(proc.pid) for proc in self.procs) pids = (str(proc.pid) for proc in self.procs)
logging.info('Started processes with PIDs: %s', ', '.join(pids)) print('Started processes with PIDs:', ', '.join(pids))
def on_any_event(self, event: FileSystemEvent) -> None: def on_any_event(self, event: FileSystemEvent) -> None:

View file

@ -9,30 +9,27 @@ license = {text = "AGPLv3"}
classifiers = [ classifiers = [
"Environment :: Console", "Environment :: Console",
"License :: OSI Approved :: GNU Affero General Public License v3", "License :: OSI Approved :: GNU Affero General Public License v3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.12"
] ]
dependencies = [ dependencies = [
"activitypub-utils == 0.3.1", "activitypub-utils >= 0.3.2, < 0.4",
"aiohttp >= 3.9.5", "aiohttp >= 3.9.5",
"aiohttp-swagger[performance] == 1.0.16", "aiohttp-swagger[performance] == 1.0.16",
"argon2-cffi == 23.1.0", "argon2-cffi == 23.1.0",
"barkshark-lib >= 0.1.3-1", "barkshark-lib >= 0.2.3, < 0.3.0",
"barkshark-sql == 0.1.4-1", "barkshark-sql >= 0.2.0, < 0.3.0",
"click >= 8.1.2", "click == 8.1.2",
"hiredis == 2.3.2", "hiredis == 2.3.2",
"idna == 3.4", "idna == 3.4",
"jinja2-haml == 0.3.5", "jinja2-haml == 0.3.5",
"markdown == 3.6", "markdown == 3.6",
"platformdirs == 4.2.2", "platformdirs == 4.2.2",
"pyyaml >= 6.0", "pyyaml == 6.0.1",
"redis == 5.0.5", "redis == 5.0.7"
"importlib-resources == 6.4.0; python_version < '3.9'"
] ]
requires-python = ">=3.8" requires-python = ">=3.10"
dynamic = ["version"] dynamic = ["version"]
[project.readme] [project.readme]
@ -49,11 +46,10 @@ activityrelay = "relay.manage:main"
[project.optional-dependencies] [project.optional-dependencies]
dev = [ dev = [
"flake8 == 7.0.0", "flake8 == 7.1.0",
"mypy == 1.10.0", "mypy == 1.11.1",
"pyinstaller == 6.8.0", "pyinstaller == 6.10.0",
"watchdog == 4.0.1", "watchdog == 4.0.2"
"typing-extensions >= 4.12.2; python_version < '3.11.0'"
] ]
[tool.setuptools] [tool.setuptools]
@ -104,7 +100,3 @@ implicit_reexport = true
[[tool.mypy.overrides]] [[tool.mypy.overrides]]
module = "blib" module = "blib"
implicit_reexport = true implicit_reexport = true
[[tool.mypy.overrides]]
module = "bsql"
implicit_reexport = true

View file

@ -1 +1 @@
__version__ = '0.3.2' __version__ = '0.3.3'

View file

@ -6,29 +6,33 @@ import signal
import time import time
import traceback import traceback
from Crypto.Random import get_random_bytes
from aiohttp import web from aiohttp import web
from aiohttp.web import StaticResource from aiohttp.web import HTTPException, StaticResource
from aiohttp_swagger import setup_swagger from aiohttp_swagger import setup_swagger
from aputils.signer import Signer from aputils.signer import Signer
from bsql import Database, Row from base64 import b64encode
from blib import File, HttpError, port_check
from bsql import Database
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta from datetime import datetime, timedelta
from mimetypes import guess_type from mimetypes import guess_type
from pathlib import Path from pathlib import Path
from queue import Empty
from threading import Event, Thread from threading import Event, Thread
from typing import Any from typing import Any, cast
from . import logger as logging from . import logger as logging
from .cache import Cache, get_cache from .cache import Cache, get_cache
from .config import Config from .config import Config
from .database import Connection, get_database from .database import Connection, get_database
from .database.schema import Instance
from .http_client import HttpClient from .http_client import HttpClient
from .misc import IS_WINDOWS, Message, Response, check_open_port, get_resource from .misc import JSON_PATHS, TOKEN_PATHS, Message, Response
from .template import Template from .template import Template
from .views import VIEWS from .views import VIEWS
from .views.api import handle_api_path from .views.api import handle_api_path
from .views.frontend import handle_frontend_path from .views.frontend import handle_frontend_path
from .workers import PushWorkers
def get_csp(request: web.Request) -> str: def get_csp(request: web.Request) -> str:
@ -54,9 +58,9 @@ class Application(web.Application):
def __init__(self, cfgpath: Path | None, dev: bool = False): def __init__(self, cfgpath: Path | None, dev: bool = False):
web.Application.__init__(self, web.Application.__init__(self,
middlewares = [ middlewares = [
handle_api_path, # type: ignore[list-item] handle_response_headers, # type: ignore[list-item]
handle_frontend_path, # type: ignore[list-item] handle_frontend_path, # type: ignore[list-item]
handle_response_headers # type: ignore[list-item] handle_api_path # type: ignore[list-item]
] ]
) )
@ -75,7 +79,7 @@ class Application(web.Application):
self['cache'].setup() self['cache'].setup()
self['template'] = Template(self) self['template'] = Template(self)
self['push_queue'] = multiprocessing.Queue() self['push_queue'] = multiprocessing.Queue()
self['workers'] = [] self['workers'] = PushWorkers(self.config.workers)
self.cache.setup() self.cache.setup()
self.on_cleanup.append(handle_cleanup) # type: ignore self.on_cleanup.append(handle_cleanup) # type: ignore
@ -86,33 +90,33 @@ class Application(web.Application):
setup_swagger( setup_swagger(
self, self,
ui_version = 3, ui_version = 3,
swagger_from_file = get_resource('data/swagger.yaml') swagger_from_file = File.from_resource('relay', 'data/swagger.yaml')
) )
@property @property
def cache(self) -> Cache: def cache(self) -> Cache:
return self['cache'] # type: ignore[no-any-return] return cast(Cache, self['cache'])
@property @property
def client(self) -> HttpClient: def client(self) -> HttpClient:
return self['client'] # type: ignore[no-any-return] return cast(HttpClient, self['client'])
@property @property
def config(self) -> Config: def config(self) -> Config:
return self['config'] # type: ignore[no-any-return] return cast(Config, self['config'])
@property @property
def database(self) -> Database[Connection]: def database(self) -> Database[Connection]:
return self['database'] # type: ignore[no-any-return] return cast(Database[Connection], self['database'])
@property @property
def signer(self) -> Signer: def signer(self) -> Signer:
return self['signer'] # type: ignore[no-any-return] return cast(Signer, self['signer'])
@signer.setter @signer.setter
@ -126,7 +130,7 @@ class Application(web.Application):
@property @property
def template(self) -> Template: def template(self) -> Template:
return self['template'] # type: ignore[no-any-return] return cast(Template, self['template'])
@property @property
@ -139,16 +143,23 @@ class Application(web.Application):
return timedelta(seconds=uptime.seconds) return timedelta(seconds=uptime.seconds)
def push_message(self, inbox: str, message: Message, instance: Row) -> None: @property
self['push_queue'].put((inbox, message, instance)) def workers(self) -> PushWorkers:
return cast(PushWorkers, self['workers'])
def push_message(self, inbox: str, message: Message, instance: Instance) -> None:
self['workers'].push_message(inbox, message, instance)
def register_static_routes(self) -> None: def register_static_routes(self) -> None:
if self['dev']: if self['dev']:
static = StaticResource('/static', get_resource('frontend/static')) static = StaticResource('/static', File.from_resource('relay', 'frontend/static'))
else: else:
static = CachedStaticResource('/static', get_resource('frontend/static')) static = CachedStaticResource(
'/static', Path(File.from_resource('relay', 'frontend/static'))
)
self.router.register_resource(static) self.router.register_resource(static)
@ -161,7 +172,7 @@ class Application(web.Application):
host = self.config.listen host = self.config.listen
port = self.config.port port = self.config.port
if not check_open_port(host, port): if port_check(port, '127.0.0.1' if host == '0.0.0.0' else host):
logging.error(f'A server is already running on {host}:{port}') logging.error(f'A server is already running on {host}:{port}')
return return
@ -195,12 +206,7 @@ class Application(web.Application):
self['cache'].setup() self['cache'].setup()
self['cleanup_thread'] = CacheCleanupThread(self) self['cleanup_thread'] = CacheCleanupThread(self)
self['cleanup_thread'].start() self['cleanup_thread'].start()
self['workers'].start()
for _ in range(self.config.workers):
worker = PushWorker(self['push_queue'])
worker.start()
self['workers'].append(worker)
runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"') runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"')
await runner.setup() await runner.setup()
@ -220,15 +226,13 @@ class Application(web.Application):
await site.stop() await site.stop()
for worker in self['workers']: self['workers'].stop()
worker.stop()
self.set_signal_handler(False) self.set_signal_handler(False)
self['starttime'] = None self['starttime'] = None
self['running'] = False self['running'] = False
self['cleanup_thread'].stop() self['cleanup_thread'].stop()
self['workers'].clear()
self['database'].disconnect() self['database'].disconnect()
self['cache'].close() self['cache'].close()
@ -290,56 +294,15 @@ class CacheCleanupThread(Thread):
self.running.clear() self.running.clear()
class PushWorker(multiprocessing.Process): def format_error(request: web.Request, error: HttpError) -> Response:
def __init__(self, queue: multiprocessing.Queue[tuple[str, Message, Row]]) -> None: app: Application = request.app # type: ignore[assignment]
if Application.DEFAULT is None:
raise RuntimeError('Application not setup yet')
multiprocessing.Process.__init__(self) if request.path.startswith(JSON_PATHS) or 'json' in request.headers.get('accept', ''):
return Response.new({'error': error.message}, error.status, ctype = 'json')
self.queue = queue
self.shutdown = multiprocessing.Event()
self.path = Application.DEFAULT.config.path
def stop(self) -> None:
self.shutdown.set()
def run(self) -> None:
asyncio.run(self.handle_queue())
async def handle_queue(self) -> None:
if IS_WINDOWS:
app = Application(self.path)
client = app.client
client.open()
app.database.connect()
app.cache.setup()
else: else:
client = HttpClient() body = app.template.render('page/error.haml', request, e = error)
client.open() return Response.new(body, error.status, ctype = 'html')
while not self.shutdown.is_set():
try:
inbox, message, instance = self.queue.get(block=True, timeout=0.1)
asyncio.create_task(client.post(inbox, message, instance))
except Empty:
await asyncio.sleep(0)
# make sure an exception doesn't bring down the worker
except Exception:
traceback.print_exc()
if IS_WINDOWS:
app.database.disconnect()
app.cache.close()
await client.close()
@web.middleware @web.middleware
@ -347,14 +310,60 @@ async def handle_response_headers(
request: web.Request, request: web.Request,
handler: Callable[[web.Request], Awaitable[Response]]) -> Response: handler: Callable[[web.Request], Awaitable[Response]]) -> Response:
request['hash'] = b64encode(get_random_bytes(16)).decode('ascii')
request['token'] = None
request['user'] = None
app: Application = request.app # type: ignore[assignment]
if request.path == "/" or request.path.startswith(TOKEN_PATHS):
with app.database.session() as conn:
tokens = (
request.headers.get('Authorization', '').replace('Bearer', '').strip(),
request.cookies.get('user-token')
)
for token in tokens:
if not token:
continue
request['token'] = conn.get_app_by_token(token)
if request['token'] is not None:
request['user'] = conn.get_user(request['token'].user)
break
try:
resp = await handler(request) resp = await handler(request)
except HttpError as e:
resp = format_error(request, e)
except HTTPException as e:
if e.status == 404:
try:
text = (e.text or "").split(":")[1].strip()
except IndexError:
text = e.text or ""
resp = format_error(request, HttpError(e.status, text))
else:
raise
except Exception:
resp = format_error(request, HttpError(500, 'Internal server error'))
traceback.print_exc()
resp.headers['Server'] = 'ActivityRelay' resp.headers['Server'] = 'ActivityRelay'
# Still have to figure out how csp headers work # Still have to figure out how csp headers work
if resp.content_type == 'text/html' and not request.path.startswith("/api"): if resp.content_type == 'text/html' and not request.path.startswith("/api"):
resp.headers['Content-Security-Policy'] = get_csp(request) resp.headers['Content-Security-Policy'] = get_csp(request)
if not request.app['dev'] and request.path.endswith(('.css', '.js')): if not request.app['dev'] and request.path.endswith(('.css', '.js', '.woff2')):
# cache for 2 weeks # cache for 2 weeks
resp.headers['Cache-Control'] = 'public,max-age=1209600,immutable' resp.headers['Cache-Control'] = 'public,max-age=1209600,immutable'

View file

@ -4,15 +4,16 @@ import json
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from bsql import Database from blib import Date, convert_to_boolean
from bsql import Database, Row
from collections.abc import Callable, Iterator from collections.abc import Callable, Iterator
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from datetime import datetime, timedelta, timezone from datetime import timedelta, timezone
from redis import Redis from redis import Redis
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, TypedDict
from .database import Connection, get_database from .database import Connection, get_database
from .misc import Message, boolean from .misc import Message
if TYPE_CHECKING: if TYPE_CHECKING:
from .application import Application from .application import Application
@ -25,12 +26,20 @@ BACKENDS: dict[str, type[Cache]] = {}
CONVERTERS: dict[str, tuple[SerializerCallback, DeserializerCallback]] = { CONVERTERS: dict[str, tuple[SerializerCallback, DeserializerCallback]] = {
'str': (str, str), 'str': (str, str),
'int': (str, int), 'int': (str, int),
'bool': (str, boolean), 'bool': (str, convert_to_boolean),
'json': (json.dumps, json.loads), 'json': (json.dumps, json.loads),
'message': (lambda x: x.to_json(), Message.parse) 'message': (lambda x: x.to_json(), Message.parse)
} }
class RedisConnectType(TypedDict):
client_name: str
decode_responses: bool
username: str | None
password: str | None
db: int
def get_cache(app: Application) -> Cache: def get_cache(app: Application) -> Cache:
return BACKENDS[app.config.ca_type](app) return BACKENDS[app.config.ca_type](app)
@ -57,12 +66,14 @@ class Item:
key: str key: str
value: Any value: Any
value_type: str value_type: str
updated: datetime updated: Date
def __post_init__(self) -> None: def __post_init__(self) -> None:
if isinstance(self.updated, str): # type: ignore[unreachable] self.updated = Date.parse(self.updated)
self.updated = datetime.fromisoformat(self.updated) # type: ignore[unreachable]
if self.updated.tzinfo is None:
self.updated = self.updated.replace(tzinfo = timezone.utc)
@classmethod @classmethod
@ -70,15 +81,11 @@ class Item:
data = cls(*args) data = cls(*args)
data.value = deserialize_value(data.value, data.value_type) data.value = deserialize_value(data.value, data.value_type)
if not isinstance(data.updated, datetime):
data.updated = datetime.fromtimestamp(data.updated, tz = timezone.utc) # type: ignore
return data return data
def older_than(self, hours: int) -> bool: def older_than(self, hours: int) -> bool:
delta = datetime.now(tz = timezone.utc) - self.updated return self.updated + timedelta(hours = hours) < Date.new_utc()
return (delta.total_seconds()) > hours * 3600
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
@ -172,7 +179,7 @@ class SqlCache(Cache):
with self._db.session(False) as conn: with self._db.session(False) as conn:
with conn.run('get-cache-item', params) as cur: with conn.run('get-cache-item', params) as cur:
if not (row := cur.one()): if not (row := cur.one(Row)):
raise KeyError(f'{namespace}:{key}') raise KeyError(f'{namespace}:{key}')
row.pop('id', None) row.pop('id', None)
@ -206,14 +213,16 @@ class SqlCache(Cache):
'key': key, 'key': key,
'value': serialize_value(value, value_type), 'value': serialize_value(value, value_type),
'type': value_type, 'type': value_type,
'date': datetime.now(tz = timezone.utc) 'date': Date.new_utc()
} }
with self._db.session(True) as conn: with self._db.session(True) as conn:
with conn.run('set-cache-item', params) as cur: with conn.run('set-cache-item', params) as cur:
row = cur.one() if (row := cur.one(Row)) is None:
row.pop('id', None) # type: ignore[union-attr] raise RuntimeError("Cache item not set")
return Item.from_data(*tuple(row.values())) # type: ignore[union-attr]
row.pop('id', None)
return Item.from_data(*tuple(row.values()))
def delete(self, namespace: str, key: str) -> None: def delete(self, namespace: str, key: str) -> None:
@ -234,11 +243,10 @@ class SqlCache(Cache):
if self._db is None: if self._db is None:
raise RuntimeError("Database has not been setup") raise RuntimeError("Database has not been setup")
limit = datetime.now(tz = timezone.utc) - timedelta(days = days) date = Date.new_utc() - timedelta(days = days)
params = {"limit": limit.timestamp()}
with self._db.session(True) as conn: with self._db.session(True) as conn:
with conn.execute("DELETE FROM cache WHERE updated < :limit", params): with conn.execute("DELETE FROM cache WHERE updated < :limit", {"limit": date}):
pass pass
@ -278,7 +286,7 @@ class RedisCache(Cache):
def __init__(self, app: Application): def __init__(self, app: Application):
Cache.__init__(self, app) Cache.__init__(self, app)
self._rd: Redis = None # type: ignore self._rd: Redis | None = None
@property @property
@ -291,28 +299,38 @@ class RedisCache(Cache):
def get(self, namespace: str, key: str) -> Item: def get(self, namespace: str, key: str) -> Item:
if self._rd is None:
raise ConnectionError("Not connected")
key_name = self.get_key_name(namespace, key) key_name = self.get_key_name(namespace, key)
if not (raw_value := self._rd.get(key_name)): if not (raw_value := self._rd.get(key_name)):
raise KeyError(f'{namespace}:{key}') raise KeyError(f'{namespace}:{key}')
value_type, updated, value = raw_value.split(':', 2) # type: ignore value_type, updated, value = raw_value.split(':', 2) # type: ignore[union-attr]
return Item.from_data( return Item.from_data(
namespace, namespace,
key, key,
value, value,
value_type, value_type,
datetime.fromtimestamp(float(updated), tz = timezone.utc) Date.parse(float(updated))
) )
def get_keys(self, namespace: str) -> Iterator[str]: def get_keys(self, namespace: str) -> Iterator[str]:
if self._rd is None:
raise ConnectionError("Not connected")
for key in self._rd.scan_iter(self.get_key_name(namespace, '*')): for key in self._rd.scan_iter(self.get_key_name(namespace, '*')):
*_, key_name = key.split(':', 2) *_, key_name = key.split(':', 2)
yield key_name yield key_name
def get_namespaces(self) -> Iterator[str]: def get_namespaces(self) -> Iterator[str]:
if self._rd is None:
raise ConnectionError("Not connected")
namespaces = [] namespaces = []
for key in self._rd.scan_iter(f'{self.prefix}:*'): for key in self._rd.scan_iter(f'{self.prefix}:*'):
@ -324,7 +342,10 @@ class RedisCache(Cache):
def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> Item: def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> Item:
date = datetime.now(tz = timezone.utc).timestamp() if self._rd is None:
raise ConnectionError("Not connected")
date = Date.new_utc().timestamp()
value = serialize_value(value, value_type) value = serialize_value(value, value_type)
self._rd.set( self._rd.set(
@ -336,11 +357,17 @@ class RedisCache(Cache):
def delete(self, namespace: str, key: str) -> None: def delete(self, namespace: str, key: str) -> None:
if self._rd is None:
raise ConnectionError("Not connected")
self._rd.delete(self.get_key_name(namespace, key)) self._rd.delete(self.get_key_name(namespace, key))
def delete_old(self, days: int = 14) -> None: def delete_old(self, days: int = 14) -> None:
limit = datetime.now(tz = timezone.utc) - timedelta(days = days) if self._rd is None:
raise ConnectionError("Not connected")
limit = Date.new_utc() - timedelta(days = days)
for full_key in self._rd.scan_iter(f'{self.prefix}:*'): for full_key in self._rd.scan_iter(f'{self.prefix}:*'):
_, namespace, key = full_key.split(':', 2) _, namespace, key = full_key.split(':', 2)
@ -351,14 +378,17 @@ class RedisCache(Cache):
def clear(self) -> None: def clear(self) -> None:
if self._rd is None:
raise ConnectionError("Not connected")
self._rd.delete(f"{self.prefix}:*") self._rd.delete(f"{self.prefix}:*")
def setup(self) -> None: def setup(self) -> None:
if self._rd: if self._rd is not None:
return return
options = { options: RedisConnectType = {
'client_name': f'ActivityRelay_{self.app.config.domain}', 'client_name': f'ActivityRelay_{self.app.config.domain}',
'decode_responses': True, 'decode_responses': True,
'username': self.app.config.rd_user, 'username': self.app.config.rd_user,
@ -367,18 +397,22 @@ class RedisCache(Cache):
} }
if os.path.exists(self.app.config.rd_host): if os.path.exists(self.app.config.rd_host):
options['unix_socket_path'] = self.app.config.rd_host self._rd = Redis(
unix_socket_path = self.app.config.rd_host,
**options
)
return
else: self._rd = Redis(
options['host'] = self.app.config.rd_host host = self.app.config.rd_host,
options['port'] = self.app.config.rd_port port = self.app.config.rd_port,
**options
self._rd = Redis(**options) # type: ignore )
def close(self) -> None: def close(self) -> None:
if not self._rd: if not self._rd:
return return
self._rd.close() # type: ignore self._rd.close() # type: ignore[no-untyped-call]
self._rd = None # type: ignore self._rd = None

View file

@ -2,13 +2,12 @@ import json
import os import os
import yaml import yaml
from blib import convert_to_boolean
from functools import cached_property from functools import cached_property
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from urllib.parse import urlparse from urllib.parse import urlparse
from .misc import boolean
class RelayConfig(dict[str, Any]): class RelayConfig(dict[str, Any]):
def __init__(self, path: str): def __init__(self, path: str):
@ -31,7 +30,7 @@ class RelayConfig(dict[str, Any]):
elif key == 'whitelist_enabled': elif key == 'whitelist_enabled':
if not isinstance(value, bool): if not isinstance(value, bool):
value = boolean(value) value = convert_to_boolean(value)
super().__setitem__(key, value) super().__setitem__(key, value)

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import getpass import getpass
import os import os
import platform import platform
@ -6,16 +8,13 @@ import yaml
from dataclasses import asdict, dataclass, fields from dataclasses import asdict, dataclass, fields
from pathlib import Path from pathlib import Path
from platformdirs import user_config_dir from platformdirs import user_config_dir
from typing import Any from typing import TYPE_CHECKING, Any
from .misc import IS_DOCKER from .misc import IS_DOCKER
try: if TYPE_CHECKING:
from typing import Self from typing import Self
except ImportError:
from typing_extensions import Self
if platform.system() == 'Windows': if platform.system() == 'Windows':
import multiprocessing import multiprocessing
@ -61,7 +60,7 @@ class Config:
def __init__(self, path: Path | None = None, load: bool = False): def __init__(self, path: Path | None = None, load: bool = False):
self.path = Config.get_config_dir(path) self.path: Path = Config.get_config_dir(path)
self.reset() self.reset()
if load: if load:
@ -81,7 +80,7 @@ class Config:
def DEFAULT(cls: type[Self], key: str) -> str | int | None: def DEFAULT(cls: type[Self], key: str) -> str | int | None:
for field in fields(cls): for field in fields(cls):
if field.name == key: if field.name == key:
return field.default # type: ignore return field.default # type: ignore[return-value]
raise KeyError(key) raise KeyError(key)
@ -146,7 +145,7 @@ class Config:
if not config: if not config:
raise ValueError('Config is empty') raise ValueError('Config is empty')
pgcfg = config.get('postgresql', {}) pgcfg = config.get('postgres', {})
rdcfg = config.get('redis', {}) rdcfg = config.get('redis', {})
for key in type(self).KEYS(): for key in type(self).KEYS():

View file

@ -40,7 +40,7 @@ WHERE domain = :value or inbox = :value or actor = :value;
-- name: get-request -- name: get-request
SELECT * FROM inboxes WHERE accepted = 0 and domain = :domain; SELECT * FROM inboxes WHERE accepted = false and domain = :domain;
-- name: get-user -- name: get-user
@ -51,8 +51,8 @@ WHERE username = :value or handle = :value;
-- name: get-user-by-token -- name: get-user-by-token
SELECT * FROM users SELECT * FROM users
WHERE username = ( WHERE username = (
SELECT user FROM tokens SELECT user FROM apps
WHERE code = :code WHERE token = :token
); );
@ -64,28 +64,35 @@ RETURNING *;
-- name: del-user -- name: del-user
DELETE FROM users DELETE FROM users
WHERE username = :value or handle = :value; WHERE username = :username or handle = :username;
-- name: get-token -- name: get-app
SELECT * FROM tokens SELECT * FROM apps
WHERE code = :code; WHERE client_id = :id and client_secret = :secret;
-- name: put-token -- name: get-app-with-token
INSERT INTO tokens (code, user, created) SELECT * FROM apps
VALUES (:code, :user, :created) WHERE client_id = :id and client_secret = :secret and token = :token;
RETURNING *;
-- name: del-token -- name: get-app-by-token
DELETE FROM tokens SELECT * FROM apps
WHERE code = :code; WHERE token = :token;
-- name: del-app
DELETE FROM apps
WHERE client_id = :id and client_secret = :secret;
-- name: del-app-with-token
DELETE FROM apps
WHERE client_id = :id and client_secret = :secret and token = :token;
-- name: del-token-user -- name: del-token-user
DELETE FROM tokens DELETE FROM apps WHERE "user" = :username;
WHERE user = :username;
-- name: get-software-ban -- name: get-software-ban

View file

@ -18,10 +18,12 @@ securityDefinitions:
in: cookie in: cookie
name: user-token name: user-token
Bearer: Bearer:
type: apiKey type: oauth2
name: Authorization name: Authorization
in: header in: header
description: "Enter the token with the `Bearer ` prefix" flow: accessCode
authorizationUrl: /oauth/authorize
tokenUrl: /oauth/token
paths: paths:
/: /:
@ -35,6 +37,161 @@ paths:
schema: schema:
$ref: "#/definitions/Error" $ref: "#/definitions/Error"
/oauth/authorize:
get:
tags:
- OAuth
description: Get an authorization code
parameters:
- in: query
name: response-type
required: true
type: string
- in: query
name: client_id
required: true
type: string
- in: query
name: redirect_uri
required: true
type: string
/oauth/token:
post:
tags:
- OAuth
description: Get a token for an authorized app
parameters:
- in: formData
name: grant_type
required: true
type: string
- in: formData
name: code
required: true
type: string
- in: formData
name: client_id
required: true
type: string
- in: formData
name: client_secret
required: true
type: string
- in: formData
name: redirect_uri
required: true
type: string
consumes:
- application/x-www-form-urlencoded
- application/json
- multipart/form-data
produces:
- application/json
responses:
"200":
description: Application
schema:
$ref: "#/definitions/Application"
/oauth/revoke:
post:
tags:
- OAuth
description: Get a token for an authorized app
parameters:
- in: formData
name: client_id
required: true
type: string
- in: formData
name: client_secret
required: true
type: string
- in: formData
name: token
required: true
type: string
consumes:
- application/json
- multipart/form-data
- application/x-www-form-urlencoded
produces:
- application/json
responses:
"200":
description: Message confirming application deletion
schema:
$ref: "#/definitions/Message"
/v1/app:
get:
tags:
- Applications
description: Verify the token is valid
produces:
- application/json
responses:
"200":
description: Application with the associated token
schema:
$ref: "#/definitions/Application"
post:
tags:
- Applications
description: Create a new application
parameters:
- in: query
name: name
required: true
type: string
- in: query
name: redirect_uri
required: true
type: string
- in: query
name: website
required: false
type: string
format: url
consumes:
- application/json
- multipart/form-data
- application/x-www-form-urlencoded
produces:
- application/json
responses:
"200":
description: Newly created application
schema:
$ref: "#/definitions/Application"
delete:
tags:
- Applications
description: Deletes an application
parameters:
- in: formData
name: client_id
required: true
type: string
- in: formData
name: client_secret
required: true
type: string
consumes:
- application/json
- multipart/form-data
- application/x-www-form-urlencoded
produces:
- application/json
responses:
"200":
description: Confirmation of application deletion
schema:
$ref: "#/definitions/Message"
/v1/relay: /v1/relay:
get: get:
tags: tags:
@ -48,23 +205,11 @@ paths:
schema: schema:
$ref: "#/definitions/Info" $ref: "#/definitions/Info"
/v1/token: /v1/login:
get:
tags:
- Token
description: Verify API token
produces:
- application/json
responses:
"200":
description: Valid token
schema:
$ref: "#/definitions/Message"
post: post:
tags: tags:
- Token - Login
description: Get a new token description: Login with a username and password
parameters: parameters:
- in: formData - in: formData
name: username name: username
@ -74,7 +219,6 @@ paths:
name: password name: password
required: true required: true
type: string type: string
format: password
consumes: consumes:
- application/json - application/json
- multipart/form-data - multipart/form-data
@ -83,22 +227,9 @@ paths:
- application/json - application/json
responses: responses:
"200": "200":
description: Created token description: A new Application
schema: schema:
$ref: "#/definitions/Token" $ref: "#/definitions/Application"
delete:
tags:
- Token
description: Revoke a token
produces:
- application/json
responses:
"200":
description: Revoked token
schema:
$ref: "#/definitions/Message"
/v1/config: /v1/config:
get: get:
@ -731,9 +862,43 @@ definitions:
description: Human-readable message text description: Human-readable message text
type: string type: string
Application:
type: object
properties:
client_id:
description: Identifier for the application
type: string
client_secret:
description: Secret string for the application
type: string
name:
description: Human-readable name of the application
type: string
website:
description: Website for the application
type: string
format: url
redirect_uri:
description: URL to redirect to when authorizing an app
type: string
token:
description: String to use in the Authorization header for client requests
type: string
created:
description: Date the application was created
type: string
format: date-time
accessed:
description: Date the application was last used
type: string
format: date-time
Config: Config:
type: object type: object
properties: properties:
approval-required:
description: Require instances to be approved when following
type: bool
log-level: log-level:
description: Maximum level of log messages to print to the console description: Maximum level of log messages to print to the console
type: string type: string
@ -743,6 +908,9 @@ definitions:
note: note:
description: Blurb to display on the home page description: Blurb to display on the home page
type: string type: string
theme:
description: Name of the color scheme to use for the frontend
type: string
whitelist-enabled: whitelist-enabled:
description: Only allow specific instances to join the relay when enabled description: Only allow specific instances to join the relay when enabled
type: boolean type: boolean
@ -843,13 +1011,6 @@ definitions:
type: string type: string
format: date-time format: date-time
Token:
type: object
properties:
token:
description: Character string used for authenticating with the api
type: string
User: User:
type: object type: object
properties: properties:

View file

@ -1,3 +1,6 @@
import sqlite3
from blib import Date, File
from bsql import Database from bsql import Database
from .config import THEMES, ConfigData from .config import THEMES, ConfigData
@ -6,7 +9,9 @@ from .schema import TABLES, VERSIONS, migrate_0
from .. import logger as logging from .. import logger as logging
from ..config import Config from ..config import Config
from ..misc import get_resource
sqlite3.register_adapter(Date, Date.timestamp)
def get_database(config: Config, migrate: bool = True) -> Database[Connection]: def get_database(config: Config, migrate: bool = True) -> Database[Connection]:
@ -16,6 +21,8 @@ def get_database(config: Config, migrate: bool = True) -> Database[Connection]:
'tables': TABLES 'tables': TABLES
} }
db: Database[Connection]
if config.db_type == 'sqlite': if config.db_type == 'sqlite':
db = Database.sqlite(config.sqlite_path, **options) db = Database.sqlite(config.sqlite_path, **options)
@ -29,7 +36,7 @@ def get_database(config: Config, migrate: bool = True) -> Database[Connection]:
**options **options
) )
db.load_prepared_statements(get_resource('data/statements.sql')) db.load_prepared_statements(File.from_resource('relay', 'data/statements.sql'))
db.connect() db.connect()
if not migrate: if not migrate:

View file

@ -2,20 +2,17 @@ from __future__ import annotations
# removing the above line turns annotations into types instead of str objects which messes with # removing the above line turns annotations into types instead of str objects which messes with
# `Field.type` # `Field.type`
from blib import convert_to_boolean
from bsql import Row from bsql import Row
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from dataclasses import Field, asdict, dataclass, fields from dataclasses import Field, asdict, dataclass, fields
from typing import Any from typing import TYPE_CHECKING, Any
from .. import logger as logging from .. import logger as logging
from ..misc import boolean
try: if TYPE_CHECKING:
from typing import Self from typing import Self
except ImportError:
from typing_extensions import Self
THEMES = { THEMES = {
'default': { 'default': {
@ -69,14 +66,14 @@ THEMES = {
CONFIG_CONVERT: dict[str, tuple[Callable[[Any], str], Callable[[str], Any]]] = { CONFIG_CONVERT: dict[str, tuple[Callable[[Any], str], Callable[[str], Any]]] = {
'str': (str, str), 'str': (str, str),
'int': (str, int), 'int': (str, int),
'bool': (str, boolean), 'bool': (str, convert_to_boolean),
'logging.LogLevel': (lambda x: x.name, logging.LogLevel.parse) 'logging.LogLevel': (lambda x: x.name, logging.LogLevel.parse)
} }
@dataclass() @dataclass()
class ConfigData: class ConfigData:
schema_version: int = 20240310 schema_version: int = 20240625
private_key: str = '' private_key: str = ''
approval_required: bool = False approval_required: bool = False
log_level: logging.LogLevel = logging.LogLevel.INFO log_level: logging.LogLevel = logging.LogLevel.INFO
@ -114,11 +111,11 @@ class ConfigData:
@classmethod @classmethod
def DEFAULT(cls: type[Self], key: str) -> str | int | bool: def DEFAULT(cls: type[Self], key: str) -> str | int | bool:
return cls.FIELD(key.replace('-', '_')).default # type: ignore return cls.FIELD(key.replace('-', '_')).default # type: ignore[return-value]
@classmethod @classmethod
def FIELD(cls: type[Self], key: str) -> Field[Any]: def FIELD(cls: type[Self], key: str) -> Field[str | int | bool]:
for field in fields(cls): for field in fields(cls):
if field.name == key.replace('-', '_'): if field.name == key.replace('-', '_'):
return field return field

View file

@ -1,20 +1,23 @@
from __future__ import annotations from __future__ import annotations
import secrets
from argon2 import PasswordHasher from argon2 import PasswordHasher
from blib import Date, convert_to_boolean
from bsql import Connection as SqlConnection, Row, Update from bsql import Connection as SqlConnection, Row, Update
from collections.abc import Iterator, Sequence from collections.abc import Iterator
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse from urllib.parse import urlparse
from uuid import uuid4
from . import schema
from .config import ( from .config import (
THEMES, THEMES,
ConfigData ConfigData
) )
from .. import logger as logging from .. import logger as logging
from ..misc import Message, boolean, get_app from ..misc import Message, get_app
if TYPE_CHECKING: if TYPE_CHECKING:
from ..application import Application from ..application import Application
@ -37,22 +40,52 @@ class Connection(SqlConnection):
return get_app() return get_app()
def distill_inboxes(self, message: Message) -> Iterator[Row]: def distill_inboxes(self, message: Message) -> Iterator[schema.Instance]:
src_domains = { src_domains = {
message.domain, message.domain,
urlparse(message.object_id).netloc urlparse(message.object_id).netloc
} }
for instance in self.get_inboxes(): for instance in self.get_inboxes():
if instance['domain'] not in src_domains: if instance.domain not in src_domains:
yield instance yield instance
def fix_timestamps(self) -> None:
for app in self.select('apps').all(schema.App):
data = {'created': app.created.timestamp(), 'accessed': app.accessed.timestamp()}
self.update('apps', data, client_id = app.client_id)
for item in self.select('cache'):
data = {'updated': Date.parse(item['updated']).timestamp()}
self.update('cache', data, id = item['id'])
for dban in self.select('domain_bans').all(schema.DomainBan):
data = {'created': dban.created.timestamp()}
self.update('domain_bans', data, domain = dban.domain)
for instance in self.select('inboxes').all(schema.Instance):
data = {'created': instance.created.timestamp()}
self.update('inboxes', data, domain = instance.domain)
for sban in self.select('software_bans').all(schema.SoftwareBan):
data = {'created': sban.created.timestamp()}
self.update('software_bans', data, name = sban.name)
for user in self.select('users').all(schema.User):
data = {'created': user.created.timestamp()}
self.update('users', data, username = user.username)
for wlist in self.select('whitelist').all(schema.Whitelist):
data = {'created': wlist.created.timestamp()}
self.update('whitelist', data, domain = wlist.domain)
def get_config(self, key: str) -> Any: def get_config(self, key: str) -> Any:
key = key.replace('_', '-') key = key.replace('_', '-')
with self.run('get-config', {'key': key}) as cur: with self.run('get-config', {'key': key}) as cur:
if not (row := cur.one()): if (row := cur.one(Row)) is None:
return ConfigData.DEFAULT(key) return ConfigData.DEFAULT(key)
data = ConfigData() data = ConfigData()
@ -61,8 +94,8 @@ class Connection(SqlConnection):
def get_config_all(self) -> ConfigData: def get_config_all(self) -> ConfigData:
with self.run('get-config-all', None) as cur: rows = tuple(self.run('get-config-all', None).all(schema.Row))
return ConfigData.from_rows(tuple(cur.all())) return ConfigData.from_rows(rows)
def put_config(self, key: str, value: Any) -> Any: def put_config(self, key: str, value: Any) -> Any:
@ -75,9 +108,10 @@ class Connection(SqlConnection):
elif key == 'log-level': elif key == 'log-level':
value = logging.LogLevel.parse(value) value = logging.LogLevel.parse(value)
logging.set_level(value) logging.set_level(value)
self.app['workers'].set_log_level(value)
elif key in {'approval-required', 'whitelist-enabled'}: elif key in {'approval-required', 'whitelist-enabled'}:
value = boolean(value) value = convert_to_boolean(value)
elif key == 'theme': elif key == 'theme':
if value not in THEMES: if value not in THEMES:
@ -98,23 +132,23 @@ class Connection(SqlConnection):
return data.get(key) return data.get(key)
def get_inbox(self, value: str) -> Row: def get_inbox(self, value: str) -> schema.Instance | None:
with self.run('get-inbox', {'value': value}) as cur: with self.run('get-inbox', {'value': value}) as cur:
return cur.one() # type: ignore return cur.one(schema.Instance)
def get_inboxes(self) -> Sequence[Row]: def get_inboxes(self) -> Iterator[schema.Instance]:
with self.execute("SELECT * FROM inboxes WHERE accepted = 1") as cur: return self.execute("SELECT * FROM inboxes WHERE accepted = true").all(schema.Instance)
return tuple(cur.all())
def put_inbox(self, # todo: check if software is different than stored row
def put_inbox(self, # noqa: E301
domain: str, domain: str,
inbox: str | None = None, inbox: str | None = None,
actor: str | None = None, actor: str | None = None,
followid: str | None = None, followid: str | None = None,
software: str | None = None, software: str | None = None,
accepted: bool = True) -> Row: accepted: bool = True) -> schema.Instance:
params: dict[str, Any] = { params: dict[str, Any] = {
'inbox': inbox, 'inbox': inbox,
@ -124,7 +158,7 @@ class Connection(SqlConnection):
'accepted': accepted 'accepted': accepted
} }
if not self.get_inbox(domain): if self.get_inbox(domain) is None:
if not inbox: if not inbox:
raise ValueError("Missing inbox") raise ValueError("Missing inbox")
@ -132,14 +166,20 @@ class Connection(SqlConnection):
params['created'] = datetime.now(tz = timezone.utc) params['created'] = datetime.now(tz = timezone.utc)
with self.run('put-inbox', params) as cur: with self.run('put-inbox', params) as cur:
return cur.one() # type: ignore if (row := cur.one(schema.Instance)) is None:
raise RuntimeError(f"Failed to insert instance: {domain}")
return row
for key, value in tuple(params.items()): for key, value in tuple(params.items()):
if value is None: if value is None:
del params[key] del params[key]
with self.update('inboxes', params, domain = domain) as cur: with self.update('inboxes', params, domain = domain) as cur:
return cur.one() # type: ignore if (row := cur.one(schema.Instance)) is None:
raise RuntimeError(f"Failed to update instance: {domain}")
return row
def del_inbox(self, value: str) -> bool: def del_inbox(self, value: str) -> bool:
@ -150,24 +190,23 @@ class Connection(SqlConnection):
return cur.row_count == 1 return cur.row_count == 1
def get_request(self, domain: str) -> Row: def get_request(self, domain: str) -> schema.Instance | None:
with self.run('get-request', {'domain': domain}) as cur: with self.run('get-request', {'domain': domain}) as cur:
if not (row := cur.one()): return cur.one(schema.Instance)
def get_requests(self) -> Iterator[schema.Instance]:
return self.execute('SELECT * FROM inboxes WHERE accepted = false').all(schema.Instance)
def put_request_response(self, domain: str, accepted: bool) -> schema.Instance:
if (instance := self.get_request(domain)) is None:
raise KeyError(domain) raise KeyError(domain)
return row
def get_requests(self) -> Sequence[Row]:
with self.execute('SELECT * FROM inboxes WHERE accepted = 0') as cur:
return tuple(cur.all())
def put_request_response(self, domain: str, accepted: bool) -> Row:
instance = self.get_request(domain)
if not accepted: if not accepted:
self.del_inbox(domain) if not self.del_inbox(domain):
raise RuntimeError(f'Failed to delete request: {domain}')
return instance return instance
params = { params = {
@ -176,21 +215,28 @@ class Connection(SqlConnection):
} }
with self.run('put-inbox-accept', params) as cur: with self.run('put-inbox-accept', params) as cur:
return cur.one() # type: ignore if (row := cur.one(schema.Instance)) is None:
raise RuntimeError(f"Failed to insert response for domain: {domain}")
return row
def get_user(self, value: str) -> Row: def get_user(self, value: str) -> schema.User | None:
with self.run('get-user', {'value': value}) as cur: with self.run('get-user', {'value': value}) as cur:
return cur.one() # type: ignore return cur.one(schema.User)
def get_user_by_token(self, code: str) -> Row: def get_user_by_token(self, token: str) -> schema.User | None:
with self.run('get-user-by-token', {'code': code}) as cur: with self.run('get-user-by-token', {'token': token}) as cur:
return cur.one() # type: ignore return cur.one(schema.User)
def put_user(self, username: str, password: str | None, handle: str | None = None) -> Row: def get_users(self) -> Iterator[schema.User]:
if self.get_user(username): return self.execute("SELECT * FROM users").all(schema.User)
def put_user(self, username: str, password: str | None, handle: str | None = None) -> schema.User:
if self.get_user(username) is not None:
data: dict[str, str | datetime | None] = {} data: dict[str, str | datetime | None] = {}
if password: if password:
@ -203,7 +249,10 @@ class Connection(SqlConnection):
stmt.set_where("username", username) stmt.set_where("username", username)
with self.query(stmt) as cur: with self.query(stmt) as cur:
return cur.one() # type: ignore if (row := cur.one(schema.User)) is None:
raise RuntimeError(f"Failed to update user: {username}")
return row
if password is None: if password is None:
raise ValueError('Password cannot be empty') raise ValueError('Password cannot be empty')
@ -216,52 +265,149 @@ class Connection(SqlConnection):
} }
with self.run('put-user', data) as cur: with self.run('put-user', data) as cur:
return cur.one() # type: ignore if (row := cur.one(schema.User)) is None:
raise RuntimeError(f"Failed to insert user: {username}")
return row
def del_user(self, username: str) -> None: def del_user(self, username: str) -> None:
user = self.get_user(username) if (user := self.get_user(username)) is None:
raise KeyError(username)
with self.run('del-user', {'value': user['username']}): with self.run('del-token-user', {'username': user.username}):
pass pass
with self.run('del-token-user', {'username': user['username']}): with self.run('del-user', {'username': user.username}):
pass pass
def get_token(self, code: str) -> Row: def get_app(self,
with self.run('get-token', {'code': code}) as cur: client_id: str,
return cur.one() # type: ignore client_secret: str,
token: str | None = None) -> schema.App | None:
params = {
def put_token(self, username: str) -> Row: 'id': client_id,
data = { 'secret': client_secret
'code': uuid4().hex,
'user': username,
'created': datetime.now(tz = timezone.utc)
} }
with self.run('put-token', data) as cur: if token is not None:
return cur.one() # type: ignore command = 'get-app-with-token'
params['token'] = token
else:
command = 'get-app'
with self.run(command, params) as cur:
return cur.one(schema.App)
def del_token(self, code: str) -> None: def get_app_by_token(self, token: str) -> schema.App | None:
with self.run('del-token', {'code': code}): with self.run('get-app-by-token', {'token': token}) as cur:
pass return cur.one(schema.App)
def get_domain_ban(self, domain: str) -> Row: def put_app(self, name: str, redirect_uri: str, website: str | None = None) -> schema.App:
params = {
'name': name,
'redirect_uri': redirect_uri,
'website': website,
'client_id': secrets.token_hex(20),
'client_secret': secrets.token_hex(20),
'created': Date.new_utc(),
'accessed': Date.new_utc()
}
with self.insert('apps', params) as cur:
if (row := cur.one(schema.App)) is None:
raise RuntimeError(f'Failed to insert app: {name}')
return row
def put_app_login(self, user: schema.User) -> schema.App:
params = {
'name': 'Web',
'redirect_uri': 'urn:ietf:wg:oauth:2.0:oob',
'website': None,
'user': user.username,
'client_id': secrets.token_hex(20),
'client_secret': secrets.token_hex(20),
'auth_code': None,
'token': secrets.token_hex(20),
'created': Date.new_utc(),
'accessed': Date.new_utc()
}
with self.insert('apps', params) as cur:
if (row := cur.one(schema.App)) is None:
raise RuntimeError(f'Failed to create app for "{user.username}"')
return row
def update_app(self, app: schema.App, user: schema.User | None, set_auth: bool) -> schema.App:
data: dict[str, str | None] = {}
if user is not None:
data['user'] = user.username
if set_auth:
data['auth_code'] = secrets.token_hex(20)
else:
data['token'] = secrets.token_hex(20)
data['auth_code'] = None
params = {
'client_id': app.client_id,
'client_secret': app.client_secret
}
with self.update('apps', data, **params) as cur: # type: ignore[arg-type]
if (row := cur.one(schema.App)) is None:
raise RuntimeError('Failed to update row')
return row
def del_app(self, client_id: str, client_secret: str, token: str | None = None) -> bool:
params = {
'id': client_id,
'secret': client_secret
}
if token is not None:
command = 'del-app-with-token'
params['token'] = token
else:
command = 'del-app'
with self.run(command, params) as cur:
if cur.row_count > 1:
raise RuntimeError('More than 1 row was deleted')
return cur.row_count == 0
def get_domain_ban(self, domain: str) -> schema.DomainBan | None:
if domain.startswith('http'): if domain.startswith('http'):
domain = urlparse(domain).netloc domain = urlparse(domain).netloc
with self.run('get-domain-ban', {'domain': domain}) as cur: with self.run('get-domain-ban', {'domain': domain}) as cur:
return cur.one() # type: ignore return cur.one(schema.DomainBan)
def get_domain_bans(self) -> Iterator[schema.DomainBan]:
return self.execute("SELECT * FROM domain_bans").all(schema.DomainBan)
def put_domain_ban(self, def put_domain_ban(self,
domain: str, domain: str,
reason: str | None = None, reason: str | None = None,
note: str | None = None) -> Row: note: str | None = None) -> schema.DomainBan:
params = { params = {
'domain': domain, 'domain': domain,
@ -271,13 +417,16 @@ class Connection(SqlConnection):
} }
with self.run('put-domain-ban', params) as cur: with self.run('put-domain-ban', params) as cur:
return cur.one() # type: ignore if (row := cur.one(schema.DomainBan)) is None:
raise RuntimeError(f"Failed to insert domain ban: {domain}")
return row
def update_domain_ban(self, def update_domain_ban(self,
domain: str, domain: str,
reason: str | None = None, reason: str | None = None,
note: str | None = None) -> Row: note: str | None = None) -> schema.DomainBan:
if not (reason or note): if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified') raise ValueError('"reason" and/or "note" must be specified')
@ -297,7 +446,10 @@ class Connection(SqlConnection):
if cur.row_count > 1: if cur.row_count > 1:
raise ValueError('More than one row was modified') raise ValueError('More than one row was modified')
return self.get_domain_ban(domain) if (row := cur.one(schema.DomainBan)) is None:
raise RuntimeError(f"Failed to update domain ban: {domain}")
return row
def del_domain_ban(self, domain: str) -> bool: def del_domain_ban(self, domain: str) -> bool:
@ -308,15 +460,19 @@ class Connection(SqlConnection):
return cur.row_count == 1 return cur.row_count == 1
def get_software_ban(self, name: str) -> Row: def get_software_ban(self, name: str) -> schema.SoftwareBan | None:
with self.run('get-software-ban', {'name': name}) as cur: with self.run('get-software-ban', {'name': name}) as cur:
return cur.one() # type: ignore return cur.one(schema.SoftwareBan)
def get_software_bans(self) -> Iterator[schema.SoftwareBan,]:
return self.execute('SELECT * FROM software_bans').all(schema.SoftwareBan)
def put_software_ban(self, def put_software_ban(self,
name: str, name: str,
reason: str | None = None, reason: str | None = None,
note: str | None = None) -> Row: note: str | None = None) -> schema.SoftwareBan:
params = { params = {
'name': name, 'name': name,
@ -326,13 +482,16 @@ class Connection(SqlConnection):
} }
with self.run('put-software-ban', params) as cur: with self.run('put-software-ban', params) as cur:
return cur.one() # type: ignore if (row := cur.one(schema.SoftwareBan)) is None:
raise RuntimeError(f'Failed to insert software ban: {name}')
return row
def update_software_ban(self, def update_software_ban(self,
name: str, name: str,
reason: str | None = None, reason: str | None = None,
note: str | None = None) -> Row: note: str | None = None) -> schema.SoftwareBan:
if not (reason or note): if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified') raise ValueError('"reason" and/or "note" must be specified')
@ -352,7 +511,10 @@ class Connection(SqlConnection):
if cur.row_count > 1: if cur.row_count > 1:
raise ValueError('More than one row was modified') raise ValueError('More than one row was modified')
return self.get_software_ban(name) if (row := cur.one(schema.SoftwareBan)) is None:
raise RuntimeError(f'Failed to update software ban: {name}')
return row
def del_software_ban(self, name: str) -> bool: def del_software_ban(self, name: str) -> bool:
@ -363,19 +525,26 @@ class Connection(SqlConnection):
return cur.row_count == 1 return cur.row_count == 1
def get_domain_whitelist(self, domain: str) -> Row: def get_domain_whitelist(self, domain: str) -> schema.Whitelist | None:
with self.run('get-domain-whitelist', {'domain': domain}) as cur: with self.run('get-domain-whitelist', {'domain': domain}) as cur:
return cur.one() # type: ignore return cur.one()
def put_domain_whitelist(self, domain: str) -> Row: def get_domains_whitelist(self) -> Iterator[schema.Whitelist,]:
return self.execute("SELECT * FROM whitelist").all(schema.Whitelist)
def put_domain_whitelist(self, domain: str) -> schema.Whitelist:
params = { params = {
'domain': domain, 'domain': domain,
'created': datetime.now(tz = timezone.utc) 'created': datetime.now(tz = timezone.utc)
} }
with self.run('put-domain-whitelist', params) as cur: with self.run('put-domain-whitelist', params) as cur:
return cur.one() # type: ignore if (row := cur.one(schema.Whitelist)) is None:
raise RuntimeError(f'Failed to insert whitelisted domain: {domain}')
return row
def del_domain_whitelist(self, domain: str) -> bool: def del_domain_whitelist(self, domain: str) -> bool:

View file

@ -1,61 +1,133 @@
from bsql import Column, Table, Tables from __future__ import annotations
from blib import Date
from bsql import Column, Row, Tables
from collections.abc import Callable from collections.abc import Callable
from copy import deepcopy
from datetime import timezone
from typing import TYPE_CHECKING, Any
from .config import ConfigData from .config import ConfigData
from .connection import Connection
if TYPE_CHECKING:
from .connection import Connection
VERSIONS: dict[int, Callable[[Connection], None]] = {} VERSIONS: dict[int, Callable[[Connection], None]] = {}
TABLES: Tables = Tables( TABLES = Tables()
Table(
'config',
Column('key', 'text', primary_key = True, unique = True, nullable = False), def deserialize_timestamp(value: Any) -> Date:
Column('value', 'text'), try:
Column('type', 'text', default = 'str') date = Date.parse(value)
),
Table( except ValueError:
'inboxes', date = Date.fromisoformat(value)
Column('domain', 'text', primary_key = True, unique = True, nullable = False),
Column('actor', 'text', unique = True), if date.tzinfo is None:
Column('inbox', 'text', unique = True, nullable = False), date = date.replace(tzinfo = timezone.utc)
Column('followid', 'text'),
Column('software', 'text'), return date
Column('accepted', 'boolean'),
Column('created', 'timestamp', nullable = False)
), @TABLES.add_row
Table( class Config(Row):
'whitelist', key: Column[str] = Column('key', 'text', primary_key = True, unique = True, nullable = False)
Column('domain', 'text', primary_key = True, unique = True, nullable = True), value: Column[str] = Column('value', 'text')
Column('created', 'timestamp') type: Column[str] = Column('type', 'text', default = 'str')
),
Table(
'domain_bans', @TABLES.add_row
Column('domain', 'text', primary_key = True, unique = True, nullable = True), class Instance(Row):
Column('reason', 'text'), table_name: str = 'inboxes'
Column('note', 'text'),
Column('created', 'timestamp', nullable = False)
), domain: Column[str] = Column(
Table( 'domain', 'text', primary_key = True, unique = True, nullable = False)
'software_bans', actor: Column[str] = Column('actor', 'text', unique = True)
Column('name', 'text', primary_key = True, unique = True, nullable = True), inbox: Column[str] = Column('inbox', 'text', unique = True, nullable = False)
Column('reason', 'text'), followid: Column[str] = Column('followid', 'text')
Column('note', 'text'), software: Column[str] = Column('software', 'text')
Column('created', 'timestamp', nullable = False) accepted: Column[Date] = Column('accepted', 'boolean')
), created: Column[Date] = Column(
Table( 'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
'users',
Column('username', 'text', primary_key = True, unique = True, nullable = False),
Column('hash', 'text', nullable = False), @TABLES.add_row
Column('handle', 'text'), class Whitelist(Row):
Column('created', 'timestamp', nullable = False) domain: Column[str] = Column(
), 'domain', 'text', primary_key = True, unique = True, nullable = True)
Table( created: Column[Date] = Column(
'tokens', 'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
Column('code', 'text', primary_key = True, unique = True, nullable = False),
Column('user', 'text', nullable = False),
Column('created', 'timestmap', nullable = False) @TABLES.add_row
) class DomainBan(Row):
) table_name: str = 'domain_bans'
domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text')
created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row
class SoftwareBan(Row):
table_name: str = 'software_bans'
name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text')
created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row
class User(Row):
table_name: str = 'users'
username: Column[str] = Column(
'username', 'text', primary_key = True, unique = True, nullable = False)
hash: Column[str] = Column('hash', 'text', nullable = False)
handle: Column[str] = Column('handle', 'text')
created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row
class App(Row):
table_name: str = 'apps'
client_id: Column[str] = Column(
'client_id', 'text', primary_key = True, unique = True, nullable = False)
client_secret: Column[str] = Column('client_secret', 'text', nullable = False)
name: Column[str] = Column('name', 'text')
website: Column[str] = Column('website', 'text')
redirect_uri: Column[str] = Column('redirect_uri', 'text', nullable = False)
token: Column[str | None] = Column('token', 'text')
auth_code: Column[str | None] = Column('auth_code', 'text')
user: Column[str | None] = Column('user', 'text')
created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
accessed: Column[Date] = Column(
'accessed', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
def get_api_data(self, include_token: bool = False) -> dict[str, Any]:
data = deepcopy(self)
data.pop('user')
data.pop('auth_code')
if not include_token:
data.pop('token')
return data
def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]: def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]:
@ -76,5 +148,11 @@ def migrate_20240206(conn: Connection) -> None:
@migration @migration
def migrate_20240310(conn: Connection) -> None: def migrate_20240310(conn: Connection) -> None:
conn.execute("ALTER TABLE inboxes ADD COLUMN accepted BOOLEAN") conn.execute('ALTER TABLE "inboxes" ADD COLUMN "accepted" BOOLEAN').close()
conn.execute("UPDATE inboxes SET accepted = 1") conn.execute('UPDATE "inboxes" SET "accepted" = true').close()
@migration
def migrate_20240625(conn: Connection) -> None:
conn.create_tables()
conn.execute('DROP TABLE "tokens"').close()

2
relay/errors.py Normal file
View file

@ -0,0 +1,2 @@
class EmptyBodyError(Exception):
pass

View file

@ -1,5 +1,5 @@
-macro menu_item(name, path) -macro menu_item(name, path)
-if view.request.path == path or (path != "/" and view.request.path.startswith(path)) -if request.path == path or (path != "/" and request.path.startswith(path))
%a.button(href="{{path}}" active="true") -> =name %a.button(href="{{path}}" active="true") -> =name
-else -else
@ -11,11 +11,11 @@
%title << {{config.name}}: {{page}} %title << {{config.name}}: {{page}}
%meta(charset="UTF-8") %meta(charset="UTF-8")
%meta(name="viewport" content="width=device-width, initial-scale=1") %meta(name="viewport" content="width=device-width, initial-scale=1")
%link(rel="stylesheet" type="text/css" href="/theme/{{config.theme}}.css" nonce="{{view.request['hash']}}" class="theme") %link(rel="stylesheet" type="text/css" href="/theme/{{config.theme}}.css?{{version}}" nonce="{{request['hash']}}" class="theme")
%link(rel="stylesheet" type="text/css" href="/static/style.css" nonce="{{view.request['hash']}}") %link(rel="stylesheet" type="text/css" href="/static/style.css?{{version}}" nonce="{{request['hash']}}")
%link(rel="stylesheet" type="text/css" href="/static/bootstrap-icons.css" nonce="{{view.request['hash']}}") %link(rel="stylesheet" type="text/css" href="/static/bootstrap-icons.css?{{version}}" nonce="{{request['hash']}}")
%link(rel="manifest" href="/manifest.json") %link(rel="manifest" href="/manifest.json?{{version}}")
%script(type="application/javascript" src="/static/api.js" nonce="{{view.request['hash']}}" defer) %script(type="application/javascript" src="/static/functions.js?{{version}}" nonce="{{request['hash']}}" defer)
-block head -block head
%body %body
@ -26,7 +26,7 @@
{{menu_item("Home", "/")}} {{menu_item("Home", "/")}}
-if view.request["user"] -if request["user"]
{{menu_item("Instances", "/admin/instances")}} {{menu_item("Instances", "/admin/instances")}}
{{menu_item("Whitelist", "/admin/whitelist")}} {{menu_item("Whitelist", "/admin/whitelist")}}
{{menu_item("Domain Bans", "/admin/domain_bans")}} {{menu_item("Domain Bans", "/admin/domain_bans")}}
@ -61,11 +61,11 @@
#footer.section #footer.section
.col1 .col1
-if not view.request["user"] -if not request["user"]
%a(href="/login") << Login %a(href="/login") << Login
-else -else
=view.request["user"]["username"] =request["user"]["username"]
( (
%a(href="/logout") << Logout %a(href="/logout") << Logout
) )

View file

@ -1,29 +1,32 @@
-extends "base.haml" -extends "base.haml"
-set page="Config" -set page="Config"
-block head
%script(type="application/javascript" src="/static/config.js" nonce="{{view.request['hash']}}" defer)
-import "functions.haml" as func -import "functions.haml" as func
-block content -block content
%fieldset.section %fieldset.section
%legend << Config %legend << Config
.grid-2col .grid-2col
%label(for="name") << Name %label(for="name") << Name
%i(class="bi bi-question-circle-fill" title="{{desc.name}}")
%input(id = "name" placeholder="Relay Name" value="{{config.name or ''}}") %input(id = "name" placeholder="Relay Name" value="{{config.name or ''}}")
%label(for="note") << Description %label(for="note") << Description
%i(class="bi bi-question-circle-fill" title="{{desc.note}}")
%textarea(id="note" value="{{config.note or ''}}") << {{config.note}} %textarea(id="note" value="{{config.note or ''}}") << {{config.note}}
%label(for="theme") << Color Theme %label(for="theme") << Color Theme
%i(class="bi bi-question-circle-fill" title="{{desc.theme}}")
=func.new_select("theme", config.theme, themes) =func.new_select("theme", config.theme, themes)
%label(for="log-level") << Log Level %label(for="log-level") << Log Level
%i(class="bi bi-question-circle-fill" title="{{desc.log_level}}")
=func.new_select("log-level", config.log_level.name, levels) =func.new_select("log-level", config.log_level.name, levels)
%label(for="whitelist-enabled") << Whitelist %label(for="whitelist-enabled") << Whitelist
%i(class="bi bi-question-circle-fill" title="{{desc.whitelist_enabled}}")
=func.new_checkbox("whitelist-enabled", config.whitelist_enabled) =func.new_checkbox("whitelist-enabled", config.whitelist_enabled)
%label(for="approval-required") << Approval Required %label(for="approval-required") << Approval Required
%i(class="bi bi-question-circle-fill" title="{{desc.approval_required}}")
=func.new_checkbox("approval-required", config.approval_required) =func.new_checkbox("approval-required", config.approval_required)

View file

@ -1,9 +1,6 @@
-extends "base.haml" -extends "base.haml"
-set page="Domain Bans" -set page="Domain Bans"
-block head
%script(type="application/javascript" src="/static/domain_ban.js" nonce="{{view.request['hash']}}" defer)
-block content -block content
%details.section %details.section
%summary << Ban Domain %summary << Ban Domain

View file

@ -1,9 +1,6 @@
-extends "base.haml" -extends "base.haml"
-set page="Instances" -set page="Instances"
-block head
%script(type="application/javascript" src="/static/instance.js" nonce="{{view.request['hash']}}" defer)
-block content -block content
%details.section %details.section
%summary << Add Instance %summary << Add Instance

View file

@ -1,9 +1,6 @@
-extends "base.haml" -extends "base.haml"
-set page="Software Bans" -set page="Software Bans"
-block head
%script(type="application/javascript" src="/static/software_ban.js" nonce="{{view.request['hash']}}" defer)
-block content -block content
%details.section %details.section
%summary << Ban Software %summary << Ban Software

View file

@ -1,9 +1,6 @@
-extends "base.haml" -extends "base.haml"
-set page="Users" -set page="Users"
-block head
%script(type="application/javascript" src="/static/user.js" nonce="{{view.request['hash']}}" defer)
-block content -block content
%details.section %details.section
%summary << Add User %summary << Add User

View file

@ -1,9 +1,6 @@
-extends "base.haml" -extends "base.haml"
-set page="Whitelist" -set page="Whitelist"
-block head
%script(type="application/javascript" src="/static/whitelist.js" nonce="{{view.request['hash']}}" defer)
-block content -block content
%details.section %details.section
%summary << Add Domain %summary << Add Domain

View file

@ -0,0 +1,31 @@
-extends "base.haml"
-set page="App Authorization"
-block content
%fieldset.section
%legend << App Authorization
-if application.website
#title << Application "<a href="{{application.website}}" target="_new">{{application.name}}</a>" wants full API access
-else
#title << Application "{{application.name}}" wants full API access
#buttons
.spacer
%form(action="/oauth/authorize" method="POST")
%input(type="hidden" name="client_id" value="{{application.client_id}}")
%input(type="hidden" name="client_secret" value="{{application.client_secret}}")
%input(type="hidden" name="redirect_uri" value="{{application.redirect_uri}}")
%input(type="hidden" name="response" value="true")
%input.button(type="submit" value="Allow")
%form(action="/oauth/authorize" method="POST")
%input(type="hidden" name="client_id" value="{{application.client_id}}")
%input(type="hidden" name="client_secret" value="{{application.client_secret}}")
%input(type="hidden" name="redirect_uri" value="{{application.redirect_uri}}")
%input(type="hidden" name="response" value="false")
%input.button(type="submit" value="Deny")
.spacer

View file

@ -0,0 +1,18 @@
-extends "base.haml"
-set page="App Authorization"
-block content
%fieldset.section
%legend << App Authorization Code
-if application.website
%p
Copy the following code into
%a(href="{{application.website}}" target="_main") -> %code -> =application.name
-else
%p
Copy the following code info
%code -> =application.name
%pre#code -> =application.auth_code

View file

@ -0,0 +1,7 @@
-extends "base.haml"
-set page="Error"
-block content
.section.error
.title << HTTP Error {{e.status}}
.body -> =e.message

View file

@ -1,5 +1,6 @@
-extends "base.haml" -extends "base.haml"
-set page = "Home" -set page = "Home"
-block content -block content
-if config.note -if config.note
.section .section
@ -14,9 +15,7 @@
%a(href="https://{{domain}}/actor") << https://{{domain}}/actor</a> %a(href="https://{{domain}}/actor") << https://{{domain}}/actor</a>
-if config.approval_required -if config.approval_required
%fieldset.section.message %div.section.message
%legend << Require Approval
Follow requests require approval. You will need to wait for an admin to accept or deny Follow requests require approval. You will need to wait for an admin to accept or deny
your request. your request.

View file

@ -1,9 +1,6 @@
-extends "base.haml" -extends "base.haml"
-set page="Login" -set page="Login"
-block head
%script(type="application/javascript" src="/static/login.js" nonce="{{view.request['hash']}}" defer)
-block content -block content
%fieldset.section %fieldset.section
%legend << Login %legend << Login
@ -15,4 +12,6 @@
%label(for="password") << Password %label(for="password") << Password
%input(id="password" name="password" placeholder="Password" type="password") %input(id="password" name="password" placeholder="Password" type="password")
%input#redir(type="hidden" name="redir" value="{{redir}}")
%input.submit(type="button" value="Login") %input.submit(type="button" value="Login")

View file

@ -1,132 +0,0 @@
// toast notifications
const notifications = document.querySelector("#notifications")
function remove_toast(toast) {
toast.classList.add("hide");
if (toast.timeoutId) {
clearTimeout(toast.timeoutId);
}
setTimeout(() => toast.remove(), 300);
}
function toast(text, type="error", timeout=5) {
const toast = document.createElement("li");
toast.className = `section ${type}`
toast.innerHTML = `<span class=".text">${text}</span><a href="#">&#10006;</span>`
toast.querySelector("a").addEventListener("click", async (event) => {
event.preventDefault();
await remove_toast(toast);
});
notifications.appendChild(toast);
toast.timeoutId = setTimeout(() => remove_toast(toast), timeout * 1000);
}
// menu
const body = document.getElementById("container")
const menu = document.getElementById("menu");
const menu_open = document.querySelector("#menu-open i");
const menu_close = document.getElementById("menu-close");
function toggle_menu() {
let new_value = menu.attributes.visible.nodeValue === "true" ? "false" : "true";
menu.attributes.visible.nodeValue = new_value;
}
menu_open.addEventListener("click", toggle_menu);
menu_close.addEventListener("click", toggle_menu);
body.addEventListener("click", (event) => {
if (event.target === menu_open) {
return;
}
menu.attributes.visible.nodeValue = "false";
});
for (const elem of document.querySelectorAll("#menu-open div")) {
elem.addEventListener("click", toggle_menu);
}
// misc
function get_date_string(date) {
var year = date.getUTCFullYear().toString();
var month = (date.getUTCMonth() + 1).toString().padStart(2, "0");
var day = date.getUTCDate().toString().padStart(2, "0");
return `${year}-${month}-${day}`;
}
function append_table_row(table, row_name, row) {
var table_row = table.insertRow(-1);
table_row.id = row_name;
index = 0;
for (var prop in row) {
if (Object.prototype.hasOwnProperty.call(row, prop)) {
var cell = table_row.insertCell(index);
cell.className = prop;
cell.innerHTML = row[prop];
index += 1;
}
}
return table_row;
}
async function request(method, path, body = null) {
var headers = {
"Accept": "application/json"
}
if (body !== null) {
headers["Content-Type"] = "application/json"
body = JSON.stringify(body)
}
const response = await fetch("/api/" + path, {
method: method,
mode: "cors",
cache: "no-store",
redirect: "follow",
body: body,
headers: headers
});
const message = await response.json();
if (Object.hasOwn(message, "error")) {
throw new Error(message.error);
}
if (Array.isArray(message)) {
message.forEach((msg) => {
if (Object.hasOwn(msg, "created")) {
msg.created = new Date(msg.created);
}
});
} else {
if (Object.hasOwn(message, "created")) {
console.log(message.created)
message.created = new Date(message.created);
}
}
return message;
}

View file

@ -1,40 +0,0 @@
const elems = [
document.querySelector("#name"),
document.querySelector("#note"),
document.querySelector("#theme"),
document.querySelector("#log-level"),
document.querySelector("#whitelist-enabled"),
document.querySelector("#approval-required")
]
async function handle_config_change(event) {
params = {
key: event.target.id,
value: event.target.type === "checkbox" ? event.target.checked : event.target.value
}
try {
await request("POST", "v1/config", params);
} catch (error) {
toast(error);
return;
}
if (params.key === "name") {
document.querySelector("#header .title").innerHTML = params.value;
document.querySelector("title").innerHTML = params.value;
}
if (params.key === "theme") {
document.querySelector("link.theme").href = `/theme/${params.value}.css`;
}
toast("Updated config", "message");
}
for (const elem of elems) {
elem.addEventListener("change", handle_config_change);
}

View file

@ -1,123 +0,0 @@
function create_ban_object(domain, reason, note) {
var text = '<details>\n';
text += `<summary>${domain}</summary>\n`;
text += '<div class="grid-2col">\n';
text += `<label for="${domain}-reason" class="reason">Reason</label>\n`;
text += `<textarea id="${domain}-reason" class="reason">${reason}</textarea>\n`;
text += `<label for="${domain}-note" class="note">Note</label>\n`;
text += `<textarea id="${domain}-note" class="note">${note}</textarea>\n`;
text += `<input class="update-ban" type="button" value="Update">`;
text += '</details>';
return text;
}
function add_row_listeners(row) {
row.querySelector(".update-ban").addEventListener("click", async (event) => {
await update_ban(row.id);
});
row.querySelector(".remove a").addEventListener("click", async (event) => {
event.preventDefault();
await unban(row.id);
});
}
async function ban() {
var table = document.querySelector("table");
var elems = {
domain: document.getElementById("new-domain"),
reason: document.getElementById("new-reason"),
note: document.getElementById("new-note")
}
var values = {
domain: elems.domain.value.trim(),
reason: elems.reason.value.trim(),
note: elems.note.value.trim()
}
if (values.domain === "") {
toast("Domain is required");
return;
}
try {
var ban = await request("POST", "v1/domain_ban", values);
} catch (err) {
toast(err);
return
}
var row = append_table_row(document.querySelector("table"), ban.domain, {
domain: create_ban_object(ban.domain, ban.reason, ban.note),
date: get_date_string(ban.created),
remove: `<a href="#" title="Unban domain">&#10006;</a>`
});
add_row_listeners(row);
elems.domain.value = null;
elems.reason.value = null;
elems.note.value = null;
document.querySelector("details.section").open = false;
toast("Banned domain", "message");
}
async function update_ban(domain) {
var row = document.getElementById(domain);
var elems = {
"reason": row.querySelector("textarea.reason"),
"note": row.querySelector("textarea.note")
}
var values = {
"domain": domain,
"reason": elems.reason.value,
"note": elems.note.value
}
try {
await request("PATCH", "v1/domain_ban", values)
} catch (error) {
toast(error);
return;
}
row.querySelector("details").open = false;
toast("Updated baned domain", "message");
}
async function unban(domain) {
try {
await request("DELETE", "v1/domain_ban", {"domain": domain});
} catch (error) {
toast(error);
return;
}
document.getElementById(domain).remove();
toast("Unbanned domain", "message");
}
document.querySelector("#new-ban").addEventListener("click", async (event) => {
await ban();
});
for (var row of document.querySelector("fieldset.section table").rows) {
if (!row.querySelector(".update-ban")) {
continue;
}
add_row_listeners(row);
}

View file

@ -0,0 +1,864 @@
// toast notifications
const notifications = document.querySelector("#notifications")
function remove_toast(toast) {
toast.classList.add("hide");
if (toast.timeoutId) {
clearTimeout(toast.timeoutId);
}
setTimeout(() => toast.remove(), 300);
}
function toast(text, type="error", timeout=5) {
const toast = document.createElement("li");
toast.className = `section ${type}`
toast.innerHTML = `<span class=".text">${text}</span><a href="#">&#10006;</span>`
toast.querySelector("a").addEventListener("click", async (event) => {
event.preventDefault();
await remove_toast(toast);
});
notifications.appendChild(toast);
toast.timeoutId = setTimeout(() => remove_toast(toast), timeout * 1000);
}
// menu
const body = document.getElementById("container")
const menu = document.getElementById("menu");
const menu_open = document.querySelector("#menu-open i");
const menu_close = document.getElementById("menu-close");
function toggle_menu() {
let new_value = menu.attributes.visible.nodeValue === "true" ? "false" : "true";
menu.attributes.visible.nodeValue = new_value;
}
menu_open.addEventListener("click", toggle_menu);
menu_close.addEventListener("click", toggle_menu);
body.addEventListener("click", (event) => {
if (event.target === menu_open) {
return;
}
menu.attributes.visible.nodeValue = "false";
});
for (const elem of document.querySelectorAll("#menu-open div")) {
elem.addEventListener("click", toggle_menu);
}
// misc
function get_date_string(date) {
var year = date.getUTCFullYear().toString();
var month = (date.getUTCMonth() + 1).toString().padStart(2, "0");
var day = date.getUTCDate().toString().padStart(2, "0");
return `${year}-${month}-${day}`;
}
function append_table_row(table, row_name, row) {
var table_row = table.insertRow(-1);
table_row.id = row_name;
index = 0;
for (var prop in row) {
if (Object.prototype.hasOwnProperty.call(row, prop)) {
var cell = table_row.insertCell(index);
cell.className = prop;
cell.innerHTML = row[prop];
index += 1;
}
}
return table_row;
}
async function request(method, path, body = null) {
var headers = {
"Accept": "application/json"
}
if (body !== null) {
headers["Content-Type"] = "application/json"
body = JSON.stringify(body)
}
const response = await fetch("/api/" + path, {
method: method,
mode: "cors",
cache: "no-store",
redirect: "follow",
body: body,
headers: headers
});
const message = await response.json();
if (Object.hasOwn(message, "error")) {
throw new Error(message.error);
}
if (Array.isArray(message)) {
message.forEach((msg) => {
if (Object.hasOwn(msg, "created")) {
msg.created = new Date(msg.created);
}
});
} else {
if (Object.hasOwn(message, "created")) {
message.created = new Date(message.created);
}
}
return message;
}
// page functions
function page_config() {
const elems = [
document.querySelector("#name"),
document.querySelector("#note"),
document.querySelector("#theme"),
document.querySelector("#log-level"),
document.querySelector("#whitelist-enabled"),
document.querySelector("#approval-required")
]
async function handle_config_change(event) {
params = {
key: event.target.id,
value: event.target.type === "checkbox" ? event.target.checked : event.target.value
}
try {
await request("POST", "v1/config", params);
} catch (error) {
toast(error);
return;
}
if (params.key === "name") {
document.querySelector("#header .title").innerHTML = params.value;
document.querySelector("title").innerHTML = params.value;
}
if (params.key === "theme") {
document.querySelector("link.theme").href = `/theme/${params.value}.css`;
}
toast("Updated config", "message");
}
document.querySelector("#name").addEventListener("keydown", async (event) => {
if (event.which === 13) {
await handle_config_change(event);
}
});
document.querySelector("#note").addEventListener("keydown", async (event) => {
if (event.which === 13 && event.ctrlKey) {
await handle_config_change(event);
}
});
for (const elem of elems) {
elem.addEventListener("change", handle_config_change);
}
}
function page_domain_ban() {
function create_ban_object(domain, reason, note) {
var text = '<details>\n';
text += `<summary>${domain}</summary>\n`;
text += '<div class="grid-2col">\n';
text += `<label for="${domain}-reason" class="reason">Reason</label>\n`;
text += `<textarea id="${domain}-reason" class="reason">${reason}</textarea>\n`;
text += `<label for="${domain}-note" class="note">Note</label>\n`;
text += `<textarea id="${domain}-note" class="note">${note}</textarea>\n`;
text += `<input class="update-ban" type="button" value="Update">`;
text += '</details>';
return text;
}
function add_row_listeners(row) {
row.querySelector(".update-ban").addEventListener("click", async (event) => {
await update_ban(row.id);
});
row.querySelector(".remove a").addEventListener("click", async (event) => {
event.preventDefault();
await unban(row.id);
});
}
async function ban() {
var table = document.querySelector("table");
var elems = {
domain: document.getElementById("new-domain"),
reason: document.getElementById("new-reason"),
note: document.getElementById("new-note")
}
var values = {
domain: elems.domain.value.trim(),
reason: elems.reason.value.trim(),
note: elems.note.value.trim()
}
if (values.domain === "") {
toast("Domain is required");
return;
}
try {
var ban = await request("POST", "v1/domain_ban", values);
} catch (err) {
toast(err);
return
}
var row = append_table_row(document.querySelector("table"), ban.domain, {
domain: create_ban_object(ban.domain, ban.reason, ban.note),
date: get_date_string(ban.created),
remove: `<a href="#" title="Unban domain">&#10006;</a>`
});
add_row_listeners(row);
elems.domain.value = null;
elems.reason.value = null;
elems.note.value = null;
document.querySelector("details.section").open = false;
toast("Banned domain", "message");
}
async function update_ban(domain) {
var row = document.getElementById(domain);
var elems = {
"reason": row.querySelector("textarea.reason"),
"note": row.querySelector("textarea.note")
}
var values = {
"domain": domain,
"reason": elems.reason.value,
"note": elems.note.value
}
try {
await request("PATCH", "v1/domain_ban", values)
} catch (error) {
toast(error);
return;
}
row.querySelector("details").open = false;
toast("Updated baned domain", "message");
}
async function unban(domain) {
try {
await request("DELETE", "v1/domain_ban", {"domain": domain});
} catch (error) {
toast(error);
return;
}
document.getElementById(domain).remove();
toast("Unbanned domain", "message");
}
document.querySelector("#new-ban").addEventListener("click", async (event) => {
await ban();
});
for (var elem of document.querySelectorAll("#add-item input")) {
elem.addEventListener("keydown", async (event) => {
if (event.which === 13) {
await ban();
}
});
}
for (var row of document.querySelector("fieldset.section table").rows) {
if (!row.querySelector(".update-ban")) {
continue;
}
add_row_listeners(row);
}
}
function page_instance() {
function add_instance_listeners(row) {
row.querySelector(".remove a").addEventListener("click", async (event) => {
event.preventDefault();
await del_instance(row.id);
});
}
function add_request_listeners(row) {
row.querySelector(".approve a").addEventListener("click", async (event) => {
event.preventDefault();
await req_response(row.id, true);
});
row.querySelector(".deny a").addEventListener("click", async (event) => {
event.preventDefault();
await req_response(row.id, false);
});
}
async function add_instance() {
var elems = {
actor: document.getElementById("new-actor"),
inbox: document.getElementById("new-inbox"),
followid: document.getElementById("new-followid"),
software: document.getElementById("new-software")
}
var values = {
actor: elems.actor.value.trim(),
inbox: elems.inbox.value.trim(),
followid: elems.followid.value.trim(),
software: elems.software.value.trim()
}
if (values.actor === "") {
toast("Actor is required");
return;
}
try {
var instance = await request("POST", "v1/instance", values);
} catch (err) {
toast(err);
return
}
row = append_table_row(document.getElementById("instances"), instance.domain, {
domain: `<a href="https://${instance.domain}/" target="_new">${instance.domain}</a>`,
software: instance.software,
date: get_date_string(instance.created),
remove: `<a href="#" title="Remove Instance">&#10006;</a>`
});
add_instance_listeners(row);
elems.actor.value = null;
elems.inbox.value = null;
elems.followid.value = null;
elems.software.value = null;
document.querySelector("details.section").open = false;
toast("Added instance", "message");
}
async function del_instance(domain) {
try {
await request("DELETE", "v1/instance", {"domain": domain});
} catch (error) {
toast(error);
return;
}
document.getElementById(domain).remove();
}
async function req_response(domain, accept) {
params = {
"domain": domain,
"accept": accept
}
try {
await request("POST", "v1/request", params);
} catch (error) {
toast(error);
return;
}
document.getElementById(domain).remove();
if (document.getElementById("requests").rows.length < 2) {
document.querySelector("fieldset.requests").remove()
}
if (!accept) {
toast("Denied instance request", "message");
return;
}
instances = await request("GET", `v1/instance`, null);
instances.forEach((instance) => {
if (instance.domain === domain) {
row = append_table_row(document.getElementById("instances"), instance.domain, {
domain: `<a href="https://${instance.domain}/" target="_new">${instance.domain}</a>`,
software: instance.software,
date: get_date_string(instance.created),
remove: `<a href="#" title="Remove Instance">&#10006;</a>`
});
add_instance_listeners(row);
}
});
toast("Accepted instance request", "message");
}
document.querySelector("#add-instance").addEventListener("click", async (event) => {
await add_instance();
})
for (var elem of document.querySelectorAll("#add-item input")) {
elem.addEventListener("keydown", async (event) => {
if (event.which === 13) {
await add_instance();
}
});
}
for (var row of document.querySelector("#instances").rows) {
if (!row.querySelector(".remove a")) {
continue;
}
add_instance_listeners(row);
}
if (document.querySelector("#requests")) {
for (var row of document.querySelector("#requests").rows) {
if (!row.querySelector(".approve a")) {
continue;
}
add_request_listeners(row);
}
}
}
function page_login() {
const fields = {
username: document.querySelector("#username"),
password: document.querySelector("#password"),
redir: document.querySelector("#redir")
};
async function login(event) {
const values = {
username: fields.username.value.trim(),
password: fields.password.value.trim(),
redir: fields.redir.value.trim()
}
if (values.username === "" | values.password === "") {
toast("Username and/or password field is blank");
return;
}
try {
await request("POST", "v1/login", values);
} catch (error) {
toast(error);
return;
}
document.location = values.redir;
}
document.querySelector("#username").addEventListener("keydown", async (event) => {
if (event.which === 13) {
fields.password.focus();
fields.password.select();
}
});
document.querySelector("#password").addEventListener("keydown", async (event) => {
if (event.which === 13) {
await login(event);
}
});
document.querySelector(".submit").addEventListener("click", login);
}
function page_software_ban() {
function create_ban_object(name, reason, note) {
var text = '<details>\n';
text += `<summary>${name}</summary>\n`;
text += '<div class="grid-2col">\n';
text += `<label for="${name}-reason" class="reason">Reason</label>\n`;
text += `<textarea id="${name}-reason" class="reason">${reason}</textarea>\n`;
text += `<label for="${name}-note" class="note">Note</label>\n`;
text += `<textarea id="${name}-note" class="note">${note}</textarea>\n`;
text += `<input class="update-ban" type="button" value="Update">`;
text += '</details>';
return text;
}
function add_row_listeners(row) {
row.querySelector(".update-ban").addEventListener("click", async (event) => {
await update_ban(row.id);
});
row.querySelector(".remove a").addEventListener("click", async (event) => {
event.preventDefault();
await unban(row.id);
});
}
async function ban() {
var elems = {
name: document.getElementById("new-name"),
reason: document.getElementById("new-reason"),
note: document.getElementById("new-note")
}
var values = {
name: elems.name.value.trim(),
reason: elems.reason.value,
note: elems.note.value
}
if (values.name === "") {
toast("Domain is required");
return;
}
try {
var ban = await request("POST", "v1/software_ban", values);
} catch (err) {
toast(err);
return
}
var row = append_table_row(document.getElementById("bans"), ban.name, {
name: create_ban_object(ban.name, ban.reason, ban.note),
date: get_date_string(ban.created),
remove: `<a href="#" title="Unban software">&#10006;</a>`
});
add_row_listeners(row);
elems.name.value = null;
elems.reason.value = null;
elems.note.value = null;
document.querySelector("details.section").open = false;
toast("Banned software", "message");
}
async function update_ban(name) {
var row = document.getElementById(name);
var elems = {
"reason": row.querySelector("textarea.reason"),
"note": row.querySelector("textarea.note")
}
var values = {
"name": name,
"reason": elems.reason.value,
"note": elems.note.value
}
try {
await request("PATCH", "v1/software_ban", values)
} catch (error) {
toast(error);
return;
}
row.querySelector("details").open = false;
toast("Updated software ban", "message");
}
async function unban(name) {
try {
await request("DELETE", "v1/software_ban", {"name": name});
} catch (error) {
toast(error);
return;
}
document.getElementById(name).remove();
toast("Unbanned software", "message");
}
document.querySelector("#new-ban").addEventListener("click", async (event) => {
await ban();
});
for (var elem of document.querySelectorAll("#add-item input")) {
elem.addEventListener("keydown", async (event) => {
if (event.which === 13) {
await ban();
}
});
}
for (var elem of document.querySelectorAll("#add-item textarea")) {
elem.addEventListener("keydown", async (event) => {
if (event.which === 13 && event.ctrlKey) {
await ban();
}
});
}
for (var row of document.querySelector("#bans").rows) {
if (!row.querySelector(".update-ban")) {
continue;
}
add_row_listeners(row);
}
}
function page_user() {
function add_row_listeners(row) {
row.querySelector(".remove a").addEventListener("click", async (event) => {
event.preventDefault();
await del_user(row.id);
});
}
async function add_user() {
var elems = {
username: document.getElementById("new-username"),
password: document.getElementById("new-password"),
password2: document.getElementById("new-password2"),
handle: document.getElementById("new-handle")
}
var values = {
username: elems.username.value.trim(),
password: elems.password.value.trim(),
password2: elems.password2.value.trim(),
handle: elems.handle.value.trim()
}
if (values.username === "" | values.password === "" | values.password2 === "") {
toast("Username, password, and password2 are required");
return;
}
if (values.password !== values.password2) {
toast("Passwords do not match");
return;
}
try {
var user = await request("POST", "v1/user", values);
} catch (err) {
toast(err);
return
}
var row = append_table_row(document.querySelector("fieldset.section table"), user.username, {
domain: user.username,
handle: user.handle ? self.handle : "n/a",
date: get_date_string(user.created),
remove: `<a href="#" title="Delete User">&#10006;</a>`
});
add_row_listeners(row);
elems.username.value = null;
elems.password.value = null;
elems.password2.value = null;
elems.handle.value = null;
document.querySelector("details.section").open = false;
toast("Created user", "message");
}
async function del_user(username) {
try {
await request("DELETE", "v1/user", {"username": username});
} catch (error) {
toast(error);
return;
}
document.getElementById(username).remove();
toast("Deleted user", "message");
}
document.querySelector("#new-user").addEventListener("click", async (event) => {
await add_user();
});
for (var elem of document.querySelectorAll("#add-item input")) {
elem.addEventListener("keydown", async (event) => {
if (event.which === 13) {
await add_user();
}
});
}
for (var row of document.querySelector("#users").rows) {
if (!row.querySelector(".remove a")) {
continue;
}
add_row_listeners(row);
}
}
function page_whitelist() {
function add_row_listeners(row) {
row.querySelector(".remove a").addEventListener("click", async (event) => {
event.preventDefault();
await del_whitelist(row.id);
});
}
async function add_whitelist() {
var domain_elem = document.getElementById("new-domain");
var domain = domain_elem.value.trim();
if (domain === "") {
toast("Domain is required");
return;
}
try {
var item = await request("POST", "v1/whitelist", {"domain": domain});
} catch (err) {
toast(err);
return;
}
var row = append_table_row(document.getElementById("whitelist"), item.domain, {
domain: item.domain,
date: get_date_string(item.created),
remove: `<a href="#" title="Remove whitelisted domain">&#10006;</a>`
});
add_row_listeners(row);
domain_elem.value = null;
document.querySelector("details.section").open = false;
toast("Added domain", "message");
}
async function del_whitelist(domain) {
try {
await request("DELETE", "v1/whitelist", {"domain": domain});
} catch (error) {
toast(error);
return;
}
document.getElementById(domain).remove();
toast("Removed domain", "message");
}
document.querySelector("#new-item").addEventListener("click", async (event) => {
await add_whitelist();
});
document.querySelector("#add-item").addEventListener("keydown", async (event) => {
if (event.which === 13) {
await add_whitelist();
}
});
for (var row of document.querySelector("fieldset.section table").rows) {
if (!row.querySelector(".remove a")) {
continue;
}
add_row_listeners(row);
}
}
if (location.pathname.startsWith("/admin/config")) {
page_config();
} else if (location.pathname.startsWith("/admin/domain_bans")) {
page_domain_ban();
} else if (location.pathname.startsWith("/admin/instances")) {
page_instance();
} else if (location.pathname.startsWith("/admin/software_bans")) {
page_software_ban();
} else if (location.pathname.startsWith("/admin/users")) {
page_user();
} else if (location.pathname.startsWith("/admin/whitelist")) {
page_whitelist();
} else if (location.pathname.startsWith("/login")) {
page_login();
}

View file

@ -1,145 +0,0 @@
function add_instance_listeners(row) {
row.querySelector(".remove a").addEventListener("click", async (event) => {
event.preventDefault();
await del_instance(row.id);
});
}
function add_request_listeners(row) {
row.querySelector(".approve a").addEventListener("click", async (event) => {
event.preventDefault();
await req_response(row.id, true);
});
row.querySelector(".deny a").addEventListener("click", async (event) => {
event.preventDefault();
await req_response(row.id, false);
});
}
async function add_instance() {
var elems = {
actor: document.getElementById("new-actor"),
inbox: document.getElementById("new-inbox"),
followid: document.getElementById("new-followid"),
software: document.getElementById("new-software")
}
var values = {
actor: elems.actor.value.trim(),
inbox: elems.inbox.value.trim(),
followid: elems.followid.value.trim(),
software: elems.software.value.trim()
}
if (values.actor === "") {
toast("Actor is required");
return;
}
try {
var instance = await request("POST", "v1/instance", values);
} catch (err) {
toast(err);
return
}
row = append_table_row(document.getElementById("instances"), instance.domain, {
domain: `<a href="https://${instance.domain}/" target="_new">${instance.domain}</a>`,
software: instance.software,
date: get_date_string(instance.created),
remove: `<a href="#" title="Remove Instance">&#10006;</a>`
});
add_instance_listeners(row);
elems.actor.value = null;
elems.inbox.value = null;
elems.followid.value = null;
elems.software.value = null;
document.querySelector("details.section").open = false;
toast("Added instance", "message");
}
async function del_instance(domain) {
try {
await request("DELETE", "v1/instance", {"domain": domain});
} catch (error) {
toast(error);
return;
}
document.getElementById(domain).remove();
}
async function req_response(domain, accept) {
params = {
"domain": domain,
"accept": accept
}
try {
await request("POST", "v1/request", params);
} catch (error) {
toast(error);
return;
}
document.getElementById(domain).remove();
if (document.getElementById("requests").rows.length < 2) {
document.querySelector("fieldset.requests").remove()
}
if (!accept) {
toast("Denied instance request", "message");
return;
}
instances = await request("GET", `v1/instance`, null);
instances.forEach((instance) => {
if (instance.domain === domain) {
row = append_table_row(document.getElementById("instances"), instance.domain, {
domain: `<a href="https://${instance.domain}/" target="_new">${instance.domain}</a>`,
software: instance.software,
date: get_date_string(instance.created),
remove: `<a href="#" title="Remove Instance">&#10006;</a>`
});
add_instance_listeners(row);
}
});
toast("Accepted instance request", "message");
}
document.querySelector("#add-instance").addEventListener("click", async (event) => {
await add_instance();
})
for (var row of document.querySelector("#instances").rows) {
if (!row.querySelector(".remove a")) {
continue;
}
add_instance_listeners(row);
}
if (document.querySelector("#requests")) {
for (var row of document.querySelector("#requests").rows) {
if (!row.querySelector(".approve a")) {
continue;
}
add_request_listeners(row);
}
}

View file

@ -1,29 +0,0 @@
async function login(event) {
fields = {
username: document.querySelector("#username"),
password: document.querySelector("#password")
}
values = {
username: fields.username.value.trim(),
password: fields.password.value.trim()
}
if (values.username === "" | values.password === "") {
toast("Username and/or password field is blank");
return;
}
try {
await request("POST", "v1/token", values);
} catch (error) {
toast(error);
return;
}
document.location = "/";
}
document.querySelector(".submit").addEventListener("click", login);

View file

@ -1,122 +0,0 @@
function create_ban_object(name, reason, note) {
var text = '<details>\n';
text += `<summary>${name}</summary>\n`;
text += '<div class="grid-2col">\n';
text += `<label for="${name}-reason" class="reason">Reason</label>\n`;
text += `<textarea id="${name}-reason" class="reason">${reason}</textarea>\n`;
text += `<label for="${name}-note" class="note">Note</label>\n`;
text += `<textarea id="${name}-note" class="note">${note}</textarea>\n`;
text += `<input class="update-ban" type="button" value="Update">`;
text += '</details>';
return text;
}
function add_row_listeners(row) {
row.querySelector(".update-ban").addEventListener("click", async (event) => {
await update_ban(row.id);
});
row.querySelector(".remove a").addEventListener("click", async (event) => {
event.preventDefault();
await unban(row.id);
});
}
async function ban() {
var elems = {
name: document.getElementById("new-name"),
reason: document.getElementById("new-reason"),
note: document.getElementById("new-note")
}
var values = {
name: elems.name.value.trim(),
reason: elems.reason.value,
note: elems.note.value
}
if (values.name === "") {
toast("Domain is required");
return;
}
try {
var ban = await request("POST", "v1/software_ban", values);
} catch (err) {
toast(err);
return
}
var row = append_table_row(document.getElementById("bans"), ban.name, {
name: create_ban_object(ban.name, ban.reason, ban.note),
date: get_date_string(ban.created),
remove: `<a href="#" title="Unban software">&#10006;</a>`
});
add_row_listeners(row);
elems.name.value = null;
elems.reason.value = null;
elems.note.value = null;
document.querySelector("details.section").open = false;
toast("Banned software", "message");
}
async function update_ban(name) {
var row = document.getElementById(name);
var elems = {
"reason": row.querySelector("textarea.reason"),
"note": row.querySelector("textarea.note")
}
var values = {
"name": name,
"reason": elems.reason.value,
"note": elems.note.value
}
try {
await request("PATCH", "v1/software_ban", values)
} catch (error) {
toast(error);
return;
}
row.querySelector("details").open = false;
toast("Updated software ban", "message");
}
async function unban(name) {
try {
await request("DELETE", "v1/software_ban", {"name": name});
} catch (error) {
toast(error);
return;
}
document.getElementById(name).remove();
toast("Unbanned software", "message");
}
document.querySelector("#new-ban").addEventListener("click", async (event) => {
await ban();
});
for (var row of document.querySelector("#bans").rows) {
if (!row.querySelector(".update-ban")) {
continue;
}
add_row_listeners(row);
}

View file

@ -12,7 +12,7 @@ body {
color: var(--text); color: var(--text);
background-color: #222; background-color: #222;
margin: var(--spacing); margin: var(--spacing);
font-family: sans serif; font-family: sans-serif;
} }
details *:nth-child(2) { details *:nth-child(2) {
@ -88,6 +88,7 @@ tbody tr:last-child td:last-child {
table td { table td {
padding: 5px; padding: 5px;
white-space: nowrap;
} }
table thead td { table thead td {
@ -282,8 +283,11 @@ textarea {
width: 100%; width: 100%;
} }
.data-table .date { .data-table td:not(:first-child) {
width: max-content; width: max-content;
}
.data-table .date {
text-align: right; text-align: right;
} }
@ -297,13 +301,13 @@ textarea {
border: 1px solid var(--error-border) !important; border: 1px solid var(--error-border) !important;
} }
/* create .grid base class and .2col and 3col classes */
.grid-2col { .grid-2col {
display: grid; display: grid;
grid-template-columns: max-content auto; grid-template-columns: max-content auto;
grid-gap: var(--spacing); grid-gap: var(--spacing);
margin-bottom: var(--spacing); margin-bottom: var(--spacing);
align-items: center; align-items: center;
} }
.message { .message {
@ -333,6 +337,48 @@ textarea {
justify-self: left; justify-self: left;
} }
#content.page-config .grid-2col {
grid-template-columns: max-content max-content auto;
}
/* error */
#content.page-error {
text-align: center;
}
#content.page-error .title {
font-size: 24px;
font-weight: bold;
}
/* auth */
#content.page-app_authorization {
text-align: center;
}
#content.page-app_authorization #code {
background: var(--background);
border: 1px solid var(--border);
font-size: 18px;
margin: 0 auto;
width: max-content;
padding: 5px;
}
#content.page-app_authorization #title {
font-size: 24px;
}
#content.page-app_authorization #buttons {
display: grid;
grid-template-columns: auto max-content max-content auto;
grid-gap: var(--spacing);
justify-items: center;
margin: var(--spacing) 0;
}
@keyframes show_toast { @keyframes show_toast {
0% { 0% {

View file

@ -1,85 +0,0 @@
function add_row_listeners(row) {
row.querySelector(".remove a").addEventListener("click", async (event) => {
event.preventDefault();
await del_user(row.id);
});
}
async function add_user() {
var elems = {
username: document.getElementById("new-username"),
password: document.getElementById("new-password"),
password2: document.getElementById("new-password2"),
handle: document.getElementById("new-handle")
}
var values = {
username: elems.username.value.trim(),
password: elems.password.value.trim(),
password2: elems.password2.value.trim(),
handle: elems.handle.value.trim()
}
if (values.username === "" | values.password === "" | values.password2 === "") {
toast("Username, password, and password2 are required");
return;
}
if (values.password !== values.password2) {
toast("Passwords do not match");
return;
}
try {
var user = await request("POST", "v1/user", values);
} catch (err) {
toast(err);
return
}
var row = append_table_row(document.querySelector("fieldset.section table"), user.username, {
domain: user.username,
handle: user.handle ? self.handle : "n/a",
date: get_date_string(user.created),
remove: `<a href="#" title="Delete User">&#10006;</a>`
});
add_row_listeners(row);
elems.username.value = null;
elems.password.value = null;
elems.password2.value = null;
elems.handle.value = null;
document.querySelector("details.section").open = false;
toast("Created user", "message");
}
async function del_user(username) {
try {
await request("DELETE", "v1/user", {"username": username});
} catch (error) {
toast(error);
return;
}
document.getElementById(username).remove();
toast("Deleted user", "message");
}
document.querySelector("#new-user").addEventListener("click", async (event) => {
await add_user();
});
for (var row of document.querySelector("#users").rows) {
if (!row.querySelector(".remove a")) {
continue;
}
add_row_listeners(row);
}

View file

@ -1,64 +0,0 @@
function add_row_listeners(row) {
row.querySelector(".remove a").addEventListener("click", async (event) => {
event.preventDefault();
await del_whitelist(row.id);
});
}
async function add_whitelist() {
var domain_elem = document.getElementById("new-domain");
var domain = domain_elem.value.trim();
if (domain === "") {
toast("Domain is required");
return;
}
try {
var item = await request("POST", "v1/whitelist", {"domain": domain});
} catch (err) {
toast(err);
return;
}
var row = append_table_row(document.getElementById("whitelist"), item.domain, {
domain: item.domain,
date: get_date_string(item.created),
remove: `<a href="#" title="Remove whitelisted domain">&#10006;</a>`
});
add_row_listeners(row);
domain_elem.value = null;
document.querySelector("details.section").open = false;
toast("Added domain", "message");
}
async function del_whitelist(domain) {
try {
await request("DELETE", "v1/whitelist", {"domain": domain});
} catch (error) {
toast(error);
return;
}
document.getElementById(domain).remove();
toast("Removed domain", "message");
}
document.querySelector("#new-item").addEventListener("click", async (event) => {
await add_whitelist();
});
for (var row of document.querySelector("fieldset.section table").rows) {
if (!row.querySelector(".remove a")) {
continue;
}
add_row_listeners(row);
}

View file

@ -1,20 +1,16 @@
from __future__ import annotations from __future__ import annotations
import json import json
import traceback
from aiohttp import ClientSession, ClientTimeout, TCPConnector from aiohttp import ClientSession, ClientTimeout, TCPConnector
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
from aputils import AlgorithmType, Nodeinfo, ObjectType, Signer, WellKnownNodeinfo from aputils import AlgorithmType, Nodeinfo, ObjectType, Signer, WellKnownNodeinfo
from asyncio.exceptions import TimeoutError as AsyncTimeoutError from blib import HttpError, JsonBase
from blib import JsonBase from typing import TYPE_CHECKING, Any, TypeVar, overload
from bsql import Row
from json.decoder import JSONDecodeError
from typing import TYPE_CHECKING, Any, TypeVar
from urllib.parse import urlparse
from . import __version__, logger as logging from . import __version__, logger as logging
from .cache import Cache from .cache import Cache
from .database.schema import Instance
from .errors import EmptyBodyError
from .misc import MIMETYPES, Message, get_app from .misc import MIMETYPES, Message, get_app
if TYPE_CHECKING: if TYPE_CHECKING:
@ -36,7 +32,7 @@ SUPPORTS_HS2019 = {
'sharkey' 'sharkey'
} }
T = TypeVar('T', bound = JsonBase) T = TypeVar('T', bound = JsonBase[Any])
HEADERS = { HEADERS = {
'Accept': f'{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9', 'Accept': f'{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9',
'User-Agent': f'ActivityRelay/{__version__}' 'User-Agent': f'ActivityRelay/{__version__}'
@ -107,21 +103,17 @@ class HttpClient:
url: str, url: str,
sign_headers: bool, sign_headers: bool,
force: bool, force: bool,
old_algo: bool) -> dict[str, Any] | None: old_algo: bool) -> str | None:
if not self._session: if not self._session:
raise RuntimeError('Client not open') raise RuntimeError('Client not open')
try: url = url.split("#", 1)[0]
url, _ = url.split('#', 1)
except ValueError:
pass
if not force: if not force:
try: try:
if not (item := self.cache.get('request', url)).older_than(48): if not (item := self.cache.get('request', url)).older_than(48):
return json.loads(item.value) # type: ignore[no-any-return] return item.value # type: ignore [no-any-return]
except KeyError: except KeyError:
logging.verbose('No cached data for url: %s', url) logging.verbose('No cached data for url: %s', url)
@ -132,7 +124,6 @@ class HttpClient:
algo = AlgorithmType.RSASHA256 if old_algo else AlgorithmType.HS2019 algo = AlgorithmType.RSASHA256 if old_algo else AlgorithmType.HS2019
headers = self.signer.sign_headers('GET', url, algorithm = algo) headers = self.signer.sign_headers('GET', url, algorithm = algo)
try:
logging.debug('Fetching resource: %s', url) logging.debug('Fetching resource: %s', url)
async with self._session.get(url, headers = headers) as resp: async with self._session.get(url, headers = headers) as resp:
@ -142,57 +133,65 @@ class HttpClient:
data = await resp.text() data = await resp.text()
if resp.status != 200: if resp.status not in (200, 202):
logging.verbose('Received error when requesting %s: %i', url, resp.status) try:
logging.debug(data) error = json.loads(data)["error"]
return None
self.cache.set('request', url, data, 'str')
logging.debug('%s >> resp %s', url, json.dumps(json.loads(data), indent = 4))
return json.loads(data) # type: ignore [no-any-return]
except JSONDecodeError:
logging.verbose('Failed to parse JSON')
logging.debug(data)
return None
except ClientSSLError as e:
logging.verbose('SSL error when connecting to %s', urlparse(url).netloc)
logging.warning(str(e))
except (AsyncTimeoutError, ClientConnectionError) as e:
logging.verbose('Failed to connect to %s', urlparse(url).netloc)
logging.warning(str(e))
except Exception: except Exception:
traceback.print_exc() error = data
return None raise HttpError(resp.status, error)
self.cache.set('request', url, data, 'str')
return data
@overload
async def get(self,
url: str,
sign_headers: bool,
cls: None = None,
force: bool = False,
old_algo: bool = True) -> str | None: ...
@overload
async def get(self,
url: str,
sign_headers: bool,
cls: type[T] = JsonBase, # type: ignore[assignment]
force: bool = False,
old_algo: bool = True) -> T: ...
async def get(self, async def get(self,
url: str, url: str,
sign_headers: bool, sign_headers: bool,
cls: type[T], cls: type[T] | None = None,
force: bool = False, force: bool = False,
old_algo: bool = True) -> T | None: old_algo: bool = True) -> T | str | None:
if not issubclass(cls, JsonBase): if cls is not None and not issubclass(cls, JsonBase):
raise TypeError('cls must be a sub-class of "blib.JsonBase"') raise TypeError('cls must be a sub-class of "blib.JsonBase"')
if (data := (await self._get(url, sign_headers, force, old_algo))) is None: data = await self._get(url, sign_headers, force, old_algo)
return None
if cls is not None:
if data is None:
# this shouldn't actually get raised, but keeping just in case
raise EmptyBodyError(f"GET {url}")
return cls.parse(data) return cls.parse(data)
return data
async def post(self, url: str, data: Message | bytes, instance: Row | None = None) -> None:
async def post(self, url: str, data: Message | bytes, instance: Instance | None = None) -> None:
if not self._session: if not self._session:
raise RuntimeError('Client not open') raise RuntimeError('Client not open')
# akkoma and pleroma do not support HS2019 and other software still needs to be tested # akkoma and pleroma do not support HS2019 and other software still needs to be tested
if instance and instance['software'] in SUPPORTS_HS2019: if instance is not None and instance.software in SUPPORTS_HS2019:
algorithm = AlgorithmType.HS2019 algorithm = AlgorithmType.HS2019
else: else:
@ -218,45 +217,22 @@ class HttpClient:
algorithm = algorithm algorithm = algorithm
) )
try:
logging.verbose('Sending "%s" to %s', mtype, url) logging.verbose('Sending "%s" to %s', mtype, url)
async with self._session.post(url, headers = headers, data = body) as resp: async with self._session.post(url, headers = headers, data = body) as resp:
# Not expecting a response, so just return if resp.status not in (200, 202):
if resp.status in {200, 202}: raise HttpError(
logging.verbose('Successfully sent "%s" to %s', mtype, url) resp.status,
return await resp.text(),
headers = {k: v for k, v in resp.headers.items()}
logging.verbose('Received error when pushing to %s: %i', url, resp.status)
logging.debug(await resp.read())
logging.debug("message: %s", body.decode("utf-8"))
logging.debug("headers: %s", json.dumps(headers, indent = 4))
return
except ClientSSLError as e:
logging.warning('SSL error when pushing to %s', urlparse(url).netloc)
logging.warning(str(e))
except (AsyncTimeoutError, ClientConnectionError) as e:
logging.warning('Failed to connect to %s for message push', urlparse(url).netloc)
logging.warning(str(e))
# prevent workers from being brought down
except Exception:
traceback.print_exc()
async def fetch_nodeinfo(self, domain: str) -> Nodeinfo | None:
nodeinfo_url = None
wk_nodeinfo = await self.get(
f'https://{domain}/.well-known/nodeinfo',
False,
WellKnownNodeinfo
) )
if wk_nodeinfo is None:
logging.verbose('Failed to fetch well-known nodeinfo url for %s', domain) async def fetch_nodeinfo(self, domain: str, force: bool = False) -> Nodeinfo:
return None nodeinfo_url = None
wk_nodeinfo = await self.get(
f'https://{domain}/.well-known/nodeinfo', False, WellKnownNodeinfo, force
)
for version in ('20', '21'): for version in ('20', '21'):
try: try:
@ -266,10 +242,9 @@ class HttpClient:
pass pass
if nodeinfo_url is None: if nodeinfo_url is None:
logging.verbose('Failed to fetch nodeinfo url for %s', domain) raise ValueError(f'Failed to fetch nodeinfo url for {domain}')
return None
return await self.get(nodeinfo_url, False, Nodeinfo) return await self.get(nodeinfo_url, False, Nodeinfo, force)
async def get(*args: Any, **kwargs: Any) -> Any: async def get(*args: Any, **kwargs: Any) -> Any:

View file

@ -1,16 +1,15 @@
from __future__ import annotations
import logging import logging
import os import os
from enum import IntEnum from enum import IntEnum
from pathlib import Path from pathlib import Path
from typing import Any, Protocol from typing import TYPE_CHECKING, Any, Protocol
try: if TYPE_CHECKING:
from typing import Self from typing import Self
except ImportError:
from typing_extensions import Self
class LoggingMethod(Protocol): class LoggingMethod(Protocol):
def __call__(self, msg: Any, *args: Any, **kwargs: Any) -> None: ... def __call__(self, msg: Any, *args: Any, **kwargs: Any) -> None: ...

View file

@ -6,7 +6,6 @@ import click
import json import json
import os import os
from bsql import Row
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any from typing import Any
@ -17,7 +16,8 @@ from . import http_client as http
from . import logger as logging from . import logger as logging
from .application import Application from .application import Application
from .compat import RelayConfig, RelayDatabase from .compat import RelayConfig, RelayDatabase
from .database import RELAY_SOFTWARE, get_database from .config import Config
from .database import RELAY_SOFTWARE, get_database, schema
from .misc import ACTOR_FORMATS, SOFTWARE, IS_DOCKER, Message from .misc import ACTOR_FORMATS, SOFTWARE, IS_DOCKER, Message
@ -213,6 +213,24 @@ def cli_run(ctx: click.Context, dev: bool = False) -> None:
os._exit(0) os._exit(0)
@cli.command('db-maintenance')
@click.option('--fix-timestamps', '-t', is_flag = True,
help = 'Make sure timestamps in the database are float values')
@click.pass_context
def cli_db_maintenance(ctx: click.Context, fix_timestamps: bool) -> None:
'Perform maintenance tasks on the database'
if fix_timestamps:
with ctx.obj.database.session(True) as conn:
conn.fix_timestamps()
if ctx.obj.config.db_type == "postgres":
return
with ctx.obj.database.session(False) as conn:
with conn.execute("VACUUM"):
pass
@cli.command('convert') @cli.command('convert')
@click.option('--old-config', '-o', help = 'Path to the config file to convert from') @click.option('--old-config', '-o', help = 'Path to the config file to convert from')
@ -240,18 +258,18 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
ctx.obj.config.set('domain', config['host']) ctx.obj.config.set('domain', config['host'])
ctx.obj.config.save() ctx.obj.config.save()
# fix: mypy complains about the types returned by click.progressbar when updating click to 8.1.7
with get_database(ctx.obj.config) as db: with get_database(ctx.obj.config) as db:
with db.session(True) as conn: with db.session(True) as conn:
conn.put_config('private-key', database['private-key']) conn.put_config('private-key', database['private-key'])
conn.put_config('note', config['note']) conn.put_config('note', config['note'])
conn.put_config('whitelist-enabled', config['whitelist_enabled']) conn.put_config('whitelist-enabled', config['whitelist_enabled'])
with click.progressbar( # type: ignore with click.progressbar(
database['relay-list'].values(), database['relay-list'].values(),
label = 'Inboxes'.ljust(15), label = 'Inboxes'.ljust(15),
width = 0 width = 0
) as inboxes: ) as inboxes:
for inbox in inboxes: for inbox in inboxes:
if inbox['software'] in {'akkoma', 'pleroma'}: if inbox['software'] in {'akkoma', 'pleroma'}:
actor = f'https://{inbox["domain"]}/relay' actor = f'https://{inbox["domain"]}/relay'
@ -270,7 +288,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
software = inbox['software'] software = inbox['software']
) )
with click.progressbar( # type: ignore with click.progressbar(
config['blocked_software'], config['blocked_software'],
label = 'Banned software'.ljust(15), label = 'Banned software'.ljust(15),
width = 0 width = 0
@ -282,7 +300,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
reason = 'relay' if software in RELAY_SOFTWARE else None reason = 'relay' if software in RELAY_SOFTWARE else None
) )
with click.progressbar( # type: ignore with click.progressbar(
config['blocked_instances'], config['blocked_instances'],
label = 'Banned domains'.ljust(15), label = 'Banned domains'.ljust(15),
width = 0 width = 0
@ -291,7 +309,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
for domain in banned_software: for domain in banned_software:
conn.put_domain_ban(domain) conn.put_domain_ban(domain)
with click.progressbar( # type: ignore with click.progressbar(
config['whitelist'], config['whitelist'],
label = 'Whitelist'.ljust(15), label = 'Whitelist'.ljust(15),
width = 0 width = 0
@ -315,6 +333,33 @@ def cli_editconfig(ctx: click.Context, editor: str) -> None:
) )
@cli.command('switch-backend')
@click.pass_context
def cli_switchbackend(ctx: click.Context) -> None:
"""
Copy the database from one backend to the other
Be sure to set the database type to the backend you want to convert from. For instance, set
the database type to `sqlite`, fill out the connection details for postgresql, and the
data from the sqlite database will be copied to the postgresql database. This only works if
the database in postgresql already exists.
"""
config = Config(ctx.obj.config.path, load = True)
config.db_type = "sqlite" if config.db_type == "postgres" else "postgres"
database = get_database(config, migrate = False)
with database.session(True) as new, ctx.obj.database.session(False) as old:
new.create_tables()
for table in schema.TABLES.keys():
for row in old.execute(f"SELECT * FROM {table}"):
new.insert(table, row).close()
config.save()
click.echo(f"Converted database to {repr(config.db_type)}")
@cli.group('config') @cli.group('config')
def cli_config() -> None: def cli_config() -> None:
'Manage the relay settings stored in the database' 'Manage the relay settings stored in the database'
@ -348,10 +393,15 @@ def cli_config_list(ctx: click.Context) -> None:
def cli_config_set(ctx: click.Context, key: str, value: Any) -> None: def cli_config_set(ctx: click.Context, key: str, value: Any) -> None:
'Set a config value' 'Set a config value'
try:
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
new_value = conn.put_config(key, value) new_value = conn.put_config(key, value)
print(f'{key}: {repr(new_value)}') except Exception:
click.echo(f'Invalid config name: {key}')
return
click.echo(f'{key}: {repr(new_value)}')
@cli.group('user') @cli.group('user')
@ -367,8 +417,8 @@ def cli_user_list(ctx: click.Context) -> None:
click.echo('Users:') click.echo('Users:')
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
for user in conn.execute('SELECT * FROM users'): for row in conn.get_users():
click.echo(f'- {user["username"]}') click.echo(f'- {row.username}')
@cli_user.command('create') @cli_user.command('create')
@ -379,7 +429,7 @@ def cli_user_create(ctx: click.Context, username: str, handle: str) -> None:
'Create a new local user' 'Create a new local user'
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
if conn.get_user(username): if conn.get_user(username) is not None:
click.echo(f'User already exists: {username}') click.echo(f'User already exists: {username}')
return return
@ -406,7 +456,7 @@ def cli_user_delete(ctx: click.Context, username: str) -> None:
'Delete a local user' 'Delete a local user'
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
if not conn.get_user(username): if conn.get_user(username) is None:
click.echo(f'User does not exist: {username}') click.echo(f'User does not exist: {username}')
return return
@ -424,8 +474,8 @@ def cli_user_list_tokens(ctx: click.Context, username: str) -> None:
click.echo(f'Tokens for "{username}":') click.echo(f'Tokens for "{username}":')
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
for token in conn.execute('SELECT * FROM tokens WHERE user = :user', {'user': username}): for row in conn.get_tokens(username):
click.echo(f'- {token["code"]}') click.echo(f'- {row.code}')
@cli_user.command('create-token') @cli_user.command('create-token')
@ -435,13 +485,13 @@ def cli_user_create_token(ctx: click.Context, username: str) -> None:
'Create a new API token for a user' 'Create a new API token for a user'
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
if not (user := conn.get_user(username)): if (user := conn.get_user(username)) is None:
click.echo(f'User does not exist: {username}') click.echo(f'User does not exist: {username}')
return return
token = conn.put_token(user['username']) token = conn.put_token(user.username)
click.echo(f'New token for "{username}": {token["code"]}') click.echo(f'New token for "{username}": {token.code}')
@cli_user.command('delete-token') @cli_user.command('delete-token')
@ -451,7 +501,7 @@ def cli_user_delete_token(ctx: click.Context, code: str) -> None:
'Delete an API token' 'Delete an API token'
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
if not conn.get_token(code): if conn.get_token(code) is None:
click.echo('Token does not exist') click.echo('Token does not exist')
return return
@ -473,8 +523,8 @@ def cli_inbox_list(ctx: click.Context) -> None:
click.echo('Connected to the following instances or relays:') click.echo('Connected to the following instances or relays:')
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
for inbox in conn.get_inboxes(): for row in conn.get_inboxes():
click.echo(f'- {inbox["inbox"]}') click.echo(f'- {row.inbox}')
@cli_inbox.command('follow') @cli_inbox.command('follow')
@ -483,19 +533,21 @@ def cli_inbox_list(ctx: click.Context) -> None:
def cli_inbox_follow(ctx: click.Context, actor: str) -> None: def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
'Follow an actor (Relay must be running)' 'Follow an actor (Relay must be running)'
instance: schema.Instance | None = None
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
if conn.get_domain_ban(actor): if conn.get_domain_ban(actor):
click.echo(f'Error: Refusing to follow banned actor: {actor}') click.echo(f'Error: Refusing to follow banned actor: {actor}')
return return
if (inbox_data := conn.get_inbox(actor)): if (instance := conn.get_inbox(actor)) is not None:
inbox = inbox_data['inbox'] inbox = instance.inbox
else: else:
if not actor.startswith('http'): if not actor.startswith('http'):
actor = f'https://{actor}/actor' actor = f'https://{actor}/actor'
if not (actor_data := asyncio.run(http.get(actor, sign_headers = True))): if (actor_data := asyncio.run(http.get(actor, sign_headers = True))) is None:
click.echo(f'Failed to fetch actor: {actor}') click.echo(f'Failed to fetch actor: {actor}')
return return
@ -506,7 +558,7 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
actor = actor actor = actor
) )
asyncio.run(http.post(inbox, message, inbox_data)) asyncio.run(http.post(inbox, message, instance))
click.echo(f'Sent follow message to actor: {actor}') click.echo(f'Sent follow message to actor: {actor}')
@ -516,19 +568,19 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None: def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
'Unfollow an actor (Relay must be running)' 'Unfollow an actor (Relay must be running)'
inbox_data: Row | None = None instance: schema.Instance | None = None
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
if conn.get_domain_ban(actor): if conn.get_domain_ban(actor):
click.echo(f'Error: Refusing to follow banned actor: {actor}') click.echo(f'Error: Refusing to follow banned actor: {actor}')
return return
if (inbox_data := conn.get_inbox(actor)): if (instance := conn.get_inbox(actor)):
inbox = inbox_data['inbox'] inbox = instance.inbox
message = Message.new_unfollow( message = Message.new_unfollow(
host = ctx.obj.config.domain, host = ctx.obj.config.domain,
actor = actor, actor = actor,
follow = inbox_data['followid'] follow = instance.followid
) )
else: else:
@ -552,7 +604,7 @@ def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
} }
) )
asyncio.run(http.post(inbox, message, inbox_data)) asyncio.run(http.post(inbox, message, instance))
click.echo(f'Sent unfollow message to: {actor}') click.echo(f'Sent unfollow message to: {actor}')
@ -632,9 +684,9 @@ def cli_request_list(ctx: click.Context) -> None:
click.echo('Follow requests:') click.echo('Follow requests:')
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
for instance in conn.get_requests(): for row in conn.get_requests():
date = instance['created'].strftime('%Y-%m-%d') date = row.created.strftime('%Y-%m-%d')
click.echo(f'- [{date}] {instance["domain"]}') click.echo(f'- [{date}] {row.domain}')
@cli_request.command('accept') @cli_request.command('accept')
@ -653,20 +705,20 @@ def cli_request_accept(ctx: click.Context, domain: str) -> None:
message = Message.new_response( message = Message.new_response(
host = ctx.obj.config.domain, host = ctx.obj.config.domain,
actor = instance['actor'], actor = instance.actor,
followid = instance['followid'], followid = instance.followid,
accept = True accept = True
) )
asyncio.run(http.post(instance['inbox'], message, instance)) asyncio.run(http.post(instance.inbox, message, instance))
if instance['software'] != 'mastodon': if instance.software != 'mastodon':
message = Message.new_follow( message = Message.new_follow(
host = ctx.obj.config.domain, host = ctx.obj.config.domain,
actor = instance['actor'] actor = instance.actor
) )
asyncio.run(http.post(instance['inbox'], message, instance)) asyncio.run(http.post(instance.inbox, message, instance))
@cli_request.command('deny') @cli_request.command('deny')
@ -685,12 +737,12 @@ def cli_request_deny(ctx: click.Context, domain: str) -> None:
response = Message.new_response( response = Message.new_response(
host = ctx.obj.config.domain, host = ctx.obj.config.domain,
actor = instance['actor'], actor = instance.actor,
followid = instance['followid'], followid = instance.followid,
accept = False accept = False
) )
asyncio.run(http.post(instance['inbox'], response, instance)) asyncio.run(http.post(instance.inbox, response, instance))
@cli.group('instance') @cli.group('instance')
@ -706,12 +758,12 @@ def cli_instance_list(ctx: click.Context) -> None:
click.echo('Banned domains:') click.echo('Banned domains:')
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
for instance in conn.execute('SELECT * FROM domain_bans'): for row in conn.get_domain_bans():
if instance['reason']: if row.reason is not None:
click.echo(f'- {instance["domain"]} ({instance["reason"]})') click.echo(f'- {row.domain} ({row.reason})')
else: else:
click.echo(f'- {instance["domain"]}') click.echo(f'- {row.domain}')
@cli_instance.command('ban') @cli_instance.command('ban')
@ -723,7 +775,7 @@ def cli_instance_ban(ctx: click.Context, domain: str, reason: str, note: str) ->
'Ban an instance and remove the associated inbox if it exists' 'Ban an instance and remove the associated inbox if it exists'
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
if conn.get_domain_ban(domain): if conn.get_domain_ban(domain) is not None:
click.echo(f'Domain already banned: {domain}') click.echo(f'Domain already banned: {domain}')
return return
@ -739,7 +791,7 @@ def cli_instance_unban(ctx: click.Context, domain: str) -> None:
'Unban an instance' 'Unban an instance'
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
if not conn.del_domain_ban(domain): if conn.del_domain_ban(domain) is None:
click.echo(f'Instance wasn\'t banned: {domain}') click.echo(f'Instance wasn\'t banned: {domain}')
return return
@ -764,11 +816,11 @@ def cli_instance_update(ctx: click.Context, domain: str, reason: str, note: str)
click.echo(f'Updated domain ban: {domain}') click.echo(f'Updated domain ban: {domain}')
if row['reason']: if row.reason:
click.echo(f'- {row["domain"]} ({row["reason"]})') click.echo(f'- {row.domain} ({row.reason})')
else: else:
click.echo(f'- {row["domain"]}') click.echo(f'- {row.domain}')
@cli.group('software') @cli.group('software')
@ -784,12 +836,12 @@ def cli_software_list(ctx: click.Context) -> None:
click.echo('Banned software:') click.echo('Banned software:')
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
for software in conn.execute('SELECT * FROM software_bans'): for row in conn.get_software_bans():
if software['reason']: if row.reason:
click.echo(f'- {software["name"]} ({software["reason"]})') click.echo(f'- {row.name} ({row.reason})')
else: else:
click.echo(f'- {software["name"]}') click.echo(f'- {row.name}')
@cli_software.command('ban') @cli_software.command('ban')
@ -811,12 +863,12 @@ def cli_software_ban(ctx: click.Context,
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
if name == 'RELAYS': if name == 'RELAYS':
for software in RELAY_SOFTWARE: for item in RELAY_SOFTWARE:
if conn.get_software_ban(software): if conn.get_software_ban(item):
click.echo(f'Relay already banned: {software}') click.echo(f'Relay already banned: {item}')
continue continue
conn.put_software_ban(software, reason or 'relay', note) conn.put_software_ban(item, reason or 'relay', note)
click.echo('Banned all relay software') click.echo('Banned all relay software')
return return
@ -893,11 +945,11 @@ def cli_software_update(ctx: click.Context, name: str, reason: str, note: str) -
click.echo(f'Updated software ban: {name}') click.echo(f'Updated software ban: {name}')
if row['reason']: if row.reason:
click.echo(f'- {row["name"]} ({row["reason"]})') click.echo(f'- {row.name} ({row.reason})')
else: else:
click.echo(f'- {row["name"]}') click.echo(f'- {row.name}')
@cli.group('whitelist') @cli.group('whitelist')
@ -913,8 +965,8 @@ def cli_whitelist_list(ctx: click.Context) -> None:
click.echo('Current whitelisted domains:') click.echo('Current whitelisted domains:')
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
for domain in conn.execute('SELECT * FROM whitelist'): for row in conn.get_domain_whitelist():
click.echo(f'- {domain["domain"]}') click.echo(f'- {row.domain}')
@cli_whitelist.command('add') @cli_whitelist.command('add')
@ -953,23 +1005,19 @@ def cli_whitelist_remove(ctx: click.Context, domain: str) -> None:
@cli_whitelist.command('import') @cli_whitelist.command('import')
@click.pass_context @click.pass_context
def cli_whitelist_import(ctx: click.Context) -> None: def cli_whitelist_import(ctx: click.Context) -> None:
'Add all current inboxes to the whitelist' 'Add all current instances to the whitelist'
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
for inbox in conn.execute('SELECT * FROM inboxes').all(): for row in conn.get_inboxes():
if conn.get_domain_whitelist(inbox['domain']): if conn.get_domain_whitelist(row.domain) is not None:
click.echo(f'Domain already in whitelist: {inbox["domain"]}') click.echo(f'Domain already in whitelist: {row.domain}')
continue continue
conn.put_domain_whitelist(inbox['domain']) conn.put_domain_whitelist(row.domain)
click.echo('Imported whitelist from inboxes') click.echo('Imported whitelist from inboxes')
def main() -> None: def main() -> None:
cli(prog_name='relay') cli(prog_name='activityrelay')
if __name__ == '__main__':
click.echo('Running relay.manage is depreciated. Run `activityrelay [command]` instead.')

View file

@ -4,28 +4,15 @@ import aputils
import json import json
import os import os
import platform import platform
import socket
from aiohttp.web import Response as AiohttpResponse from aiohttp.web import Response as AiohttpResponse
from collections.abc import Sequence from collections.abc import Sequence
from datetime import datetime from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypedDict, TypeVar from typing import TYPE_CHECKING, Any, TypedDict, TypeVar
from uuid import uuid4 from uuid import uuid4
try:
from importlib.resources import files as pkgfiles
except ImportError:
from importlib_resources import files as pkgfiles # type: ignore
try:
from typing import Self
except ImportError:
from typing_extensions import Self
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Self
from .application import Application from .application import Application
@ -50,11 +37,6 @@ MIMETYPES = {
'webmanifest': 'application/manifest+json' 'webmanifest': 'application/manifest+json'
} }
NODEINFO_NS = {
'20': 'http://nodeinfo.diaspora.software/ns/schema/2.0',
'21': 'http://nodeinfo.diaspora.software/ns/schema/2.1'
}
ACTOR_FORMATS = { ACTOR_FORMATS = {
'mastodon': 'https://{domain}/actor', 'mastodon': 'https://{domain}/actor',
'akkoma': 'https://{domain}/relay', 'akkoma': 'https://{domain}/relay',
@ -72,42 +54,26 @@ SOFTWARE = (
'gotosocial' 'gotosocial'
) )
JSON_PATHS: tuple[str, ...] = (
'/api/v1',
'/actor',
'/inbox',
'/outbox',
'/following',
'/followers',
'/.well-known',
'/nodeinfo',
'/oauth/token',
'/oauth/revoke'
)
def boolean(value: Any) -> bool: TOKEN_PATHS: tuple[str, ...] = (
if isinstance(value, str): '/logout',
if value.lower() in {'on', 'y', 'yes', 'true', 'enable', 'enabled', '1'}: '/admin',
return True '/api',
'/oauth/authorize',
if value.lower() in {'off', 'n', 'no', 'false', 'disable', 'disabled', '0'}: '/oauth/revoke'
return False )
raise TypeError(f'Cannot parse string "{value}" as a boolean')
if isinstance(value, int):
if value == 1:
return True
if value == 0:
return False
raise ValueError('Integer value must be 1 or 0')
if value is None:
return False
return bool(value)
def check_open_port(host: str, port: int) -> bool:
if host == '0.0.0.0':
host = '127.0.0.1'
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
return s.connect_ex((host, port)) != 0
except socket.error:
return False
def get_app() -> Application: def get_app() -> Application:
@ -119,10 +85,6 @@ def get_app() -> Application:
return Application.DEFAULT return Application.DEFAULT
def get_resource(path: str) -> Path:
return Path(str(pkgfiles('relay'))).joinpath(path)
class JsonEncoder(json.JSONEncoder): class JsonEncoder(json.JSONEncoder):
def default(self, o: Any) -> str: def default(self, o: Any) -> str:
if isinstance(o, datetime): if isinstance(o, datetime):
@ -240,21 +202,9 @@ class Response(AiohttpResponse):
@classmethod @classmethod
def new_error(cls: type[Self], def new_redir(cls: type[Self], path: str, status: int = 307) -> Self:
status: int,
body: str | bytes | dict[str, Any],
ctype: str = 'text') -> Self:
if ctype == 'json':
body = {'error': body}
return cls.new(body=body, status=status, ctype=ctype)
@classmethod
def new_redir(cls: type[Self], path: str) -> Self:
body = f'Redirect to <a href="{path}">{path}</a>' body = f'Redirect to <a href="{path}">{path}</a>'
return cls.new(body, 302, {'Location': path}) return cls.new(body, status, {'Location': path}, ctype = 'html')
@property @property

View file

@ -34,7 +34,7 @@ async def handle_relay(view: ActorView, conn: Connection) -> None:
logging.debug('>> relay: %s', message) logging.debug('>> relay: %s', message)
for instance in conn.distill_inboxes(view.message): for instance in conn.distill_inboxes(view.message):
view.app.push_message(instance["inbox"], message, instance) view.app.push_message(instance.inbox, message, instance)
view.cache.set('handle-relay', view.message.object_id, message.id, 'str') view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
@ -52,13 +52,13 @@ async def handle_forward(view: ActorView, conn: Connection) -> None:
logging.debug('>> forward: %s', message) logging.debug('>> forward: %s', message)
for instance in conn.distill_inboxes(view.message): for instance in conn.distill_inboxes(view.message):
view.app.push_message(instance["inbox"], view.message, instance) view.app.push_message(instance.inbox, view.message, instance)
view.cache.set('handle-relay', view.message.id, message.id, 'str') view.cache.set('handle-relay', view.message.id, message.id, 'str')
async def handle_follow(view: ActorView, conn: Connection) -> None: async def handle_follow(view: ActorView, conn: Connection) -> None:
nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain) nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain, force = True)
software = nodeinfo.sw_name if nodeinfo else None software = nodeinfo.sw_name if nodeinfo else None
config = conn.get_config_all() config = conn.get_config_all()
@ -177,7 +177,7 @@ async def handle_undo(view: ActorView, conn: Connection) -> None:
return return
# prevent past unfollows from removing an instance # prevent past unfollows from removing an instance
if view.instance['followid'] and view.instance['followid'] != view.message.object_id: if view.instance.followid and view.instance.followid != view.message.object_id:
return return
with conn.transaction(): with conn.transaction():
@ -221,18 +221,18 @@ async def run_processor(view: ActorView) -> None:
with view.database.session() as conn: with view.database.session() as conn:
if view.instance: if view.instance:
if not view.instance['software']: if not view.instance.software:
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])): if (nodeinfo := await view.client.fetch_nodeinfo(view.instance.domain)):
with conn.transaction(): with conn.transaction():
view.instance = conn.put_inbox( view.instance = conn.put_inbox(
domain = view.instance['domain'], domain = view.instance.domain,
software = nodeinfo.sw_name software = nodeinfo.sw_name
) )
if not view.instance['actor']: if not view.instance.actor:
with conn.transaction(): with conn.transaction():
view.instance = conn.put_inbox( view.instance = conn.put_inbox(
domain = view.instance['domain'], domain = view.instance.domain,
actor = view.actor.id actor = view.actor.id
) )

View file

@ -2,6 +2,8 @@ from __future__ import annotations
import textwrap import textwrap
from aiohttp.web import Request
from blib import File
from collections.abc import Callable from collections.abc import Callable
from hamlish_jinja import HamlishExtension from hamlish_jinja import HamlishExtension
from jinja2 import Environment, FileSystemLoader from jinja2 import Environment, FileSystemLoader
@ -12,14 +14,15 @@ from markdown import Markdown
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from . import __version__ from . import __version__
from .misc import get_resource
from .views.base import View
if TYPE_CHECKING: if TYPE_CHECKING:
from .application import Application from .application import Application
class Template(Environment): class Template(Environment):
_render_markdown: Callable[[str], str]
def __init__(self, app: Application): def __init__(self, app: Application):
Environment.__init__(self, Environment.__init__(self,
autoescape = True, autoescape = True,
@ -30,7 +33,7 @@ class Template(Environment):
MarkdownExtension MarkdownExtension
], ],
loader = FileSystemLoader([ loader = FileSystemLoader([
get_resource('frontend'), File.from_resource('relay', 'frontend'),
app.config.path.parent.joinpath('template') app.config.path.parent.joinpath('template')
]) ])
) )
@ -40,12 +43,12 @@ class Template(Environment):
self.hamlish_mode = 'indented' self.hamlish_mode = 'indented'
def render(self, path: str, view: View | None = None, **context: Any) -> str: def render(self, path: str, request: Request, **context: Any) -> str:
with self.app.database.session(False) as conn: with self.app.database.session(False) as conn:
config = conn.get_config_all() config = conn.get_config_all()
new_context = { new_context = {
'view': view, 'request': request,
'domain': self.app.config.domain, 'domain': self.app.config.domain,
'version': __version__, 'version': __version__,
'config': config, 'config': config,
@ -56,7 +59,7 @@ class Template(Environment):
def render_markdown(self, text: str) -> str: def render_markdown(self, text: str) -> str:
return self._render_markdown(text) # type: ignore return self._render_markdown(text)
class MarkdownExtension(Extension): class MarkdownExtension(Extension):

View file

@ -1,26 +1,23 @@
from __future__ import annotations
import aputils import aputils
import traceback import traceback
import typing
from aiohttp.web import Request
from blib import HttpError
from .base import View, register_route from .base import View, register_route
from .. import logger as logging from .. import logger as logging
from ..database import schema
from ..misc import Message, Response from ..misc import Message, Response
from ..processors import run_processor from ..processors import run_processor
if typing.TYPE_CHECKING:
from aiohttp.web import Request
from bsql import Row
@register_route('/actor', '/inbox') @register_route('/actor', '/inbox')
class ActorView(View): class ActorView(View):
signature: aputils.Signature signature: aputils.Signature
message: Message message: Message
actor: Message actor: Message
instancce: Row instance: schema.Instance
signer: aputils.Signer signer: aputils.Signer
@ -43,16 +40,15 @@ class ActorView(View):
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
if response := await self.get_post_data(): await self.get_post_data()
return response
with self.database.session() as conn: with self.database.session() as conn:
self.instance = conn.get_inbox(self.actor.shared_inbox) self.instance = conn.get_inbox(self.actor.shared_inbox) # type: ignore[assignment]
# reject if actor is banned # reject if actor is banned
if conn.get_domain_ban(self.actor.domain): if conn.get_domain_ban(self.actor.domain):
logging.verbose('Ignored request from banned actor: %s', self.actor.id) logging.verbose('Ignored request from banned actor: %s', self.actor.id)
return Response.new_error(403, 'access denied', 'json') raise HttpError(403, 'access denied')
# reject if activity type isn't 'Follow' and the actor isn't following # reject if activity type isn't 'Follow' and the actor isn't following
if self.message.type != 'Follow' and not self.instance: if self.message.type != 'Follow' and not self.instance:
@ -61,7 +57,7 @@ class ActorView(View):
self.actor.id self.actor.id
) )
return Response.new_error(401, 'access denied', 'json') raise HttpError(401, 'access denied')
logging.debug('>> payload %s', self.message.to_json(4)) logging.debug('>> payload %s', self.message.to_json(4))
@ -69,60 +65,62 @@ class ActorView(View):
return Response.new(status = 202) return Response.new(status = 202)
async def get_post_data(self) -> Response | None: async def get_post_data(self) -> None:
try: try:
self.signature = aputils.Signature.parse(self.request.headers['signature']) self.signature = aputils.Signature.parse(self.request.headers['signature'])
except KeyError: except KeyError:
logging.verbose('Missing signature header') logging.verbose('Missing signature header')
return Response.new_error(400, 'missing signature header', 'json') raise HttpError(400, 'missing signature header')
try: try:
message: Message | None = await self.request.json(loads = Message.parse) message: Message | None = await self.request.json(loads = Message.parse)
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
logging.verbose('Failed to parse inbox message') logging.verbose('Failed to parse message from actor: %s', self.signature.keyid)
return Response.new_error(400, 'failed to parse message', 'json') raise HttpError(400, 'failed to parse message')
if message is None: if message is None:
logging.verbose('empty message') logging.verbose('empty message')
return Response.new_error(400, 'missing message', 'json') raise HttpError(400, 'missing message')
self.message = message self.message = message
if 'actor' not in self.message: if 'actor' not in self.message:
logging.verbose('actor not in message') logging.verbose('actor not in message')
return Response.new_error(400, 'no actor in message', 'json') raise HttpError(400, 'no actor in message')
actor: Message | None = await self.client.get(self.signature.keyid, True, Message) try:
self.actor = await self.client.get(self.signature.keyid, True, Message)
if actor is None: except HttpError as e:
# ld signatures aren't handled atm, so just ignore it # ld signatures aren't handled atm, so just ignore it
if self.message.type == 'Delete': if self.message.type == 'Delete':
logging.verbose('Instance sent a delete which cannot be handled') logging.verbose('Instance sent a delete which cannot be handled')
return Response.new(status=202) raise HttpError(202, '')
logging.verbose(f'Failed to fetch actor: {self.signature.keyid}') logging.verbose('Failed to fetch actor: %s', self.signature.keyid)
return Response.new_error(400, 'failed to fetch actor', 'json') logging.debug('HTTP Status %i: %s', e.status, e.message)
raise HttpError(400, 'failed to fetch actor')
self.actor = actor except Exception:
traceback.print_exc()
raise HttpError(500, 'unexpected error when fetching actor')
try: try:
self.signer = self.actor.signer self.signer = self.actor.signer
except KeyError: except KeyError:
logging.verbose('Actor missing public key: %s', self.signature.keyid) logging.verbose('Actor missing public key: %s', self.signature.keyid)
return Response.new_error(400, 'actor missing public key', 'json') raise HttpError(400, 'actor missing public key')
try: try:
await self.signer.validate_request_async(self.request) await self.signer.validate_request_async(self.request)
except aputils.SignatureFailureError as e: except aputils.SignatureFailureError as e:
logging.verbose('signature validation failed for "%s": %s', self.actor.id, e) logging.verbose('signature validation failed for "%s": %s', self.actor.id, e)
return Response.new_error(401, str(e), 'json') raise HttpError(401, str(e))
return None
@register_route('/outbox') @register_route('/outbox')
@ -165,10 +163,10 @@ class WebfingerView(View):
subject = request.query['resource'] subject = request.query['resource']
except KeyError: except KeyError:
return Response.new_error(400, 'missing "resource" query key', 'json') raise HttpError(400, 'missing "resource" query key')
if subject != f'acct:relay@{self.config.domain}': if subject != f'acct:relay@{self.config.domain}':
return Response.new_error(404, 'user not found', 'json') raise HttpError(404, 'user not found')
data = aputils.Webfinger.new( data = aputils.Webfinger.new(
handle = 'relay', handle = 'relay',

View file

@ -1,17 +1,20 @@
import traceback
from aiohttp.web import Request, middleware from aiohttp.web import Request, middleware
from argon2.exceptions import VerifyMismatchError from argon2.exceptions import VerifyMismatchError
from blib import HttpError, convert_to_boolean
from collections.abc import Awaitable, Callable, Sequence from collections.abc import Awaitable, Callable, Sequence
from typing import Any
from urllib.parse import urlparse from urllib.parse import urlparse
from .base import View, register_route from .base import View, register_route
from .. import __version__ from .. import __version__
from ..database import ConfigData from ..database import ConfigData, schema
from ..misc import Message, Response, boolean, get_app from ..misc import Message, Response
ALLOWED_HEADERS = { DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob'
ALLOWED_HEADERS: set[str] = {
'accept', 'accept',
'authorization', 'authorization',
'content-type' 'content-type'
@ -19,6 +22,8 @@ ALLOWED_HEADERS = {
PUBLIC_API_PATHS: Sequence[tuple[str, str]] = ( PUBLIC_API_PATHS: Sequence[tuple[str, str]] = (
('GET', '/api/v1/relay'), ('GET', '/api/v1/relay'),
('POST', '/api/v1/app'),
('POST', '/api/v1/login'),
('POST', '/api/v1/token') ('POST', '/api/v1/token')
) )
@ -34,64 +39,184 @@ def check_api_path(method: str, path: str) -> bool:
async def handle_api_path( async def handle_api_path(
request: Request, request: Request,
handler: Callable[[Request], Awaitable[Response]]) -> Response: handler: Callable[[Request], Awaitable[Response]]) -> Response:
try:
if (token := request.cookies.get('user-token')):
request['token'] = token
else: if not request.path.startswith('/api') or request.path == '/api/doc':
request['token'] = request.headers['Authorization'].replace('Bearer', '').strip() return await handler(request)
with get_app().database.session() as conn:
request['user'] = conn.get_user_by_token(request['token'])
except (KeyError, ValueError):
request['token'] = None
request['user'] = None
if request.method != "OPTIONS" and check_api_path(request.method, request.path): if request.method != "OPTIONS" and check_api_path(request.method, request.path):
if not request['token']: if request['token'] is None:
return Response.new_error(401, 'Missing token', 'json') raise HttpError(401, 'Missing token')
if not request['user']: if request['user'] is None:
return Response.new_error(401, 'Invalid token', 'json') raise HttpError(401, 'Invalid token')
response = await handler(request) response = await handler(request)
if request.path.startswith('/api'):
response.headers['Access-Control-Allow-Origin'] = '*' response.headers['Access-Control-Allow-Origin'] = '*'
response.headers['Access-Control-Allow-Headers'] = ', '.join(ALLOWED_HEADERS) response.headers['Access-Control-Allow-Headers'] = ', '.join(ALLOWED_HEADERS)
return response return response
@register_route('/api/v1/token') @register_route('/oauth/authorize')
class Login(View): @register_route('/api/oauth/authorize')
class OauthAuthorize(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
return Response.new({'message': 'Token valid'}, ctype = 'json') data = await self.get_api_data(['response_type', 'client_id', 'redirect_uri'], [])
if data['response_type'] != 'code':
raise HttpError(400, 'Response type is not "code"')
with self.database.session(True) as conn:
with conn.select('apps', client_id = data['client_id']) as cur:
if (app := cur.one(schema.App)) is None:
raise HttpError(404, 'Could not find app')
if app.token is not None:
raise HttpError(400, 'Application has already been authorized')
if app.auth_code is not None:
context = {'application': app}
html = self.template.render(
'page/authorize_show.haml', self.request, **context
)
return Response.new(html, ctype = 'html')
if data['redirect_uri'] != app.redirect_uri:
raise HttpError(400, 'redirect_uri does not match application')
context = {'application': app}
html = self.template.render('page/authorize_new.haml', self.request, **context)
return Response.new(html, ctype = 'html')
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
data = await self.get_api_data(['username', 'password'], []) data = await self.get_api_data(
['client_id', 'client_secret', 'redirect_uri', 'response'], []
)
if isinstance(data, Response): with self.database.session(True) as conn:
return data if (app := conn.get_app(data['client_id'], data['client_secret'])) is None:
raise HttpError(404, 'Could not find app')
if convert_to_boolean(data['response']):
if app.token is not None:
raise HttpError(400, 'Application has already been authorized')
if app.auth_code is None:
app = conn.update_app(app, request['user'], True)
if app.redirect_uri == DEFAULT_REDIRECT:
context = {'application': app}
html = self.template.render(
'page/authorize_show.haml', self.request, **context
)
return Response.new(html, ctype = 'html')
return Response.new_redir(f'{app.redirect_uri}?code={app.auth_code}')
if not conn.del_app(app.client_id, app.client_secret):
raise HttpError(404, 'App not found')
return Response.new_redir('/')
@register_route('/oauth/token')
@register_route('/api/oauth/token')
class OauthToken(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(
['grant_type', 'code', 'client_id', 'client_secret', 'redirect_uri'], []
)
if data['grant_type'] != 'authorization_code':
raise HttpError(400, 'Invalid grant type')
with self.database.session(True) as conn:
if (app := conn.get_app(data['client_id'], data['client_secret'])) is None:
raise HttpError(404, 'Application not found')
if app.auth_code != data['code']:
raise HttpError(400, 'Invalid authentication code')
if app.redirect_uri != data['redirect_uri']:
raise HttpError(400, 'Invalid redirect uri')
app = conn.update_app(app, request['user'], False)
return Response.new(app.get_api_data(True), ctype = 'json')
@register_route('/oauth/revoke')
@register_route('/api/oauth/revoke')
class OauthRevoke(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['client_id', 'client_secret', 'token'], [])
with self.database.session(True) as conn:
if (app := conn.get_app(**data)) is None:
raise HttpError(404, 'Could not find token')
if app.user != request['token'].username:
raise HttpError(403, 'Invalid token')
if not conn.del_app(**data):
raise HttpError(400, 'Token not removed')
return Response.new({'msg': 'Token deleted'}, ctype = 'json')
@register_route('/api/v1/app')
class App(View):
async def get(self, request: Request) -> Response:
return Response.new(request['token'].get_api_data(), ctype = 'json')
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['name', 'redirect_uri'], ['website'])
with self.database.session(True) as conn:
app = conn.put_app(
name = data['name'],
redirect_uri = data['redirect_uri'],
website = data.get('website')
)
return Response.new(app.get_api_data(), ctype = 'json')
async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['client_id', 'client_secret'], [])
with self.database.session(True) as conn:
if not conn.del_app(data['client_id'], data['client_secret'], request['token'].code):
raise HttpError(400, 'Token not removed')
return Response.new({'msg': 'Token deleted'}, ctype = 'json')
@register_route('/api/v1/login')
class Login(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['username', 'password'], [])
with self.database.session(True) as conn: with self.database.session(True) as conn:
if not (user := conn.get_user(data['username'])): if not (user := conn.get_user(data['username'])):
return Response.new_error(401, 'User not found', 'json') raise HttpError(401, 'User not found')
try: try:
conn.hasher.verify(user['hash'], data['password']) conn.hasher.verify(user['hash'], data['password'])
except VerifyMismatchError: except VerifyMismatchError:
return Response.new_error(401, 'Invalid password', 'json') raise HttpError(401, 'Invalid password')
token = conn.put_token(data['username']) app = conn.put_app_login(user)
resp = Response.new({'token': token['code']}, ctype = 'json') resp = Response.new(app.get_api_data(True), ctype = 'json')
resp.set_cookie( resp.set_cookie(
'user-token', 'user-token',
token['code'], app.token, # type: ignore[arg-type]
max_age = 60 * 60 * 24 * 365, max_age = 60 * 60 * 24 * 365,
domain = self.config.domain, domain = self.config.domain,
path = '/', path = '/',
@ -103,19 +228,12 @@ class Login(View):
return resp return resp
async def delete(self, request: Request) -> Response:
with self.database.session() as conn:
conn.del_token(request['token'])
return Response.new({'message': 'Token revoked'}, ctype = 'json')
@register_route('/api/v1/relay') @register_route('/api/v1/relay')
class RelayInfo(View): class RelayInfo(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
config = conn.get_config_all() config = conn.get_config_all()
inboxes = [row['domain'] for row in conn.get_inboxes()] inboxes = [row.domain for row in conn.get_inboxes()]
data = { data = {
'domain': self.config.domain, 'domain': self.config.domain,
@ -152,17 +270,16 @@ class Config(View):
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
data = await self.get_api_data(['key', 'value'], []) data = await self.get_api_data(['key', 'value'], [])
if isinstance(data, Response):
return data
data['key'] = data['key'].replace('-', '_') data['key'] = data['key'].replace('-', '_')
if data['key'] not in ConfigData.USER_KEYS(): if data['key'] not in ConfigData.USER_KEYS():
return Response.new_error(400, 'Invalid key', 'json') raise HttpError(400, 'Invalid key')
with self.database.session() as conn: with self.database.session() as conn:
conn.put_config(data['key'], data['value']) value = conn.put_config(data['key'], data['value'])
if data['key'] == 'log-level':
self.app.workers.set_log_level(value)
return Response.new({'message': 'Updated config'}, ctype = 'json') return Response.new({'message': 'Updated config'}, ctype = 'json')
@ -170,14 +287,14 @@ class Config(View):
async def delete(self, request: Request) -> Response: async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['key'], []) data = await self.get_api_data(['key'], [])
if isinstance(data, Response):
return data
if data['key'] not in ConfigData.USER_KEYS(): if data['key'] not in ConfigData.USER_KEYS():
return Response.new_error(400, 'Invalid key', 'json') raise HttpError(400, 'Invalid key')
with self.database.session() as conn: with self.database.session() as conn:
conn.put_config(data['key'], ConfigData.DEFAULT(data['key'])) value = conn.put_config(data['key'], ConfigData.DEFAULT(data['key']))
if data['key'] == 'log-level':
self.app.workers.set_log_level(value)
return Response.new({'message': 'Updated config'}, ctype = 'json') return Response.new({'message': 'Updated config'}, ctype = 'json')
@ -186,40 +303,46 @@ class Config(View):
class Inbox(View): class Inbox(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
data = conn.get_inboxes() data = tuple(conn.get_inboxes())
return Response.new(data, ctype = 'json') return Response.new(data, ctype = 'json')
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
data = await self.get_api_data(['actor'], ['inbox', 'software', 'followid']) data = await self.get_api_data(['actor'], ['inbox', 'software', 'followid'])
if isinstance(data, Response):
return data
data['domain'] = urlparse(data["actor"]).netloc data['domain'] = urlparse(data["actor"]).netloc
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_inbox(data['domain']): if conn.get_inbox(data['domain']) is not None:
return Response.new_error(404, 'Instance already in database', 'json') raise HttpError(404, 'Instance already in database')
data['domain'] = data['domain'].encode('idna').decode() data['domain'] = data['domain'].encode('idna').decode()
if not data.get('inbox'): if not data.get('inbox'):
actor_data: Message | None = await self.client.get(data['actor'], True, Message) try:
actor_data = await self.client.get(data['actor'], True, Message)
if actor_data is None: except Exception:
return Response.new_error(500, 'Failed to fetch actor', 'json') traceback.print_exc()
raise HttpError(500, 'Failed to fetch actor') from None
data['inbox'] = actor_data.shared_inbox data['inbox'] = actor_data.shared_inbox
if not data.get('software'): if not data.get('software'):
try:
nodeinfo = await self.client.fetch_nodeinfo(data['domain']) nodeinfo = await self.client.fetch_nodeinfo(data['domain'])
if nodeinfo is not None:
data['software'] = nodeinfo.sw_name data['software'] = nodeinfo.sw_name
row = conn.put_inbox(**data) # type: ignore[arg-type] except Exception:
pass
row = conn.put_inbox(
domain = data['domain'],
actor = data['actor'],
inbox = data.get('inbox'),
software = data.get('software'),
followid = data.get('followid')
)
return Response.new(row, ctype = 'json') return Response.new(row, ctype = 'json')
@ -227,16 +350,17 @@ class Inbox(View):
async def patch(self, request: Request) -> Response: async def patch(self, request: Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
data = await self.get_api_data(['domain'], ['actor', 'software', 'followid']) data = await self.get_api_data(['domain'], ['actor', 'software', 'followid'])
if isinstance(data, Response):
return data
data['domain'] = data['domain'].encode('idna').decode() data['domain'] = data['domain'].encode('idna').decode()
if not (instance := conn.get_inbox(data['domain'])): if (instance := conn.get_inbox(data['domain'])) is None:
return Response.new_error(404, 'Instance with domain not found', 'json') raise HttpError(404, 'Instance with domain not found')
instance = conn.put_inbox(instance['domain'], **data) # type: ignore[arg-type] instance = conn.put_inbox(
instance.domain,
actor = data.get('actor'),
software = data.get('software'),
followid = data.get('followid')
)
return Response.new(instance, ctype = 'json') return Response.new(instance, ctype = 'json')
@ -244,14 +368,10 @@ class Inbox(View):
async def delete(self, request: Request) -> Response: async def delete(self, request: Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
data = await self.get_api_data(['domain'], []) data = await self.get_api_data(['domain'], [])
if isinstance(data, Response):
return data
data['domain'] = data['domain'].encode('idna').decode() data['domain'] = data['domain'].encode('idna').decode()
if not conn.get_inbox(data['domain']): if not conn.get_inbox(data['domain']):
return Response.new_error(404, 'Instance with domain not found', 'json') raise HttpError(404, 'Instance with domain not found')
conn.del_inbox(data['domain']) conn.del_inbox(data['domain'])
@ -262,43 +382,41 @@ class Inbox(View):
class RequestView(View): class RequestView(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
instances = conn.get_requests() instances = tuple(conn.get_requests())
return Response.new(instances, ctype = 'json') return Response.new(instances, ctype = 'json')
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
data: dict[str, Any] | Response = await self.get_api_data(['domain', 'accept'], []) data = await self.get_api_data(['domain', 'accept'], [])
if isinstance(data, Response):
return data
data['accept'] = boolean(data['accept'])
data['domain'] = data['domain'].encode('idna').decode() data['domain'] = data['domain'].encode('idna').decode()
try: try:
with self.database.session(True) as conn: with self.database.session(True) as conn:
instance = conn.put_request_response(data['domain'], data['accept']) instance = conn.put_request_response(
data['domain'],
convert_to_boolean(data['accept'])
)
except KeyError: except KeyError:
return Response.new_error(404, 'Request not found', 'json') raise HttpError(404, 'Request not found') from None
message = Message.new_response( message = Message.new_response(
host = self.config.domain, host = self.config.domain,
actor = instance['actor'], actor = instance.actor,
followid = instance['followid'], followid = instance.followid,
accept = data['accept'] accept = convert_to_boolean(data['accept'])
) )
self.app.push_message(instance['inbox'], message, instance) self.app.push_message(instance.inbox, message, instance)
if data['accept'] and instance['software'] != 'mastodon': if data['accept'] and instance.software != 'mastodon':
message = Message.new_follow( message = Message.new_follow(
host = self.config.domain, host = self.config.domain,
actor = instance['actor'] actor = instance.actor
) )
self.app.push_message(instance['inbox'], message, instance) self.app.push_message(instance.inbox, message, instance)
resp_message = {'message': 'Request accepted' if data['accept'] else 'Request denied'} resp_message = {'message': 'Request accepted' if data['accept'] else 'Request denied'}
return Response.new(resp_message, ctype = 'json') return Response.new(resp_message, ctype = 'json')
@ -308,24 +426,24 @@ class RequestView(View):
class DomainBan(View): class DomainBan(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
bans = tuple(conn.execute('SELECT * FROM domain_bans').all()) bans = tuple(conn.get_domain_bans())
return Response.new(bans, ctype = 'json') return Response.new(bans, ctype = 'json')
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
data = await self.get_api_data(['domain'], ['note', 'reason']) data = await self.get_api_data(['domain'], ['note', 'reason'])
if isinstance(data, Response):
return data
data['domain'] = data['domain'].encode('idna').decode() data['domain'] = data['domain'].encode('idna').decode()
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_domain_ban(data['domain']): if conn.get_domain_ban(data['domain']) is not None:
return Response.new_error(400, 'Domain already banned', 'json') raise HttpError(400, 'Domain already banned')
ban = conn.put_domain_ban(**data) ban = conn.put_domain_ban(
domain = data['domain'],
reason = data.get('reason'),
note = data.get('note')
)
return Response.new(ban, ctype = 'json') return Response.new(ban, ctype = 'json')
@ -334,18 +452,19 @@ class DomainBan(View):
with self.database.session() as conn: with self.database.session() as conn:
data = await self.get_api_data(['domain'], ['note', 'reason']) data = await self.get_api_data(['domain'], ['note', 'reason'])
if isinstance(data, Response): if not any([data.get('note'), data.get('reason')]):
return data raise HttpError(400, 'Must include note and/or reason parameters')
data['domain'] = data['domain'].encode('idna').decode() data['domain'] = data['domain'].encode('idna').decode()
if not conn.get_domain_ban(data['domain']): if conn.get_domain_ban(data['domain']) is None:
return Response.new_error(404, 'Domain not banned', 'json') raise HttpError(404, 'Domain not banned')
if not any([data.get('note'), data.get('reason')]): ban = conn.update_domain_ban(
return Response.new_error(400, 'Must include note and/or reason parameters', 'json') domain = data['domain'],
reason = data.get('reason'),
ban = conn.update_domain_ban(**data) note = data.get('note')
)
return Response.new(ban, ctype = 'json') return Response.new(ban, ctype = 'json')
@ -353,14 +472,10 @@ class DomainBan(View):
async def delete(self, request: Request) -> Response: async def delete(self, request: Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
data = await self.get_api_data(['domain'], []) data = await self.get_api_data(['domain'], [])
if isinstance(data, Response):
return data
data['domain'] = data['domain'].encode('idna').decode() data['domain'] = data['domain'].encode('idna').decode()
if not conn.get_domain_ban(data['domain']): if conn.get_domain_ban(data['domain']) is None:
return Response.new_error(404, 'Domain not banned', 'json') raise HttpError(404, 'Domain not banned')
conn.del_domain_ban(data['domain']) conn.del_domain_ban(data['domain'])
@ -371,7 +486,7 @@ class DomainBan(View):
class SoftwareBan(View): class SoftwareBan(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
bans = tuple(conn.execute('SELECT * FROM software_bans').all()) bans = tuple(conn.get_software_bans())
return Response.new(bans, ctype = 'json') return Response.new(bans, ctype = 'json')
@ -379,14 +494,15 @@ class SoftwareBan(View):
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
data = await self.get_api_data(['name'], ['note', 'reason']) data = await self.get_api_data(['name'], ['note', 'reason'])
if isinstance(data, Response):
return data
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_software_ban(data['name']): if conn.get_software_ban(data['name']) is not None:
return Response.new_error(400, 'Domain already banned', 'json') raise HttpError(400, 'Domain already banned')
ban = conn.put_software_ban(**data) ban = conn.put_software_ban(
name = data['name'],
reason = data.get('reason'),
note = data.get('note')
)
return Response.new(ban, ctype = 'json') return Response.new(ban, ctype = 'json')
@ -394,17 +510,18 @@ class SoftwareBan(View):
async def patch(self, request: Request) -> Response: async def patch(self, request: Request) -> Response:
data = await self.get_api_data(['name'], ['note', 'reason']) data = await self.get_api_data(['name'], ['note', 'reason'])
if isinstance(data, Response): if not any([data.get('note'), data.get('reason')]):
return data raise HttpError(400, 'Must include note and/or reason parameters')
with self.database.session() as conn: with self.database.session() as conn:
if not conn.get_software_ban(data['name']): if conn.get_software_ban(data['name']) is None:
return Response.new_error(404, 'Software not banned', 'json') raise HttpError(404, 'Software not banned')
if not any([data.get('note'), data.get('reason')]): ban = conn.update_software_ban(
return Response.new_error(400, 'Must include note and/or reason parameters', 'json') name = data['name'],
reason = data.get('reason'),
ban = conn.update_software_ban(**data) note = data.get('note')
)
return Response.new(ban, ctype = 'json') return Response.new(ban, ctype = 'json')
@ -412,12 +529,9 @@ class SoftwareBan(View):
async def delete(self, request: Request) -> Response: async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['name'], []) data = await self.get_api_data(['name'], [])
if isinstance(data, Response):
return data
with self.database.session() as conn: with self.database.session() as conn:
if not conn.get_software_ban(data['name']): if conn.get_software_ban(data['name']) is None:
return Response.new_error(404, 'Software not banned', 'json') raise HttpError(404, 'Software not banned')
conn.del_software_ban(data['name']) conn.del_software_ban(data['name'])
@ -430,7 +544,7 @@ class User(View):
with self.database.session() as conn: with self.database.session() as conn:
items = [] items = []
for row in conn.execute('SELECT * FROM users'): for row in conn.get_users():
del row['hash'] del row['hash']
items.append(row) items.append(row)
@ -440,41 +554,40 @@ class User(View):
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
data = await self.get_api_data(['username', 'password'], ['handle']) data = await self.get_api_data(['username', 'password'], ['handle'])
if isinstance(data, Response):
return data
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_user(data['username']): if conn.get_user(data['username']) is not None:
return Response.new_error(404, 'User already exists', 'json') raise HttpError(404, 'User already exists')
user = conn.put_user(
username = data['username'],
password = data['password'],
handle = data.get('handle')
)
user = conn.put_user(**data)
del user['hash'] del user['hash']
return Response.new(user, ctype = 'json') return Response.new(user, ctype = 'json')
async def patch(self, request: Request) -> Response: async def patch(self, request: Request) -> Response:
data = await self.get_api_data(['username'], ['password', 'handle']) data = await self.get_api_data(['username'], ['password', 'handle'])
if isinstance(data, Response):
return data
with self.database.session(True) as conn: with self.database.session(True) as conn:
user = conn.put_user(**data) user = conn.put_user(
del user['hash'] username = data['username'],
password = data['password'],
handle = data.get('handle')
)
del user['hash']
return Response.new(user, ctype = 'json') return Response.new(user, ctype = 'json')
async def delete(self, request: Request) -> Response: async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['username'], []) data = await self.get_api_data(['username'], [])
if isinstance(data, Response):
return data
with self.database.session(True) as conn: with self.database.session(True) as conn:
if not conn.get_user(data['username']): if conn.get_user(data['username']) is None:
return Response.new_error(404, 'User does not exist', 'json') raise HttpError(404, 'User does not exist')
conn.del_user(data['username']) conn.del_user(data['username'])
@ -485,7 +598,7 @@ class User(View):
class Whitelist(View): class Whitelist(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
items = tuple(conn.execute('SELECT * FROM whitelist').all()) items = tuple(conn.get_domains_whitelist())
return Response.new(items, ctype = 'json') return Response.new(items, ctype = 'json')
@ -493,16 +606,13 @@ class Whitelist(View):
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
data = await self.get_api_data(['domain'], []) data = await self.get_api_data(['domain'], [])
if isinstance(data, Response): domain = data['domain'].encode('idna').decode()
return data
data['domain'] = data['domain'].encode('idna').decode()
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_domain_whitelist(data['domain']): if conn.get_domain_whitelist(domain) is not None:
return Response.new_error(400, 'Domain already added to whitelist', 'json') raise HttpError(400, 'Domain already added to whitelist')
item = conn.put_domain_whitelist(**data) item = conn.put_domain_whitelist(domain)
return Response.new(item, ctype = 'json') return Response.new(item, ctype = 'json')
@ -510,15 +620,12 @@ class Whitelist(View):
async def delete(self, request: Request) -> Response: async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['domain'], []) data = await self.get_api_data(['domain'], [])
if isinstance(data, Response): domain = data['domain'].encode('idna').decode()
return data
data['domain'] = data['domain'].encode('idna').decode()
with self.database.session() as conn: with self.database.session() as conn:
if not conn.get_domain_whitelist(data['domain']): if conn.get_domain_whitelist(domain) is None:
return Response.new_error(404, 'Domain not in whitelist', 'json') raise HttpError(404, 'Domain not in whitelist')
conn.del_domain_whitelist(data['domain']) conn.del_domain_whitelist(domain)
return Response.new({'message': 'Removed domain from whitelist'}, ctype = 'json') return Response.new({'message': 'Removed domain from whitelist'}, ctype = 'json')

View file

@ -1,10 +1,9 @@
from __future__ import annotations from __future__ import annotations
from Crypto.Random import get_random_bytes
from aiohttp.abc import AbstractView from aiohttp.abc import AbstractView
from aiohttp.hdrs import METH_ALL as METHODS from aiohttp.hdrs import METH_ALL as METHODS
from aiohttp.web import HTTPMethodNotAllowed, Request from aiohttp.web import Request
from base64 import b64encode from blib import HttpError
from bsql import Database from bsql import Database
from collections.abc import Awaitable, Callable, Generator, Sequence, Mapping from collections.abc import Awaitable, Callable, Generator, Sequence, Mapping
from functools import cached_property from functools import cached_property
@ -18,18 +17,12 @@ from ..http_client import HttpClient
from ..misc import Response, get_app from ..misc import Response, get_app
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Self
from ..application import Application from ..application import Application
from ..template import Template from ..template import Template
try:
from typing import Self
except ImportError:
from typing_extensions import Self
HandlerCallback = Callable[[Request], Awaitable[Response]] HandlerCallback = Callable[[Request], Awaitable[Response]]
VIEWS: list[tuple[str, type[View]]] = [] VIEWS: list[tuple[str, type[View]]] = []
@ -49,10 +42,10 @@ def register_route(*paths: str) -> Callable[[type[View]], type[View]]:
class View(AbstractView): class View(AbstractView):
def __await__(self) -> Generator[Any, None, Response]: def __await__(self) -> Generator[Any, None, Response]:
if self.request.method not in METHODS: if self.request.method not in METHODS:
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) raise HttpError(405, f'"{self.request.method}" method not allowed')
if not (handler := self.handlers.get(self.request.method)): if not (handler := self.handlers.get(self.request.method)):
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) raise HttpError(405, f'"{self.request.method}" method not allowed')
return self._run_handler(handler).__await__() return self._run_handler(handler).__await__()
@ -64,7 +57,6 @@ class View(AbstractView):
async def _run_handler(self, handler: HandlerCallback, **kwargs: Any) -> Response: async def _run_handler(self, handler: HandlerCallback, **kwargs: Any) -> Response:
self.request['hash'] = b64encode(get_random_bytes(16)).decode('ascii')
return await handler(self.request, **self.request.match_info, **kwargs) return await handler(self.request, **self.request.match_info, **kwargs)
@ -123,17 +115,18 @@ class View(AbstractView):
async def get_api_data(self, async def get_api_data(self,
required: list[str], required: list[str],
optional: list[str]) -> dict[str, str] | Response: optional: list[str]) -> dict[str, str]:
if self.request.content_type in {'x-www-form-urlencoded', 'multipart/form-data'}: if self.request.content_type in {'application/x-www-form-urlencoded', 'multipart/form-data'}:
post_data = convert_data(await self.request.post()) post_data = convert_data(await self.request.post())
# post_data = {key: value for key, value in parse_qsl(await self.request.text())}
elif self.request.content_type == 'application/json': elif self.request.content_type == 'application/json':
try: try:
post_data = convert_data(await self.request.json()) post_data = convert_data(await self.request.json())
except JSONDecodeError: except JSONDecodeError:
return Response.new_error(400, 'Invalid JSON data', 'json') raise HttpError(400, 'Invalid JSON data')
else: else:
post_data = convert_data(self.request.query) post_data = convert_data(self.request.query)
@ -145,9 +138,9 @@ class View(AbstractView):
data[key] = post_data[key] data[key] = post_data[key]
except KeyError as e: except KeyError as e:
return Response.new_error(400, f'Missing {str(e)} pararmeter', 'json') raise HttpError(400, f'Missing {str(e)} pararmeter') from None
for key in optional: for key in optional:
data[key] = post_data.get(key, '') data[key] = post_data.get(key) # type: ignore[assignment]
return data return data

View file

@ -1,18 +1,13 @@
from aiohttp import web from aiohttp import web
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import Any from typing import Any
from urllib.parse import unquote
from .base import View, register_route from .base import View, register_route
from ..database import THEMES from ..database import THEMES
from ..logger import LogLevel from ..logger import LogLevel
from ..misc import Response, get_app from ..misc import TOKEN_PATHS, Response
UNAUTH_ROUTES = {
'/',
'/login'
}
@web.middleware @web.middleware
@ -20,27 +15,24 @@ async def handle_frontend_path(
request: web.Request, request: web.Request,
handler: Callable[[web.Request], Awaitable[Response]]) -> Response: handler: Callable[[web.Request], Awaitable[Response]]) -> Response:
app = get_app() if request['user'] is not None and request.path == '/login':
return Response.new_redir('/')
if request.path in UNAUTH_ROUTES or request.path.startswith('/admin'): if request.path.startswith(TOKEN_PATHS[:2]) and request['user'] is None:
request['token'] = request.cookies.get('user-token') if request.path == '/logout':
request['user'] = None return Response.new_redir('/')
if request['token']: response = Response.new_redir(f'/login?redir={request.path}')
with app.database.session(False) as conn:
request['user'] = conn.get_user_by_token(request['token'])
if request['user'] and request.path == '/login': if request['token'] is not None:
return Response.new('', 302, {'Location': '/'})
if not request['user'] and request.path.startswith('/admin'):
response = Response.new('', 302, {'Location': f'/login?redir={request.path}'})
response.del_cookie('user-token') response.del_cookie('user-token')
return response return response
response = await handler(request) response = await handler(request)
if not request.path.startswith('/api') and not request['user'] and request['token']: if not request.path.startswith('/api'):
if request['user'] is None and request['token'] is not None:
response.del_cookie('user-token') response.del_cookie('user-token')
return response return response
@ -54,14 +46,15 @@ class HomeView(View):
'instances': tuple(conn.get_inboxes()) 'instances': tuple(conn.get_inboxes())
} }
data = self.template.render('page/home.haml', self, **context) data = self.template.render('page/home.haml', self.request, **context)
return Response.new(data, ctype='html') return Response.new(data, ctype='html')
@register_route('/login') @register_route('/login')
class Login(View): class Login(View):
async def get(self, request: web.Request) -> Response: async def get(self, request: web.Request) -> Response:
data = self.template.render('page/login.haml', self) redir = unquote(request.query.get('redir', '/'))
data = self.template.render('page/login.haml', self.request, redir = redir)
return Response.new(data, ctype = 'html') return Response.new(data, ctype = 'html')
@ -69,7 +62,7 @@ class Login(View):
class Logout(View): class Logout(View):
async def get(self, request: web.Request) -> Response: async def get(self, request: web.Request) -> Response:
with self.database.session(True) as conn: with self.database.session(True) as conn:
conn.del_token(request['token']) conn.del_app(request['token'].client_id, request['token'].client_secret)
resp = Response.new_redir('/') resp = Response.new_redir('/')
resp.del_cookie('user-token', domain = self.config.domain, path = '/') resp.del_cookie('user-token', domain = self.config.domain, path = '/')
@ -79,7 +72,7 @@ class Logout(View):
@register_route('/admin') @register_route('/admin')
class Admin(View): class Admin(View):
async def get(self, request: web.Request) -> Response: async def get(self, request: web.Request) -> Response:
return Response.new('', 302, {'Location': '/admin/instances'}) return Response.new_redir(f'/login?redir={request.path}', 301)
@register_route('/admin/instances') @register_route('/admin/instances')
@ -101,7 +94,7 @@ class AdminInstances(View):
if message: if message:
context['message'] = message context['message'] = message
data = self.template.render('page/admin-instances.haml', self, **context) data = self.template.render('page/admin-instances.haml', self.request, **context)
return Response.new(data, ctype = 'html') return Response.new(data, ctype = 'html')
@ -123,7 +116,7 @@ class AdminWhitelist(View):
if message: if message:
context['message'] = message context['message'] = message
data = self.template.render('page/admin-whitelist.haml', self, **context) data = self.template.render('page/admin-whitelist.haml', self.request, **context)
return Response.new(data, ctype = 'html') return Response.new(data, ctype = 'html')
@ -145,7 +138,7 @@ class AdminDomainBans(View):
if message: if message:
context['message'] = message context['message'] = message
data = self.template.render('page/admin-domain_bans.haml', self, **context) data = self.template.render('page/admin-domain_bans.haml', self.request, **context)
return Response.new(data, ctype = 'html') return Response.new(data, ctype = 'html')
@ -167,7 +160,7 @@ class AdminSoftwareBans(View):
if message: if message:
context['message'] = message context['message'] = message
data = self.template.render('page/admin-software_bans.haml', self, **context) data = self.template.render('page/admin-software_bans.haml', self.request, **context)
return Response.new(data, ctype = 'html') return Response.new(data, ctype = 'html')
@ -189,7 +182,7 @@ class AdminUsers(View):
if message: if message:
context['message'] = message context['message'] = message
data = self.template.render('page/admin-users.haml', self, **context) data = self.template.render('page/admin-users.haml', self.request, **context)
return Response.new(data, ctype = 'html') return Response.new(data, ctype = 'html')
@ -199,10 +192,21 @@ class AdminConfig(View):
context: dict[str, Any] = { context: dict[str, Any] = {
'themes': tuple(THEMES.keys()), 'themes': tuple(THEMES.keys()),
'levels': tuple(level.name for level in LogLevel), 'levels': tuple(level.name for level in LogLevel),
'message': message 'message': message,
'desc': {
"name": "Name of the relay to be displayed in the header of the pages and in " +
"the actor endpoint.", # noqa: E131
"note": "Description of the relay to be displayed on the front page and as the " +
"bio in the actor endpoint.",
"theme": "Color theme to use on the web pages.",
"log_level": "Minimum level of logging messages to print to the console.",
"whitelist_enabled": "Only allow instances in the whitelist to be able to follow.",
"approval_required": "Require instances not on the whitelist to be approved by " +
"and admin. The `whitelist-enabled` setting is ignored when this is enabled."
}
} }
data = self.template.render('page/admin-config.haml', self, **context) data = self.template.render('page/admin-config.haml', self.request, **context)
return Response.new(data, ctype = 'html') return Response.new(data, ctype = 'html')
@ -240,5 +244,5 @@ class ThemeCss(View):
except KeyError: except KeyError:
return Response.new('Invalid theme', 404) return Response.new('Invalid theme', 404)
data = self.template.render('variables.css', self, **context) data = self.template.render('variables.css', self.request, **context)
return Response.new(data, ctype = 'css') return Response.new(data, ctype = 'css')

143
relay/workers.py Normal file
View file

@ -0,0 +1,143 @@
from __future__ import annotations
import asyncio
import traceback
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
from blib import HttpError
from dataclasses import dataclass
from multiprocessing import Event, Process, Queue, Value
from multiprocessing.queues import Queue as QueueType
from multiprocessing.sharedctypes import Synchronized
from multiprocessing.synchronize import Event as EventType
from pathlib import Path
from queue import Empty
from urllib.parse import urlparse
from . import application, logger as logging
from .database.schema import Instance
from .http_client import HttpClient
from .misc import IS_WINDOWS, Message, get_app
@dataclass
class PostItem:
inbox: str
message: Message
instance: Instance | None
@property
def domain(self) -> str:
return urlparse(self.inbox).netloc
class PushWorker(Process):
client: HttpClient
def __init__(self, queue: QueueType[PostItem], log_level: Synchronized[int]) -> None:
Process.__init__(self)
self.queue: QueueType[PostItem] = queue
self.shutdown: EventType = Event()
self.path: Path = get_app().config.path
self.log_level: Synchronized[int] = log_level
self._log_level_changed: EventType = Event()
def stop(self) -> None:
self.shutdown.set()
def run(self) -> None:
asyncio.run(self.handle_queue())
async def handle_queue(self) -> None:
if IS_WINDOWS:
app = application.Application(self.path)
self.client = app.client
self.client.open()
app.database.connect()
app.cache.setup()
else:
self.client = HttpClient()
self.client.open()
logging.verbose("[%i] Starting worker", self.pid)
while not self.shutdown.is_set():
try:
if self._log_level_changed.is_set():
logging.set_level(logging.LogLevel.parse(self.log_level.value))
self._log_level_changed.clear()
item = self.queue.get(block=True, timeout=0.1)
asyncio.create_task(self.handle_post(item))
except Empty:
await asyncio.sleep(0)
except Exception:
traceback.print_exc()
if IS_WINDOWS:
app.database.disconnect()
app.cache.close()
await self.client.close()
async def handle_post(self, item: PostItem) -> None:
try:
await self.client.post(item.inbox, item.message, item.instance)
except HttpError as e:
logging.error('HTTP Error when pushing to %s: %i %s', item.inbox, e.status, e.message)
except AsyncTimeoutError:
logging.error('Timeout when pushing to %s', item.domain)
except ClientConnectionError as e:
logging.error('Failed to connect to %s for message push: %s', item.domain, str(e))
except ClientSSLError as e:
logging.error('SSL error when pushing to %s: %s', item.domain, str(e))
class PushWorkers(list[PushWorker]):
def __init__(self, count: int) -> None:
self.queue: QueueType[PostItem] = Queue()
self._log_level: Synchronized[int] = Value("i", logging.get_level())
self._count: int = count
def push_message(self, inbox: str, message: Message, instance: Instance) -> None:
self.queue.put(PostItem(inbox, message, instance))
def set_log_level(self, value: logging.LogLevel) -> None:
self._log_level.value = value
for worker in self:
worker._log_level_changed.set()
def start(self) -> None:
if len(self) > 0:
return
for _ in range(self._count):
worker = PushWorker(self.queue, self._log_level)
worker.start()
self.append(worker)
def stop(self) -> None:
for worker in self:
worker.stop()
self.clear()