Compare commits

..

16 commits

Author SHA1 Message Date
Izalia Mae b308b03546 only watch $REPO/relay directory with dev run command 2024-06-18 23:05:37 -04:00
Izalia Mae 5407027af8 rework person_check and rename it to actor_type_check 2024-06-18 21:52:52 -04:00
Izalia Mae f49bc0ae90 replace menu open icon 2024-06-16 11:27:43 -04:00
Izalia Mae cad7f47e7e properly get date on api requests from frontend 2024-06-16 09:08:33 -04:00
Izalia Mae 058df0ac78 properly handle non-ascii domain names
* ensure domains are stored as ascii in the database
* convert domains to unicode on the frontend
2024-06-16 08:45:14 -04:00
Izalia Mae e825a01795 update activitypub-utils to 0.3.1 2024-06-15 13:10:24 -04:00
Izalia Mae ab9b8abbd2 don't allow bytes for message in push_message 2024-06-14 22:57:08 -04:00
Izalia Mae 15882f3e49 fix linter issues 2024-06-14 15:05:55 -04:00
Izalia Mae a2b96d03dc enable strict mode by default for mypy 2024-06-14 13:37:19 -04:00
Izalia Mae 98a975550a use hs2019 for some servers that support it 2024-06-14 13:34:05 -04:00
Izalia Mae ed03779a11 use correct command for dev install command 2024-06-14 13:29:55 -04:00
Izalia Mae e44108f341 update dependencies 2024-06-12 16:02:59 -04:00
Izalia Mae a0d84b5ae5 fix typing issue in cache 2024-06-12 14:50:40 -04:00
Izalia Mae 478e21fb15 require token for /api/v1/instance 2024-06-12 13:24:44 -04:00
Izalia Mae 0d50215fc1 add missing AP routes
Adds routes for "/outbox", "/following", and "/followers"
2024-06-12 13:23:53 -04:00
Izalia Mae 62555b3591 use old signing algorithm by default 2024-06-12 12:40:32 -04:00
31 changed files with 2460 additions and 312 deletions

64
dev.py
View file

