Compare commits

..

No commits in common. "5217516c8a5c6b00711a4784eea4caf26801de4d" and "dec7c6a674794866f87bc70724b63f7535d745b7" have entirely different histories.

43 changed files with 1395 additions and 3017 deletions

64
dev.py
View file

@ -5,17 +5,16 @@ import shutil
import subprocess import subprocess
import sys import sys
import time import time
import tomllib
from datetime import datetime, timedelta from datetime import datetime, timedelta
from pathlib import Path from pathlib import Path
from relay import __version__, logger as logging from relay import __version__, logger as logging
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Any, Sequence from typing import Sequence
try: try:
from watchdog.observers import Observer from watchdog.observers import Observer
from watchdog.events import FileSystemEvent, PatternMatchingEventHandler from watchdog.events import PatternMatchingEventHandler
except ImportError: except ImportError:
class PatternMatchingEventHandler: # type: ignore class PatternMatchingEventHandler: # type: ignore
@ -30,38 +29,39 @@ IGNORE_EXT = {
@click.group('cli') @click.group('cli')
def cli() -> None: def cli():
'Useful commands for development' 'Useful commands for development'
@cli.command('install') @cli.command('install')
@click.option('--no-dev', '-d', is_flag = True, help = 'Do not install development dependencies') def cli_install():
def cli_install(no_dev: bool) -> None: cmd = [
with open('pyproject.toml', 'rb') as fd: sys.executable, '-m', 'pip', 'install',
data = tomllib.load(fd) '-r', 'requirements.txt',
'-r', 'dev-requirements.txt'
]
deps = data['project']['dependencies'] subprocess.run(cmd, check = False)
if not no_dev:
deps.extend(data['project']['optional-dependencies']['dev'])
subprocess.run([sys.executable, '-m', 'pip', 'install', '-U', *deps], check = False)
@cli.command('lint') @cli.command('lint')
@click.argument('path', required = False, type = Path, default = REPO.joinpath('relay')) @click.argument('path', required = False, type = Path, default = REPO.joinpath('relay'))
@click.option('--strict', '-s', is_flag = True, help = 'Enable strict mode for mypy')
@click.option('--watch', '-w', is_flag = True, @click.option('--watch', '-w', is_flag = True,
help = 'Automatically, re-run the linters on source change') help = 'Automatically, re-run the linters on source change')
def cli_lint(path: Path, watch: bool) -> None: def cli_lint(path: Path, strict: bool, watch: bool) -> None:
path = path.expanduser().resolve() path = path.expanduser().resolve()
if watch: if watch:
handle_run_watcher([sys.executable, "dev.py", "lint", str(path)], wait = True) handle_run_watcher([sys.executable, "-m", "relay.dev", "lint", str(path)], wait = True)
return return
flake8 = [sys.executable, '-m', 'flake8', "dev.py", str(path)] flake8 = [sys.executable, '-m', 'flake8', "dev.py", str(path)]
mypy = [sys.executable, '-m', 'mypy', "dev.py", str(path)] mypy = [sys.executable, '-m', 'mypy', "dev.py", str(path)]
if strict:
mypy.append('--strict')
click.echo('----- flake8 -----') click.echo('----- flake8 -----')
subprocess.run(flake8) subprocess.run(flake8)
@ -70,7 +70,7 @@ def cli_lint(path: Path, watch: bool) -> None:
@cli.command('clean') @cli.command('clean')
def cli_clean() -> None: def cli_clean():
dirs = { dirs = {
'dist', 'dist',
'build', 'build',
@ -88,7 +88,7 @@ def cli_clean() -> None:
@cli.command('build') @cli.command('build')
def cli_build() -> None: def cli_build():
with TemporaryDirectory() as tmp: with TemporaryDirectory() as tmp:
arch = 'amd64' if sys.maxsize >= 2**32 else 'i386' arch = 'amd64' if sys.maxsize >= 2**32 else 'i386'
cmd = [ cmd = [
@ -118,7 +118,7 @@ def cli_build() -> None:
@cli.command('run') @cli.command('run')
@click.option('--dev', '-d', is_flag = True) @click.option('--dev', '-d', is_flag = True)
def cli_run(dev: bool) -> None: def cli_run(dev: bool):
print('Starting process watcher') print('Starting process watcher')
cmd = [sys.executable, '-m', 'relay', 'run'] cmd = [sys.executable, '-m', 'relay', 'run']
@ -126,20 +126,16 @@ def cli_run(dev: bool) -> None:
if dev: if dev:
cmd.append('-d') cmd.append('-d')
handle_run_watcher(cmd, watch_path = REPO.joinpath("relay")) handle_run_watcher(cmd)
def handle_run_watcher( def handle_run_watcher(*commands: Sequence[str], wait: bool = False):
*commands: Sequence[str],
watch_path: Path | str = REPO,
wait: bool = False) -> None:
handler = WatchHandler(*commands, wait = wait) handler = WatchHandler(*commands, wait = wait)
handler.run_procs() handler.run_procs()
watcher = Observer() watcher = Observer()
watcher.schedule(handler, str(watch_path), recursive=True) # type: ignore watcher.schedule(handler, str(REPO), recursive=True)
watcher.start() # type: ignore watcher.start()
try: try:
while True: while True:
@ -149,7 +145,7 @@ def handle_run_watcher(
pass pass
handler.kill_procs() handler.kill_procs()
watcher.stop() # type: ignore watcher.stop()
watcher.join() watcher.join()
@ -157,16 +153,16 @@ class WatchHandler(PatternMatchingEventHandler):
patterns = ['*.py'] patterns = ['*.py']
def __init__(self, *commands: Sequence[str], wait: bool = False) -> None: def __init__(self, *commands: Sequence[str], wait: bool = False):
PatternMatchingEventHandler.__init__(self) # type: ignore PatternMatchingEventHandler.__init__(self)
self.commands: Sequence[Sequence[str]] = commands self.commands: Sequence[Sequence[str]] = commands
self.wait: bool = wait self.wait: bool = wait
self.procs: list[subprocess.Popen[Any]] = [] self.procs: list[subprocess.Popen] = []
self.last_restart: datetime = datetime.now() self.last_restart: datetime = datetime.now()
def kill_procs(self) -> None: def kill_procs(self):
for proc in self.procs: for proc in self.procs:
if proc.poll() is not None: if proc.poll() is not None:
continue continue
@ -187,7 +183,7 @@ class WatchHandler(PatternMatchingEventHandler):
logging.info('Process terminated') logging.info('Process terminated')
def run_procs(self, restart: bool = False) -> None: def run_procs(self, restart: bool = False):
if restart: if restart:
if datetime.now() - timedelta(seconds = 3) < self.last_restart: if datetime.now() - timedelta(seconds = 3) < self.last_restart:
return return
@ -209,7 +205,7 @@ class WatchHandler(PatternMatchingEventHandler):
logging.info('Started processes with PIDs: %s', ', '.join(pids)) logging.info('Started processes with PIDs: %s', ', '.join(pids))
def on_any_event(self, event: FileSystemEvent) -> None: def on_any_event(self, event):
if event.event_type not in ['modified', 'created', 'deleted']: if event.event_type not in ['modified', 'created', 'deleted']:
return return

View file

@ -16,21 +16,19 @@ classifiers = [
"Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.12",
] ]
dependencies = [ dependencies = [
"activitypub-utils >= 0.3.1, < 0.4.0", "activitypub-utils == 0.2.1",
"aiohttp >= 3.9.5", "aiohttp >= 3.9.1",
"aiohttp-swagger[performance] == 1.0.16", "aiohttp-swagger[performance] == 1.0.16",
"argon2-cffi == 23.1.0", "argon2-cffi == 23.1.0",
"barkshark-lib >= 0.1.4, < 0.2.0", "barkshark-sql == 0.1.2",
"barkshark-sql >= 0.2.0-rc1, < 0.3.0", "click >= 8.1.2",
"click == 8.1.2",
"hiredis == 2.3.2", "hiredis == 2.3.2",
"idna == 3.4",
"jinja2-haml == 0.3.5", "jinja2-haml == 0.3.5",
"markdown == 3.6", "markdown == 3.5.2",
"platformdirs == 4.2.2", "platformdirs == 4.2.0",
"pyyaml == 6.0", "pyyaml >= 6.0",
"redis == 5.0.5", "redis == 5.0.1",
"importlib-resources == 6.4.0; python_version < '3.9'" "importlib_resources == 6.1.1; python_version < '3.9'"
] ]
requires-python = ">=3.8" requires-python = ">=3.8"
dynamic = ["version"] dynamic = ["version"]
@ -50,10 +48,10 @@ activityrelay = "relay.manage:main"
[project.optional-dependencies] [project.optional-dependencies]
dev = [ dev = [
"flake8 == 7.0.0", "flake8 == 7.0.0",
"mypy == 1.10.0", "mypy == 1.9.0",
"pyinstaller == 6.8.0", "pyinstaller == 6.3.0",
"watchdog == 4.0.1", "watchdog == 4.0.0",
"typing-extensions == 4.12.2; python_version < '3.11.0'" "typing_extensions >= 4.10.0; python_version < '3.11.0'"
] ]
[tool.setuptools] [tool.setuptools]
@ -89,18 +87,4 @@ warn_redundant_casts = true
warn_unreachable = true warn_unreachable = true
warn_unused_ignores = true warn_unused_ignores = true
ignore_missing_imports = true ignore_missing_imports = true
implicit_reexport = true
strict = true
follow_imports = "silent" follow_imports = "silent"
[[tool.mypy.overrides]]
module = "relay.database"
implicit_reexport = true
[[tool.mypy.overrides]]
module = "aputils"
implicit_reexport = true
[[tool.mypy.overrides]]
module = "blib"
implicit_reexport = true

View file

@ -1 +1 @@
__version__ = '0.3.3' __version__ = '0.3.1'

View file

@ -4,35 +4,40 @@ import asyncio
import multiprocessing import multiprocessing
import signal import signal
import time import time
import traceback
import typing
from aiohttp import web from aiohttp import web
from aiohttp.web import StaticResource from aiohttp.web import StaticResource
from aiohttp_swagger import setup_swagger from aiohttp_swagger import setup_swagger
from aputils.signer import Signer from aputils.signer import Signer
from bsql import Database
from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta from datetime import datetime, timedelta
from mimetypes import guess_type from mimetypes import guess_type
from pathlib import Path from pathlib import Path
from queue import Empty
from threading import Event, Thread from threading import Event, Thread
from typing import Any
from . import logger as logging, workers from . import logger as logging
from .cache import Cache, get_cache from .cache import get_cache
from .config import Config from .config import Config
from .database import Connection, get_database from .database import get_database
from .database.schema import Instance
from .http_client import HttpClient from .http_client import HttpClient
from .misc import Message, Response, check_open_port, get_resource from .misc import check_open_port, get_resource
from .template import Template from .template import Template
from .views import VIEWS from .views import VIEWS
from .views.api import handle_api_path from .views.api import handle_api_path
from .views.frontend import handle_frontend_path from .views.frontend import handle_frontend_path
if typing.TYPE_CHECKING:
from collections.abc import Callable
from bsql import Database, Row
from .cache import Cache
from .misc import Message, Response
def get_csp(request: web.Request) -> str: def get_csp(request: web.Request) -> str:
data = [ data = [
"default-src 'self'", "default-src 'none'",
f"script-src 'nonce-{request['hash']}'", f"script-src 'nonce-{request['hash']}'",
f"style-src 'self' 'nonce-{request['hash']}'", f"style-src 'self' 'nonce-{request['hash']}'",
"form-action 'self'", "form-action 'self'",
@ -53,9 +58,9 @@ class Application(web.Application):
def __init__(self, cfgpath: Path | None, dev: bool = False): def __init__(self, cfgpath: Path | None, dev: bool = False):
web.Application.__init__(self, web.Application.__init__(self,
middlewares = [ middlewares = [
handle_api_path, # type: ignore[list-item] handle_api_path,
handle_frontend_path, # type: ignore[list-item] handle_frontend_path,
handle_response_headers # type: ignore[list-item] handle_response_headers
] ]
) )
@ -74,7 +79,7 @@ class Application(web.Application):
self['cache'].setup() self['cache'].setup()
self['template'] = Template(self) self['template'] = Template(self)
self['push_queue'] = multiprocessing.Queue() self['push_queue'] = multiprocessing.Queue()
self['workers'] = workers.PushWorkers(self.config.workers) self['workers'] = []
self.cache.setup() self.cache.setup()
self.on_cleanup.append(handle_cleanup) # type: ignore self.on_cleanup.append(handle_cleanup) # type: ignore
@ -91,27 +96,27 @@ class Application(web.Application):
@property @property
def cache(self) -> Cache: def cache(self) -> Cache:
return self['cache'] # type: ignore[no-any-return] return self['cache']
@property @property
def client(self) -> HttpClient: def client(self) -> HttpClient:
return self['client'] # type: ignore[no-any-return] return self['client']
@property @property
def config(self) -> Config: def config(self) -> Config:
return self['config'] # type: ignore[no-any-return] return self['config']
@property @property
def database(self) -> Database[Connection]: def database(self) -> Database:
return self['database'] # type: ignore[no-any-return] return self['database']
@property @property
def signer(self) -> Signer: def signer(self) -> Signer:
return self['signer'] # type: ignore[no-any-return] return self['signer']
@signer.setter @signer.setter
@ -125,7 +130,7 @@ class Application(web.Application):
@property @property
def template(self) -> Template: def template(self) -> Template:
return self['template'] # type: ignore[no-any-return] return self['template']
@property @property
@ -138,8 +143,8 @@ class Application(web.Application):
return timedelta(seconds=uptime.seconds) return timedelta(seconds=uptime.seconds)
def push_message(self, inbox: str, message: Message, instance: Instance) -> None: def push_message(self, inbox: str, message: Message | bytes, instance: Row) -> None:
self['workers'].push_message(inbox, message, instance) self['push_queue'].put((inbox, message, instance))
def register_static_routes(self) -> None: def register_static_routes(self) -> None:
@ -180,11 +185,11 @@ class Application(web.Application):
pass pass
def stop(self, *_: Any) -> None: def stop(self, *_):
self['running'] = False self['running'] = False
async def handle_run(self) -> None: async def handle_run(self):
self['running'] = True self['running'] = True
self.set_signal_handler(True) self.set_signal_handler(True)
@ -194,7 +199,12 @@ class Application(web.Application):
self['cache'].setup() self['cache'].setup()
self['cleanup_thread'] = CacheCleanupThread(self) self['cleanup_thread'] = CacheCleanupThread(self)
self['cleanup_thread'].start() self['cleanup_thread'].start()
self['workers'].start()
for _ in range(self.config.workers):
worker = PushWorker(self['push_queue'])
worker.start()
self['workers'].append(worker)
runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"') runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"')
await runner.setup() await runner.setup()
@ -214,13 +224,15 @@ class Application(web.Application):
await site.stop() await site.stop()
self['workers'].stop() for worker in self['workers']:
worker.stop()
self.set_signal_handler(False) self.set_signal_handler(False)
self['starttime'] = None self['starttime'] = None
self['running'] = False self['running'] = False
self['cleanup_thread'].stop() self['cleanup_thread'].stop()
self['workers'].clear()
self['database'].disconnect() self['database'].disconnect()
self['cache'].close() self['cache'].close()
@ -282,11 +294,42 @@ class CacheCleanupThread(Thread):
self.running.clear() self.running.clear()
@web.middleware class PushWorker(multiprocessing.Process):
async def handle_response_headers( def __init__(self, queue: multiprocessing.Queue):
request: web.Request, multiprocessing.Process.__init__(self)
handler: Callable[[web.Request], Awaitable[Response]]) -> Response: self.queue = queue
self.shutdown = multiprocessing.Event()
def stop(self) -> None:
self.shutdown.set()
def run(self) -> None:
asyncio.run(self.handle_queue())
async def handle_queue(self) -> None:
client = HttpClient()
client.open()
while not self.shutdown.is_set():
try:
inbox, message, instance = self.queue.get(block=True, timeout=0.1)
asyncio.create_task(client.post(inbox, message, instance))
except Empty:
await asyncio.sleep(0)
# make sure an exception doesn't bring down the worker
except Exception:
traceback.print_exc()
await client.close()
@web.middleware
async def handle_response_headers(request: web.Request, handler: Callable) -> Response:
resp = await handler(request) resp = await handler(request)
resp.headers['Server'] = 'ActivityRelay' resp.headers['Server'] = 'ActivityRelay'

View file

@ -2,27 +2,28 @@ from __future__ import annotations
import json import json
import os import os
import typing
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from bsql import Database, Row
from collections.abc import Callable, Iterator
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from redis import Redis from redis import Redis
from typing import TYPE_CHECKING, Any
from .database import Connection, get_database from .database import get_database
from .misc import Message, boolean from .misc import Message, boolean
if TYPE_CHECKING: if typing.TYPE_CHECKING:
from blib import Database
from collections.abc import Callable, Iterator
from typing import Any
from .application import Application from .application import Application
SerializerCallback = Callable[[Any], str] # todo: implement more caching backends
DeserializerCallback = Callable[[str], Any]
BACKENDS: dict[str, type[Cache]] = {} BACKENDS: dict[str, type[Cache]] = {}
CONVERTERS: dict[str, tuple[SerializerCallback, DeserializerCallback]] = { CONVERTERS: dict[str, tuple[Callable, Callable]] = {
'str': (str, str), 'str': (str, str),
'int': (str, int), 'int': (str, int),
'bool': (str, boolean), 'bool': (str, boolean),
@ -60,13 +61,13 @@ class Item:
updated: datetime updated: datetime
def __post_init__(self) -> None: def __post_init__(self):
if isinstance(self.updated, str): # type: ignore[unreachable] if isinstance(self.updated, str):
self.updated = datetime.fromisoformat(self.updated) # type: ignore[unreachable] self.updated = datetime.fromisoformat(self.updated)
@classmethod @classmethod
def from_data(cls: type[Item], *args: Any) -> Item: def from_data(cls: type[Item], *args) -> Item:
data = cls(*args) data = cls(*args)
data.value = deserialize_value(data.value, data.value_type) data.value = deserialize_value(data.value, data.value_type)
@ -158,13 +159,10 @@ class SqlCache(Cache):
def __init__(self, app: Application): def __init__(self, app: Application):
Cache.__init__(self, app) Cache.__init__(self, app)
self._db: Database[Connection] | None = None self._db: Database = None
def get(self, namespace: str, key: str) -> Item: def get(self, namespace: str, key: str) -> Item:
if self._db is None:
raise RuntimeError("Database has not been setup")
params = { params = {
'namespace': namespace, 'namespace': namespace,
'key': key 'key': key
@ -172,7 +170,7 @@ class SqlCache(Cache):
with self._db.session(False) as conn: with self._db.session(False) as conn:
with conn.run('get-cache-item', params) as cur: with conn.run('get-cache-item', params) as cur:
if not (row := cur.one(Row)): if not (row := cur.one()):
raise KeyError(f'{namespace}:{key}') raise KeyError(f'{namespace}:{key}')
row.pop('id', None) row.pop('id', None)
@ -180,27 +178,18 @@ class SqlCache(Cache):
def get_keys(self, namespace: str) -> Iterator[str]: def get_keys(self, namespace: str) -> Iterator[str]:
if self._db is None:
raise RuntimeError("Database has not been setup")
with self._db.session(False) as conn: with self._db.session(False) as conn:
for row in conn.run('get-cache-keys', {'namespace': namespace}): for row in conn.run('get-cache-keys', {'namespace': namespace}):
yield row['key'] yield row['key']
def get_namespaces(self) -> Iterator[str]: def get_namespaces(self) -> Iterator[str]:
if self._db is None:
raise RuntimeError("Database has not been setup")
with self._db.session(False) as conn: with self._db.session(False) as conn:
for row in conn.run('get-cache-namespaces', None): for row in conn.run('get-cache-namespaces', None):
yield row['namespace'] yield row['namespace']
def set(self, namespace: str, key: str, value: Any, value_type: str = 'str') -> Item: def set(self, namespace: str, key: str, value: Any, value_type: str = 'str') -> Item:
if self._db is None:
raise RuntimeError("Database has not been setup")
params = { params = {
'namespace': namespace, 'namespace': namespace,
'key': key, 'key': key,
@ -210,18 +199,13 @@ class SqlCache(Cache):
} }
with self._db.session(True) as conn: with self._db.session(True) as conn:
with conn.run('set-cache-item', params) as cur: with conn.run('set-cache-item', params) as conn:
if (row := cur.one(Row)) is None: row = conn.one()
raise RuntimeError("Cache item not set")
row.pop('id', None) row.pop('id', None)
return Item.from_data(*tuple(row.values())) return Item.from_data(*tuple(row.values()))
def delete(self, namespace: str, key: str) -> None: def delete(self, namespace: str, key: str) -> None:
if self._db is None:
raise RuntimeError("Database has not been setup")
params = { params = {
'namespace': namespace, 'namespace': namespace,
'key': key 'key': key
@ -233,9 +217,6 @@ class SqlCache(Cache):
def delete_old(self, days: int = 14) -> None: def delete_old(self, days: int = 14) -> None:
if self._db is None:
raise RuntimeError("Database has not been setup")
limit = datetime.now(tz = timezone.utc) - timedelta(days = days) limit = datetime.now(tz = timezone.utc) - timedelta(days = days)
params = {"limit": limit.timestamp()} params = {"limit": limit.timestamp()}
@ -245,9 +226,6 @@ class SqlCache(Cache):
def clear(self) -> None: def clear(self) -> None:
if self._db is None:
raise RuntimeError("Database has not been setup")
with self._db.session(True) as conn: with self._db.session(True) as conn:
with conn.execute("DELETE FROM cache"): with conn.execute("DELETE FROM cache"):
pass pass
@ -382,5 +360,5 @@ class RedisCache(Cache):
if not self._rd: if not self._rd:
return return
self._rd.close() # type: ignore self._rd.close()
self._rd = None # type: ignore self._rd = None # type: ignore

View file

@ -1,16 +1,21 @@
from __future__ import annotations
import json import json
import os import os
import typing
import yaml import yaml
from functools import cached_property from functools import cached_property
from pathlib import Path from pathlib import Path
from typing import Any
from urllib.parse import urlparse from urllib.parse import urlparse
from .misc import boolean from .misc import boolean
if typing.TYPE_CHECKING:
from typing import Any
class RelayConfig(dict[str, Any]):
class RelayConfig(dict):
def __init__(self, path: str): def __init__(self, path: str):
dict.__init__(self, {}) dict.__init__(self, {})
@ -117,7 +122,7 @@ class RelayConfig(dict[str, Any]):
self[key] = value self[key] = value
class RelayDatabase(dict[str, Any]): class RelayDatabase(dict):
def __init__(self, config: RelayConfig): def __init__(self, config: RelayConfig):
dict.__init__(self, { dict.__init__(self, {
'relay-list': {}, 'relay-list': {},

View file

@ -3,16 +3,18 @@ from __future__ import annotations
import getpass import getpass
import os import os
import platform import platform
import typing
import yaml import yaml
from dataclasses import asdict, dataclass, fields from dataclasses import asdict, dataclass, fields
from pathlib import Path from pathlib import Path
from platformdirs import user_config_dir from platformdirs import user_config_dir
from typing import TYPE_CHECKING, Any
from .misc import IS_DOCKER from .misc import IS_DOCKER
if TYPE_CHECKING: if typing.TYPE_CHECKING:
from typing import Any
try: try:
from typing import Self from typing import Self
@ -64,7 +66,7 @@ class Config:
def __init__(self, path: Path | None = None, load: bool = False): def __init__(self, path: Path | None = None, load: bool = False):
self.path: Path = Config.get_config_dir(path) self.path = Config.get_config_dir(path)
self.reset() self.reset()
if load: if load:

View file

@ -1,28 +1,31 @@
from bsql import Database from __future__ import annotations
import bsql
import typing
from .config import THEMES, ConfigData from .config import THEMES, ConfigData
from .connection import RELAY_SOFTWARE, Connection from .connection import RELAY_SOFTWARE, Connection
from .schema import TABLES, VERSIONS, migrate_0 from .schema import TABLES, VERSIONS, migrate_0
from .. import logger as logging from .. import logger as logging
from ..config import Config
from ..misc import get_resource from ..misc import get_resource
if typing.TYPE_CHECKING:
from ..config import Config
def get_database(config: Config, migrate: bool = True) -> Database[Connection]:
def get_database(config: Config, migrate: bool = True) -> bsql.Database:
options = { options = {
'connection_class': Connection, 'connection_class': Connection,
'pool_size': 5, 'pool_size': 5,
'tables': TABLES 'tables': TABLES
} }
db: Database[Connection]
if config.db_type == 'sqlite': if config.db_type == 'sqlite':
db = Database.sqlite(config.sqlite_path, **options) db = bsql.Database.sqlite(config.sqlite_path, **options)
elif config.db_type == 'postgres': elif config.db_type == 'postgres':
db = Database.postgresql( db = bsql.Database.postgresql(
config.pg_name, config.pg_name,
config.pg_host, config.pg_host,
config.pg_port, config.pg_port,

View file

@ -1,16 +1,17 @@
from __future__ import annotations from __future__ import annotations
# removing the above line turns annotations into types instead of str objects which messes with
# `Field.type`
from bsql import Row import typing
from collections.abc import Callable, Sequence
from dataclasses import Field, asdict, dataclass, fields from dataclasses import Field, asdict, dataclass, fields
from typing import TYPE_CHECKING, Any
from .. import logger as logging from .. import logger as logging
from ..misc import boolean from ..misc import boolean
if TYPE_CHECKING: if typing.TYPE_CHECKING:
from bsql import Row
from collections.abc import Callable, Sequence
from typing import Any
try: try:
from typing import Self from typing import Self
@ -119,7 +120,7 @@ class ConfigData:
@classmethod @classmethod
def FIELD(cls: type[Self], key: str) -> Field[Any]: def FIELD(cls: type[Self], key: str) -> Field:
for field in fields(cls): for field in fields(cls):
if field.name == key.replace('-', '_'): if field.name == key.replace('-', '_'):
return field return field

View file

@ -1,24 +1,28 @@
from __future__ import annotations from __future__ import annotations
import typing
from argon2 import PasswordHasher from argon2 import PasswordHasher
from bsql import Connection as SqlConnection, Row, Update from bsql import Connection as SqlConnection, Update
from collections.abc import Iterator
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse from urllib.parse import urlparse
from uuid import uuid4 from uuid import uuid4
from . import schema
from .config import ( from .config import (
THEMES, THEMES,
ConfigData ConfigData
) )
from .. import logger as logging from .. import logger as logging
from ..misc import Message, boolean, get_app from ..misc import boolean, get_app
if TYPE_CHECKING: if typing.TYPE_CHECKING:
from collections.abc import Iterator, Sequence
from bsql import Row
from typing import Any
from ..application import Application from ..application import Application
from ..misc import Message
RELAY_SOFTWARE = [ RELAY_SOFTWARE = [
'activityrelay', # https://git.pleroma.social/pleroma/relay 'activityrelay', # https://git.pleroma.social/pleroma/relay
@ -38,14 +42,14 @@ class Connection(SqlConnection):
return get_app() return get_app()
def distill_inboxes(self, message: Message) -> Iterator[schema.Instance]: def distill_inboxes(self, message: Message) -> Iterator[Row]:
src_domains = { src_domains = {
message.domain, message.domain,
urlparse(message.object_id).netloc urlparse(message.object_id).netloc
} }
for instance in self.get_inboxes(): for instance in self.get_inboxes():
if instance.domain not in src_domains: if instance['domain'] not in src_domains:
yield instance yield instance
@ -53,7 +57,7 @@ class Connection(SqlConnection):
key = key.replace('_', '-') key = key.replace('_', '-')
with self.run('get-config', {'key': key}) as cur: with self.run('get-config', {'key': key}) as cur:
if (row := cur.one(Row)) is None: if not (row := cur.one()):
return ConfigData.DEFAULT(key) return ConfigData.DEFAULT(key)
data = ConfigData() data = ConfigData()
@ -62,8 +66,8 @@ class Connection(SqlConnection):
def get_config_all(self) -> ConfigData: def get_config_all(self) -> ConfigData:
rows = tuple(self.run('get-config-all', None).all(schema.Row)) with self.run('get-config-all', None) as cur:
return ConfigData.from_rows(rows) return ConfigData.from_rows(tuple(cur.all()))
def put_config(self, key: str, value: Any) -> Any: def put_config(self, key: str, value: Any) -> Any:
@ -76,7 +80,6 @@ class Connection(SqlConnection):
elif key == 'log-level': elif key == 'log-level':
value = logging.LogLevel.parse(value) value = logging.LogLevel.parse(value)
logging.set_level(value) logging.set_level(value)
self.app['workers'].set_log_level(value)
elif key in {'approval-required', 'whitelist-enabled'}: elif key in {'approval-required', 'whitelist-enabled'}:
value = boolean(value) value = boolean(value)
@ -91,7 +94,7 @@ class Connection(SqlConnection):
params = { params = {
'key': key, 'key': key,
'value': data.get(key, serialize = True), 'value': data.get(key, serialize = True),
'type': 'LogLevel' if field.type == 'logging.LogLevel' else field.type # type: ignore 'type': 'LogLevel' if field.type == 'logging.LogLevel' else field.type
} }
with self.run('put-config', params): with self.run('put-config', params):
@ -100,13 +103,14 @@ class Connection(SqlConnection):
return data.get(key) return data.get(key)
def get_inbox(self, value: str) -> schema.Instance | None: def get_inbox(self, value: str) -> Row:
with self.run('get-inbox', {'value': value}) as cur: with self.run('get-inbox', {'value': value}) as cur:
return cur.one(schema.Instance) return cur.one() # type: ignore
def get_inboxes(self) -> Iterator[schema.Instance]: def get_inboxes(self) -> Sequence[Row]:
return self.execute("SELECT * FROM inboxes WHERE accepted = 1").all(schema.Instance) with self.execute("SELECT * FROM inboxes WHERE accepted = 1") as cur:
return tuple(cur.all())
def put_inbox(self, def put_inbox(self,
@ -115,7 +119,7 @@ class Connection(SqlConnection):
actor: str | None = None, actor: str | None = None,
followid: str | None = None, followid: str | None = None,
software: str | None = None, software: str | None = None,
accepted: bool = True) -> schema.Instance: accepted: bool = True) -> Row:
params: dict[str, Any] = { params: dict[str, Any] = {
'inbox': inbox, 'inbox': inbox,
@ -125,7 +129,7 @@ class Connection(SqlConnection):
'accepted': accepted 'accepted': accepted
} }
if self.get_inbox(domain) is None: if not self.get_inbox(domain):
if not inbox: if not inbox:
raise ValueError("Missing inbox") raise ValueError("Missing inbox")
@ -133,20 +137,14 @@ class Connection(SqlConnection):
params['created'] = datetime.now(tz = timezone.utc) params['created'] = datetime.now(tz = timezone.utc)
with self.run('put-inbox', params) as cur: with self.run('put-inbox', params) as cur:
if (row := cur.one(schema.Instance)) is None: return cur.one() # type: ignore
raise RuntimeError(f"Failed to insert instance: {domain}")
return row
for key, value in tuple(params.items()): for key, value in tuple(params.items()):
if value is None: if value is None:
del params[key] del params[key]
with self.update('inboxes', params, domain = domain) as cur: with self.update('inboxes', params, domain = domain) as cur:
if (row := cur.one(schema.Instance)) is None: return cur.one() # type: ignore
raise RuntimeError(f"Failed to update instance: {domain}")
return row
def del_inbox(self, value: str) -> bool: def del_inbox(self, value: str) -> bool:
@ -157,23 +155,24 @@ class Connection(SqlConnection):
return cur.row_count == 1 return cur.row_count == 1
def get_request(self, domain: str) -> schema.Instance | None: def get_request(self, domain: str) -> Row:
with self.run('get-request', {'domain': domain}) as cur: with self.run('get-request', {'domain': domain}) as cur:
return cur.one(schema.Instance) if not (row := cur.one()):
raise KeyError(domain)
return row
def get_requests(self) -> Iterator[schema.Instance]: def get_requests(self) -> Sequence[Row]:
return self.execute('SELECT * FROM inboxes WHERE accepted = 0').all(schema.Instance) with self.execute('SELECT * FROM inboxes WHERE accepted = 0') as cur:
return tuple(cur.all())
def put_request_response(self, domain: str, accepted: bool) -> schema.Instance: def put_request_response(self, domain: str, accepted: bool) -> Row:
if (instance := self.get_request(domain)) is None: instance = self.get_request(domain)
raise KeyError(domain)
if not accepted: if not accepted:
if not self.del_inbox(domain): self.del_inbox(domain)
raise RuntimeError(f'Failed to delete request: {domain}')
return instance return instance
params = { params = {
@ -182,28 +181,21 @@ class Connection(SqlConnection):
} }
with self.run('put-inbox-accept', params) as cur: with self.run('put-inbox-accept', params) as cur:
if (row := cur.one(schema.Instance)) is None: return cur.one() # type: ignore
raise RuntimeError(f"Failed to insert response for domain: {domain}")
return row
def get_user(self, value: str) -> schema.User | None: def get_user(self, value: str) -> Row:
with self.run('get-user', {'value': value}) as cur: with self.run('get-user', {'value': value}) as cur:
return cur.one(schema.User) return cur.one() # type: ignore
def get_user_by_token(self, code: str) -> schema.User | None: def get_user_by_token(self, code: str) -> Row:
with self.run('get-user-by-token', {'code': code}) as cur: with self.run('get-user-by-token', {'code': code}) as cur:
return cur.one(schema.User) return cur.one() # type: ignore
def get_users(self) -> Iterator[schema.User]: def put_user(self, username: str, password: str | None, handle: str | None = None) -> Row:
return self.execute("SELECT * FROM users").all(schema.User) if self.get_user(username):
def put_user(self, username: str, password: str | None, handle: str | None = None) -> schema.User:
if self.get_user(username) is not None:
data: dict[str, str | datetime | None] = {} data: dict[str, str | datetime | None] = {}
if password: if password:
@ -216,10 +208,7 @@ class Connection(SqlConnection):
stmt.set_where("username", username) stmt.set_where("username", username)
with self.query(stmt) as cur: with self.query(stmt) as cur:
if (row := cur.one(schema.User)) is None: return cur.one() # type: ignore
raise RuntimeError(f"Failed to update user: {username}")
return row
if password is None: if password is None:
raise ValueError('Password cannot be empty') raise ValueError('Password cannot be empty')
@ -232,36 +221,25 @@ class Connection(SqlConnection):
} }
with self.run('put-user', data) as cur: with self.run('put-user', data) as cur:
if (row := cur.one(schema.User)) is None: return cur.one() # type: ignore
raise RuntimeError(f"Failed to insert user: {username}")
return row
def del_user(self, username: str) -> None: def del_user(self, username: str) -> None:
if (user := self.get_user(username)) is None: user = self.get_user(username)
raise KeyError(username)
with self.run('del-user', {'value': user.username}): with self.run('del-user', {'value': user['username']}):
pass pass
with self.run('del-token-user', {'username': user.username}): with self.run('del-token-user', {'username': user['username']}):
pass pass
def get_token(self, code: str) -> schema.Token | None: def get_token(self, code: str) -> Row:
with self.run('get-token', {'code': code}) as cur: with self.run('get-token', {'code': code}) as cur:
return cur.one(schema.Token) return cur.one() # type: ignore
def get_tokens(self, username: str | None = None) -> Iterator[schema.Token]: def put_token(self, username: str) -> Row:
if username is not None:
return self.select('tokens').all(schema.Token)
return self.select('tokens', username = username).all(schema.Token)
def put_token(self, username: str) -> schema.Token:
data = { data = {
'code': uuid4().hex, 'code': uuid4().hex,
'user': username, 'user': username,
@ -269,10 +247,7 @@ class Connection(SqlConnection):
} }
with self.run('put-token', data) as cur: with self.run('put-token', data) as cur:
if (row := cur.one(schema.Token)) is None: return cur.one() # type: ignore
raise RuntimeError(f"Failed to insert token for user: {username}")
return row
def del_token(self, code: str) -> None: def del_token(self, code: str) -> None:
@ -280,22 +255,18 @@ class Connection(SqlConnection):
pass pass
def get_domain_ban(self, domain: str) -> schema.DomainBan | None: def get_domain_ban(self, domain: str) -> Row:
if domain.startswith('http'): if domain.startswith('http'):
domain = urlparse(domain).netloc domain = urlparse(domain).netloc
with self.run('get-domain-ban', {'domain': domain}) as cur: with self.run('get-domain-ban', {'domain': domain}) as cur:
return cur.one(schema.DomainBan) return cur.one() # type: ignore
def get_domain_bans(self) -> Iterator[schema.DomainBan]:
return self.execute("SELECT * FROM domain_bans").all(schema.DomainBan)
def put_domain_ban(self, def put_domain_ban(self,
domain: str, domain: str,
reason: str | None = None, reason: str | None = None,
note: str | None = None) -> schema.DomainBan: note: str | None = None) -> Row:
params = { params = {
'domain': domain, 'domain': domain,
@ -305,16 +276,13 @@ class Connection(SqlConnection):
} }
with self.run('put-domain-ban', params) as cur: with self.run('put-domain-ban', params) as cur:
if (row := cur.one(schema.DomainBan)) is None: return cur.one() # type: ignore
raise RuntimeError(f"Failed to insert domain ban: {domain}")
return row
def update_domain_ban(self, def update_domain_ban(self,
domain: str, domain: str,
reason: str | None = None, reason: str | None = None,
note: str | None = None) -> schema.DomainBan: note: str | None = None) -> Row:
if not (reason or note): if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified') raise ValueError('"reason" and/or "note" must be specified')
@ -334,10 +302,7 @@ class Connection(SqlConnection):
if cur.row_count > 1: if cur.row_count > 1:
raise ValueError('More than one row was modified') raise ValueError('More than one row was modified')
if (row := cur.one(schema.DomainBan)) is None: return self.get_domain_ban(domain)
raise RuntimeError(f"Failed to update domain ban: {domain}")
return row
def del_domain_ban(self, domain: str) -> bool: def del_domain_ban(self, domain: str) -> bool:
@ -348,19 +313,15 @@ class Connection(SqlConnection):
return cur.row_count == 1 return cur.row_count == 1
def get_software_ban(self, name: str) -> schema.SoftwareBan | None: def get_software_ban(self, name: str) -> Row:
with self.run('get-software-ban', {'name': name}) as cur: with self.run('get-software-ban', {'name': name}) as cur:
return cur.one(schema.SoftwareBan) return cur.one() # type: ignore
def get_software_bans(self) -> Iterator[schema.SoftwareBan,]:
return self.execute('SELECT * FROM software_bans').all(schema.SoftwareBan)
def put_software_ban(self, def put_software_ban(self,
name: str, name: str,
reason: str | None = None, reason: str | None = None,
note: str | None = None) -> schema.SoftwareBan: note: str | None = None) -> Row:
params = { params = {
'name': name, 'name': name,
@ -370,16 +331,13 @@ class Connection(SqlConnection):
} }
with self.run('put-software-ban', params) as cur: with self.run('put-software-ban', params) as cur:
if (row := cur.one(schema.SoftwareBan)) is None: return cur.one() # type: ignore
raise RuntimeError(f'Failed to insert software ban: {name}')
return row
def update_software_ban(self, def update_software_ban(self,
name: str, name: str,
reason: str | None = None, reason: str | None = None,
note: str | None = None) -> schema.SoftwareBan: note: str | None = None) -> Row:
if not (reason or note): if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified') raise ValueError('"reason" and/or "note" must be specified')
@ -399,10 +357,7 @@ class Connection(SqlConnection):
if cur.row_count > 1: if cur.row_count > 1:
raise ValueError('More than one row was modified') raise ValueError('More than one row was modified')
if (row := cur.one(schema.SoftwareBan)) is None: return self.get_software_ban(name)
raise RuntimeError(f'Failed to update software ban: {name}')
return row
def del_software_ban(self, name: str) -> bool: def del_software_ban(self, name: str) -> bool:
@ -413,26 +368,19 @@ class Connection(SqlConnection):
return cur.row_count == 1 return cur.row_count == 1
def get_domain_whitelist(self, domain: str) -> schema.Whitelist | None: def get_domain_whitelist(self, domain: str) -> Row:
with self.run('get-domain-whitelist', {'domain': domain}) as cur: with self.run('get-domain-whitelist', {'domain': domain}) as cur:
return cur.one() return cur.one() # type: ignore
def get_domains_whitelist(self) -> Iterator[schema.Whitelist,]: def put_domain_whitelist(self, domain: str) -> Row:
return self.execute("SELECT * FROM whitelist").all(schema.Whitelist)
def put_domain_whitelist(self, domain: str) -> schema.Whitelist:
params = { params = {
'domain': domain, 'domain': domain,
'created': datetime.now(tz = timezone.utc) 'created': datetime.now(tz = timezone.utc)
} }
with self.run('put-domain-whitelist', params) as cur: with self.run('put-domain-whitelist', params) as cur:
if (row := cur.one(schema.Whitelist)) is None: return cur.one() # type: ignore
raise RuntimeError(f'Failed to insert whitelisted domain: {domain}')
return row
def del_domain_whitelist(self, domain: str) -> bool: def del_domain_whitelist(self, domain: str) -> bool:

View file

@ -2,90 +2,69 @@ from __future__ import annotations
import typing import typing
from bsql import Column, Row, Tables from bsql import Column, Table, Tables
from collections.abc import Callable
from datetime import datetime
from .config import ConfigData from .config import ConfigData
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from collections.abc import Callable
from .connection import Connection from .connection import Connection
VERSIONS: dict[int, Callable[[Connection], None]] = {} VERSIONS: dict[int, Callable] = {}
TABLES = Tables() TABLES: Tables = Tables(
Table(
'config',
Column('key', 'text', primary_key = True, unique = True, nullable = False),
Column('value', 'text'),
Column('type', 'text', default = 'str')
),
Table(
'inboxes',
Column('domain', 'text', primary_key = True, unique = True, nullable = False),
Column('actor', 'text', unique = True),
Column('inbox', 'text', unique = True, nullable = False),
Column('followid', 'text'),
Column('software', 'text'),
Column('accepted', 'boolean'),
Column('created', 'timestamp', nullable = False)
),
Table(
'whitelist',
Column('domain', 'text', primary_key = True, unique = True, nullable = True),
Column('created', 'timestamp')
),
Table(
'domain_bans',
Column('domain', 'text', primary_key = True, unique = True, nullable = True),
Column('reason', 'text'),
Column('note', 'text'),
Column('created', 'timestamp', nullable = False)
),
Table(
'software_bans',
Column('name', 'text', primary_key = True, unique = True, nullable = True),
Column('reason', 'text'),
Column('note', 'text'),
Column('created', 'timestamp', nullable = False)
),
Table(
'users',
Column('username', 'text', primary_key = True, unique = True, nullable = False),
Column('hash', 'text', nullable = False),
Column('handle', 'text'),
Column('created', 'timestamp', nullable = False)
),
Table(
'tokens',
Column('code', 'text', primary_key = True, unique = True, nullable = False),
Column('user', 'text', nullable = False),
Column('created', 'timestmap', nullable = False)
)
)
@TABLES.add_row def migration(func: Callable) -> Callable:
class Config(Row):
key: Column[str] = Column('key', 'text', primary_key = True, unique = True, nullable = False)
value: Column[str] = Column('value', 'text')
type: Column[str] = Column('type', 'text', default = 'str')
@TABLES.add_row
class Instance(Row):
table_name: str = 'inboxes'
domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = False)
actor: Column[str] = Column('actor', 'text', unique = True)
inbox: Column[str] = Column('inbox', 'text', unique = True, nullable = False)
followid: Column[str] = Column('followid', 'text')
software: Column[str] = Column('software', 'text')
accepted: Column[datetime] = Column('accepted', 'boolean')
created: Column[datetime] = Column('created', 'timestamp', nullable = False)
@TABLES.add_row
class Whitelist(Row):
domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = True)
created: Column[datetime] = Column('created', 'timestamp')
@TABLES.add_row
class DomainBan(Row):
table_name: str = 'domain_bans'
domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text')
created: Column[datetime] = Column('created', 'timestamp')
@TABLES.add_row
class SoftwareBan(Row):
table_name: str = 'software_bans'
name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text')
created: Column[datetime] = Column('created', 'timestamp')
@TABLES.add_row
class User(Row):
table_name: str = 'users'
username: Column[str] = Column(
'username', 'text', primary_key = True, unique = True, nullable = False)
hash: Column[str] = Column('hash', 'text', nullable = False)
handle: Column[str] = Column('handle', 'text')
created: Column[datetime] = Column('created', 'timestamp')
@TABLES.add_row
class Token(Row):
table_name: str = 'tokens'
code: Column[str] = Column('code', 'text', primary_key = True, unique = True, nullable = False)
user: Column[str] = Column('user', 'text', nullable = False)
created: Column[datetime] = Column('created', 'timestamp')
def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]:
ver = int(func.__name__.replace('migrate_', '')) ver = int(func.__name__.replace('migrate_', ''))
VERSIONS[ver] = func VERSIONS[ver] = func
return func return func

View file

@ -11,11 +11,10 @@
%title << {{config.name}}: {{page}} %title << {{config.name}}: {{page}}
%meta(charset="UTF-8") %meta(charset="UTF-8")
%meta(name="viewport" content="width=device-width, initial-scale=1") %meta(name="viewport" content="width=device-width, initial-scale=1")
%link(rel="stylesheet" type="text/css" href="/theme/{{config.theme}}.css?{{version}}" nonce="{{view.request['hash']}}" class="theme") %link(rel="stylesheet" type="text/css" href="/theme/{{config.theme}}.css" nonce="{{view.request['hash']}}" class="theme")
%link(rel="stylesheet" type="text/css" href="/static/style.css?{{version}}" nonce="{{view.request['hash']}}") %link(rel="stylesheet" type="text/css" href="/static/style.css" nonce="{{view.request['hash']}}")
%link(rel="stylesheet" type="text/css" href="/static/bootstrap-icons.css?{{version}}" nonce="{{view.request['hash']}}") %link(rel="manifest" href="/manifest.json")
%link(rel="manifest" href="/manifest.json?{{version}}") %script(type="application/javascript" src="/static/api.js" nonce="{{view.request['hash']}}" defer)
%script(type="application/javascript" src="/static/functions.js?{{version}}" nonce="{{view.request['hash']}}" defer)
-block head -block head
%body %body
@ -42,7 +41,7 @@
#container #container
#header.section #header.section
%span#menu-open -> %i(class="bi bi-list") %span#menu-open << &#8286;
%a.title(href="/") -> =config.name %a.title(href="/") -> =config.name
.empty .empty

View file

@ -1,32 +1,29 @@
-extends "base.haml" -extends "base.haml"
-set page="Config" -set page="Config"
-import "functions.haml" as func
-block head
%script(type="application/javascript" src="/static/config.js" nonce="{{view.request['hash']}}" defer)
-import "functions.haml" as func
-block content -block content
%fieldset.section %fieldset.section
%legend << Config %legend << Config
.grid-2col .grid-2col
%label(for="name") << Name %label(for="name") << Name
%i(class="bi bi-question-circle-fill" title="{{desc.name}}")
%input(id = "name" placeholder="Relay Name" value="{{config.name or ''}}") %input(id = "name" placeholder="Relay Name" value="{{config.name or ''}}")
%label(for="note") << Description %label(for="note") << Description
%i(class="bi bi-question-circle-fill" title="{{desc.note}}")
%textarea(id="note" value="{{config.note or ''}}") << {{config.note}} %textarea(id="note" value="{{config.note or ''}}") << {{config.note}}
%label(for="theme") << Color Theme %label(for="theme") << Color Theme
%i(class="bi bi-question-circle-fill" title="{{desc.theme}}")
=func.new_select("theme", config.theme, themes) =func.new_select("theme", config.theme, themes)
%label(for="log-level") << Log Level %label(for="log-level") << Log Level
%i(class="bi bi-question-circle-fill" title="{{desc.log_level}}")
=func.new_select("log-level", config.log_level.name, levels) =func.new_select("log-level", config.log_level.name, levels)
%label(for="whitelist-enabled") << Whitelist %label(for="whitelist-enabled") << Whitelist
%i(class="bi bi-question-circle-fill" title="{{desc.whitelist_enabled}}")
=func.new_checkbox("whitelist-enabled", config.whitelist_enabled) =func.new_checkbox("whitelist-enabled", config.whitelist_enabled)
%label(for="approval-required") << Approval Required %label(for="approval-required") << Approval Required
%i(class="bi bi-question-circle-fill" title="{{desc.approval_required}}")
=func.new_checkbox("approval-required", config.approval_required) =func.new_checkbox("approval-required", config.approval_required)

View file

@ -1,6 +1,9 @@
-extends "base.haml" -extends "base.haml"
-set page="Domain Bans" -set page="Domain Bans"
-block head
%script(type="application/javascript" src="/static/domain_ban.js" nonce="{{view.request['hash']}}" defer)
-block content -block content
%details.section %details.section
%summary << Ban Domain %summary << Ban Domain
@ -32,7 +35,7 @@
%tr(id="{{ban.domain}}") %tr(id="{{ban.domain}}")
%td.domain %td.domain
%details %details
%summary -> =ban.domain.encode().decode("idna") %summary -> =ban.domain
.grid-2col .grid-2col
%label.reason(for="{{ban.domain}}-reason") << Reason %label.reason(for="{{ban.domain}}-reason") << Reason

View file

@ -1,6 +1,9 @@
-extends "base.haml" -extends "base.haml"
-set page="Instances" -set page="Instances"
-block head
%script(type="application/javascript" src="/static/instance.js" nonce="{{view.request['hash']}}" defer)
-block content -block content
%details.section %details.section
%summary << Add Instance %summary << Add Instance
@ -36,7 +39,7 @@
-for request in requests -for request in requests
%tr(id="{{request.domain}}") %tr(id="{{request.domain}}")
%td.instance %td.instance
%a(href="https://{{request.domain}}" target="_new") -> =request.domain.encode().decode("idna") %a(href="https://{{request.domain}}" target="_new") -> =request.domain
%td.software %td.software
=request.software or "n/a" =request.software or "n/a"
@ -66,7 +69,7 @@
-for instance in instances -for instance in instances
%tr(id="{{instance.domain}}") %tr(id="{{instance.domain}}")
%td.instance %td.instance
%a(href="https://{{instance.domain}}/" target="_new") -> =instance.domain.encode().decode("idna") %a(href="https://{{instance.domain}}/" target="_new") -> =instance.domain
%td.software %td.software
=instance.software or "n/a" =instance.software or "n/a"

View file

@ -1,6 +1,9 @@
-extends "base.haml" -extends "base.haml"
-set page="Software Bans" -set page="Software Bans"
-block head
%script(type="application/javascript" src="/static/software_ban.js" nonce="{{view.request['hash']}}" defer)
-block content -block content
%details.section %details.section
%summary << Ban Software %summary << Ban Software

View file

@ -1,6 +1,9 @@
-extends "base.haml" -extends "base.haml"
-set page="Users" -set page="Users"
-block head
%script(type="application/javascript" src="/static/user.js" nonce="{{view.request['hash']}}" defer)
-block content -block content
%details.section %details.section
%summary << Add User %summary << Add User

View file

@ -1,6 +1,9 @@
-extends "base.haml" -extends "base.haml"
-set page="Whitelist" -set page="Whitelist"
-block head
%script(type="application/javascript" src="/static/whitelist.js" nonce="{{view.request['hash']}}" defer)
-block content -block content
%details.section %details.section
%summary << Add Domain %summary << Add Domain
@ -24,7 +27,7 @@
-for item in whitelist -for item in whitelist
%tr(id="{{item.domain}}") %tr(id="{{item.domain}}")
%td.domain %td.domain
=item.domain.encode().decode("idna") =item.domain
%td.date %td.date
=item.created.strftime("%Y-%m-%d") =item.created.strftime("%Y-%m-%d")

View file

@ -1,6 +1,5 @@
-extends "base.haml" -extends "base.haml"
-set page = "Home" -set page = "Home"
-block content -block content
-if config.note -if config.note
.section .section
@ -42,7 +41,7 @@
-for instance in instances -for instance in instances
%tr %tr
%td.instance -> %a(href="https://{{instance.domain}}/" target="_new") %td.instance -> %a(href="https://{{instance.domain}}/" target="_new")
=instance.domain.encode().decode("idna") =instance.domain
%td.date %td.date
=instance.created.strftime("%Y-%m-%d") =instance.created.strftime("%Y-%m-%d")

View file

@ -1,6 +1,9 @@
-extends "base.haml" -extends "base.haml"
-set page="Login" -set page="Login"
-block head
%script(type="application/javascript" src="/static/login.js" nonce="{{view.request['hash']}}" defer)
-block content -block content
%fieldset.section %fieldset.section
%legend << Login %legend << Login

View file

@ -0,0 +1,135 @@
// toast notifications
const notifications = document.querySelector("#notifications")
function remove_toast(toast) {
toast.classList.add("hide");
if (toast.timeoutId) {
clearTimeout(toast.timeoutId);
}
setTimeout(() => toast.remove(), 300);
}
function toast(text, type="error", timeout=5) {
const toast = document.createElement("li");
toast.className = `section ${type}`
toast.innerHTML = `<span class=".text">${text}</span><a href="#">&#10006;</span>`
toast.querySelector("a").addEventListener("click", async (event) => {
event.preventDefault();
await remove_toast(toast);
});
notifications.appendChild(toast);
toast.timeoutId = setTimeout(() => remove_toast(toast), timeout * 1000);
}
// menu
const body = document.getElementById("container")
const menu = document.getElementById("menu");
const menu_open = document.getElementById("menu-open");
const menu_close = document.getElementById("menu-close");
menu_open.addEventListener("click", (event) => {
var new_value = menu.attributes.visible.nodeValue === "true" ? "false" : "true";
menu.attributes.visible.nodeValue = new_value;
});
menu_close.addEventListener("click", (event) => {
menu.attributes.visible.nodeValue = "false"
});
body.addEventListener("click", (event) => {
if (event.target === menu_open) {
return;
}
menu.attributes.visible.nodeValue = "false";
});
// misc
function get_date_string(date) {
var year = date.getFullYear().toString();
var month = date.getMonth().toString();
var day = date.getDay().toString();
if (month.length === 1) {
month = "0" + month;
}
if (day.length === 1) {
day = "0" + day
}
return `${year}-${month}-${day}`;
}
function append_table_row(table, row_name, row) {
var table_row = table.insertRow(-1);
table_row.id = row_name;
index = 0;
for (var prop in row) {
if (Object.prototype.hasOwnProperty.call(row, prop)) {
var cell = table_row.insertCell(index);
cell.className = prop;
cell.innerHTML = row[prop];
index += 1;
}
}
return table_row;
}
async function request(method, path, body = null) {
var headers = {
"Accept": "application/json"
}
if (body !== null) {
headers["Content-Type"] = "application/json"
body = JSON.stringify(body)
}
const response = await fetch("/api/" + path, {
method: method,
mode: "cors",
cache: "no-store",
redirect: "follow",
body: body,
headers: headers
});
const message = await response.json();
if (Object.hasOwn(message, "error")) {
throw new Error(message.error);
}
if (Array.isArray(message)) {
message.forEach((msg) => {
if (Object.hasOwn(msg, "created")) {
msg.created = new Date(msg.created);
}
});
} else {
if (Object.hasOwn(message, "created")) {
message.created = new Date(message.created);
}
}
return message;
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,40 @@
const elems = [
document.querySelector("#name"),
document.querySelector("#note"),
document.querySelector("#theme"),
document.querySelector("#log-level"),
document.querySelector("#whitelist-enabled"),
document.querySelector("#approval-required")
]
async function handle_config_change(event) {
params = {
key: event.target.id,
value: event.target.type === "checkbox" ? event.target.checked : event.target.value
}
try {
await request("POST", "v1/config", params);
} catch (error) {
toast(error);
return;
}
if (params.key === "name") {
document.querySelector("#header .title").innerHTML = params.value;
document.querySelector("title").innerHTML = params.value;
}
if (params.key === "theme") {
document.querySelector("link.theme").href = `/theme/${params.value}.css`;
}
toast("Updated config", "message");
}
for (const elem of elems) {
elem.addEventListener("change", handle_config_change);
}

View file

@ -0,0 +1,123 @@
function create_ban_object(domain, reason, note) {
var text = '<details>\n';
text += `<summary>${domain}</summary>\n`;
text += '<div class="grid-2col">\n';
text += `<label for="${domain}-reason" class="reason">Reason</label>\n`;
text += `<textarea id="${domain}-reason" class="reason">${reason}</textarea>\n`;
text += `<label for="${domain}-note" class="note">Note</label>\n`;
text += `<textarea id="${domain}-note" class="note">${note}</textarea>\n`;
text += `<input class="update-ban" type="button" value="Update">`;
text += '</details>';
return text;
}
function add_row_listeners(row) {
row.querySelector(".update-ban").addEventListener("click", async (event) => {
await update_ban(row.id);
});
row.querySelector(".remove a").addEventListener("click", async (event) => {
event.preventDefault();
await unban(row.id);
});
}
async function ban() {
var table = document.querySelector("table");
var elems = {
domain: document.getElementById("new-domain"),
reason: document.getElementById("new-reason"),
note: document.getElementById("new-note")
}
var values = {
domain: elems.domain.value.trim(),
reason: elems.reason.value.trim(),
note: elems.note.value.trim()
}
if (values.domain === "") {
toast("Domain is required");
return;
}
try {
var ban = await request("POST", "v1/domain_ban", values);
} catch (err) {
toast(err);
return
}
var row = append_table_row(document.querySelector("table"), ban.domain, {
domain: create_ban_object(ban.domain, ban.reason, ban.note),
date: get_date_string(ban.created),
remove: `<a href="#" title="Unban domain">&#10006;</a>`
});
add_row_listeners(row);
elems.domain.value = null;
elems.reason.value = null;
elems.note.value = null;
document.querySelector("details.section").open = false;
toast("Banned domain", "message");
}
async function update_ban(domain) {
var row = document.getElementById(domain);
var elems = {
"reason": row.querySelector("textarea.reason"),
"note": row.querySelector("textarea.note")
}
var values = {
"domain": domain,
"reason": elems.reason.value,
"note": elems.note.value
}
try {
await request("PATCH", "v1/domain_ban", values)
} catch (error) {
toast(error);
return;
}
row.querySelector("details").open = false;
toast("Updated baned domain", "message");
}
async function unban(domain) {
try {
await request("DELETE", "v1/domain_ban", {"domain": domain});
} catch (error) {
toast(error);
return;
}
document.getElementById(domain).remove();
toast("Unbanned domain", "message");
}
document.querySelector("#new-ban").addEventListener("click", async (event) => {
await ban();
});
for (var row of document.querySelector("fieldset.section table").rows) {
if (!row.querySelector(".update-ban")) {
continue;
}
add_row_listeners(row);
}

View file

@ -0,0 +1,145 @@
function add_instance_listeners(row) {
row.querySelector(".remove a").addEventListener("click", async (event) => {
event.preventDefault();
await del_instance(row.id);
});
}
function add_request_listeners(row) {
row.querySelector(".approve a").addEventListener("click", async (event) => {
event.preventDefault();
await req_response(row.id, true);
});
row.querySelector(".deny a").addEventListener("click", async (event) => {
event.preventDefault();
await req_response(row.id, false);
});
}
async function add_instance() {
var elems = {
actor: document.getElementById("new-actor"),
inbox: document.getElementById("new-inbox"),
followid: document.getElementById("new-followid"),
software: document.getElementById("new-software")
}
var values = {
actor: elems.actor.value.trim(),
inbox: elems.inbox.value.trim(),
followid: elems.followid.value.trim(),
software: elems.software.value.trim()
}
if (values.actor === "") {
toast("Actor is required");
return;
}
try {
var instance = await request("POST", "v1/instance", values);
} catch (err) {
toast(err);
return
}
row = append_table_row(document.getElementById("instances"), instance.domain, {
domain: `<a href="https://${instance.domain}/" target="_new">${instance.domain}</a>`,
software: instance.software,
date: get_date_string(instance.created),
remove: `<a href="#" title="Remove Instance">&#10006;</a>`
});
add_instance_listeners(row);
elems.actor.value = null;
elems.inbox.value = null;
elems.followid.value = null;
elems.software.value = null;
document.querySelector("details.section").open = false;
toast("Added instance", "message");
}
async function del_instance(domain) {
try {
await request("DELETE", "v1/instance", {"domain": domain});
} catch (error) {
toast(error);
return;
}
document.getElementById(domain).remove();
}
async function req_response(domain, accept) {
params = {
"domain": domain,
"accept": accept
}
try {
await request("POST", "v1/request", params);
} catch (error) {
toast(error);
return;
}
document.getElementById(domain).remove();
if (document.getElementById("requests").rows.length < 2) {
document.querySelector("fieldset.requests").remove()
}
if (!accept) {
toast("Denied instance request", "message");
return;
}
instances = await request("GET", `v1/instance`, null);
instances.forEach((instance) => {
if (instance.domain === domain) {
row = append_table_row(document.getElementById("instances"), instance.domain, {
domain: `<a href="https://${instance.domain}/" target="_new">${instance.domain}</a>`,
software: instance.software,
date: get_date_string(instance.created),
remove: `<a href="#" title="Remove Instance">&#10006;</a>`
});
add_instance_listeners(row);
}
});
toast("Accepted instance request", "message");
}
document.querySelector("#add-instance").addEventListener("click", async (event) => {
await add_instance();
})
for (var row of document.querySelector("#instances").rows) {
if (!row.querySelector(".remove a")) {
continue;
}
add_instance_listeners(row);
}
if (document.querySelector("#requests")) {
for (var row of document.querySelector("#requests").rows) {
if (!row.querySelector(".approve a")) {
continue;
}
add_request_listeners(row);
}
}

View file

@ -0,0 +1,29 @@
async function login(event) {
fields = {
username: document.querySelector("#username"),
password: document.querySelector("#password")
}
values = {
username: fields.username.value.trim(),
password: fields.password.value.trim()
}
if (values.username === "" | values.password === "") {
toast("Username and/or password field is blank");
return;
}
try {
await request("POST", "v1/token", values);
} catch (error) {
toast(error);
return;
}
document.location = "/";
}
document.querySelector(".submit").addEventListener("click", login);

View file

@ -0,0 +1,122 @@
function create_ban_object(name, reason, note) {
var text = '<details>\n';
text += `<summary>${name}</summary>\n`;
text += '<div class="grid-2col">\n';
text += `<label for="${name}-reason" class="reason">Reason</label>\n`;
text += `<textarea id="${name}-reason" class="reason">${reason}</textarea>\n`;
text += `<label for="${name}-note" class="note">Note</label>\n`;
text += `<textarea id="${name}-note" class="note">${note}</textarea>\n`;
text += `<input class="update-ban" type="button" value="Update">`;
text += '</details>';
return text;
}
function add_row_listeners(row) {
row.querySelector(".update-ban").addEventListener("click", async (event) => {
await update_ban(row.id);
});
row.querySelector(".remove a").addEventListener("click", async (event) => {
event.preventDefault();
await unban(row.id);
});
}
async function ban() {
var elems = {
name: document.getElementById("new-name"),
reason: document.getElementById("new-reason"),
note: document.getElementById("new-note")
}
var values = {
name: elems.name.value.trim(),
reason: elems.reason.value,
note: elems.note.value
}
if (values.name === "") {
toast("Domain is required");
return;
}
try {
var ban = await request("POST", "v1/software_ban", values);
} catch (err) {
toast(err);
return
}
var row = append_table_row(document.getElementById("bans"), ban.name, {
name: create_ban_object(ban.name, ban.reason, ban.note),
date: get_date_string(ban.created),
remove: `<a href="#" title="Unban software">&#10006;</a>`
});
add_row_listeners(row);
elems.name.value = null;
elems.reason.value = null;
elems.note.value = null;
document.querySelector("details.section").open = false;
toast("Banned software", "message");
}
async function update_ban(name) {
var row = document.getElementById(name);
var elems = {
"reason": row.querySelector("textarea.reason"),
"note": row.querySelector("textarea.note")
}
var values = {
"name": name,
"reason": elems.reason.value,
"note": elems.note.value
}
try {
await request("PATCH", "v1/software_ban", values)
} catch (error) {
toast(error);
return;
}
row.querySelector("details").open = false;
toast("Updated software ban", "message");
}
async function unban(name) {
try {
await request("DELETE", "v1/software_ban", {"name": name});
} catch (error) {
toast(error);
return;
}
document.getElementById(name).remove();
toast("Unbanned software", "message");
}
document.querySelector("#new-ban").addEventListener("click", async (event) => {
await ban();
});
for (var row of document.querySelector("#bans").rows) {
if (!row.querySelector(".update-ban")) {
continue;
}
add_row_listeners(row);
}

View file

@ -155,7 +155,6 @@ textarea {
z-index: 1; z-index: 1;
font-size: 1.5em; font-size: 1.5em;
min-width: 300px; min-width: 300px;
overflow-x: auto;
} }
#menu[visible="false"] { #menu[visible="false"] {
@ -189,17 +188,11 @@ textarea {
} }
#menu-open { #menu-open {
color: var(--background); color: var(--primary);
background: var(--primary);
font-size: 38px;
line-height: 38px;
border: 1px solid var(--primary);
border-radius: 5px;
} }
#menu-open:hover { #menu-open:hover {
color: var(--primary); color: var(--primary-hover);
background: var(--background);
} }
#menu-open, #menu-close { #menu-open, #menu-close {
@ -297,13 +290,13 @@ textarea {
border: 1px solid var(--error-border) !important; border: 1px solid var(--error-border) !important;
} }
/* create .grid base class and .2col and 3col classes */
.grid-2col { .grid-2col {
display: grid; display: grid;
grid-template-columns: max-content auto; grid-template-columns: max-content auto;
grid-gap: var(--spacing); grid-gap: var(--spacing);
margin-bottom: var(--spacing); margin-bottom: var(--spacing);
align-items: center; align-items: center;
} }
.message { .message {
@ -333,10 +326,6 @@ textarea {
justify-self: left; justify-self: left;
} }
#content.page-config .grid-2col {
grid-template-columns: max-content max-content auto;
}
@keyframes show_toast { @keyframes show_toast {
0% { 0% {

View file

@ -0,0 +1,85 @@
function add_row_listeners(row) {
row.querySelector(".remove a").addEventListener("click", async (event) => {
event.preventDefault();
await del_user(row.id);
});
}
async function add_user() {
var elems = {
username: document.getElementById("new-username"),
password: document.getElementById("new-password"),
password2: document.getElementById("new-password2"),
handle: document.getElementById("new-handle")
}
var values = {
username: elems.username.value.trim(),
password: elems.password.value.trim(),
password2: elems.password2.value.trim(),
handle: elems.handle.value.trim()
}
if (values.username === "" | values.password === "" | values.password2 === "") {
toast("Username, password, and password2 are required");
return;
}
if (values.password !== values.password2) {
toast("Passwords do not match");
return;
}
try {
var user = await request("POST", "v1/user", values);
} catch (err) {
toast(err);
return
}
var row = append_table_row(document.querySelector("fieldset.section table"), user.username, {
domain: user.username,
handle: user.handle ? self.handle : "n/a",
date: get_date_string(user.created),
remove: `<a href="#" title="Delete User">&#10006;</a>`
});
add_row_listeners(row);
elems.username.value = null;
elems.password.value = null;
elems.password2.value = null;
elems.handle.value = null;
document.querySelector("details.section").open = false;
toast("Created user", "message");
}
async function del_user(username) {
try {
await request("DELETE", "v1/user", {"username": username});
} catch (error) {
toast(error);
return;
}
document.getElementById(username).remove();
toast("Deleted user", "message");
}
document.querySelector("#new-user").addEventListener("click", async (event) => {
await add_user();
});
for (var row of document.querySelector("#users").rows) {
if (!row.querySelector(".remove a")) {
continue;
}
add_row_listeners(row);
}

View file

@ -0,0 +1,64 @@
function add_row_listeners(row) {
row.querySelector(".remove a").addEventListener("click", async (event) => {
event.preventDefault();
await del_whitelist(row.id);
});
}
async function add_whitelist() {
var domain_elem = document.getElementById("new-domain");
var domain = domain_elem.value.trim();
if (domain === "") {
toast("Domain is required");
return;
}
try {
var item = await request("POST", "v1/whitelist", {"domain": domain});
} catch (err) {
toast(err);
return;
}
var row = append_table_row(document.getElementById("whitelist"), item.domain, {
domain: item.domain,
date: get_date_string(item.created),
remove: `<a href="#" title="Remove whitelisted domain">&#10006;</a>`
});
add_row_listeners(row);
domain_elem.value = null;
document.querySelector("details.section").open = false;
toast("Added domain", "message");
}
async function del_whitelist(domain) {
try {
await request("DELETE", "v1/whitelist", {"domain": domain});
} catch (error) {
toast(error);
return;
}
document.getElementById(domain).remove();
toast("Removed domain", "message");
}
document.querySelector("#new-item").addEventListener("click", async (event) => {
await add_whitelist();
});
for (var row of document.querySelector("fieldset.section table").rows) {
if (!row.querySelector(".remove a")) {
continue;
}
add_row_listeners(row);
}

View file

@ -1,37 +1,29 @@
from __future__ import annotations from __future__ import annotations
import json import json
import traceback
import typing
from aiohttp import ClientSession, ClientTimeout, TCPConnector from aiohttp import ClientSession, ClientTimeout, TCPConnector
from aputils import AlgorithmType, Nodeinfo, ObjectType, Signer, WellKnownNodeinfo from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
from blib import JsonBase from asyncio.exceptions import TimeoutError as AsyncTimeoutError
from typing import TYPE_CHECKING, Any, TypeVar, overload from aputils import AlgorithmType, JsonBase, Nodeinfo, ObjectType, WellKnownNodeinfo
from json.decoder import JSONDecodeError
from urllib.parse import urlparse
from . import __version__, logger as logging from . import __version__
from .cache import Cache from . import logger as logging
from .database.schema import Instance
from .misc import MIMETYPES, Message, get_app from .misc import MIMETYPES, Message, get_app
if TYPE_CHECKING: if typing.TYPE_CHECKING:
from aputils import Signer
from bsql import Row
from typing import Any
from .application import Application from .application import Application
from .cache import Cache
SUPPORTS_HS2019 = { T = typing.TypeVar('T', bound = JsonBase)
'friendica',
'gotosocial',
'hubzilla'
'mastodon',
'socialhome',
'misskey',
'catodon',
'cherrypick',
'firefish',
'foundkey',
'iceshrimp',
'sharkey'
}
T = TypeVar('T', bound = JsonBase)
HEADERS = { HEADERS = {
'Accept': f'{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9', 'Accept': f'{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9',
'User-Agent': f'ActivityRelay/{__version__}' 'User-Agent': f'ActivityRelay/{__version__}'
@ -98,12 +90,7 @@ class HttpClient:
self._session = None self._session = None
async def _get(self, async def _get(self, url: str, sign_headers: bool, force: bool) -> dict[str, Any] | None:
url: str,
sign_headers: bool,
force: bool,
old_algo: bool) -> str | None:
if not self._session: if not self._session:
raise RuntimeError('Client not open') raise RuntimeError('Client not open')
@ -116,7 +103,7 @@ class HttpClient:
if not force: if not force:
try: try:
if not (item := self.cache.get('request', url)).older_than(48): if not (item := self.cache.get('request', url)).older_than(48):
return item.value # type: ignore [no-any-return] return json.loads(item.value)
except KeyError: except KeyError:
logging.verbose('No cached data for url: %s', url) logging.verbose('No cached data for url: %s', url)
@ -124,72 +111,67 @@ class HttpClient:
headers = {} headers = {}
if sign_headers: if sign_headers:
algo = AlgorithmType.RSASHA256 if old_algo else AlgorithmType.HS2019 headers = self.signer.sign_headers('GET', url, algorithm = AlgorithmType.HS2019)
headers = self.signer.sign_headers('GET', url, algorithm = algo)
logging.debug('Fetching resource: %s', url) try:
logging.debug('Fetching resource: %s', url)
async with self._session.get(url, headers = headers) as resp: async with self._session.get(url, headers = headers) as resp:
# Not expecting a response with 202s, so just return # Not expecting a response with 202s, so just return
if resp.status == 202: if resp.status == 202:
return None
data = await resp.text()
if resp.status != 200:
logging.verbose('Received error when requesting %s: %i', url, resp.status)
logging.debug(data)
return None return None
data = await resp.text() self.cache.set('request', url, data, 'str')
logging.debug('%s >> resp %s', url, json.dumps(json.loads(data), indent = 4))
if resp.status != 200: return json.loads(data)
logging.verbose('Received error when requesting %s: %i', url, resp.status)
logging.debug(data) except JSONDecodeError:
logging.verbose('Failed to parse JSON')
return None return None
self.cache.set('request', url, data, 'str') except ClientSSLError as e:
return data logging.verbose('SSL error when connecting to %s', urlparse(url).netloc)
logging.warning(str(e))
except (AsyncTimeoutError, ClientConnectionError) as e:
logging.verbose('Failed to connect to %s', urlparse(url).netloc)
logging.warning(str(e))
@overload except Exception:
async def get(self, # type: ignore[overload-overlap] traceback.print_exc()
url: str,
sign_headers: bool,
cls: None = None,
force: bool = False,
old_algo: bool = True) -> None: ...
@overload
async def get(self,
url: str,
sign_headers: bool,
cls: type[T] = JsonBase, # type: ignore[assignment]
force: bool = False,
old_algo: bool = True) -> T: ...
async def get(self,
url: str,
sign_headers: bool,
cls: type[T] | None = None,
force: bool = False,
old_algo: bool = True) -> T | None:
if cls is not None and not issubclass(cls, JsonBase):
raise TypeError('cls must be a sub-class of "blib.JsonBase"')
data = await self._get(url, sign_headers, force, old_algo)
if cls is not None:
if data is None:
raise ValueError("Empty response")
return cls.parse(data)
return None return None
async def post(self, url: str, data: Message | bytes, instance: Instance | None = None) -> None: async def get(self,
url: str,
sign_headers: bool,
cls: type[T],
force: bool = False) -> T | None:
if not issubclass(cls, JsonBase):
raise TypeError('cls must be a sub-class of "aputils.JsonBase"')
if (data := (await self._get(url, sign_headers, force))) is None:
return None
return cls.parse(data)
async def post(self, url: str, data: Message | bytes, instance: Row | None = None) -> None:
if not self._session: if not self._session:
raise RuntimeError('Client not open') raise RuntimeError('Client not open')
# akkoma and pleroma do not support HS2019 and other software still needs to be tested # akkoma and pleroma do not support HS2019 and other software still needs to be tested
if instance is not None and instance.software in SUPPORTS_HS2019: if instance and instance['software'] in {'mastodon'}:
algorithm = AlgorithmType.HS2019 algorithm = AlgorithmType.HS2019
else: else:
@ -215,22 +197,35 @@ class HttpClient:
algorithm = algorithm algorithm = algorithm
) )
logging.verbose('Sending "%s" to %s', mtype, url) try:
logging.verbose('Sending "%s" to %s', mtype, url)
async with self._session.post(url, headers = headers, data = body) as resp: async with self._session.post(url, headers = headers, data = body) as resp:
# Not expecting a response, so just return # Not expecting a response, so just return
if resp.status in {200, 202}: if resp.status in {200, 202}:
logging.verbose('Successfully sent "%s" to %s', mtype, url) logging.verbose('Successfully sent "%s" to %s', mtype, url)
return
logging.verbose('Received error when pushing to %s: %i', url, resp.status)
logging.debug(await resp.read())
logging.debug("message: %s", body.decode("utf-8"))
logging.debug("headers: %s", json.dumps(headers, indent = 4))
return return
logging.error('Received error when pushing to %s: %i', url, resp.status) except ClientSSLError as e:
logging.debug(await resp.read()) logging.warning('SSL error when pushing to %s', urlparse(url).netloc)
logging.debug("message: %s", body.decode("utf-8")) logging.warning(str(e))
logging.debug("headers: %s", json.dumps(headers, indent = 4))
return except (AsyncTimeoutError, ClientConnectionError) as e:
logging.warning('Failed to connect to %s for message push', urlparse(url).netloc)
logging.warning(str(e))
# prevent workers from being brought down
except Exception:
traceback.print_exc()
async def fetch_nodeinfo(self, domain: str) -> Nodeinfo: async def fetch_nodeinfo(self, domain: str) -> Nodeinfo | None:
nodeinfo_url = None nodeinfo_url = None
wk_nodeinfo = await self.get( wk_nodeinfo = await self.get(
f'https://{domain}/.well-known/nodeinfo', f'https://{domain}/.well-known/nodeinfo',
@ -238,6 +233,10 @@ class HttpClient:
WellKnownNodeinfo WellKnownNodeinfo
) )
if wk_nodeinfo is None:
logging.verbose('Failed to fetch well-known nodeinfo url for %s', domain)
return None
for version in ('20', '21'): for version in ('20', '21'):
try: try:
nodeinfo_url = wk_nodeinfo.get_url(version) nodeinfo_url = wk_nodeinfo.get_url(version)
@ -246,7 +245,8 @@ class HttpClient:
pass pass
if nodeinfo_url is None: if nodeinfo_url is None:
raise ValueError(f'Failed to fetch nodeinfo url for {domain}') logging.verbose('Failed to fetch nodeinfo url for %s', domain)
return None
return await self.get(nodeinfo_url, False, Nodeinfo) return await self.get(nodeinfo_url, False, Nodeinfo)

View file

@ -2,12 +2,15 @@ from __future__ import annotations
import logging import logging
import os import os
import typing
from enum import IntEnum from enum import IntEnum
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Protocol
if TYPE_CHECKING: if typing.TYPE_CHECKING:
from collections.abc import Callable
from typing import Any
try: try:
from typing import Self from typing import Self
@ -15,10 +18,6 @@ if TYPE_CHECKING:
from typing_extensions import Self from typing_extensions import Self
class LoggingMethod(Protocol):
def __call__(self, msg: Any, *args: Any, **kwargs: Any) -> None: ...
class LogLevel(IntEnum): class LogLevel(IntEnum):
DEBUG = logging.DEBUG DEBUG = logging.DEBUG
VERBOSE = 15 VERBOSE = 15
@ -76,11 +75,11 @@ def verbose(message: str, *args: Any, **kwargs: Any) -> None:
logging.log(LogLevel.VERBOSE, message, *args, **kwargs) logging.log(LogLevel.VERBOSE, message, *args, **kwargs)
debug: LoggingMethod = logging.debug debug: Callable = logging.debug
info: LoggingMethod = logging.info info: Callable = logging.info
warning: LoggingMethod = logging.warning warning: Callable = logging.warning
error: LoggingMethod = logging.error error: Callable = logging.error
critical: LoggingMethod = logging.critical critical: Callable = logging.critical
try: try:

View file

@ -5,10 +5,10 @@ import asyncio
import click import click
import json import json
import os import os
import typing
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any
from urllib.parse import urlparse from urllib.parse import urlparse
from . import __version__ from . import __version__
@ -16,9 +16,13 @@ from . import http_client as http
from . import logger as logging from . import logger as logging
from .application import Application from .application import Application
from .compat import RelayConfig, RelayDatabase from .compat import RelayConfig, RelayDatabase
from .database import RELAY_SOFTWARE, get_database, schema from .database import RELAY_SOFTWARE, get_database
from .misc import ACTOR_FORMATS, SOFTWARE, IS_DOCKER, Message from .misc import ACTOR_FORMATS, SOFTWARE, IS_DOCKER, Message
if typing.TYPE_CHECKING:
from bsql import Row
from typing import Any
def check_alphanumeric(text: str) -> str: def check_alphanumeric(text: str) -> str:
if not text.isalnum(): if not text.isalnum():
@ -366,8 +370,8 @@ def cli_user_list(ctx: click.Context) -> None:
click.echo('Users:') click.echo('Users:')
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
for row in conn.get_users(): for user in conn.execute('SELECT * FROM users'):
click.echo(f'- {row.username}') click.echo(f'- {user["username"]}')
@cli_user.command('create') @cli_user.command('create')
@ -378,7 +382,7 @@ def cli_user_create(ctx: click.Context, username: str, handle: str) -> None:
'Create a new local user' 'Create a new local user'
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
if conn.get_user(username) is not None: if conn.get_user(username):
click.echo(f'User already exists: {username}') click.echo(f'User already exists: {username}')
return return
@ -405,7 +409,7 @@ def cli_user_delete(ctx: click.Context, username: str) -> None:
'Delete a local user' 'Delete a local user'
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
if conn.get_user(username) is None: if not conn.get_user(username):
click.echo(f'User does not exist: {username}') click.echo(f'User does not exist: {username}')
return return
@ -423,8 +427,8 @@ def cli_user_list_tokens(ctx: click.Context, username: str) -> None:
click.echo(f'Tokens for "{username}":') click.echo(f'Tokens for "{username}":')
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
for row in conn.get_tokens(username): for token in conn.execute('SELECT * FROM tokens WHERE user = :user', {'user': username}):
click.echo(f'- {row.code}') click.echo(f'- {token["code"]}')
@cli_user.command('create-token') @cli_user.command('create-token')
@ -434,13 +438,13 @@ def cli_user_create_token(ctx: click.Context, username: str) -> None:
'Create a new API token for a user' 'Create a new API token for a user'
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
if (user := conn.get_user(username)) is None: if not (user := conn.get_user(username)):
click.echo(f'User does not exist: {username}') click.echo(f'User does not exist: {username}')
return return
token = conn.put_token(user.username) token = conn.put_token(user['username'])
click.echo(f'New token for "{username}": {token.code}') click.echo(f'New token for "{username}": {token["code"]}')
@cli_user.command('delete-token') @cli_user.command('delete-token')
@ -450,7 +454,7 @@ def cli_user_delete_token(ctx: click.Context, code: str) -> None:
'Delete an API token' 'Delete an API token'
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
if conn.get_token(code) is None: if not conn.get_token(code):
click.echo('Token does not exist') click.echo('Token does not exist')
return return
@ -472,8 +476,8 @@ def cli_inbox_list(ctx: click.Context) -> None:
click.echo('Connected to the following instances or relays:') click.echo('Connected to the following instances or relays:')
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
for row in conn.get_inboxes(): for inbox in conn.get_inboxes():
click.echo(f'- {row.inbox}') click.echo(f'- {inbox["inbox"]}')
@cli_inbox.command('follow') @cli_inbox.command('follow')
@ -482,21 +486,19 @@ def cli_inbox_list(ctx: click.Context) -> None:
def cli_inbox_follow(ctx: click.Context, actor: str) -> None: def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
'Follow an actor (Relay must be running)' 'Follow an actor (Relay must be running)'
instance: schema.Instance | None = None
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
if conn.get_domain_ban(actor): if conn.get_domain_ban(actor):
click.echo(f'Error: Refusing to follow banned actor: {actor}') click.echo(f'Error: Refusing to follow banned actor: {actor}')
return return
if (instance := conn.get_inbox(actor)) is not None: if (inbox_data := conn.get_inbox(actor)):
inbox = instance.inbox inbox = inbox_data['inbox']
else: else:
if not actor.startswith('http'): if not actor.startswith('http'):
actor = f'https://{actor}/actor' actor = f'https://{actor}/actor'
if (actor_data := asyncio.run(http.get(actor, sign_headers = True))) is None: if not (actor_data := asyncio.run(http.get(actor, sign_headers = True))):
click.echo(f'Failed to fetch actor: {actor}') click.echo(f'Failed to fetch actor: {actor}')
return return
@ -507,7 +509,7 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
actor = actor actor = actor
) )
asyncio.run(http.post(inbox, message, instance)) asyncio.run(http.post(inbox, message, inbox_data))
click.echo(f'Sent follow message to actor: {actor}') click.echo(f'Sent follow message to actor: {actor}')
@ -517,19 +519,19 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None: def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
'Unfollow an actor (Relay must be running)' 'Unfollow an actor (Relay must be running)'
instance: schema.Instance | None = None inbox_data: Row | None = None
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
if conn.get_domain_ban(actor): if conn.get_domain_ban(actor):
click.echo(f'Error: Refusing to follow banned actor: {actor}') click.echo(f'Error: Refusing to follow banned actor: {actor}')
return return
if (instance := conn.get_inbox(actor)): if (inbox_data := conn.get_inbox(actor)):
inbox = instance.inbox inbox = inbox_data['inbox']
message = Message.new_unfollow( message = Message.new_unfollow(
host = ctx.obj.config.domain, host = ctx.obj.config.domain,
actor = actor, actor = actor,
follow = instance.followid follow = inbox_data['followid']
) )
else: else:
@ -553,7 +555,7 @@ def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
} }
) )
asyncio.run(http.post(inbox, message, instance)) asyncio.run(http.post(inbox, message, inbox_data))
click.echo(f'Sent unfollow message to: {actor}') click.echo(f'Sent unfollow message to: {actor}')
@ -633,9 +635,9 @@ def cli_request_list(ctx: click.Context) -> None:
click.echo('Follow requests:') click.echo('Follow requests:')
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
for row in conn.get_requests(): for instance in conn.get_requests():
date = row.created.strftime('%Y-%m-%d') date = instance['created'].strftime('%Y-%m-%d')
click.echo(f'- [{date}] {row.domain}') click.echo(f'- [{date}] {instance["domain"]}')
@cli_request.command('accept') @cli_request.command('accept')
@ -654,20 +656,20 @@ def cli_request_accept(ctx: click.Context, domain: str) -> None:
message = Message.new_response( message = Message.new_response(
host = ctx.obj.config.domain, host = ctx.obj.config.domain,
actor = instance.actor, actor = instance['actor'],
followid = instance.followid, followid = instance['followid'],
accept = True accept = True
) )
asyncio.run(http.post(instance.inbox, message, instance)) asyncio.run(http.post(instance['inbox'], message, instance))
if instance.software != 'mastodon': if instance['software'] != 'mastodon':
message = Message.new_follow( message = Message.new_follow(
host = ctx.obj.config.domain, host = ctx.obj.config.domain,
actor = instance.actor actor = instance['actor']
) )
asyncio.run(http.post(instance.inbox, message, instance)) asyncio.run(http.post(instance['inbox'], message, instance))
@cli_request.command('deny') @cli_request.command('deny')
@ -686,12 +688,12 @@ def cli_request_deny(ctx: click.Context, domain: str) -> None:
response = Message.new_response( response = Message.new_response(
host = ctx.obj.config.domain, host = ctx.obj.config.domain,
actor = instance.actor, actor = instance['actor'],
followid = instance.followid, followid = instance['followid'],
accept = False accept = False
) )
asyncio.run(http.post(instance.inbox, response, instance)) asyncio.run(http.post(instance['inbox'], response, instance))
@cli.group('instance') @cli.group('instance')
@ -707,12 +709,12 @@ def cli_instance_list(ctx: click.Context) -> None:
click.echo('Banned domains:') click.echo('Banned domains:')
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
for row in conn.get_domain_bans(): for instance in conn.execute('SELECT * FROM domain_bans'):
if row.reason is not None: if instance['reason']:
click.echo(f'- {row.domain} ({row.reason})') click.echo(f'- {instance["domain"]} ({instance["reason"]})')
else: else:
click.echo(f'- {row.domain}') click.echo(f'- {instance["domain"]}')
@cli_instance.command('ban') @cli_instance.command('ban')
@ -724,7 +726,7 @@ def cli_instance_ban(ctx: click.Context, domain: str, reason: str, note: str) ->
'Ban an instance and remove the associated inbox if it exists' 'Ban an instance and remove the associated inbox if it exists'
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
if conn.get_domain_ban(domain) is not None: if conn.get_domain_ban(domain):
click.echo(f'Domain already banned: {domain}') click.echo(f'Domain already banned: {domain}')
return return
@ -740,7 +742,7 @@ def cli_instance_unban(ctx: click.Context, domain: str) -> None:
'Unban an instance' 'Unban an instance'
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
if conn.del_domain_ban(domain) is None: if not conn.del_domain_ban(domain):
click.echo(f'Instance wasn\'t banned: {domain}') click.echo(f'Instance wasn\'t banned: {domain}')
return return
@ -765,11 +767,11 @@ def cli_instance_update(ctx: click.Context, domain: str, reason: str, note: str)
click.echo(f'Updated domain ban: {domain}') click.echo(f'Updated domain ban: {domain}')
if row.reason: if row['reason']:
click.echo(f'- {row.domain} ({row.reason})') click.echo(f'- {row["domain"]} ({row["reason"]})')
else: else:
click.echo(f'- {row.domain}') click.echo(f'- {row["domain"]}')
@cli.group('software') @cli.group('software')
@ -785,12 +787,12 @@ def cli_software_list(ctx: click.Context) -> None:
click.echo('Banned software:') click.echo('Banned software:')
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
for row in conn.get_software_bans(): for software in conn.execute('SELECT * FROM software_bans'):
if row.reason: if software['reason']:
click.echo(f'- {row.name} ({row.reason})') click.echo(f'- {software["name"]} ({software["reason"]})')
else: else:
click.echo(f'- {row.name}') click.echo(f'- {software["name"]}')
@cli_software.command('ban') @cli_software.command('ban')
@ -812,12 +814,12 @@ def cli_software_ban(ctx: click.Context,
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
if name == 'RELAYS': if name == 'RELAYS':
for item in RELAY_SOFTWARE: for software in RELAY_SOFTWARE:
if conn.get_software_ban(item): if conn.get_software_ban(software):
click.echo(f'Relay already banned: {item}') click.echo(f'Relay already banned: {software}')
continue continue
conn.put_software_ban(item, reason or 'relay', note) conn.put_software_ban(software, reason or 'relay', note)
click.echo('Banned all relay software') click.echo('Banned all relay software')
return return
@ -894,11 +896,11 @@ def cli_software_update(ctx: click.Context, name: str, reason: str, note: str) -
click.echo(f'Updated software ban: {name}') click.echo(f'Updated software ban: {name}')
if row.reason: if row['reason']:
click.echo(f'- {row.name} ({row.reason})') click.echo(f'- {row["name"]} ({row["reason"]})')
else: else:
click.echo(f'- {row.name}') click.echo(f'- {row["name"]}')
@cli.group('whitelist') @cli.group('whitelist')
@ -914,8 +916,8 @@ def cli_whitelist_list(ctx: click.Context) -> None:
click.echo('Current whitelisted domains:') click.echo('Current whitelisted domains:')
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
for row in conn.get_domain_whitelist(): for domain in conn.execute('SELECT * FROM whitelist'):
click.echo(f'- {row.domain}') click.echo(f'- {domain["domain"]}')
@cli_whitelist.command('add') @cli_whitelist.command('add')
@ -954,19 +956,23 @@ def cli_whitelist_remove(ctx: click.Context, domain: str) -> None:
@cli_whitelist.command('import') @cli_whitelist.command('import')
@click.pass_context @click.pass_context
def cli_whitelist_import(ctx: click.Context) -> None: def cli_whitelist_import(ctx: click.Context) -> None:
'Add all current instances to the whitelist' 'Add all current inboxes to the whitelist'
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
for row in conn.get_inboxes(): for inbox in conn.execute('SELECT * FROM inboxes').all():
if conn.get_domain_whitelist(row.domain) is not None: if conn.get_domain_whitelist(inbox['domain']):
click.echo(f'Domain already in whitelist: {row.domain}') click.echo(f'Domain already in whitelist: {inbox["domain"]}')
continue continue
conn.put_domain_whitelist(row.domain) conn.put_domain_whitelist(inbox['domain'])
click.echo('Imported whitelist from inboxes') click.echo('Imported whitelist from inboxes')
def main() -> None: def main() -> None:
cli(prog_name='activityrelay') cli(prog_name='relay')
if __name__ == '__main__':
click.echo('Running relay.manage is depreciated. Run `activityrelay [command]` instead.')

View file

@ -3,14 +3,12 @@ from __future__ import annotations
import aputils import aputils
import json import json
import os import os
import platform
import socket import socket
import typing
from aiohttp.web import Response as AiohttpResponse from aiohttp.web import Response as AiohttpResponse
from collections.abc import Sequence
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, TypedDict, TypeVar
from uuid import uuid4 from uuid import uuid4
try: try:
@ -19,7 +17,8 @@ try:
except ImportError: except ImportError:
from importlib_resources import files as pkgfiles # type: ignore from importlib_resources import files as pkgfiles # type: ignore
if TYPE_CHECKING: if typing.TYPE_CHECKING:
from typing import Any
from .application import Application from .application import Application
try: try:
@ -29,17 +28,16 @@ if TYPE_CHECKING:
from typing_extensions import Self from typing_extensions import Self
T = TypeVar('T') T = typing.TypeVar('T')
ResponseType = TypedDict('ResponseType', { ResponseType = typing.TypedDict('ResponseType', {
'status': int, 'status': int,
'headers': dict[str, Any] | None, 'headers': dict[str, typing.Any] | None,
'content_type': str, 'content_type': str,
'body': bytes | None, 'body': bytes | None,
'text': str | None 'text': str | None
}) })
IS_DOCKER = bool(os.environ.get('DOCKER_RUNNING')) IS_DOCKER = bool(os.environ.get('DOCKER_RUNNING'))
IS_WINDOWS = platform.system() == 'Windows'
MIMETYPES = { MIMETYPES = {
'activity': 'application/activity+json', 'activity': 'application/activity+json',
@ -128,7 +126,7 @@ class JsonEncoder(json.JSONEncoder):
if isinstance(o, datetime): if isinstance(o, datetime):
return o.isoformat() return o.isoformat()
return json.JSONEncoder.default(self, o) # type: ignore[no-any-return] return json.JSONEncoder.default(self, o)
class Message(aputils.Message): class Message(aputils.Message):
@ -148,7 +146,6 @@ class Message(aputils.Message):
'followers': f'https://{host}/followers', 'followers': f'https://{host}/followers',
'following': f'https://{host}/following', 'following': f'https://{host}/following',
'inbox': f'https://{host}/inbox', 'inbox': f'https://{host}/inbox',
'outbox': f'https://{host}/outbox',
'url': f'https://{host}/', 'url': f'https://{host}/',
'endpoints': { 'endpoints': {
'sharedInbox': f'https://{host}/inbox' 'sharedInbox': f'https://{host}/inbox'
@ -214,7 +211,7 @@ class Response(AiohttpResponse):
@classmethod @classmethod
def new(cls: type[Self], def new(cls: type[Self],
body: str | bytes | dict[str, Any] | Sequence[Any] = '', body: str | bytes | dict | tuple | list | set = '',
status: int = 200, status: int = 200,
headers: dict[str, str] | None = None, headers: dict[str, str] | None = None,
ctype: str = 'text') -> Self: ctype: str = 'text') -> Self:
@ -227,22 +224,22 @@ class Response(AiohttpResponse):
'text': None 'text': None
} }
if isinstance(body, str): if isinstance(body, bytes):
kwargs['text'] = body
elif isinstance(body, bytes):
kwargs['body'] = body kwargs['body'] = body
elif isinstance(body, (dict, Sequence)): elif isinstance(body, (dict, list, tuple, set)) or ctype in {'json', 'activity'}:
kwargs['text'] = json.dumps(body, cls = JsonEncoder) kwargs['text'] = json.dumps(body, cls = JsonEncoder)
else:
kwargs['text'] = body
return cls(**kwargs) return cls(**kwargs)
@classmethod @classmethod
def new_error(cls: type[Self], def new_error(cls: type[Self],
status: int, status: int,
body: str | bytes | dict[str, Any], body: str | bytes | dict,
ctype: str = 'text') -> Self: ctype: str = 'text') -> Self:
if ctype == 'json': if ctype == 'json':

View file

@ -10,12 +10,14 @@ if typing.TYPE_CHECKING:
from .views.activitypub import ActorView from .views.activitypub import ActorView
def actor_type_check(actor: Message, software: str | None) -> bool: def person_check(actor: Message, software: str | None) -> bool:
if actor.type == 'Application': # pleroma and akkoma may use Person for the actor type for some reason
return True # akkoma changed this in 3.6.0
# akkoma (< 3.6.0) and pleroma use Person for the actor type
if software in {'akkoma', 'pleroma'} and actor.id == f'https://{actor.domain}/relay': if software in {'akkoma', 'pleroma'} and actor.id == f'https://{actor.domain}/relay':
return False
# make sure the actor is an application
if actor.type != 'Application':
return True return True
return False return False
@ -34,7 +36,7 @@ async def handle_relay(view: ActorView, conn: Connection) -> None:
logging.debug('>> relay: %s', message) logging.debug('>> relay: %s', message)
for instance in conn.distill_inboxes(view.message): for instance in conn.distill_inboxes(view.message):
view.app.push_message(instance.inbox, message, instance) view.app.push_message(instance["inbox"], message, instance)
view.cache.set('handle-relay', view.message.object_id, message.id, 'str') view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
@ -52,7 +54,7 @@ async def handle_forward(view: ActorView, conn: Connection) -> None:
logging.debug('>> forward: %s', message) logging.debug('>> forward: %s', message)
for instance in conn.distill_inboxes(view.message): for instance in conn.distill_inboxes(view.message):
view.app.push_message(instance.inbox, view.message, instance) view.app.push_message(instance["inbox"], await view.request.read(), instance)
view.cache.set('handle-relay', view.message.id, message.id, 'str') view.cache.set('handle-relay', view.message.id, message.id, 'str')
@ -86,7 +88,7 @@ async def handle_follow(view: ActorView, conn: Connection) -> None:
return return
# reject if the actor is not an instance actor # reject if the actor is not an instance actor
if actor_type_check(view.actor, software): if person_check(view.actor, software):
logging.verbose('Non-application actor tried to follow: %s', view.actor.id) logging.verbose('Non-application actor tried to follow: %s', view.actor.id)
view.app.push_message( view.app.push_message(
@ -177,7 +179,7 @@ async def handle_undo(view: ActorView, conn: Connection) -> None:
return return
# prevent past unfollows from removing an instance # prevent past unfollows from removing an instance
if view.instance.followid and view.instance.followid != view.message.object_id: if view.instance['followid'] and view.instance['followid'] != view.message.object_id:
return return
with conn.transaction(): with conn.transaction():
@ -221,18 +223,18 @@ async def run_processor(view: ActorView) -> None:
with view.database.session() as conn: with view.database.session() as conn:
if view.instance: if view.instance:
if not view.instance.software: if not view.instance['software']:
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance.domain)): if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])):
with conn.transaction(): with conn.transaction():
view.instance = conn.put_inbox( view.instance = conn.put_inbox(
domain = view.instance.domain, domain = view.instance['domain'],
software = nodeinfo.sw_name software = nodeinfo.sw_name
) )
if not view.instance.actor: if not view.instance['actor']:
with conn.transaction(): with conn.transaction():
view.instance = conn.put_inbox( view.instance = conn.put_inbox(
domain = view.instance.domain, domain = view.instance['domain'],
actor = view.actor.id actor = view.actor.id
) )

View file

@ -1,22 +1,25 @@
from __future__ import annotations from __future__ import annotations
import textwrap import textwrap
import typing
from collections.abc import Callable from collections.abc import Callable
from hamlish_jinja import HamlishExtension from hamlish_jinja import HamlishExtension
from jinja2 import Environment, FileSystemLoader from jinja2 import Environment, FileSystemLoader
from jinja2.ext import Extension from jinja2.ext import Extension
from jinja2.nodes import CallBlock, Node from jinja2.nodes import CallBlock
from jinja2.parser import Parser
from markdown import Markdown from markdown import Markdown
from typing import TYPE_CHECKING, Any
from . import __version__ from . import __version__
from .misc import get_resource from .misc import get_resource
from .views.base import View
if TYPE_CHECKING: if typing.TYPE_CHECKING:
from jinja2.nodes import Node
from jinja2.parser import Parser
from typing import Any
from .application import Application from .application import Application
from .views.base import View
class Template(Environment): class Template(Environment):

View file

@ -1,22 +1,26 @@
from __future__ import annotations
import aputils import aputils
import traceback import traceback
import typing
from aiohttp.web import Request
from .base import View, register_route from .base import View, register_route
from .. import logger as logging from .. import logger as logging
from ..database import schema
from ..misc import Message, Response from ..misc import Message, Response
from ..processors import run_processor from ..processors import run_processor
if typing.TYPE_CHECKING:
from aiohttp.web import Request
from bsql import Row
@register_route('/actor', '/inbox') @register_route('/actor', '/inbox')
class ActorView(View): class ActorView(View):
signature: aputils.Signature signature: aputils.Signature
message: Message message: Message
actor: Message actor: Message
instance: schema.Instance instancce: Row
signer: aputils.Signer signer: aputils.Signer
@ -43,7 +47,7 @@ class ActorView(View):
return response return response
with self.database.session() as conn: with self.database.session() as conn:
self.instance = conn.get_inbox(self.actor.shared_inbox) # type: ignore[assignment] self.instance = conn.get_inbox(self.actor.shared_inbox)
# reject if actor is banned # reject if actor is banned
if conn.get_domain_ban(self.actor.domain): if conn.get_domain_ban(self.actor.domain):
@ -91,10 +95,9 @@ class ActorView(View):
logging.verbose('actor not in message') logging.verbose('actor not in message')
return Response.new_error(400, 'no actor in message', 'json') return Response.new_error(400, 'no actor in message', 'json')
try: actor: Message | None = await self.client.get(self.signature.keyid, True, Message)
self.actor = await self.client.get(self.signature.keyid, True, Message)
except Exception: if actor is None:
# ld signatures aren't handled atm, so just ignore it # ld signatures aren't handled atm, so just ignore it
if self.message.type == 'Delete': if self.message.type == 'Delete':
logging.verbose('Instance sent a delete which cannot be handled') logging.verbose('Instance sent a delete which cannot be handled')
@ -103,6 +106,8 @@ class ActorView(View):
logging.verbose(f'Failed to fetch actor: {self.signature.keyid}') logging.verbose(f'Failed to fetch actor: {self.signature.keyid}')
return Response.new_error(400, 'failed to fetch actor', 'json') return Response.new_error(400, 'failed to fetch actor', 'json')
self.actor = actor
try: try:
self.signer = self.actor.signer self.signer = self.actor.signer
@ -120,39 +125,6 @@ class ActorView(View):
return None return None
@register_route('/outbox')
class OutboxView(View):
async def get(self, request: Request) -> Response:
msg = aputils.Message.new(
aputils.ObjectType.ORDERED_COLLECTION,
{
"id": f'https://{self.config.domain}/outbox',
"totalItems": 0,
"orderedItems": []
}
)
return Response.new(msg, ctype = 'activity')
@register_route('/following', '/followers')
class RelationshipView(View):
async def get(self, request: Request) -> Response:
with self.database.session(False) as s:
inboxes = [row['actor'] for row in s.get_inboxes()]
msg = aputils.Message.new(
aputils.ObjectType.COLLECTION,
{
"id": f'https://{self.config.domain}{request.path}',
"totalItems": len(inboxes),
"items": inboxes
}
)
return Response.new(msg, ctype = 'activity')
@register_route('/.well-known/webfinger') @register_route('/.well-known/webfinger')
class WebfingerView(View): class WebfingerView(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:

View file

@ -1,9 +1,9 @@
import traceback from __future__ import annotations
from aiohttp.web import Request, middleware import typing
from aiohttp import web
from argon2.exceptions import VerifyMismatchError from argon2.exceptions import VerifyMismatchError
from collections.abc import Awaitable, Callable, Sequence
from typing import Any
from urllib.parse import urlparse from urllib.parse import urlparse
from .base import View, register_route from .base import View, register_route
@ -12,6 +12,11 @@ from .. import __version__
from ..database import ConfigData from ..database import ConfigData
from ..misc import Message, Response, boolean, get_app from ..misc import Message, Response, boolean, get_app
if typing.TYPE_CHECKING:
from aiohttp.web import Request
from collections.abc import Callable, Sequence
from typing import Any
ALLOWED_HEADERS = { ALLOWED_HEADERS = {
'accept', 'accept',
@ -21,6 +26,7 @@ ALLOWED_HEADERS = {
PUBLIC_API_PATHS: Sequence[tuple[str, str]] = ( PUBLIC_API_PATHS: Sequence[tuple[str, str]] = (
('GET', '/api/v1/relay'), ('GET', '/api/v1/relay'),
('GET', '/api/v1/instance'),
('POST', '/api/v1/token') ('POST', '/api/v1/token')
) )
@ -32,10 +38,8 @@ def check_api_path(method: str, path: str) -> bool:
return path.startswith('/api') return path.startswith('/api')
@middleware @web.middleware
async def handle_api_path( async def handle_api_path(request: Request, handler: Callable) -> Response:
request: Request,
handler: Callable[[Request], Awaitable[Response]]) -> Response:
try: try:
if (token := request.cookies.get('user-token')): if (token := request.cookies.get('user-token')):
request['token'] = token request['token'] = token
@ -90,10 +94,10 @@ class Login(View):
token = conn.put_token(data['username']) token = conn.put_token(data['username'])
resp = Response.new({'token': token.code}, ctype = 'json') resp = Response.new({'token': token['code']}, ctype = 'json')
resp.set_cookie( resp.set_cookie(
'user-token', 'user-token',
token.code, token['code'],
max_age = 60 * 60 * 24 * 365, max_age = 60 * 60 * 24 * 365,
domain = self.config.domain, domain = self.config.domain,
path = '/', path = '/',
@ -117,7 +121,7 @@ class RelayInfo(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
config = conn.get_config_all() config = conn.get_config_all()
inboxes = [row.domain for row in conn.get_inboxes()] inboxes = [row['domain'] for row in conn.get_inboxes()]
data = { data = {
'domain': self.config.domain, 'domain': self.config.domain,
@ -188,7 +192,7 @@ class Config(View):
class Inbox(View): class Inbox(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
data = tuple(conn.get_inboxes()) data = conn.get_inboxes()
return Response.new(data, ctype = 'json') return Response.new(data, ctype = 'json')
@ -202,36 +206,24 @@ class Inbox(View):
data['domain'] = urlparse(data["actor"]).netloc data['domain'] = urlparse(data["actor"]).netloc
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_inbox(data['domain']) is not None: if conn.get_inbox(data['domain']):
return Response.new_error(404, 'Instance already in database', 'json') return Response.new_error(404, 'Instance already in database', 'json')
data['domain'] = data['domain'].encode('idna').decode()
if not data.get('inbox'): if not data.get('inbox'):
try: actor_data: Message | None = await self.client.get(data['actor'], True, Message)
actor_data = await self.client.get(data['actor'], True, Message)
except Exception: if actor_data is None:
traceback.print_exc()
return Response.new_error(500, 'Failed to fetch actor', 'json') return Response.new_error(500, 'Failed to fetch actor', 'json')
data['inbox'] = actor_data.shared_inbox data['inbox'] = actor_data.shared_inbox
if not data.get('software'): if not data.get('software'):
try: nodeinfo = await self.client.fetch_nodeinfo(data['domain'])
nodeinfo = await self.client.fetch_nodeinfo(data['domain'])
if nodeinfo is not None:
data['software'] = nodeinfo.sw_name data['software'] = nodeinfo.sw_name
except Exception: row = conn.put_inbox(**data)
pass
row = conn.put_inbox(
domain = data['domain'],
actor = data['actor'],
inbox = data.get('inbox'),
software = data.get('software'),
followid = data.get('followid')
)
return Response.new(row, ctype = 'json') return Response.new(row, ctype = 'json')
@ -243,17 +235,10 @@ class Inbox(View):
if isinstance(data, Response): if isinstance(data, Response):
return data return data
data['domain'] = data['domain'].encode('idna').decode() if not (instance := conn.get_inbox(data['domain'])):
if (instance := conn.get_inbox(data['domain'])) is None:
return Response.new_error(404, 'Instance with domain not found', 'json') return Response.new_error(404, 'Instance with domain not found', 'json')
instance = conn.put_inbox( instance = conn.put_inbox(instance['domain'], **data)
instance.domain,
actor = data.get('actor'),
software = data.get('software'),
followid = data.get('followid')
)
return Response.new(instance, ctype = 'json') return Response.new(instance, ctype = 'json')
@ -265,8 +250,6 @@ class Inbox(View):
if isinstance(data, Response): if isinstance(data, Response):
return data return data
data['domain'] = data['domain'].encode('idna').decode()
if not conn.get_inbox(data['domain']): if not conn.get_inbox(data['domain']):
return Response.new_error(404, 'Instance with domain not found', 'json') return Response.new_error(404, 'Instance with domain not found', 'json')
@ -279,19 +262,14 @@ class Inbox(View):
class RequestView(View): class RequestView(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
instances = tuple(conn.get_requests()) instances = conn.get_requests()
return Response.new(instances, ctype = 'json') return Response.new(instances, ctype = 'json')
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
data: dict[str, Any] | Response = await self.get_api_data(['domain', 'accept'], []) data: dict[str, Any] | Response = await self.get_api_data(['domain', 'accept'], [])
if isinstance(data, Response):
return data
data['accept'] = boolean(data['accept']) data['accept'] = boolean(data['accept'])
data['domain'] = data['domain'].encode('idna').decode()
try: try:
with self.database.session(True) as conn: with self.database.session(True) as conn:
@ -302,20 +280,20 @@ class RequestView(View):
message = Message.new_response( message = Message.new_response(
host = self.config.domain, host = self.config.domain,
actor = instance.actor, actor = instance['actor'],
followid = instance.followid, followid = instance['followid'],
accept = data['accept'] accept = data['accept']
) )
self.app.push_message(instance.inbox, message, instance) self.app.push_message(instance['inbox'], message, instance)
if data['accept'] and instance.software != 'mastodon': if data['accept'] and instance['software'] != 'mastodon':
message = Message.new_follow( message = Message.new_follow(
host = self.config.domain, host = self.config.domain,
actor = instance.actor actor = instance['actor']
) )
self.app.push_message(instance.inbox, message, instance) self.app.push_message(instance['inbox'], message, instance)
resp_message = {'message': 'Request accepted' if data['accept'] else 'Request denied'} resp_message = {'message': 'Request accepted' if data['accept'] else 'Request denied'}
return Response.new(resp_message, ctype = 'json') return Response.new(resp_message, ctype = 'json')
@ -325,7 +303,7 @@ class RequestView(View):
class DomainBan(View): class DomainBan(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
bans = tuple(conn.get_domain_bans()) bans = tuple(conn.execute('SELECT * FROM domain_bans').all())
return Response.new(bans, ctype = 'json') return Response.new(bans, ctype = 'json')
@ -336,17 +314,11 @@ class DomainBan(View):
if isinstance(data, Response): if isinstance(data, Response):
return data return data
data['domain'] = data['domain'].encode('idna').decode()
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_domain_ban(data['domain']) is not None: if conn.get_domain_ban(data['domain']):
return Response.new_error(400, 'Domain already banned', 'json') return Response.new_error(400, 'Domain already banned', 'json')
ban = conn.put_domain_ban( ban = conn.put_domain_ban(**data)
domain = data['domain'],
reason = data.get('reason'),
note = data.get('note')
)
return Response.new(ban, ctype = 'json') return Response.new(ban, ctype = 'json')
@ -358,19 +330,13 @@ class DomainBan(View):
if isinstance(data, Response): if isinstance(data, Response):
return data return data
if not conn.get_domain_ban(data['domain']):
return Response.new_error(404, 'Domain not banned', 'json')
if not any([data.get('note'), data.get('reason')]): if not any([data.get('note'), data.get('reason')]):
return Response.new_error(400, 'Must include note and/or reason parameters', 'json') return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
data['domain'] = data['domain'].encode('idna').decode() ban = conn.update_domain_ban(**data)
if conn.get_domain_ban(data['domain']) is None:
return Response.new_error(404, 'Domain not banned', 'json')
ban = conn.update_domain_ban(
domain = data['domain'],
reason = data.get('reason'),
note = data.get('note')
)
return Response.new(ban, ctype = 'json') return Response.new(ban, ctype = 'json')
@ -382,9 +348,7 @@ class DomainBan(View):
if isinstance(data, Response): if isinstance(data, Response):
return data return data
data['domain'] = data['domain'].encode('idna').decode() if not conn.get_domain_ban(data['domain']):
if conn.get_domain_ban(data['domain']) is None:
return Response.new_error(404, 'Domain not banned', 'json') return Response.new_error(404, 'Domain not banned', 'json')
conn.del_domain_ban(data['domain']) conn.del_domain_ban(data['domain'])
@ -396,7 +360,7 @@ class DomainBan(View):
class SoftwareBan(View): class SoftwareBan(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
bans = tuple(conn.get_software_bans()) bans = tuple(conn.execute('SELECT * FROM software_bans').all())
return Response.new(bans, ctype = 'json') return Response.new(bans, ctype = 'json')
@ -408,14 +372,10 @@ class SoftwareBan(View):
return data return data
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_software_ban(data['name']) is not None: if conn.get_software_ban(data['name']):
return Response.new_error(400, 'Domain already banned', 'json') return Response.new_error(400, 'Domain already banned', 'json')
ban = conn.put_software_ban( ban = conn.put_software_ban(**data)
name = data['name'],
reason = data.get('reason'),
note = data.get('note')
)
return Response.new(ban, ctype = 'json') return Response.new(ban, ctype = 'json')
@ -426,18 +386,14 @@ class SoftwareBan(View):
if isinstance(data, Response): if isinstance(data, Response):
return data return data
if not any([data.get('note'), data.get('reason')]):
return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_software_ban(data['name']) is None: if not conn.get_software_ban(data['name']):
return Response.new_error(404, 'Software not banned', 'json') return Response.new_error(404, 'Software not banned', 'json')
ban = conn.update_software_ban( if not any([data.get('note'), data.get('reason')]):
name = data['name'], return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
reason = data.get('reason'),
note = data.get('note') ban = conn.update_software_ban(**data)
)
return Response.new(ban, ctype = 'json') return Response.new(ban, ctype = 'json')
@ -449,7 +405,7 @@ class SoftwareBan(View):
return data return data
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_software_ban(data['name']) is None: if not conn.get_software_ban(data['name']):
return Response.new_error(404, 'Software not banned', 'json') return Response.new_error(404, 'Software not banned', 'json')
conn.del_software_ban(data['name']) conn.del_software_ban(data['name'])
@ -463,7 +419,7 @@ class User(View):
with self.database.session() as conn: with self.database.session() as conn:
items = [] items = []
for row in conn.get_users(): for row in conn.execute('SELECT * FROM users'):
del row['hash'] del row['hash']
items.append(row) items.append(row)
@ -477,16 +433,12 @@ class User(View):
return data return data
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_user(data['username']) is not None: if conn.get_user(data['username']):
return Response.new_error(404, 'User already exists', 'json') return Response.new_error(404, 'User already exists', 'json')
user = conn.put_user( user = conn.put_user(**data)
username = data['username'], del user['hash']
password = data['password'],
handle = data.get('handle')
)
del user['hash']
return Response.new(user, ctype = 'json') return Response.new(user, ctype = 'json')
@ -497,13 +449,9 @@ class User(View):
return data return data
with self.database.session(True) as conn: with self.database.session(True) as conn:
user = conn.put_user( user = conn.put_user(**data)
username = data['username'], del user['hash']
password = data['password'],
handle = data.get('handle')
)
del user['hash']
return Response.new(user, ctype = 'json') return Response.new(user, ctype = 'json')
@ -514,7 +462,7 @@ class User(View):
return data return data
with self.database.session(True) as conn: with self.database.session(True) as conn:
if conn.get_user(data['username']) is None: if not conn.get_user(data['username']):
return Response.new_error(404, 'User does not exist', 'json') return Response.new_error(404, 'User does not exist', 'json')
conn.del_user(data['username']) conn.del_user(data['username'])
@ -526,7 +474,7 @@ class User(View):
class Whitelist(View): class Whitelist(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
items = tuple(conn.get_domains_whitelist()) items = tuple(conn.execute('SELECT * FROM whitelist').all())
return Response.new(items, ctype = 'json') return Response.new(items, ctype = 'json')
@ -537,13 +485,11 @@ class Whitelist(View):
if isinstance(data, Response): if isinstance(data, Response):
return data return data
domain = data['domain'].encode('idna').decode()
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_domain_whitelist(domain) is not None: if conn.get_domain_whitelist(data['domain']):
return Response.new_error(400, 'Domain already added to whitelist', 'json') return Response.new_error(400, 'Domain already added to whitelist', 'json')
item = conn.put_domain_whitelist(domain) item = conn.put_domain_whitelist(**data)
return Response.new(item, ctype = 'json') return Response.new(item, ctype = 'json')
@ -554,12 +500,10 @@ class Whitelist(View):
if isinstance(data, Response): if isinstance(data, Response):
return data return data
domain = data['domain'].encode('idna').decode()
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_domain_whitelist(domain) is None: if not conn.get_domain_whitelist(data['domain']):
return Response.new_error(404, 'Domain not in whitelist', 'json') return Response.new_error(404, 'Domain not in whitelist', 'json')
conn.del_domain_whitelist(domain) conn.del_domain_whitelist(data['domain'])
return Response.new({'message': 'Removed domain from whitelist'}, ctype = 'json') return Response.new({'message': 'Removed domain from whitelist'}, ctype = 'json')

View file

@ -1,24 +1,26 @@
from __future__ import annotations from __future__ import annotations
import typing
from Crypto.Random import get_random_bytes from Crypto.Random import get_random_bytes
from aiohttp.abc import AbstractView from aiohttp.abc import AbstractView
from aiohttp.hdrs import METH_ALL as METHODS from aiohttp.hdrs import METH_ALL as METHODS
from aiohttp.web import HTTPMethodNotAllowed, Request from aiohttp.web import HTTPMethodNotAllowed
from base64 import b64encode from base64 import b64encode
from bsql import Database
from collections.abc import Awaitable, Callable, Generator, Sequence, Mapping
from functools import cached_property from functools import cached_property
from json.decoder import JSONDecodeError from json.decoder import JSONDecodeError
from typing import TYPE_CHECKING, Any
from ..cache import Cache
from ..config import Config
from ..database import Connection
from ..http_client import HttpClient
from ..misc import Response, get_app from ..misc import Response, get_app
if TYPE_CHECKING: if typing.TYPE_CHECKING:
from aiohttp.web import Request
from collections.abc import Callable, Generator, Sequence, Mapping
from bsql import Database
from typing import Any
from ..application import Application from ..application import Application
from ..cache import Cache
from ..config import Config
from ..http_client import HttpClient
from ..template import Template from ..template import Template
try: try:
@ -27,8 +29,6 @@ if TYPE_CHECKING:
except ImportError: except ImportError:
from typing_extensions import Self from typing_extensions import Self
HandlerCallback = Callable[[Request], Awaitable[Response]]
VIEWS: list[tuple[str, type[View]]] = [] VIEWS: list[tuple[str, type[View]]] = []
@ -37,7 +37,7 @@ def convert_data(data: Mapping[str, Any]) -> dict[str, str]:
return {key: str(value) for key, value in data.items()} return {key: str(value) for key, value in data.items()}
def register_route(*paths: str) -> Callable[[type[View]], type[View]]: def register_route(*paths: str) -> Callable:
def wrapper(view: type[View]) -> type[View]: def wrapper(view: type[View]) -> type[View]:
for path in paths: for path in paths:
VIEWS.append((path, view)) VIEWS.append((path, view))
@ -63,7 +63,7 @@ class View(AbstractView):
return await view.handlers[method](request, **kwargs) return await view.handlers[method](request, **kwargs)
async def _run_handler(self, handler: HandlerCallback, **kwargs: Any) -> Response: async def _run_handler(self, handler: Callable[..., Any], **kwargs: Any) -> Response:
self.request['hash'] = b64encode(get_random_bytes(16)).decode('ascii') self.request['hash'] = b64encode(get_random_bytes(16)).decode('ascii')
return await handler(self.request, **self.request.match_info, **kwargs) return await handler(self.request, **self.request.match_info, **kwargs)
@ -78,7 +78,7 @@ class View(AbstractView):
@cached_property @cached_property
def handlers(self) -> dict[str, HandlerCallback]: def handlers(self) -> dict[str, Callable[..., Any]]:
data = {} data = {}
for method in METHODS: for method in METHODS:
@ -112,13 +112,13 @@ class View(AbstractView):
@property @property
def database(self) -> Database[Connection]: def database(self) -> Database:
return self.app.database return self.app.database
@property @property
def template(self) -> Template: def template(self) -> Template:
return self.app['template'] # type: ignore[no-any-return] return self.app['template']
async def get_api_data(self, async def get_api_data(self,

View file

@ -1,6 +1,8 @@
from __future__ import annotations
import typing
from aiohttp import web from aiohttp import web
from collections.abc import Awaitable, Callable
from typing import Any
from .base import View, register_route from .base import View, register_route
@ -8,6 +10,11 @@ from ..database import THEMES
from ..logger import LogLevel from ..logger import LogLevel
from ..misc import Response, get_app from ..misc import Response, get_app
if typing.TYPE_CHECKING:
from aiohttp.web import Request
from collections.abc import Callable
from typing import Any
UNAUTH_ROUTES = { UNAUTH_ROUTES = {
'/', '/',
@ -16,10 +23,7 @@ UNAUTH_ROUTES = {
@web.middleware @web.middleware
async def handle_frontend_path( async def handle_frontend_path(request: web.Request, handler: Callable) -> Response:
request: web.Request,
handler: Callable[[web.Request], Awaitable[Response]]) -> Response:
app = get_app() app = get_app()
if request.path in UNAUTH_ROUTES or request.path.startswith('/admin'): if request.path in UNAUTH_ROUTES or request.path.startswith('/admin'):
@ -48,7 +52,7 @@ async def handle_frontend_path(
@register_route('/') @register_route('/')
class HomeView(View): class HomeView(View):
async def get(self, request: web.Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
context: dict[str, Any] = { context: dict[str, Any] = {
'instances': tuple(conn.get_inboxes()) 'instances': tuple(conn.get_inboxes())
@ -60,14 +64,14 @@ class HomeView(View):
@register_route('/login') @register_route('/login')
class Login(View): class Login(View):
async def get(self, request: web.Request) -> Response: async def get(self, request: Request) -> Response:
data = self.template.render('page/login.haml', self) data = self.template.render('page/login.haml', self)
return Response.new(data, ctype = 'html') return Response.new(data, ctype = 'html')
@register_route('/logout') @register_route('/logout')
class Logout(View): class Logout(View):
async def get(self, request: web.Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.session(True) as conn: with self.database.session(True) as conn:
conn.del_token(request['token']) conn.del_token(request['token'])
@ -78,14 +82,14 @@ class Logout(View):
@register_route('/admin') @register_route('/admin')
class Admin(View): class Admin(View):
async def get(self, request: web.Request) -> Response: async def get(self, request: Request) -> Response:
return Response.new('', 302, {'Location': '/admin/instances'}) return Response.new('', 302, {'Location': '/admin/instances'})
@register_route('/admin/instances') @register_route('/admin/instances')
class AdminInstances(View): class AdminInstances(View):
async def get(self, async def get(self,
request: web.Request, request: Request,
error: str | None = None, error: str | None = None,
message: str | None = None) -> Response: message: str | None = None) -> Response:
@ -108,7 +112,7 @@ class AdminInstances(View):
@register_route('/admin/whitelist') @register_route('/admin/whitelist')
class AdminWhitelist(View): class AdminWhitelist(View):
async def get(self, async def get(self,
request: web.Request, request: Request,
error: str | None = None, error: str | None = None,
message: str | None = None) -> Response: message: str | None = None) -> Response:
@ -130,7 +134,7 @@ class AdminWhitelist(View):
@register_route('/admin/domain_bans') @register_route('/admin/domain_bans')
class AdminDomainBans(View): class AdminDomainBans(View):
async def get(self, async def get(self,
request: web.Request, request: Request,
error: str | None = None, error: str | None = None,
message: str | None = None) -> Response: message: str | None = None) -> Response:
@ -152,7 +156,7 @@ class AdminDomainBans(View):
@register_route('/admin/software_bans') @register_route('/admin/software_bans')
class AdminSoftwareBans(View): class AdminSoftwareBans(View):
async def get(self, async def get(self,
request: web.Request, request: Request,
error: str | None = None, error: str | None = None,
message: str | None = None) -> Response: message: str | None = None) -> Response:
@ -174,7 +178,7 @@ class AdminSoftwareBans(View):
@register_route('/admin/users') @register_route('/admin/users')
class AdminUsers(View): class AdminUsers(View):
async def get(self, async def get(self,
request: web.Request, request: Request,
error: str | None = None, error: str | None = None,
message: str | None = None) -> Response: message: str | None = None) -> Response:
@ -195,22 +199,11 @@ class AdminUsers(View):
@register_route('/admin/config') @register_route('/admin/config')
class AdminConfig(View): class AdminConfig(View):
async def get(self, request: web.Request, message: str | None = None) -> Response: async def get(self, request: Request, message: str | None = None) -> Response:
context: dict[str, Any] = { context: dict[str, Any] = {
'themes': tuple(THEMES.keys()), 'themes': tuple(THEMES.keys()),
'levels': tuple(level.name for level in LogLevel), 'levels': tuple(level.name for level in LogLevel),
'message': message, 'message': message
'desc': {
"name": "Name of the relay to be displayed in the header of the pages and in " +
"the actor endpoint.", # noqa: E131
"note": "Description of the relay to be displayed on the front page and as the " +
"bio in the actor endpoint.",
"theme": "Color theme to use on the web pages.",
"log_level": "Minimum level of logging messages to print to the console.",
"whitelist_enabled": "Only allow instances in the whitelist to be able to follow.",
"approval_required": "Require instances not on the whitelist to be approved by " +
"and admin. The `whitelist-enabled` setting is ignored when this is enabled."
}
} }
data = self.template.render('page/admin-config.haml', self, **context) data = self.template.render('page/admin-config.haml', self, **context)
@ -219,7 +212,7 @@ class AdminConfig(View):
@register_route('/manifest.json') @register_route('/manifest.json')
class ManifestJson(View): class ManifestJson(View):
async def get(self, request: web.Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.session(False) as conn: with self.database.session(False) as conn:
config = conn.get_config_all() config = conn.get_config_all()
theme = THEMES[config.theme] theme = THEMES[config.theme]
@ -242,7 +235,7 @@ class ManifestJson(View):
@register_route('/theme/{theme}.css') @register_route('/theme/{theme}.css')
class ThemeCss(View): class ThemeCss(View):
async def get(self, request: web.Request, theme: str) -> Response: async def get(self, request: Request, theme: str) -> Response:
try: try:
context: dict[str, Any] = { context: dict[str, Any] = {
'theme': THEMES[theme] 'theme': THEMES[theme]

View file

@ -1,7 +1,9 @@
from __future__ import annotations
import aputils import aputils
import subprocess import subprocess
import typing
from aiohttp.web import Request
from pathlib import Path from pathlib import Path
from .base import View, register_route from .base import View, register_route
@ -9,6 +11,9 @@ from .base import View, register_route
from .. import __version__ from .. import __version__
from ..misc import Response from ..misc import Response
if typing.TYPE_CHECKING:
from aiohttp.web import Request
VERSION = __version__ VERSION = __version__

View file

@ -1,150 +0,0 @@
import asyncio
import traceback
import typing
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
from dataclasses import dataclass
from multiprocessing import Event, Process, Queue, Value
from multiprocessing.synchronize import Event as EventType
from pathlib import Path
from queue import Empty, Queue as QueueType
from urllib.parse import urlparse
from . import application, logger as logging
from .database.schema import Instance
from .http_client import HttpClient
from .misc import IS_WINDOWS, Message, get_app
if typing.TYPE_CHECKING:
from .multiprocessing.synchronize import Syncronized
@dataclass
class QueueItem:
pass
@dataclass
class PostItem(QueueItem):
inbox: str
message: Message
instance: Instance | None
@property
def domain(self) -> str:
return urlparse(self.inbox).netloc
class PushWorker(Process):
client: HttpClient
def __init__(self, queue: QueueType[QueueItem], log_level: "Syncronized[str]") -> None:
Process.__init__(self)
self.queue: QueueType[QueueItem] = queue
self.shutdown: EventType = Event()
self.path: Path = get_app().config.path
self.log_level: "Syncronized[str]" = log_level
self._log_level_changed: EventType = Event()
def stop(self) -> None:
self.shutdown.set()
def run(self) -> None:
asyncio.run(self.handle_queue())
async def handle_queue(self) -> None:
if IS_WINDOWS:
app = application.Application(self.path)
self.client = app.client
self.client.open()
app.database.connect()
app.cache.setup()
else:
self.client = HttpClient()
self.client.open()
logging.verbose("[%i] Starting worker", self.pid)
while not self.shutdown.is_set():
try:
if self._log_level_changed.is_set():
logging.set_level(logging.LogLevel.parse(self.log_level.value))
self._log_level_changed.clear()
item = self.queue.get(block=True, timeout=0.1)
if isinstance(item, PostItem):
asyncio.create_task(self.handle_post(item))
except Empty:
await asyncio.sleep(0)
except Exception:
traceback.print_exc()
if IS_WINDOWS:
app.database.disconnect()
app.cache.close()
await self.client.close()
async def handle_post(self, item: PostItem) -> None:
try:
await self.client.post(item.inbox, item.message, item.instance)
except AsyncTimeoutError:
logging.error('Timeout when pushing to %s', item.domain)
except ClientConnectionError as e:
logging.error('Failed to connect to %s for message push: %s', item.domain, str(e))
except ClientSSLError as e:
logging.error('SSL error when pushing to %s: %s', item.domain, str(e))
class PushWorkers(list[PushWorker]):
def __init__(self, count: int) -> None:
self.queue: QueueType[QueueItem] = Queue() # type: ignore[assignment]
self._log_level: "Syncronized[str]" = Value("i", logging.get_level())
self._count: int = count
def push_item(self, item: QueueItem) -> None:
self.queue.put(item)
def push_message(self, inbox: str, message: Message, instance: Instance) -> None:
self.queue.put(PostItem(inbox, message, instance))
def set_log_level(self, value: logging.LogLevel) -> None:
self._log_level.value = value
for worker in self:
worker._log_level_changed.set()
def start(self) -> None:
if len(self) > 0:
return
for _ in range(self._count):
worker = PushWorker(self.queue, self._log_level)
worker.start()
self.append(worker)
def stop(self) -> None:
for worker in self:
worker.stop()
self.clear()