@ -5,16 +5,17 @@ import shutil
import subprocess import subprocess
import sys import sys
import time import time
import tomllib
from datetime import datetime, timedelta from datetime import datetime, timedelta
from pathlib import Path from pathlib import Path
from relay import __version__, logger as logging from relay import __version__, logger as logging
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Sequence from typing import Any, Sequence
try: try:
from watchdog.observers import Observer from watchdog.observers import Observer
from watchdog.events import PatternMatchingEventHandler from watchdog.events import FileSystemEvent, PatternMatchingEventHandler
except ImportError: except ImportError:
class PatternMatchingEventHandler: # type: ignore class PatternMatchingEventHandler: # type: ignore
@ -29,39 +30,38 @@ IGNORE_EXT = {
@click.group('cli') @click.group('cli')
def cli(): def cli() -> None:
'Useful commands for development' 'Useful commands for development'
@cli.command('install') @cli.command('install')
def cli_install(): @click.option('--no-dev', '-d', is_flag = True, help = 'Do not install development dependencies')
cmd = [ def cli_install(no_dev: bool) -> None:
sys.executable, '-m', 'pip', 'install', with open('pyproject.toml', 'rb') as fd:
'-r', 'requirements.txt', data = tomllib.load(fd)
'-r', 'dev-requirements.txt'
]
subprocess.run(cmd, check = False) deps = data['project']['dependencies']
if not no_dev:
deps.extend(data['project']['optional-dependencies']['dev'])
subprocess.run([sys.executable, '-m', 'pip', 'install', '-U', *deps], check = False)
@cli.command('lint') @cli.command('lint')
@click.argument('path', required = False, type = Path, default = REPO.joinpath('relay')) @click.argument('path', required = False, type = Path, default = REPO.joinpath('relay'))
@click.option('--strict', '-s', is_flag = True, help = 'Enable strict mode for mypy')
@click.option('--watch', '-w', is_flag = True, @click.option('--watch', '-w', is_flag = True,
help = 'Automatically, re-run the linters on source change') help = 'Automatically, re-run the linters on source change')
def cli_lint(path: Path, strict: bool, watch: bool) -> None: def cli_lint(path: Path, watch: bool) -> None:
path = path.expanduser().resolve() path = path.expanduser().resolve()
if watch: if watch:
handle_run_watcher([sys.executable, "-m", "relay.dev", "lint", str(path)], wait = True) handle_run_watcher([sys.executable, "dev.py", "lint", str(path)], wait = True)
return return
flake8 = [sys.executable, '-m', 'flake8', "dev.py", str(path)] flake8 = [sys.executable, '-m', 'flake8', "dev.py", str(path)]
mypy = [sys.executable, '-m', 'mypy', "dev.py", str(path)] mypy = [sys.executable, '-m', 'mypy', "dev.py", str(path)]
if strict:
mypy.append('--strict')
click.echo('----- flake8 -----') click.echo('----- flake8 -----')
subprocess.run(flake8) subprocess.run(flake8)
@ -70,7 +70,7 @@ def cli_lint(path: Path, strict: bool, watch: bool) -> None:
@cli.command('clean') @cli.command('clean')
def cli_clean(): def cli_clean() -> None:
dirs = { dirs = {
'dist', 'dist',
'build', 'build',
@ -88,7 +88,7 @@ def cli_clean():
@cli.command('build') @cli.command('build')
def cli_build(): def cli_build() -> None:
with TemporaryDirectory() as tmp: with TemporaryDirectory() as tmp:
arch = 'amd64' if sys.maxsize >= 2**32 else 'i386' arch = 'amd64' if sys.maxsize >= 2**32 else 'i386'
cmd = [ cmd = [
@ -118,7 +118,7 @@ def cli_build():
@cli.command('run') @cli.command('run')
@click.option('--dev', '-d', is_flag = True) @click.option('--dev', '-d', is_flag = True)
def cli_run(dev: bool): def cli_run(dev: bool) -> None:
print('Starting process watcher') print('Starting process watcher')
cmd = [sys.executable, '-m', 'relay', 'run'] cmd = [sys.executable, '-m', 'relay', 'run']
@ -126,16 +126,20 @@ def cli_run(dev: bool):
if dev: if dev:
cmd.append('-d') cmd.append('-d')
handle_run_watcher(cmd) handle_run_watcher(cmd, watch_path = REPO.joinpath("relay"))
def handle_run_watcher(*commands: Sequence[str], wait: bool = False): def handle_run_watcher(
*commands: Sequence[str],
watch_path: Path | str = REPO,
wait: bool = False) -> None:
handler = WatchHandler(*commands, wait = wait) handler = WatchHandler(*commands, wait = wait)
handler.run_procs() handler.run_procs()
watcher = Observer() watcher = Observer()
watcher.schedule(handler, str(REPO), recursive=True) watcher.schedule(handler, str(watch_path), recursive=True) # type: ignore
watcher.start() watcher.start() # type: ignore
try: try:
while True: while True:
@ -145,7 +149,7 @@ def handle_run_watcher(*commands: Sequence[str], wait: bool = False):
pass pass
handler.kill_procs() handler.kill_procs()
watcher.stop() watcher.stop() # type: ignore
watcher.join() watcher.join()
@ -153,16 +157,16 @@ class WatchHandler(PatternMatchingEventHandler):
patterns = ['*.py'] patterns = ['*.py']
def __init__(self, *commands: Sequence[str], wait: bool = False): def __init__(self, *commands: Sequence[str], wait: bool = False) -> None:
PatternMatchingEventHandler.__init__(self) PatternMatchingEventHandler.__init__(self) # type: ignore
self.commands: Sequence[Sequence[str]] = commands self.commands: Sequence[Sequence[str]] = commands
self.wait: bool = wait self.wait: bool = wait
self.procs: list[subprocess.Popen] = [] self.procs: list[subprocess.Popen[Any]] = []
self.last_restart: datetime = datetime.now() self.last_restart: datetime = datetime.now()
def kill_procs(self): def kill_procs(self) -> None:
for proc in self.procs: for proc in self.procs:
if proc.poll() is not None: if proc.poll() is not None:
continue continue
@ -183,7 +187,7 @@ class WatchHandler(PatternMatchingEventHandler):
logging.info('Process terminated') logging.info('Process terminated')
def run_procs(self, restart: bool = False): def run_procs(self, restart: bool = False) -> None:
if restart: if restart:
if datetime.now() - timedelta(seconds = 3) < self.last_restart: if datetime.now() - timedelta(seconds = 3) < self.last_restart:
return return
@ -205,7 +209,7 @@ class WatchHandler(PatternMatchingEventHandler):
logging.info('Started processes with PIDs: %s', ', '.join(pids)) logging.info('Started processes with PIDs: %s', ', '.join(pids))
def on_any_event(self, event): def on_any_event(self, event: FileSystemEvent) -> None:
if event.event_type not in ['modified', 'created', 'deleted']: if event.event_type not in ['modified', 'created', 'deleted']:
return return

View file

@ -16,19 +16,21 @@ classifiers = [
"Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.12",
] ]
dependencies = [ dependencies = [
"activitypub-utils == 0.2.2", "activitypub-utils == 0.3.1",
"aiohttp >= 3.9.1", "aiohttp >= 3.9.5",
"aiohttp-swagger[performance] == 1.0.16", "aiohttp-swagger[performance] == 1.0.16",
"argon2-cffi == 23.1.0", "argon2-cffi == 23.1.0",
"barkshark-sql == 0.1.2", "barkshark-lib >= 0.1.3-1",
"barkshark-sql == 0.1.4-1",
"click >= 8.1.2", "click >= 8.1.2",
"hiredis == 2.3.2", "hiredis == 2.3.2",
"idna == 3.4",
"jinja2-haml == 0.3.5", "jinja2-haml == 0.3.5",
"markdown == 3.5.2", "markdown == 3.6",
"platformdirs == 4.2.0", "platformdirs == 4.2.2",
"pyyaml >= 6.0", "pyyaml >= 6.0",
"redis == 5.0.1", "redis == 5.0.5",
"importlib_resources == 6.1.1; python_version < '3.9'" "importlib-resources == 6.4.0; python_version < '3.9'"
] ]
requires-python = ">=3.8" requires-python = ">=3.8"
dynamic = ["version"] dynamic = ["version"]
@ -48,10 +50,10 @@ activityrelay = "relay.manage:main"
[project.optional-dependencies] [project.optional-dependencies]
dev = [ dev = [
"flake8 == 7.0.0", "flake8 == 7.0.0",
"mypy == 1.9.0", "mypy == 1.10.0",
"pyinstaller == 6.3.0", "pyinstaller == 6.8.0",
"watchdog == 4.0.0", "watchdog == 4.0.1",
"typing_extensions >= 4.10.0; python_version < '3.11.0'" "typing-extensions >= 4.12.2; python_version < '3.11.0'"
] ]
[tool.setuptools] [tool.setuptools]
@ -87,4 +89,22 @@ warn_redundant_casts = true
warn_unreachable = true warn_unreachable = true
warn_unused_ignores = true warn_unused_ignores = true
ignore_missing_imports = true ignore_missing_imports = true
implicit_reexport = true
strict = true
follow_imports = "silent" follow_imports = "silent"
[[tool.mypy.overrides]]
module = "relay.database"
implicit_reexport = true
[[tool.mypy.overrides]]
module = "aputils"
implicit_reexport = true
[[tool.mypy.overrides]]
module = "blib"
implicit_reexport = true
[[tool.mypy.overrides]]
module = "bsql"
implicit_reexport = true

View file

@ -1 +1 @@
__version__ = '0.3.1' __version__ = '0.3.2'

View file

@ -5,39 +5,35 @@ import multiprocessing
import signal import signal
import time import time
import traceback import traceback
import typing
from aiohttp import web from aiohttp import web
from aiohttp.web import StaticResource from aiohttp.web import StaticResource
from aiohttp_swagger import setup_swagger from aiohttp_swagger import setup_swagger
from aputils.signer import Signer from aputils.signer import Signer
from bsql import Database, Row
from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta from datetime import datetime, timedelta
from mimetypes import guess_type from mimetypes import guess_type
from pathlib import Path from pathlib import Path
from queue import Empty from queue import Empty
from threading import Event, Thread from threading import Event, Thread
from typing import Any
from . import logger as logging from . import logger as logging
from .cache import get_cache from .cache import Cache, get_cache
from .config import Config from .config import Config
from .database import get_database from .database import Connection, get_database
from .http_client import HttpClient from .http_client import HttpClient
from .misc import IS_WINDOWS, check_open_port, get_resource from .misc import IS_WINDOWS, Message, Response, check_open_port, get_resource
from .template import Template from .template import Template
from .views import VIEWS from .views import VIEWS
from .views.api import handle_api_path from .views.api import handle_api_path
from .views.frontend import handle_frontend_path from .views.frontend import handle_frontend_path
if typing.TYPE_CHECKING:
from collections.abc import Callable
from bsql import Database, Row
from .cache import Cache
from .misc import Message, Response
def get_csp(request: web.Request) -> str: def get_csp(request: web.Request) -> str:
data = [ data = [
"default-src 'none'", "default-src 'self'",
f"script-src 'nonce-{request['hash']}'", f"script-src 'nonce-{request['hash']}'",
f"style-src 'self' 'nonce-{request['hash']}'", f"style-src 'self' 'nonce-{request['hash']}'",
"form-action 'self'", "form-action 'self'",
@ -58,9 +54,9 @@ class Application(web.Application):
def __init__(self, cfgpath: Path | None, dev: bool = False): def __init__(self, cfgpath: Path | None, dev: bool = False):
web.Application.__init__(self, web.Application.__init__(self,
middlewares = [ middlewares = [
handle_api_path, handle_api_path, # type: ignore[list-item]
handle_frontend_path, handle_frontend_path, # type: ignore[list-item]
handle_response_headers handle_response_headers # type: ignore[list-item]
] ]
) )
@ -96,27 +92,27 @@ class Application(web.Application):
@property @property
def cache(self) -> Cache: def cache(self) -> Cache:
return self['cache'] return self['cache'] # type: ignore[no-any-return]
@property @property
def client(self) -> HttpClient: def client(self) -> HttpClient:
return self['client'] return self['client'] # type: ignore[no-any-return]
@property @property
def config(self) -> Config: def config(self) -> Config:
return self['config'] return self['config'] # type: ignore[no-any-return]
@property @property
def database(self) -> Database: def database(self) -> Database[Connection]:
return self['database'] return self['database'] # type: ignore[no-any-return]
@property @property
def signer(self) -> Signer: def signer(self) -> Signer:
return self['signer'] return self['signer'] # type: ignore[no-any-return]
@signer.setter @signer.setter
@ -130,7 +126,7 @@ class Application(web.Application):
@property @property
def template(self) -> Template: def template(self) -> Template:
return self['template'] return self['template'] # type: ignore[no-any-return]
@property @property
@ -143,7 +139,7 @@ class Application(web.Application):
return timedelta(seconds=uptime.seconds) return timedelta(seconds=uptime.seconds)
def push_message(self, inbox: str, message: Message | bytes, instance: Row) -> None: def push_message(self, inbox: str, message: Message, instance: Row) -> None:
self['push_queue'].put((inbox, message, instance)) self['push_queue'].put((inbox, message, instance))
@ -185,11 +181,11 @@ class Application(web.Application):
pass pass
def stop(self, *_): def stop(self, *_: Any) -> None:
self['running'] = False self['running'] = False
async def handle_run(self): async def handle_run(self) -> None:
self['running'] = True self['running'] = True
self.set_signal_handler(True) self.set_signal_handler(True)
@ -295,7 +291,7 @@ class CacheCleanupThread(Thread):
class PushWorker(multiprocessing.Process): class PushWorker(multiprocessing.Process):
def __init__(self, queue: multiprocessing.Queue): def __init__(self, queue: multiprocessing.Queue[tuple[str, Message, Row]]) -> None:
if Application.DEFAULT is None: if Application.DEFAULT is None:
raise RuntimeError('Application not setup yet') raise RuntimeError('Application not setup yet')
@ -347,7 +343,10 @@ class PushWorker(multiprocessing.Process):
@web.middleware @web.middleware
async def handle_response_headers(request: web.Request, handler: Callable) -> Response: async def handle_response_headers(
request: web.Request,
handler: Callable[[web.Request], Awaitable[Response]]) -> Response:
resp = await handler(request) resp = await handler(request)
resp.headers['Server'] = 'ActivityRelay' resp.headers['Server'] = 'ActivityRelay'

View file

@ -2,28 +2,27 @@ from __future__ import annotations
import json import json
import os import os
import typing
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from bsql import Database
from collections.abc import Callable, Iterator
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from redis import Redis from redis import Redis
from typing import TYPE_CHECKING, Any
from .database import get_database from .database import Connection, get_database
from .misc import Message, boolean from .misc import Message, boolean
if typing.TYPE_CHECKING: if TYPE_CHECKING:
from blib import Database
from collections.abc import Callable, Iterator
from typing import Any
from .application import Application from .application import Application
# todo: implement more caching backends SerializerCallback = Callable[[Any], str]
DeserializerCallback = Callable[[str], Any]
BACKENDS: dict[str, type[Cache]] = {} BACKENDS: dict[str, type[Cache]] = {}
CONVERTERS: dict[str, tuple[Callable, Callable]] = { CONVERTERS: dict[str, tuple[SerializerCallback, DeserializerCallback]] = {
'str': (str, str), 'str': (str, str),
'int': (str, int), 'int': (str, int),
'bool': (str, boolean), 'bool': (str, boolean),
@ -61,13 +60,13 @@ class Item:
updated: datetime updated: datetime
def __post_init__(self): def __post_init__(self) -> None:
if isinstance(self.updated, str): if isinstance(self.updated, str): # type: ignore[unreachable]
self.updated = datetime.fromisoformat(self.updated) self.updated = datetime.fromisoformat(self.updated) # type: ignore[unreachable]
@classmethod @classmethod
def from_data(cls: type[Item], *args) -> Item: def from_data(cls: type[Item], *args: Any) -> Item:
data = cls(*args) data = cls(*args)
data.value = deserialize_value(data.value, data.value_type) data.value = deserialize_value(data.value, data.value_type)
@ -159,10 +158,13 @@ class SqlCache(Cache):
def __init__(self, app: Application): def __init__(self, app: Application):
Cache.__init__(self, app) Cache.__init__(self, app)
self._db: Database = None self._db: Database[Connection] | None = None
def get(self, namespace: str, key: str) -> Item: def get(self, namespace: str, key: str) -> Item:
if self._db is None:
raise RuntimeError("Database has not been setup")
params = { params = {
'namespace': namespace, 'namespace': namespace,
'key': key 'key': key
@ -178,18 +180,27 @@ class SqlCache(Cache):
def get_keys(self, namespace: str) -> Iterator[str]: def get_keys(self, namespace: str) -> Iterator[str]:
if self._db is None:
raise RuntimeError("Database has not been setup")
with self._db.session(False) as conn: with self._db.session(False) as conn:
for row in conn.run('get-cache-keys', {'namespace': namespace}): for row in conn.run('get-cache-keys', {'namespace': namespace}):
yield row['key'] yield row['key']
def get_namespaces(self) -> Iterator[str]: def get_namespaces(self) -> Iterator[str]:
if self._db is None:
raise RuntimeError("Database has not been setup")
with self._db.session(False) as conn: with self._db.session(False) as conn:
for row in conn.run('get-cache-namespaces', None): for row in conn.run('get-cache-namespaces', None):
yield row['namespace'] yield row['namespace']
def set(self, namespace: str, key: str, value: Any, value_type: str = 'str') -> Item: def set(self, namespace: str, key: str, value: Any, value_type: str = 'str') -> Item:
if self._db is None:
raise RuntimeError("Database has not been setup")
params = { params = {
'namespace': namespace, 'namespace': namespace,
'key': key, 'key': key,
@ -199,13 +210,16 @@ class SqlCache(Cache):
} }
with self._db.session(True) as conn: with self._db.session(True) as conn:
with conn.run('set-cache-item', params) as conn: with conn.run('set-cache-item', params) as cur:
row = conn.one() row = cur.one()
row.pop('id', None) row.pop('id', None) # type: ignore[union-attr]
return Item.from_data(*tuple(row.values())) return Item.from_data(*tuple(row.values())) # type: ignore[union-attr]
def delete(self, namespace: str, key: str) -> None: def delete(self, namespace: str, key: str) -> None:
if self._db is None:
raise RuntimeError("Database has not been setup")
params = { params = {
'namespace': namespace, 'namespace': namespace,
'key': key 'key': key
@ -217,6 +231,9 @@ class SqlCache(Cache):
def delete_old(self, days: int = 14) -> None: def delete_old(self, days: int = 14) -> None:
if self._db is None:
raise RuntimeError("Database has not been setup")
limit = datetime.now(tz = timezone.utc) - timedelta(days = days) limit = datetime.now(tz = timezone.utc) - timedelta(days = days)
params = {"limit": limit.timestamp()} params = {"limit": limit.timestamp()}
@ -226,6 +243,9 @@ class SqlCache(Cache):
def clear(self) -> None: def clear(self) -> None:
if self._db is None:
raise RuntimeError("Database has not been setup")
with self._db.session(True) as conn: with self._db.session(True) as conn:
with conn.execute("DELETE FROM cache"): with conn.execute("DELETE FROM cache"):
pass pass
@ -360,5 +380,5 @@ class RedisCache(Cache):
if not self._rd: if not self._rd:
return return
self._rd.close() self._rd.close() # type: ignore
self._rd = None # type: ignore self._rd = None # type: ignore

View file

@ -1,21 +1,16 @@
from __future__ import annotations
import json import json
import os import os
import typing
import yaml import yaml
from functools import cached_property from functools import cached_property
from pathlib import Path from pathlib import Path
from typing import Any
from urllib.parse import urlparse from urllib.parse import urlparse
from .misc import boolean from .misc import boolean
if typing.TYPE_CHECKING:
from typing import Any
class RelayConfig(dict[str, Any]):
class RelayConfig(dict):
def __init__(self, path: str): def __init__(self, path: str):
dict.__init__(self, {}) dict.__init__(self, {})
@ -122,7 +117,7 @@ class RelayConfig(dict):
self[key] = value self[key] = value
class RelayDatabase(dict): class RelayDatabase(dict[str, Any]):
def __init__(self, config: RelayConfig): def __init__(self, config: RelayConfig):
dict.__init__(self, { dict.__init__(self, {
'relay-list': {}, 'relay-list': {},

View file

@ -1,24 +1,19 @@
from __future__ import annotations
import getpass import getpass
import os import os
import platform import platform
import typing
import yaml import yaml
from dataclasses import asdict, dataclass, fields from dataclasses import asdict, dataclass, fields
from pathlib import Path from pathlib import Path
from platformdirs import user_config_dir from platformdirs import user_config_dir
from typing import Any
from .misc import IS_DOCKER from .misc import IS_DOCKER
if typing.TYPE_CHECKING: try:
from typing import Any
try:
from typing import Self from typing import Self
except ImportError: except ImportError:
from typing_extensions import Self from typing_extensions import Self

View file

@ -1,20 +1,15 @@
from __future__ import annotations from bsql import Database
import bsql
import typing
from .config import THEMES, ConfigData from .config import THEMES, ConfigData
from .connection import RELAY_SOFTWARE, Connection from .connection import RELAY_SOFTWARE, Connection
from .schema import TABLES, VERSIONS, migrate_0 from .schema import TABLES, VERSIONS, migrate_0
from .. import logger as logging from .. import logger as logging
from ..config import Config
from ..misc import get_resource from ..misc import get_resource
if typing.TYPE_CHECKING:
from ..config import Config
def get_database(config: Config, migrate: bool = True) -> Database[Connection]:
def get_database(config: Config, migrate: bool = True) -> bsql.Database:
options = { options = {
'connection_class': Connection, 'connection_class': Connection,
'pool_size': 5, 'pool_size': 5,
@ -22,10 +17,10 @@ def get_database(config: Config, migrate: bool = True) -> bsql.Database:
} }
if config.db_type == 'sqlite': if config.db_type == 'sqlite':
db = bsql.Database.sqlite(config.sqlite_path, **options) db = Database.sqlite(config.sqlite_path, **options)
elif config.db_type == 'postgres': elif config.db_type == 'postgres':
db = bsql.Database.postgresql( db = Database.postgresql(
config.pg_name, config.pg_name,
config.pg_host, config.pg_host,
config.pg_port, config.pg_port,

View file

@ -1,21 +1,19 @@
from __future__ import annotations from __future__ import annotations
# removing the above line turns annotations into types instead of str objects which messes with
# `Field.type`
import typing from bsql import Row
from collections.abc import Callable, Sequence
from dataclasses import Field, asdict, dataclass, fields from dataclasses import Field, asdict, dataclass, fields
from typing import Any
from .. import logger as logging from .. import logger as logging
from ..misc import boolean from ..misc import boolean
if typing.TYPE_CHECKING: try:
from bsql import Row
from collections.abc import Callable, Sequence
from typing import Any
try:
from typing import Self from typing import Self
except ImportError: except ImportError:
from typing_extensions import Self from typing_extensions import Self
@ -120,7 +118,7 @@ class ConfigData:
@classmethod @classmethod
def FIELD(cls: type[Self], key: str) -> Field: def FIELD(cls: type[Self], key: str) -> Field[Any]:
for field in fields(cls): for field in fields(cls):
if field.name == key.replace('-', '_'): if field.name == key.replace('-', '_'):
return field return field

View file

@ -1,10 +1,10 @@
from __future__ import annotations from __future__ import annotations
import typing
from argon2 import PasswordHasher from argon2 import PasswordHasher
from bsql import Connection as SqlConnection, Update from bsql import Connection as SqlConnection, Row, Update
from collections.abc import Iterator, Sequence
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse from urllib.parse import urlparse
from uuid import uuid4 from uuid import uuid4
@ -14,15 +14,10 @@ from .config import (
) )
from .. import logger as logging from .. import logger as logging
from ..misc import boolean, get_app from ..misc import Message, boolean, get_app
if typing.TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
from bsql import Row
from typing import Any
from ..application import Application from ..application import Application
from ..misc import Message
RELAY_SOFTWARE = [ RELAY_SOFTWARE = [
'activityrelay', # https://git.pleroma.social/pleroma/relay 'activityrelay', # https://git.pleroma.social/pleroma/relay
@ -94,7 +89,7 @@ class Connection(SqlConnection):
params = { params = {
'key': key, 'key': key,
'value': data.get(key, serialize = True), 'value': data.get(key, serialize = True),
'type': 'LogLevel' if field.type == 'logging.LogLevel' else field.type 'type': 'LogLevel' if field.type == 'logging.LogLevel' else field.type # type: ignore
} }
with self.run('put-config', params): with self.run('put-config', params):

View file

@ -1,17 +1,11 @@
from __future__ import annotations
import typing
from bsql import Column, Table, Tables from bsql import Column, Table, Tables
from collections.abc import Callable
from .config import ConfigData from .config import ConfigData
from .connection import Connection
if typing.TYPE_CHECKING:
from collections.abc import Callable
from .connection import Connection
VERSIONS: dict[int, Callable] = {} VERSIONS: dict[int, Callable[[Connection], None]] = {}
TABLES: Tables = Tables( TABLES: Tables = Tables(
Table( Table(
'config', 'config',
@ -64,7 +58,7 @@ TABLES: Tables = Tables(
) )
def migration(func: Callable) -> Callable: def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]:
ver = int(func.__name__.replace('migrate_', '')) ver = int(func.__name__.replace('migrate_', ''))
VERSIONS[ver] = func VERSIONS[ver] = func
return func return func

View file

@ -13,6 +13,7 @@
%meta(name="viewport" content="width=device-width, initial-scale=1") %meta(name="viewport" content="width=device-width, initial-scale=1")
%link(rel="stylesheet" type="text/css" href="/theme/{{config.theme}}.css" nonce="{{view.request['hash']}}" class="theme") %link(rel="stylesheet" type="text/css" href="/theme/{{config.theme}}.css" nonce="{{view.request['hash']}}" class="theme")
%link(rel="stylesheet" type="text/css" href="/static/style.css" nonce="{{view.request['hash']}}") %link(rel="stylesheet" type="text/css" href="/static/style.css" nonce="{{view.request['hash']}}")
%link(rel="stylesheet" type="text/css" href="/static/bootstrap-icons.css" nonce="{{view.request['hash']}}")
%link(rel="manifest" href="/manifest.json") %link(rel="manifest" href="/manifest.json")
%script(type="application/javascript" src="/static/api.js" nonce="{{view.request['hash']}}" defer) %script(type="application/javascript" src="/static/api.js" nonce="{{view.request['hash']}}" defer)
-block head -block head
@ -41,7 +42,7 @@
#container #container
#header.section #header.section
%span#menu-open << &#8286; %span#menu-open -> %i(class="bi bi-list")
%a.title(href="/") -> =config.name %a.title(href="/") -> =config.name
.empty .empty

View file

@ -35,7 +35,7 @@
%tr(id="{{ban.domain}}") %tr(id="{{ban.domain}}")
%td.domain %td.domain
%details %details
%summary -> =ban.domain %summary -> =ban.domain.encode().decode("idna")
.grid-2col .grid-2col
%label.reason(for="{{ban.domain}}-reason") << Reason %label.reason(for="{{ban.domain}}-reason") << Reason

View file

@ -39,7 +39,7 @@
-for request in requests -for request in requests
%tr(id="{{request.domain}}") %tr(id="{{request.domain}}")
%td.instance %td.instance
%a(href="https://{{request.domain}}" target="_new") -> =request.domain %a(href="https://{{request.domain}}" target="_new") -> =request.domain.encode().decode("idna")
%td.software %td.software
=request.software or "n/a" =request.software or "n/a"
@ -69,7 +69,7 @@
-for instance in instances -for instance in instances
%tr(id="{{instance.domain}}") %tr(id="{{instance.domain}}")
%td.instance %td.instance
%a(href="https://{{instance.domain}}/" target="_new") -> =instance.domain %a(href="https://{{instance.domain}}/" target="_new") -> =instance.domain.encode().decode("idna")
%td.software %td.software
=instance.software or "n/a" =instance.software or "n/a"

View file

@ -27,7 +27,7 @@
-for item in whitelist -for item in whitelist
%tr(id="{{item.domain}}") %tr(id="{{item.domain}}")
%td.domain %td.domain
=item.domain =item.domain.encode().decode("idna")
%td.date %td.date
=item.created.strftime("%Y-%m-%d") =item.created.strftime("%Y-%m-%d")

View file

@ -41,7 +41,7 @@
-for instance in instances -for instance in instances
%tr %tr
%td.instance -> %a(href="https://{{instance.domain}}/" target="_new") %td.instance -> %a(href="https://{{instance.domain}}/" target="_new")
=instance.domain =instance.domain.encode().decode("idna")
%td.date %td.date
=instance.created.strftime("%Y-%m-%d") =instance.created.strftime("%Y-%m-%d")

View file

@ -32,18 +32,18 @@ function toast(text, type="error", timeout=5) {
const body = document.getElementById("container") const body = document.getElementById("container")
const menu = document.getElementById("menu"); const menu = document.getElementById("menu");
const menu_open = document.getElementById("menu-open"); const menu_open = document.querySelector("#menu-open i");
const menu_close = document.getElementById("menu-close"); const menu_close = document.getElementById("menu-close");
menu_open.addEventListener("click", (event) => { function toggle_menu() {
var new_value = menu.attributes.visible.nodeValue === "true" ? "false" : "true"; let new_value = menu.attributes.visible.nodeValue === "true" ? "false" : "true";
menu.attributes.visible.nodeValue = new_value; menu.attributes.visible.nodeValue = new_value;
}); }
menu_close.addEventListener("click", (event) => {
menu.attributes.visible.nodeValue = "false" menu_open.addEventListener("click", toggle_menu);
}); menu_close.addEventListener("click", toggle_menu);
body.addEventListener("click", (event) => { body.addEventListener("click", (event) => {
if (event.target === menu_open) { if (event.target === menu_open) {
@ -53,21 +53,17 @@ body.addEventListener("click", (event) => {
menu.attributes.visible.nodeValue = "false"; menu.attributes.visible.nodeValue = "false";
}); });
for (const elem of document.querySelectorAll("#menu-open div")) {
elem.addEventListener("click", toggle_menu);
}
// misc // misc
function get_date_string(date) { function get_date_string(date) {
var year = date.getFullYear().toString(); var year = date.getUTCFullYear().toString();
var month = date.getMonth().toString(); var month = (date.getUTCMonth() + 1).toString().padStart(2, "0");
var day = date.getDay().toString(); var day = date.getUTCDate().toString().padStart(2, "0");
if (month.length === 1) {
month = "0" + month;
}
if (day.length === 1) {
day = "0" + day
}
return `${year}-${month}-${day}`; return `${year}-${month}-${day}`;
} }
@ -127,6 +123,7 @@ async function request(method, path, body = null) {
} else { } else {
if (Object.hasOwn(message, "created")) { if (Object.hasOwn(message, "created")) {
console.log(message.created)
message.created = new Date(message.created); message.created = new Date(message.created);
} }
} }

2077
relay/frontend/static/bootstrap-icons.css vendored Normal file

File diff suppressed because it is too large Load diff

Binary file not shown.

View file

@ -155,6 +155,7 @@ textarea {
z-index: 1; z-index: 1;
font-size: 1.5em; font-size: 1.5em;
min-width: 300px; min-width: 300px;
overflow-x: auto;
} }
#menu[visible="false"] { #menu[visible="false"] {
@ -188,11 +189,17 @@ textarea {
} }
#menu-open { #menu-open {
color: var(--primary); color: var(--background);
background: var(--primary);
font-size: 38px;
line-height: 38px;
border: 1px solid var(--primary);
border-radius: 5px;
} }
#menu-open:hover { #menu-open:hover {
color: var(--primary-hover); color: var(--primary);
background: var(--background);
} }
#menu-open, #menu-close { #menu-open, #menu-close {

View file

@ -2,28 +2,41 @@ from __future__ import annotations
import json import json
import traceback import traceback
import typing
from aiohttp import ClientSession, ClientTimeout, TCPConnector from aiohttp import ClientSession, ClientTimeout, TCPConnector
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
from aputils import AlgorithmType, Nodeinfo, ObjectType, Signer, WellKnownNodeinfo
from asyncio.exceptions import TimeoutError as AsyncTimeoutError from asyncio.exceptions import TimeoutError as AsyncTimeoutError
from aputils import AlgorithmType, JsonBase, Nodeinfo, ObjectType, WellKnownNodeinfo from blib import JsonBase
from bsql import Row
from json.decoder import JSONDecodeError from json.decoder import JSONDecodeError
from typing import TYPE_CHECKING, Any, TypeVar
from urllib.parse import urlparse from urllib.parse import urlparse
from . import __version__ from . import __version__, logger as logging
from . import logger as logging from .cache import Cache
from .misc import MIMETYPES, Message, get_app from .misc import MIMETYPES, Message, get_app
if typing.TYPE_CHECKING: if TYPE_CHECKING:
from aputils import Signer
from bsql import Row
from typing import Any
from .application import Application from .application import Application
from .cache import Cache
T = typing.TypeVar('T', bound = JsonBase) SUPPORTS_HS2019 = {
'friendica',
'gotosocial',
'hubzilla'
'mastodon',
'socialhome',
'misskey',
'catodon',
'cherrypick',
'firefish',
'foundkey',
'iceshrimp',
'sharkey'
}
T = TypeVar('T', bound = JsonBase)
HEADERS = { HEADERS = {
'Accept': f'{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9', 'Accept': f'{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9',
'User-Agent': f'ActivityRelay/{__version__}' 'User-Agent': f'ActivityRelay/{__version__}'
@ -90,7 +103,12 @@ class HttpClient:
self._session = None self._session = None
async def _get(self, url: str, sign_headers: bool, force: bool) -> dict[str, Any] | None: async def _get(self,
url: str,
sign_headers: bool,
force: bool,
old_algo: bool) -> dict[str, Any] | None:
if not self._session: if not self._session:
raise RuntimeError('Client not open') raise RuntimeError('Client not open')
@ -103,7 +121,7 @@ class HttpClient:
if not force: if not force:
try: try:
if not (item := self.cache.get('request', url)).older_than(48): if not (item := self.cache.get('request', url)).older_than(48):
return json.loads(item.value) return json.loads(item.value) # type: ignore[no-any-return]
except KeyError: except KeyError:
logging.verbose('No cached data for url: %s', url) logging.verbose('No cached data for url: %s', url)
@ -111,7 +129,8 @@ class HttpClient:
headers = {} headers = {}
if sign_headers: if sign_headers:
headers = self.signer.sign_headers('GET', url, algorithm = AlgorithmType.HS2019) algo = AlgorithmType.RSASHA256 if old_algo else AlgorithmType.HS2019
headers = self.signer.sign_headers('GET', url, algorithm = algo)
try: try:
logging.debug('Fetching resource: %s', url) logging.debug('Fetching resource: %s', url)
@ -131,10 +150,11 @@ class HttpClient:
self.cache.set('request', url, data, 'str') self.cache.set('request', url, data, 'str')
logging.debug('%s >> resp %s', url, json.dumps(json.loads(data), indent = 4)) logging.debug('%s >> resp %s', url, json.dumps(json.loads(data), indent = 4))
return json.loads(data) return json.loads(data) # type: ignore [no-any-return]
except JSONDecodeError: except JSONDecodeError:
logging.verbose('Failed to parse JSON') logging.verbose('Failed to parse JSON')
logging.debug(data)
return None return None
except ClientSSLError as e: except ClientSSLError as e:
@ -155,12 +175,13 @@ class HttpClient:
url: str, url: str,
sign_headers: bool, sign_headers: bool,
cls: type[T], cls: type[T],
force: bool = False) -> T | None: force: bool = False,
old_algo: bool = True) -> T | None:
if not issubclass(cls, JsonBase): if not issubclass(cls, JsonBase):
raise TypeError('cls must be a sub-class of "aputils.JsonBase"') raise TypeError('cls must be a sub-class of "blib.JsonBase"')
if (data := (await self._get(url, sign_headers, force))) is None: if (data := (await self._get(url, sign_headers, force, old_algo))) is None:
return None return None
return cls.parse(data) return cls.parse(data)
@ -171,7 +192,7 @@ class HttpClient:
raise RuntimeError('Client not open') raise RuntimeError('Client not open')
# akkoma and pleroma do not support HS2019 and other software still needs to be tested # akkoma and pleroma do not support HS2019 and other software still needs to be tested
if instance and instance['software'] in {'mastodon'}: if instance and instance['software'] in SUPPORTS_HS2019:
algorithm = AlgorithmType.HS2019 algorithm = AlgorithmType.HS2019
else: else:

View file

@ -1,23 +1,21 @@
from __future__ import annotations
import logging import logging
import os import os
import typing
from enum import IntEnum from enum import IntEnum
from pathlib import Path from pathlib import Path
from typing import Any, Protocol
if typing.TYPE_CHECKING: try:
from collections.abc import Callable
from typing import Any
try:
from typing import Self from typing import Self
except ImportError: except ImportError:
from typing_extensions import Self from typing_extensions import Self
class LoggingMethod(Protocol):
def __call__(self, msg: Any, *args: Any, **kwargs: Any) -> None: ...
class LogLevel(IntEnum): class LogLevel(IntEnum):
DEBUG = logging.DEBUG DEBUG = logging.DEBUG
VERBOSE = 15 VERBOSE = 15
@ -75,11 +73,11 @@ def verbose(message: str, *args: Any, **kwargs: Any) -> None:
logging.log(LogLevel.VERBOSE, message, *args, **kwargs) logging.log(LogLevel.VERBOSE, message, *args, **kwargs)
debug: Callable = logging.debug debug: LoggingMethod = logging.debug
info: Callable = logging.info info: LoggingMethod = logging.info
warning: Callable = logging.warning warning: LoggingMethod = logging.warning
error: Callable = logging.error error: LoggingMethod = logging.error
critical: Callable = logging.critical critical: LoggingMethod = logging.critical
try: try:

View file

@ -5,10 +5,11 @@ import asyncio
import click import click
import json import json
import os import os
import typing
from bsql import Row
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any
from urllib.parse import urlparse from urllib.parse import urlparse
from . import __version__ from . import __version__
@ -19,10 +20,6 @@ from .compat import RelayConfig, RelayDatabase
from .database import RELAY_SOFTWARE, get_database from .database import RELAY_SOFTWARE, get_database
from .misc import ACTOR_FORMATS, SOFTWARE, IS_DOCKER, Message from .misc import ACTOR_FORMATS, SOFTWARE, IS_DOCKER, Message
if typing.TYPE_CHECKING:
from bsql import Row
from typing import Any
def check_alphanumeric(text: str) -> str: def check_alphanumeric(text: str) -> str:
if not text.isalnum(): if not text.isalnum():

View file

@ -5,11 +5,12 @@ import json
import os import os
import platform import platform
import socket import socket
import typing
from aiohttp.web import Response as AiohttpResponse from aiohttp.web import Response as AiohttpResponse
from collections.abc import Sequence
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, TypedDict, TypeVar
from uuid import uuid4 from uuid import uuid4
try: try:
@ -18,21 +19,20 @@ try:
except ImportError: except ImportError:
from importlib_resources import files as pkgfiles # type: ignore from importlib_resources import files as pkgfiles # type: ignore
if typing.TYPE_CHECKING: try:
from typing import Any
from .application import Application
try:
from typing import Self from typing import Self
except ImportError: except ImportError:
from typing_extensions import Self from typing_extensions import Self
if TYPE_CHECKING:
from .application import Application
T = typing.TypeVar('T')
ResponseType = typing.TypedDict('ResponseType', { T = TypeVar('T')
ResponseType = TypedDict('ResponseType', {
'status': int, 'status': int,
'headers': dict[str, typing.Any] | None, 'headers': dict[str, Any] | None,
'content_type': str, 'content_type': str,
'body': bytes | None, 'body': bytes | None,
'text': str | None 'text': str | None
@ -128,7 +128,7 @@ class JsonEncoder(json.JSONEncoder):
if isinstance(o, datetime): if isinstance(o, datetime):
return o.isoformat() return o.isoformat()
return json.JSONEncoder.default(self, o) return json.JSONEncoder.default(self, o) # type: ignore[no-any-return]
class Message(aputils.Message): class Message(aputils.Message):
@ -148,6 +148,7 @@ class Message(aputils.Message):
'followers': f'https://{host}/followers', 'followers': f'https://{host}/followers',
'following': f'https://{host}/following', 'following': f'https://{host}/following',
'inbox': f'https://{host}/inbox', 'inbox': f'https://{host}/inbox',
'outbox': f'https://{host}/outbox',
'url': f'https://{host}/', 'url': f'https://{host}/',
'endpoints': { 'endpoints': {
'sharedInbox': f'https://{host}/inbox' 'sharedInbox': f'https://{host}/inbox'
@ -213,7 +214,7 @@ class Response(AiohttpResponse):
@classmethod @classmethod
def new(cls: type[Self], def new(cls: type[Self],
body: str | bytes | dict | tuple | list | set = '', body: str | bytes | dict[str, Any] | Sequence[Any] = '',
status: int = 200, status: int = 200,
headers: dict[str, str] | None = None, headers: dict[str, str] | None = None,
ctype: str = 'text') -> Self: ctype: str = 'text') -> Self:
@ -226,22 +227,22 @@ class Response(AiohttpResponse):
'text': None 'text': None
} }
if isinstance(body, bytes): if isinstance(body, str):
kwargs['text'] = body
elif isinstance(body, bytes):
kwargs['body'] = body kwargs['body'] = body
elif isinstance(body, (dict, list, tuple, set)) or ctype in {'json', 'activity'}: elif isinstance(body, (dict, Sequence)):
kwargs['text'] = json.dumps(body, cls = JsonEncoder) kwargs['text'] = json.dumps(body, cls = JsonEncoder)
else:
kwargs['text'] = body
return cls(**kwargs) return cls(**kwargs)
@classmethod @classmethod
def new_error(cls: type[Self], def new_error(cls: type[Self],
status: int, status: int,
body: str | bytes | dict, body: str | bytes | dict[str, Any],
ctype: str = 'text') -> Self: ctype: str = 'text') -> Self:
if ctype == 'json': if ctype == 'json':

View file

@ -10,14 +10,12 @@ if typing.TYPE_CHECKING:
from .views.activitypub import ActorView from .views.activitypub import ActorView
def person_check(actor: Message, software: str | None) -> bool: def actor_type_check(actor: Message, software: str | None) -> bool:
# pleroma and akkoma may use Person for the actor type for some reason if actor.type == 'Application':
# akkoma changed this in 3.6.0 return True
if software in {'akkoma', 'pleroma'} and actor.id == f'https://{actor.domain}/relay':
return False
# make sure the actor is an application # akkoma (< 3.6.0) and pleroma use Person for the actor type
if actor.type != 'Application': if software in {'akkoma', 'pleroma'} and actor.id == f'https://{actor.domain}/relay':
return True return True
return False return False
@ -54,7 +52,7 @@ async def handle_forward(view: ActorView, conn: Connection) -> None:
logging.debug('>> forward: %s', message) logging.debug('>> forward: %s', message)
for instance in conn.distill_inboxes(view.message): for instance in conn.distill_inboxes(view.message):
view.app.push_message(instance["inbox"], await view.request.read(), instance) view.app.push_message(instance["inbox"], view.message, instance)
view.cache.set('handle-relay', view.message.id, message.id, 'str') view.cache.set('handle-relay', view.message.id, message.id, 'str')
@ -88,7 +86,7 @@ async def handle_follow(view: ActorView, conn: Connection) -> None:
return return
# reject if the actor is not an instance actor # reject if the actor is not an instance actor
if person_check(view.actor, software): if actor_type_check(view.actor, software):
logging.verbose('Non-application actor tried to follow: %s', view.actor.id) logging.verbose('Non-application actor tried to follow: %s', view.actor.id)
view.app.push_message( view.app.push_message(

View file

@ -1,25 +1,22 @@
from __future__ import annotations from __future__ import annotations
import textwrap import textwrap
import typing
from collections.abc import Callable from collections.abc import Callable
from hamlish_jinja import HamlishExtension from hamlish_jinja import HamlishExtension
from jinja2 import Environment, FileSystemLoader from jinja2 import Environment, FileSystemLoader
from jinja2.ext import Extension from jinja2.ext import Extension
from jinja2.nodes import CallBlock from jinja2.nodes import CallBlock, Node
from jinja2.parser import Parser
from markdown import Markdown from markdown import Markdown
from typing import TYPE_CHECKING, Any
from . import __version__ from . import __version__
from .misc import get_resource from .misc import get_resource
from .views.base import View
if typing.TYPE_CHECKING: if TYPE_CHECKING:
from jinja2.nodes import Node
from jinja2.parser import Parser
from typing import Any
from .application import Application from .application import Application
from .views.base import View
class Template(Environment): class Template(Environment):

View file

@ -125,6 +125,39 @@ class ActorView(View):
return None return None
@register_route('/outbox')
class OutboxView(View):
async def get(self, request: Request) -> Response:
msg = aputils.Message.new(
aputils.ObjectType.ORDERED_COLLECTION,
{
"id": f'https://{self.config.domain}/outbox',
"totalItems": 0,
"orderedItems": []
}
)
return Response.new(msg, ctype = 'activity')
@register_route('/following', '/followers')
class RelationshipView(View):
async def get(self, request: Request) -> Response:
with self.database.session(False) as s:
inboxes = [row['actor'] for row in s.get_inboxes()]
msg = aputils.Message.new(
aputils.ObjectType.COLLECTION,
{
"id": f'https://{self.config.domain}{request.path}',
"totalItems": len(inboxes),
"items": inboxes
}
)
return Response.new(msg, ctype = 'activity')
@register_route('/.well-known/webfinger') @register_route('/.well-known/webfinger')
class WebfingerView(View): class WebfingerView(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:

View file

@ -1,9 +1,7 @@
from __future__ import annotations from aiohttp.web import Request, middleware
import typing
from aiohttp import web
from argon2.exceptions import VerifyMismatchError from argon2.exceptions import VerifyMismatchError
from collections.abc import Awaitable, Callable, Sequence
from typing import Any
from urllib.parse import urlparse from urllib.parse import urlparse
from .base import View, register_route from .base import View, register_route
@ -12,11 +10,6 @@ from .. import __version__
from ..database import ConfigData from ..database import ConfigData
from ..misc import Message, Response, boolean, get_app from ..misc import Message, Response, boolean, get_app
if typing.TYPE_CHECKING:
from aiohttp.web import Request
from collections.abc import Callable, Sequence
from typing import Any
ALLOWED_HEADERS = { ALLOWED_HEADERS = {
'accept', 'accept',
@ -26,7 +19,6 @@ ALLOWED_HEADERS = {
PUBLIC_API_PATHS: Sequence[tuple[str, str]] = ( PUBLIC_API_PATHS: Sequence[tuple[str, str]] = (
('GET', '/api/v1/relay'), ('GET', '/api/v1/relay'),
('GET', '/api/v1/instance'),
('POST', '/api/v1/token') ('POST', '/api/v1/token')
) )
@ -38,8 +30,10 @@ def check_api_path(method: str, path: str) -> bool:
return path.startswith('/api') return path.startswith('/api')
@web.middleware @middleware
async def handle_api_path(request: Request, handler: Callable) -> Response: async def handle_api_path(
request: Request,
handler: Callable[[Request], Awaitable[Response]]) -> Response:
try: try:
if (token := request.cookies.get('user-token')): if (token := request.cookies.get('user-token')):
request['token'] = token request['token'] = token
@ -209,6 +203,8 @@ class Inbox(View):
if conn.get_inbox(data['domain']): if conn.get_inbox(data['domain']):
return Response.new_error(404, 'Instance already in database', 'json') return Response.new_error(404, 'Instance already in database', 'json')
data['domain'] = data['domain'].encode('idna').decode()
if not data.get('inbox'): if not data.get('inbox'):
actor_data: Message | None = await self.client.get(data['actor'], True, Message) actor_data: Message | None = await self.client.get(data['actor'], True, Message)
@ -223,7 +219,7 @@ class Inbox(View):
if nodeinfo is not None: if nodeinfo is not None:
data['software'] = nodeinfo.sw_name data['software'] = nodeinfo.sw_name
row = conn.put_inbox(**data) row = conn.put_inbox(**data) # type: ignore[arg-type]
return Response.new(row, ctype = 'json') return Response.new(row, ctype = 'json')
@ -235,10 +231,12 @@ class Inbox(View):
if isinstance(data, Response): if isinstance(data, Response):
return data return data
data['domain'] = data['domain'].encode('idna').decode()
if not (instance := conn.get_inbox(data['domain'])): if not (instance := conn.get_inbox(data['domain'])):
return Response.new_error(404, 'Instance with domain not found', 'json') return Response.new_error(404, 'Instance with domain not found', 'json')
instance = conn.put_inbox(instance['domain'], **data) instance = conn.put_inbox(instance['domain'], **data) # type: ignore[arg-type]
return Response.new(instance, ctype = 'json') return Response.new(instance, ctype = 'json')
@ -250,6 +248,8 @@ class Inbox(View):
if isinstance(data, Response): if isinstance(data, Response):
return data return data
data['domain'] = data['domain'].encode('idna').decode()
if not conn.get_inbox(data['domain']): if not conn.get_inbox(data['domain']):
return Response.new_error(404, 'Instance with domain not found', 'json') return Response.new_error(404, 'Instance with domain not found', 'json')
@ -269,7 +269,12 @@ class RequestView(View):
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
data: dict[str, Any] | Response = await self.get_api_data(['domain', 'accept'], []) data: dict[str, Any] | Response = await self.get_api_data(['domain', 'accept'], [])
if isinstance(data, Response):
return data
data['accept'] = boolean(data['accept']) data['accept'] = boolean(data['accept'])
data['domain'] = data['domain'].encode('idna').decode()
try: try:
with self.database.session(True) as conn: with self.database.session(True) as conn:
@ -314,6 +319,8 @@ class DomainBan(View):
if isinstance(data, Response): if isinstance(data, Response):
return data return data
data['domain'] = data['domain'].encode('idna').decode()
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_domain_ban(data['domain']): if conn.get_domain_ban(data['domain']):
return Response.new_error(400, 'Domain already banned', 'json') return Response.new_error(400, 'Domain already banned', 'json')
@ -330,6 +337,8 @@ class DomainBan(View):
if isinstance(data, Response): if isinstance(data, Response):
return data return data
data['domain'] = data['domain'].encode('idna').decode()
if not conn.get_domain_ban(data['domain']): if not conn.get_domain_ban(data['domain']):
return Response.new_error(404, 'Domain not banned', 'json') return Response.new_error(404, 'Domain not banned', 'json')
@ -348,6 +357,8 @@ class DomainBan(View):
if isinstance(data, Response): if isinstance(data, Response):
return data return data
data['domain'] = data['domain'].encode('idna').decode()
if not conn.get_domain_ban(data['domain']): if not conn.get_domain_ban(data['domain']):
return Response.new_error(404, 'Domain not banned', 'json') return Response.new_error(404, 'Domain not banned', 'json')
@ -485,6 +496,8 @@ class Whitelist(View):
if isinstance(data, Response): if isinstance(data, Response):
return data return data
data['domain'] = data['domain'].encode('idna').decode()
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_domain_whitelist(data['domain']): if conn.get_domain_whitelist(data['domain']):
return Response.new_error(400, 'Domain already added to whitelist', 'json') return Response.new_error(400, 'Domain already added to whitelist', 'json')
@ -500,6 +513,8 @@ class Whitelist(View):
if isinstance(data, Response): if isinstance(data, Response):
return data return data
data['domain'] = data['domain'].encode('idna').decode()
with self.database.session() as conn: with self.database.session() as conn:
if not conn.get_domain_whitelist(data['domain']): if not conn.get_domain_whitelist(data['domain']):
return Response.new_error(404, 'Domain not in whitelist', 'json') return Response.new_error(404, 'Domain not in whitelist', 'json')

View file

@ -1,34 +1,34 @@
from __future__ import annotations from __future__ import annotations
import typing
from Crypto.Random import get_random_bytes from Crypto.Random import get_random_bytes
from aiohttp.abc import AbstractView from aiohttp.abc import AbstractView
from aiohttp.hdrs import METH_ALL as METHODS from aiohttp.hdrs import METH_ALL as METHODS
from aiohttp.web import HTTPMethodNotAllowed from aiohttp.web import HTTPMethodNotAllowed, Request
from base64 import b64encode from base64 import b64encode
from bsql import Database
from collections.abc import Awaitable, Callable, Generator, Sequence, Mapping
from functools import cached_property from functools import cached_property
from json.decoder import JSONDecodeError from json.decoder import JSONDecodeError
from typing import TYPE_CHECKING, Any
from ..cache import Cache
from ..config import Config
from ..database import Connection
from ..http_client import HttpClient
from ..misc import Response, get_app from ..misc import Response, get_app
if typing.TYPE_CHECKING: if TYPE_CHECKING:
from aiohttp.web import Request
from collections.abc import Callable, Generator, Sequence, Mapping
from bsql import Database
from typing import Any
from ..application import Application from ..application import Application
from ..cache import Cache
from ..config import Config
from ..http_client import HttpClient
from ..template import Template from ..template import Template
try: try:
from typing import Self from typing import Self
except ImportError: except ImportError:
from typing_extensions import Self from typing_extensions import Self
HandlerCallback = Callable[[Request], Awaitable[Response]]
VIEWS: list[tuple[str, type[View]]] = [] VIEWS: list[tuple[str, type[View]]] = []
@ -37,7 +37,7 @@ def convert_data(data: Mapping[str, Any]) -> dict[str, str]:
return {key: str(value) for key, value in data.items()} return {key: str(value) for key, value in data.items()}
def register_route(*paths: str) -> Callable: def register_route(*paths: str) -> Callable[[type[View]], type[View]]:
def wrapper(view: type[View]) -> type[View]: def wrapper(view: type[View]) -> type[View]:
for path in paths: for path in paths:
VIEWS.append((path, view)) VIEWS.append((path, view))
@ -63,7 +63,7 @@ class View(AbstractView):
return await view.handlers[method](request, **kwargs) return await view.handlers[method](request, **kwargs)
async def _run_handler(self, handler: Callable[..., Any], **kwargs: Any) -> Response: async def _run_handler(self, handler: HandlerCallback, **kwargs: Any) -> Response:
self.request['hash'] = b64encode(get_random_bytes(16)).decode('ascii') self.request['hash'] = b64encode(get_random_bytes(16)).decode('ascii')
return await handler(self.request, **self.request.match_info, **kwargs) return await handler(self.request, **self.request.match_info, **kwargs)
@ -78,7 +78,7 @@ class View(AbstractView):
@cached_property @cached_property
def handlers(self) -> dict[str, Callable[..., Any]]: def handlers(self) -> dict[str, HandlerCallback]:
data = {} data = {}
for method in METHODS: for method in METHODS:
@ -112,13 +112,13 @@ class View(AbstractView):
@property @property
def database(self) -> Database: def database(self) -> Database[Connection]:
return self.app.database return self.app.database
@property @property
def template(self) -> Template: def template(self) -> Template:
return self.app['template'] return self.app['template'] # type: ignore[no-any-return]
async def get_api_data(self, async def get_api_data(self,

View file

@ -1,8 +1,6 @@
from __future__ import annotations
import typing
from aiohttp import web from aiohttp import web
from collections.abc import Awaitable, Callable
from typing import Any
from .base import View, register_route from .base import View, register_route
@ -10,11 +8,6 @@ from ..database import THEMES
from ..logger import LogLevel from ..logger import LogLevel
from ..misc import Response, get_app from ..misc import Response, get_app
if typing.TYPE_CHECKING:
from aiohttp.web import Request
from collections.abc import Callable
from typing import Any
UNAUTH_ROUTES = { UNAUTH_ROUTES = {
'/', '/',
@ -23,7 +16,10 @@ UNAUTH_ROUTES = {
@web.middleware @web.middleware
async def handle_frontend_path(request: web.Request, handler: Callable) -> Response: async def handle_frontend_path(
request: web.Request,
handler: Callable[[web.Request], Awaitable[Response]]) -> Response:
app = get_app() app = get_app()
if request.path in UNAUTH_ROUTES or request.path.startswith('/admin'): if request.path in UNAUTH_ROUTES or request.path.startswith('/admin'):
@ -52,7 +48,7 @@ async def handle_frontend_path(request: web.Request, handler: Callable) -> Respo
@register_route('/') @register_route('/')
class HomeView(View): class HomeView(View):
async def get(self, request: Request) -> Response: async def get(self, request: web.Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
context: dict[str, Any] = { context: dict[str, Any] = {
'instances': tuple(conn.get_inboxes()) 'instances': tuple(conn.get_inboxes())
@ -64,14 +60,14 @@ class HomeView(View):
@register_route('/login') @register_route('/login')
class Login(View): class Login(View):
async def get(self, request: Request) -> Response: async def get(self, request: web.Request) -> Response:
data = self.template.render('page/login.haml', self) data = self.template.render('page/login.haml', self)
return Response.new(data, ctype = 'html') return Response.new(data, ctype = 'html')
@register_route('/logout') @register_route('/logout')
class Logout(View): class Logout(View):
async def get(self, request: Request) -> Response: async def get(self, request: web.Request) -> Response:
with self.database.session(True) as conn: with self.database.session(True) as conn:
conn.del_token(request['token']) conn.del_token(request['token'])
@ -82,14 +78,14 @@ class Logout(View):
@register_route('/admin') @register_route('/admin')
class Admin(View): class Admin(View):
async def get(self, request: Request) -> Response: async def get(self, request: web.Request) -> Response:
return Response.new('', 302, {'Location': '/admin/instances'}) return Response.new('', 302, {'Location': '/admin/instances'})
@register_route('/admin/instances') @register_route('/admin/instances')
class AdminInstances(View): class AdminInstances(View):
async def get(self, async def get(self,
request: Request, request: web.Request,
error: str | None = None, error: str | None = None,
message: str | None = None) -> Response: message: str | None = None) -> Response:
@ -112,7 +108,7 @@ class AdminInstances(View):
@register_route('/admin/whitelist') @register_route('/admin/whitelist')
class AdminWhitelist(View): class AdminWhitelist(View):
async def get(self, async def get(self,
request: Request, request: web.Request,
error: str | None = None, error: str | None = None,
message: str | None = None) -> Response: message: str | None = None) -> Response:
@ -134,7 +130,7 @@ class AdminWhitelist(View):
@register_route('/admin/domain_bans') @register_route('/admin/domain_bans')
class AdminDomainBans(View): class AdminDomainBans(View):
async def get(self, async def get(self,
request: Request, request: web.Request,
error: str | None = None, error: str | None = None,
message: str | None = None) -> Response: message: str | None = None) -> Response:
@ -156,7 +152,7 @@ class AdminDomainBans(View):
@register_route('/admin/software_bans') @register_route('/admin/software_bans')
class AdminSoftwareBans(View): class AdminSoftwareBans(View):
async def get(self, async def get(self,
request: Request, request: web.Request,
error: str | None = None, error: str | None = None,
message: str | None = None) -> Response: message: str | None = None) -> Response:
@ -178,7 +174,7 @@ class AdminSoftwareBans(View):
@register_route('/admin/users') @register_route('/admin/users')
class AdminUsers(View): class AdminUsers(View):
async def get(self, async def get(self,
request: Request, request: web.Request,
error: str | None = None, error: str | None = None,
message: str | None = None) -> Response: message: str | None = None) -> Response:
@ -199,7 +195,7 @@ class AdminUsers(View):
@register_route('/admin/config') @register_route('/admin/config')
class AdminConfig(View): class AdminConfig(View):
async def get(self, request: Request, message: str | None = None) -> Response: async def get(self, request: web.Request, message: str | None = None) -> Response:
context: dict[str, Any] = { context: dict[str, Any] = {
'themes': tuple(THEMES.keys()), 'themes': tuple(THEMES.keys()),
'levels': tuple(level.name for level in LogLevel), 'levels': tuple(level.name for level in LogLevel),
@ -212,7 +208,7 @@ class AdminConfig(View):
@register_route('/manifest.json') @register_route('/manifest.json')
class ManifestJson(View): class ManifestJson(View):
async def get(self, request: Request) -> Response: async def get(self, request: web.Request) -> Response:
with self.database.session(False) as conn: with self.database.session(False) as conn:
config = conn.get_config_all() config = conn.get_config_all()
theme = THEMES[config.theme] theme = THEMES[config.theme]
@ -235,7 +231,7 @@ class ManifestJson(View):
@register_route('/theme/{theme}.css') @register_route('/theme/{theme}.css')
class ThemeCss(View): class ThemeCss(View):
async def get(self, request: Request, theme: str) -> Response: async def get(self, request: web.Request, theme: str) -> Response:
try: try:
context: dict[str, Any] = { context: dict[str, Any] = {
'theme': THEMES[theme] 'theme': THEMES[theme]

View file

@ -1,9 +1,7 @@
from __future__ import annotations
import aputils import aputils
import subprocess import subprocess
import typing
from aiohttp.web import Request
from pathlib import Path from pathlib import Path
from .base import View, register_route from .base import View, register_route
@ -11,9 +9,6 @@ from .base import View, register_route
from .. import __version__ from .. import __version__
from ..misc import Response from ..misc import Response
if typing.TYPE_CHECKING:
from aiohttp.web import Request
VERSION = __version__ VERSION = __version__