Compare commits

...

14 commits

Author SHA1 Message Date
Izalia Mae
be21b30d61 version bump to 0.3.4 2025-02-11 13:20:21 -05:00
Izalia Mae
e4eadb69db fix linter warnings in dev script 2025-02-11 13:01:48 -05:00
Izalia Mae
f741017ef1 skip missing optional values in old config 2025-02-11 12:55:02 -05:00
Izalia Mae
760718b5dc fix actor type checking (fixes #45) 2025-02-11 12:39:18 -05:00
Izalia Mae
527afaca95 properly get actor and inbox urls
* get actor url from incoming message
* fall back to actor inbox if there is no shared inbox
2025-02-11 12:28:05 -05:00
Izalia Mae
ff275a5ba4 fix linter warnings 2025-02-11 12:26:57 -05:00
Izalia Mae
3e6a2a4f37 fix logins 2025-02-11 12:25:16 -05:00
Izalia Mae
5a1af750d5 fix spacing issue on home page 2025-02-11 12:15:21 -05:00
Izalia Mae
ade36d6e69 add buzzrelay to RELAY_SOFTWARE list 2025-01-29 06:11:25 -05:00
Izalia Mae
e47e69b14d add docs dependencies 2024-12-05 22:56:17 -05:00
Izalia Mae
f41b24406f remove unnecessary SOFTWARE const 2024-11-28 09:17:11 -05:00
Izalia Mae
338fd26688 replace single quotes with double quotes 2024-11-28 07:09:53 -05:00
Izalia Mae
29ebba7999 add more classiifiers and py.typed file 2024-11-28 06:10:35 -05:00
Izalia Mae
5131831363 sort out pyproject and replace jinja2-haml with hamlish 2024-11-28 05:43:39 -05:00
27 changed files with 1215 additions and 1189 deletions

111
dev.py
View file

@ -37,32 +37,32 @@ from watchdog.events import FileSystemEvent, PatternMatchingEventHandler
REPO = Path(__file__).parent REPO = Path(__file__).parent
IGNORE_EXT = { IGNORE_EXT = {
'.py', ".py",
'.pyc' ".pyc"
} }
@click.group('cli') @click.group("cli")
def cli() -> None: def cli() -> None:
'Useful commands for development' "Useful commands for development"
@cli.command('install') @cli.command("install")
@click.option('--no-dev', '-d', is_flag = True, help = 'Do not install development dependencies') @click.option("--no-dev", "-d", is_flag = True, help = "Do not install development dependencies")
def cli_install(no_dev: bool) -> None: def cli_install(no_dev: bool) -> None:
with open('pyproject.toml', 'r', encoding = 'utf-8') as fd: with open("pyproject.toml", "r", encoding = "utf-8") as fd:
data = tomllib.loads(fd.read()) data = tomllib.loads(fd.read())
deps = data['project']['dependencies'] deps = data["project"]["dependencies"]
deps.extend(data['project']['optional-dependencies']['dev']) deps.extend(data["project"]["optional-dependencies"]["dev"])
subprocess.run([sys.executable, '-m', 'pip', 'install', '-U', *deps], check = False) subprocess.run([sys.executable, "-m", "pip", "install", "-U", *deps], check = False)
@cli.command('lint') @cli.command("lint")
@click.argument('path', required = False, type = Path, default = REPO.joinpath('relay')) @click.argument("path", required = False, type = Path, default = REPO.joinpath("relay"))
@click.option('--watch', '-w', is_flag = True, @click.option("--watch", "-w", is_flag = True,
help = 'Automatically, re-run the linters on source change') help = "Automatically, re-run the linters on source change")
def cli_lint(path: Path, watch: bool) -> None: def cli_lint(path: Path, watch: bool) -> None:
path = path.expanduser().resolve() path = path.expanduser().resolve()
@ -70,74 +70,73 @@ def cli_lint(path: Path, watch: bool) -> None:
handle_run_watcher([sys.executable, "dev.py", "lint", str(path)], wait = True) handle_run_watcher([sys.executable, "dev.py", "lint", str(path)], wait = True)
return return
flake8 = [sys.executable, '-m', 'flake8', "dev.py", str(path)] flake8 = [sys.executable, "-m", "flake8", "dev.py", str(path)]
mypy = [sys.executable, '-m', 'mypy', '--python-version', '3.12', 'dev.py', str(path)] mypy = [sys.executable, "-m", "mypy", "--python-version", "3.12", "dev.py", str(path)]
click.echo('----- flake8 -----') click.echo("----- flake8 -----")
subprocess.run(flake8) subprocess.run(flake8)
click.echo('\n\n----- mypy -----') click.echo("\n\n----- mypy -----")
subprocess.run(mypy) subprocess.run(mypy)
@cli.command('clean') @cli.command("clean")
def cli_clean() -> None: def cli_clean() -> None:
dirs = { dirs = {
'dist', "dist",
'build', "build",
'dist-pypi' "dist-pypi"
} }
for directory in dirs: for directory in dirs:
shutil.rmtree(directory, ignore_errors = True) shutil.rmtree(directory, ignore_errors = True)
for path in REPO.glob('*.egg-info'): for path in REPO.glob("*.egg-info"):
shutil.rmtree(path) shutil.rmtree(path)
for path in REPO.glob('*.spec'): for path in REPO.glob("*.spec"):
path.unlink() path.unlink()
@cli.command('build') @cli.command("build")
def cli_build() -> None: def cli_build() -> None:
from relay import __version__ from relay import __version__
with TemporaryDirectory() as tmp: with TemporaryDirectory() as tmp:
arch = 'amd64' if sys.maxsize >= 2**32 else 'i386' arch = "amd64" if sys.maxsize >= 2**32 else "i386"
cmd = [ cmd = [
sys.executable, '-m', 'PyInstaller', sys.executable, "-m", "PyInstaller",
'--collect-data', 'relay', "--collect-data", "relay",
'--collect-data', 'aiohttp_swagger', "--hidden-import", "pg8000",
'--hidden-import', 'pg8000', "--hidden-import", "sqlite3",
'--hidden-import', 'sqlite3', "--name", f"activityrelay-{__version__}-{platform.system().lower()}-{arch}",
'--name', f'activityrelay-{__version__}-{platform.system().lower()}-{arch}', "--workpath", tmp,
'--workpath', tmp, "--onefile", "relay/__main__.py",
'--onefile', 'relay/__main__.py',
] ]
if platform.system() == 'Windows': if platform.system() == "Windows":
cmd.append('--console') cmd.append("--console")
# putting the spec path on a different drive than the source dir breaks # putting the spec path on a different drive than the source dir breaks
if str(REPO)[0] == tmp[0]: if str(REPO)[0] == tmp[0]:
cmd.extend(['--specpath', tmp]) cmd.extend(["--specpath", tmp])
else: else:
cmd.append('--strip') cmd.append("--strip")
cmd.extend(['--specpath', tmp]) cmd.extend(["--specpath", tmp])
subprocess.run(cmd, check = False) subprocess.run(cmd, check = False)
@cli.command('run') @cli.command("run")
@click.option('--dev', '-d', is_flag = True) @click.option("--dev", "-d", is_flag = True)
def cli_run(dev: bool) -> None: def cli_run(dev: bool) -> None:
print('Starting process watcher') print("Starting process watcher")
cmd = [sys.executable, '-m', 'relay', 'run'] cmd = [sys.executable, "-m", "relay", "run"]
if dev: if dev:
cmd.append('-d') cmd.append("-d")
handle_run_watcher(cmd, watch_path = REPO.joinpath("relay")) handle_run_watcher(cmd, watch_path = REPO.joinpath("relay"))
@ -151,8 +150,8 @@ def handle_run_watcher(
handler.run_procs() handler.run_procs()
watcher = Observer() watcher = Observer()
watcher.schedule(handler, str(watch_path), recursive=True) # type: ignore watcher.schedule(handler, str(watch_path), recursive=True)
watcher.start() # type: ignore watcher.start()
try: try:
while True: while True:
@ -162,16 +161,16 @@ def handle_run_watcher(
pass pass
handler.kill_procs() handler.kill_procs()
watcher.stop() # type: ignore watcher.stop()
watcher.join() watcher.join()
class WatchHandler(PatternMatchingEventHandler): class WatchHandler(PatternMatchingEventHandler):
patterns = ['*.py'] patterns = ["*.py"]
def __init__(self, *commands: Sequence[str], wait: bool = False) -> None: def __init__(self, *commands: Sequence[str], wait: bool = False) -> None:
PatternMatchingEventHandler.__init__(self) # type: ignore PatternMatchingEventHandler.__init__(self)
self.commands: Sequence[Sequence[str]] = commands self.commands: Sequence[Sequence[str]] = commands
self.wait: bool = wait self.wait: bool = wait
@ -184,7 +183,7 @@ class WatchHandler(PatternMatchingEventHandler):
if proc.poll() is not None: if proc.poll() is not None:
continue continue
print(f'Terminating process {proc.pid}') print(f"Terminating process {proc.pid}")
proc.terminate() proc.terminate()
sec = 0.0 sec = 0.0
@ -193,11 +192,11 @@ class WatchHandler(PatternMatchingEventHandler):
sec += 0.1 sec += 0.1
if sec >= 5: if sec >= 5:
print('Failed to terminate. Killing process...') print("Failed to terminate. Killing process...")
proc.kill() proc.kill()
break break
print('Process terminated') print("Process terminated")
def run_procs(self, restart: bool = False) -> None: def run_procs(self, restart: bool = False) -> None:
@ -213,21 +212,21 @@ class WatchHandler(PatternMatchingEventHandler):
self.procs = [] self.procs = []
for cmd in self.commands: for cmd in self.commands:
print('Running command:', ' '.join(cmd)) print("Running command:", " ".join(cmd))
subprocess.run(cmd) subprocess.run(cmd)
else: else:
self.procs = list(subprocess.Popen(cmd) for cmd in self.commands) self.procs = list(subprocess.Popen(cmd) for cmd in self.commands)
pids = (str(proc.pid) for proc in self.procs) pids = (str(proc.pid) for proc in self.procs)
print('Started processes with PIDs:', ', '.join(pids)) print("Started processes with PIDs:", ", ".join(pids))
def on_any_event(self, event: FileSystemEvent) -> None: def on_any_event(self, event: FileSystemEvent) -> None:
if event.event_type not in ['modified', 'created', 'deleted']: if event.event_type not in ["modified", "created", "deleted"]:
return return
self.run_procs(restart = True) self.run_procs(restart = True)
if __name__ == '__main__': if __name__ == "__main__":
cli() cli()

View file

@ -1,37 +1,50 @@
[build-system] [build-system]
requires = ["setuptools>=61.2"] requires = [
"setuptools>=61.2",
]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
[project] [project]
name = "ActivityRelay" name = "ActivityRelay"
description = "Generic LitePub relay (works with all LitePub consumers and Mastodon)" description = "Generic LitePub relay (works with all LitePub consumers and Mastodon)"
license = {text = "AGPLv3"}
classifiers = [ classifiers = [
"Environment :: Console", "Development Status :: 4 - Beta",
"License :: OSI Approved :: GNU Affero General Public License v3", "Environment :: Console",
"Programming Language :: Python :: 3.10", "Framework :: aiohttp",
"Programming Language :: Python :: 3.11", "Framework :: AsyncIO",
"Programming Language :: Python :: 3.12" "License :: OSI Approved :: GNU Affero General Public License v3",
"Programming Language :: Python",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: SQL",
"Topic :: Internet :: WWW/HTTP :: HTTP Servers",
"Typing :: Typed",
] ]
dependencies = [ dependencies = [
"activitypub-utils >= 0.3.2, < 0.4", "activitypub-utils >= 0.3.2, < 0.4",
"aiohttp >= 3.9.5", "aiohttp >= 3.9.5",
"aiohttp-swagger[performance] == 1.0.16",
"argon2-cffi == 23.1.0", "argon2-cffi == 23.1.0",
"barkshark-lib >= 0.2.3, < 0.3.0", "barkshark-lib >= 0.2.3, < 0.3.0",
"barkshark-sql >= 0.2.0, < 0.3.0", "barkshark-sql >= 0.2.0, < 0.3.0",
"click == 8.1.2", "click == 8.1.2",
"docstring-parser == 0.16", "docstring-parser == 0.16",
"hamlish == 0.4.0",
"hiredis == 2.3.2", "hiredis == 2.3.2",
"idna == 3.4", "idna == 3.4",
"jinja2-haml == 0.3.5",
"markdown == 3.6", "markdown == 3.6",
"platformdirs == 4.2.2", "platformdirs == 4.2.2",
"pyyaml == 6.0.1", "pyyaml == 6.0.1",
"redis == 5.0.7" "redis == 5.0.7",
] ]
requires-python = ">=3.10" requires-python = ">=3.10"
dynamic = ["version"] dynamic = [
"version",
]
[project.license]
file = "LICENSE"
[project.readme] [project.readme]
file = "README.md" file = "README.md"
@ -40,42 +53,46 @@ content-type = "text/markdown; charset=UTF-8"
[project.urls] [project.urls]
Documentation = "https://git.pleroma.social/pleroma/relay/-/blob/main/docs/index.md" Documentation = "https://git.pleroma.social/pleroma/relay/-/blob/main/docs/index.md"
Source = "https://git.pleroma.social/pleroma/relay" Source = "https://git.pleroma.social/pleroma/relay"
Tracker = "https://git.pleroma.social/pleroma/relay/-/issues"
[project.scripts] [project.scripts]
activityrelay = "relay.manage:main" activityrelay = "relay.manage:main"
[project.optional-dependencies] [project.optional-dependencies]
dev = [ dev = [
"flake8 == 7.1.0", "build == 1.2.2.post1",
"mypy == 1.11.1", "flake8 == 7.1.1",
"mypy == 1.13.0",
"pyinstaller == 6.10.0", "pyinstaller == 6.10.0",
"watchdog == 4.0.2" ]
docs = [
"furo == 2024.1.29",
"sphinx == 7.2.6",
"sphinx-external-toc == 1.0.1",
] ]
[tool.setuptools] [tool.setuptools]
zip-safe = false zip-safe = false
packages = [ packages = [
"relay", "relay",
"relay.database", "relay.database",
"relay.views", "relay.views",
] ]
include-package-data = true include-package-data = true
license-files = ["LICENSE"] license-files = [
"LICENSE",
]
[tool.setuptools.package-data] [tool.setuptools.package-data]
relay = [ relay = [
"data/*", "py.typed",
"frontend/*", "data/*",
"frontend/page/*", "frontend/*",
"frontend/static/*" "frontend/page/*",
"frontend/static/*",
] ]
[tool.setuptools.dynamic] [tool.setuptools.dynamic.version]
version = {attr = "relay.__version__"} attr = "relay.__version__"
[tool.setuptools.dynamic.optional-dependencies]
dev = {file = ["dev-requirements.txt"]}
[tool.mypy] [tool.mypy]
show_traceback = true show_traceback = true
@ -89,15 +106,3 @@ ignore_missing_imports = true
implicit_reexport = true implicit_reexport = true
strict = true strict = true
follow_imports = "silent" follow_imports = "silent"
[[tool.mypy.overrides]]
module = "relay.database"
implicit_reexport = true
[[tool.mypy.overrides]]
module = "aputils"
implicit_reexport = true
[[tool.mypy.overrides]]
module = "blib"
implicit_reexport = true

View file

@ -1 +1 @@
__version__ = '0.3.3' __version__ = "0.3.4"

View file

@ -1,8 +1,5 @@
import multiprocessing
from relay.manage import main from relay.manage import main
if __name__ == '__main__': if __name__ == "__main__":
multiprocessing.freeze_support()
main() main()

View file

@ -41,10 +41,10 @@ def get_csp(request: web.Request) -> str:
"img-src 'self'", "img-src 'self'",
"object-src 'none'", "object-src 'none'",
"frame-ancestors 'none'", "frame-ancestors 'none'",
f"manifest-src 'self' https://{request.app['config'].domain}" f"manifest-src 'self' https://{request.app["config"].domain}"
] ]
return '; '.join(data) + ';' return "; ".join(data) + ";"
class Application(web.Application): class Application(web.Application):
@ -61,20 +61,20 @@ class Application(web.Application):
Application.DEFAULT = self Application.DEFAULT = self
self['running'] = False self["running"] = False
self['signer'] = None self["signer"] = None
self['start_time'] = None self["start_time"] = None
self['cleanup_thread'] = None self["cleanup_thread"] = None
self['dev'] = dev self["dev"] = dev
self['config'] = Config(cfgpath, load = True) self["config"] = Config(cfgpath, load = True)
self['database'] = get_database(self.config) self["database"] = get_database(self.config)
self['client'] = HttpClient() self["client"] = HttpClient()
self['cache'] = get_cache(self) self["cache"] = get_cache(self)
self['cache'].setup() self["cache"].setup()
self['template'] = Template(self) self["template"] = Template(self)
self['push_queue'] = multiprocessing.Queue() self["push_queue"] = multiprocessing.Queue()
self['workers'] = PushWorkers(self.config.workers) self["workers"] = PushWorkers(self.config.workers)
self.cache.setup() self.cache.setup()
self.on_cleanup.append(handle_cleanup) # type: ignore self.on_cleanup.append(handle_cleanup) # type: ignore
@ -82,69 +82,69 @@ class Application(web.Application):
@property @property
def cache(self) -> Cache: def cache(self) -> Cache:
return cast(Cache, self['cache']) return cast(Cache, self["cache"])
@property @property
def client(self) -> HttpClient: def client(self) -> HttpClient:
return cast(HttpClient, self['client']) return cast(HttpClient, self["client"])
@property @property
def config(self) -> Config: def config(self) -> Config:
return cast(Config, self['config']) return cast(Config, self["config"])
@property @property
def database(self) -> Database[Connection]: def database(self) -> Database[Connection]:
return cast(Database[Connection], self['database']) return cast(Database[Connection], self["database"])
@property @property
def signer(self) -> Signer: def signer(self) -> Signer:
return cast(Signer, self['signer']) return cast(Signer, self["signer"])
@signer.setter @signer.setter
def signer(self, value: Signer | str) -> None: def signer(self, value: Signer | str) -> None:
if isinstance(value, Signer): if isinstance(value, Signer):
self['signer'] = value self["signer"] = value
return return
self['signer'] = Signer(value, self.config.keyid) self["signer"] = Signer(value, self.config.keyid)
@property @property
def template(self) -> Template: def template(self) -> Template:
return cast(Template, self['template']) return cast(Template, self["template"])
@property @property
def uptime(self) -> timedelta: def uptime(self) -> timedelta:
if not self['start_time']: if not self["start_time"]:
return timedelta(seconds=0) return timedelta(seconds=0)
uptime = datetime.now() - self['start_time'] uptime = datetime.now() - self["start_time"]
return timedelta(seconds=uptime.seconds) return timedelta(seconds=uptime.seconds)
@property @property
def workers(self) -> PushWorkers: def workers(self) -> PushWorkers:
return cast(PushWorkers, self['workers']) return cast(PushWorkers, self["workers"])
def push_message(self, inbox: str, message: Message, instance: Instance) -> None: def push_message(self, inbox: str, message: Message, instance: Instance) -> None:
self['workers'].push_message(inbox, message, instance) self["workers"].push_message(inbox, message, instance)
def register_static_routes(self) -> None: def register_static_routes(self) -> None:
if self['dev']: if self["dev"]:
static = StaticResource('/static', File.from_resource('relay', 'frontend/static')) static = StaticResource("/static", File.from_resource("relay", "frontend/static"))
else: else:
static = CachedStaticResource( static = CachedStaticResource(
'/static', Path(File.from_resource('relay', 'frontend/static')) "/static", Path(File.from_resource("relay", "frontend/static"))
) )
self.router.register_resource(static) self.router.register_resource(static)
@ -158,18 +158,18 @@ class Application(web.Application):
host = self.config.listen host = self.config.listen
port = self.config.port port = self.config.port
if port_check(port, '127.0.0.1' if host == '0.0.0.0' else host): if port_check(port, "127.0.0.1" if host == "0.0.0.0" else host):
logging.error(f'A server is already running on {host}:{port}') logging.error(f"A server is already running on {host}:{port}")
return return
self.register_static_routes() self.register_static_routes()
logging.info(f'Starting webserver at {domain} ({host}:{port})') logging.info(f"Starting webserver at {domain} ({host}:{port})")
asyncio.run(self.handle_run()) asyncio.run(self.handle_run())
def set_signal_handler(self, startup: bool) -> None: def set_signal_handler(self, startup: bool) -> None:
for sig in ('SIGHUP', 'SIGINT', 'SIGQUIT', 'SIGTERM'): for sig in ("SIGHUP", "SIGINT", "SIGQUIT", "SIGTERM"):
try: try:
signal.signal(getattr(signal, sig), self.stop if startup else signal.SIG_DFL) signal.signal(getattr(signal, sig), self.stop if startup else signal.SIG_DFL)
@ -179,22 +179,25 @@ class Application(web.Application):
def stop(self, *_: Any) -> None: def stop(self, *_: Any) -> None:
self['running'] = False self["running"] = False
async def handle_run(self) -> None: async def handle_run(self) -> None:
self['running'] = True self["running"] = True
self.set_signal_handler(True) self.set_signal_handler(True)
self['client'].open() self["client"].open()
self['database'].connect() self["database"].connect()
self['cache'].setup() self["cache"].setup()
self['cleanup_thread'] = CacheCleanupThread(self) self["cleanup_thread"] = CacheCleanupThread(self)
self['cleanup_thread'].start() self["cleanup_thread"].start()
self['workers'].start() self["workers"].start()
runner = web.AppRunner(
self, access_log_format = "%{X-Forwarded-For}i \"%r\" %s %b \"%{User-Agent}i\""
)
runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"')
await runner.setup() await runner.setup()
site = web.TCPSite( site = web.TCPSite(
@ -205,22 +208,22 @@ class Application(web.Application):
) )
await site.start() await site.start()
self['starttime'] = datetime.now() self["starttime"] = datetime.now()
while self['running']: while self["running"]:
await asyncio.sleep(0.25) await asyncio.sleep(0.25)
await site.stop() await site.stop()
self['workers'].stop() self["workers"].stop()
self.set_signal_handler(False) self.set_signal_handler(False)
self['starttime'] = None self["starttime"] = None
self['running'] = False self["running"] = False
self['cleanup_thread'].stop() self["cleanup_thread"].stop()
self['database'].disconnect() self["database"].disconnect()
self['cache'].close() self["cache"].close()
class CachedStaticResource(StaticResource): class CachedStaticResource(StaticResource):
@ -229,19 +232,19 @@ class CachedStaticResource(StaticResource):
self.cache: dict[str, bytes] = {} self.cache: dict[str, bytes] = {}
for filename in path.rglob('*'): for filename in path.rglob("*"):
if filename.is_dir(): if filename.is_dir():
continue continue
rel_path = str(filename.relative_to(path)) rel_path = str(filename.relative_to(path))
with filename.open('rb') as fd: with filename.open("rb") as fd:
logging.debug('Loading static resource "%s"', rel_path) logging.debug("Loading static resource \"%s\"", rel_path)
self.cache[rel_path] = fd.read() self.cache[rel_path] = fd.read()
async def _handle(self, request: web.Request) -> web.StreamResponse: async def _handle(self, request: web.Request) -> web.StreamResponse:
rel_url = request.match_info['filename'] rel_url = request.match_info["filename"]
if Path(rel_url).anchor: if Path(rel_url).anchor:
raise web.HTTPForbidden() raise web.HTTPForbidden()
@ -281,8 +284,8 @@ class CacheCleanupThread(Thread):
def format_error(request: web.Request, error: HttpError) -> Response: def format_error(request: web.Request, error: HttpError) -> Response:
if request.path.startswith(JSON_PATHS) or 'json' in request.headers.get('accept', ''): if request.path.startswith(JSON_PATHS) or "json" in request.headers.get("accept", ""):
return Response.new({'error': error.message}, error.status, ctype = 'json') return Response.new({"error": error.message}, error.status, ctype = "json")
else: else:
context = {"e": error} context = {"e": error}
@ -294,27 +297,27 @@ async def handle_response_headers(
request: web.Request, request: web.Request,
handler: Callable[[web.Request], Awaitable[Response]]) -> Response: handler: Callable[[web.Request], Awaitable[Response]]) -> Response:
request['hash'] = b64encode(get_random_bytes(16)).decode('ascii') request["hash"] = b64encode(get_random_bytes(16)).decode("ascii")
request['token'] = None request["token"] = None
request['user'] = None request["user"] = None
app: Application = request.app # type: ignore[assignment] app: Application = request.app # type: ignore[assignment]
if request.path in {"/", "/docs"} or request.path.startswith(TOKEN_PATHS): if request.path in {"/", "/docs"} or request.path.startswith(TOKEN_PATHS):
with app.database.session() as conn: with app.database.session() as conn:
tokens = ( tokens = (
request.headers.get('Authorization', '').replace('Bearer', '').strip(), request.headers.get("Authorization", "").replace("Bearer", "").strip(),
request.cookies.get('user-token') request.cookies.get("user-token")
) )
for token in tokens: for token in tokens:
if not token: if not token:
continue continue
request['token'] = conn.get_app_by_token(token) request["token"] = conn.get_app_by_token(token)
if request['token'] is not None: if request["token"] is not None:
request['user'] = conn.get_user(request['token'].user) request["user"] = conn.get_user(request["token"].user)
break break
@ -338,21 +341,21 @@ async def handle_response_headers(
raise raise
except Exception: except Exception:
resp = format_error(request, HttpError(500, 'Internal server error')) resp = format_error(request, HttpError(500, "Internal server error"))
traceback.print_exc() traceback.print_exc()
resp.headers['Server'] = 'ActivityRelay' resp.headers["Server"] = "ActivityRelay"
# Still have to figure out how csp headers work # Still have to figure out how csp headers work
if resp.content_type == 'text/html' and not request.path.startswith("/api"): if resp.content_type == "text/html" and not request.path.startswith("/api"):
resp.headers['Content-Security-Policy'] = get_csp(request) resp.headers["Content-Security-Policy"] = get_csp(request)
if not request.app['dev'] and request.path.endswith(('.css', '.js', '.woff2')): if not request.app["dev"] and request.path.endswith((".css", ".js", ".woff2")):
# cache for 2 weeks # cache for 2 weeks
resp.headers['Cache-Control'] = 'public,max-age=1209600,immutable' resp.headers["Cache-Control"] = "public,max-age=1209600,immutable"
else: else:
resp.headers['Cache-Control'] = 'no-store' resp.headers["Cache-Control"] = "no-store"
return resp return resp
@ -362,25 +365,25 @@ async def handle_frontend_path(
request: web.Request, request: web.Request,
handler: Callable[[web.Request], Awaitable[Response]]) -> Response: handler: Callable[[web.Request], Awaitable[Response]]) -> Response:
if request['user'] is not None and request.path == '/login': if request["user"] is not None and request.path == "/login":
return Response.new_redir('/') return Response.new_redir("/")
if request.path.startswith(TOKEN_PATHS[:2]) and request['user'] is None: if request.path.startswith(TOKEN_PATHS[:2]) and request["user"] is None:
if request.path == '/logout': if request.path == "/logout":
return Response.new_redir('/') return Response.new_redir("/")
response = Response.new_redir(f'/login?redir={request.path}') response = Response.new_redir(f"/login?redir={request.path}")
if request['token'] is not None: if request["token"] is not None:
response.del_cookie('user-token') response.del_cookie("user-token")
return response return response
response = await handler(request) response = await handler(request)
if not request.path.startswith('/api'): if not request.path.startswith("/api"):
if request['user'] is None and request['token'] is not None: if request["user"] is None and request["token"] is not None:
response.del_cookie('user-token') response.del_cookie("user-token")
return response return response

View file

@ -24,11 +24,11 @@ DeserializerCallback = Callable[[str], Any]
BACKENDS: dict[str, type[Cache]] = {} BACKENDS: dict[str, type[Cache]] = {}
CONVERTERS: dict[str, tuple[SerializerCallback, DeserializerCallback]] = { CONVERTERS: dict[str, tuple[SerializerCallback, DeserializerCallback]] = {
'str': (str, str), "str": (str, str),
'int': (str, int), "int": (str, int),
'bool': (str, convert_to_boolean), "bool": (str, convert_to_boolean),
'json': (json.dumps, json.loads), "json": (json.dumps, json.loads),
'message': (lambda x: x.to_json(), Message.parse) "message": (lambda x: x.to_json(), Message.parse)
} }
@ -49,14 +49,14 @@ def register_cache(backend: type[Cache]) -> type[Cache]:
return backend return backend
def serialize_value(value: Any, value_type: str = 'str') -> str: def serialize_value(value: Any, value_type: str = "str") -> str:
if isinstance(value, str): if isinstance(value, str):
return value return value
return CONVERTERS[value_type][0](value) return CONVERTERS[value_type][0](value)
def deserialize_value(value: str, value_type: str = 'str') -> Any: def deserialize_value(value: str, value_type: str = "str") -> Any:
return CONVERTERS[value_type][1](value) return CONVERTERS[value_type][1](value)
@ -93,7 +93,7 @@ class Item:
class Cache(ABC): class Cache(ABC):
name: str = 'null' name: str
def __init__(self, app: Application): def __init__(self, app: Application):
@ -116,7 +116,7 @@ class Cache(ABC):
@abstractmethod @abstractmethod
def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> Item: def set(self, namespace: str, key: str, value: Any, value_type: str = "key") -> Item:
... ...
@ -160,7 +160,7 @@ class Cache(ABC):
@register_cache @register_cache
class SqlCache(Cache): class SqlCache(Cache):
name: str = 'database' name: str = "database"
def __init__(self, app: Application): def __init__(self, app: Application):
@ -173,16 +173,16 @@ class SqlCache(Cache):
raise RuntimeError("Database has not been setup") raise RuntimeError("Database has not been setup")
params = { params = {
'namespace': namespace, "namespace": namespace,
'key': key "key": key
} }
with self._db.session(False) as conn: with self._db.session(False) as conn:
with conn.run('get-cache-item', params) as cur: with conn.run("get-cache-item", params) as cur:
if not (row := cur.one(Row)): if not (row := cur.one(Row)):
raise KeyError(f'{namespace}:{key}') raise KeyError(f"{namespace}:{key}")
row.pop('id', None) row.pop("id", None)
return Item.from_data(*tuple(row.values())) return Item.from_data(*tuple(row.values()))
@ -191,8 +191,8 @@ class SqlCache(Cache):
raise RuntimeError("Database has not been setup") raise RuntimeError("Database has not been setup")
with self._db.session(False) as conn: with self._db.session(False) as conn:
for row in conn.run('get-cache-keys', {'namespace': namespace}): for row in conn.run("get-cache-keys", {"namespace": namespace}):
yield row['key'] yield row["key"]
def get_namespaces(self) -> Iterator[str]: def get_namespaces(self) -> Iterator[str]:
@ -200,28 +200,28 @@ class SqlCache(Cache):
raise RuntimeError("Database has not been setup") raise RuntimeError("Database has not been setup")
with self._db.session(False) as conn: with self._db.session(False) as conn:
for row in conn.run('get-cache-namespaces', None): for row in conn.run("get-cache-namespaces", None):
yield row['namespace'] yield row["namespace"]
def set(self, namespace: str, key: str, value: Any, value_type: str = 'str') -> Item: def set(self, namespace: str, key: str, value: Any, value_type: str = "str") -> Item:
if self._db is None: if self._db is None:
raise RuntimeError("Database has not been setup") raise RuntimeError("Database has not been setup")
params = { params = {
'namespace': namespace, "namespace": namespace,
'key': key, "key": key,
'value': serialize_value(value, value_type), "value": serialize_value(value, value_type),
'type': value_type, "type": value_type,
'date': Date.new_utc() "date": Date.new_utc()
} }
with self._db.session(True) as conn: with self._db.session(True) as conn:
with conn.run('set-cache-item', params) as cur: with conn.run("set-cache-item", params) as cur:
if (row := cur.one(Row)) is None: if (row := cur.one(Row)) is None:
raise RuntimeError("Cache item not set") raise RuntimeError("Cache item not set")
row.pop('id', None) row.pop("id", None)
return Item.from_data(*tuple(row.values())) return Item.from_data(*tuple(row.values()))
@ -230,12 +230,12 @@ class SqlCache(Cache):
raise RuntimeError("Database has not been setup") raise RuntimeError("Database has not been setup")
params = { params = {
'namespace': namespace, "namespace": namespace,
'key': key "key": key
} }
with self._db.session(True) as conn: with self._db.session(True) as conn:
with conn.run('del-cache-item', params): with conn.run("del-cache-item", params):
pass pass
@ -267,7 +267,7 @@ class SqlCache(Cache):
self._db.connect() self._db.connect()
with self._db.session(True) as conn: with self._db.session(True) as conn:
with conn.run(f'create-cache-table-{self._db.backend_type.value}', None): with conn.run(f"create-cache-table-{self._db.backend_type.value}", None):
pass pass
@ -281,7 +281,7 @@ class SqlCache(Cache):
@register_cache @register_cache
class RedisCache(Cache): class RedisCache(Cache):
name: str = 'redis' name: str = "redis"
def __init__(self, app: Application): def __init__(self, app: Application):
@ -295,7 +295,7 @@ class RedisCache(Cache):
def get_key_name(self, namespace: str, key: str) -> str: def get_key_name(self, namespace: str, key: str) -> str:
return f'{self.prefix}:{namespace}:{key}' return f"{self.prefix}:{namespace}:{key}"
def get(self, namespace: str, key: str) -> Item: def get(self, namespace: str, key: str) -> Item:
@ -305,9 +305,9 @@ class RedisCache(Cache):
key_name = self.get_key_name(namespace, key) key_name = self.get_key_name(namespace, key)
if not (raw_value := self._rd.get(key_name)): if not (raw_value := self._rd.get(key_name)):
raise KeyError(f'{namespace}:{key}') raise KeyError(f"{namespace}:{key}")
value_type, updated, value = raw_value.split(':', 2) # type: ignore[union-attr] value_type, updated, value = raw_value.split(":", 2) # type: ignore[union-attr]
return Item.from_data( return Item.from_data(
namespace, namespace,
@ -322,8 +322,8 @@ class RedisCache(Cache):
if self._rd is None: if self._rd is None:
raise ConnectionError("Not connected") raise ConnectionError("Not connected")
for key in self._rd.scan_iter(self.get_key_name(namespace, '*')): for key in self._rd.scan_iter(self.get_key_name(namespace, "*")):
*_, key_name = key.split(':', 2) *_, key_name = key.split(":", 2)
yield key_name yield key_name
@ -333,15 +333,15 @@ class RedisCache(Cache):
namespaces = [] namespaces = []
for key in self._rd.scan_iter(f'{self.prefix}:*'): for key in self._rd.scan_iter(f"{self.prefix}:*"):
_, namespace, _ = key.split(':', 2) _, namespace, _ = key.split(":", 2)
if namespace not in namespaces: if namespace not in namespaces:
namespaces.append(namespace) namespaces.append(namespace)
yield namespace yield namespace
def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> Item: def set(self, namespace: str, key: str, value: Any, value_type: str = "key") -> Item:
if self._rd is None: if self._rd is None:
raise ConnectionError("Not connected") raise ConnectionError("Not connected")
@ -350,7 +350,7 @@ class RedisCache(Cache):
self._rd.set( self._rd.set(
self.get_key_name(namespace, key), self.get_key_name(namespace, key),
f'{value_type}:{date}:{value}' f"{value_type}:{date}:{value}"
) )
return self.get(namespace, key) return self.get(namespace, key)
@ -369,8 +369,8 @@ class RedisCache(Cache):
limit = Date.new_utc() - timedelta(days = days) limit = Date.new_utc() - timedelta(days = days)
for full_key in self._rd.scan_iter(f'{self.prefix}:*'): for full_key in self._rd.scan_iter(f"{self.prefix}:*"):
_, namespace, key = full_key.split(':', 2) _, namespace, key = full_key.split(":", 2)
item = self.get(namespace, key) item = self.get(namespace, key)
if item.updated < limit: if item.updated < limit:
@ -389,11 +389,11 @@ class RedisCache(Cache):
return return
options: RedisConnectType = { options: RedisConnectType = {
'client_name': f'ActivityRelay_{self.app.config.domain}', "client_name": f"ActivityRelay_{self.app.config.domain}",
'decode_responses': True, "decode_responses": True,
'username': self.app.config.rd_user, "username": self.app.config.rd_user,
'password': self.app.config.rd_pass, "password": self.app.config.rd_pass,
'db': self.app.config.rd_database "db": self.app.config.rd_database
} }
if os.path.exists(self.app.config.rd_host): if os.path.exists(self.app.config.rd_host):

View file

@ -14,21 +14,21 @@ class RelayConfig(dict[str, Any]):
dict.__init__(self, {}) dict.__init__(self, {})
if self.is_docker: if self.is_docker:
path = '/data/config.yaml' path = "/data/config.yaml"
self._path = Path(path).expanduser().resolve() self._path = Path(path).expanduser().resolve()
self.reset() self.reset()
def __setitem__(self, key: str, value: Any) -> None: def __setitem__(self, key: str, value: Any) -> None:
if key in {'blocked_instances', 'blocked_software', 'whitelist'}: if key in {"blocked_instances", "blocked_software", "whitelist"}:
assert isinstance(value, (list, set, tuple)) assert isinstance(value, (list, set, tuple))
elif key in {'port', 'workers', 'json_cache', 'timeout'}: elif key in {"port", "workers", "json_cache", "timeout"}:
if not isinstance(value, int): if not isinstance(value, int):
value = int(value) value = int(value)
elif key == 'whitelist_enabled': elif key == "whitelist_enabled":
if not isinstance(value, bool): if not isinstance(value, bool):
value = convert_to_boolean(value) value = convert_to_boolean(value)
@ -37,45 +37,45 @@ class RelayConfig(dict[str, Any]):
@property @property
def db(self) -> Path: def db(self) -> Path:
return Path(self['db']).expanduser().resolve() return Path(self["db"]).expanduser().resolve()
@property @property
def actor(self) -> str: def actor(self) -> str:
return f'https://{self["host"]}/actor' return f"https://{self['host']}/actor"
@property @property
def inbox(self) -> str: def inbox(self) -> str:
return f'https://{self["host"]}/inbox' return f"https://{self['host']}/inbox"
@property @property
def keyid(self) -> str: def keyid(self) -> str:
return f'{self.actor}#main-key' return f"{self.actor}#main-key"
@cached_property @cached_property
def is_docker(self) -> bool: def is_docker(self) -> bool:
return bool(os.environ.get('DOCKER_RUNNING')) return bool(os.environ.get("DOCKER_RUNNING"))
def reset(self) -> None: def reset(self) -> None:
self.clear() self.clear()
self.update({ self.update({
'db': str(self._path.parent.joinpath(f'{self._path.stem}.jsonld')), "db": str(self._path.parent.joinpath(f"{self._path.stem}.jsonld")),
'listen': '0.0.0.0', "listen": "0.0.0.0",
'port': 8080, "port": 8080,
'note': 'Make a note about your instance here.', "note": "Make a note about your instance here.",
'push_limit': 512, "push_limit": 512,
'json_cache': 1024, "json_cache": 1024,
'timeout': 10, "timeout": 10,
'workers': 0, "workers": 0,
'host': 'relay.example.com', "host": "relay.example.com",
'whitelist_enabled': False, "whitelist_enabled": False,
'blocked_software': [], "blocked_software": [],
'blocked_instances': [], "blocked_instances": [],
'whitelist': [] "whitelist": []
}) })
@ -85,13 +85,13 @@ class RelayConfig(dict[str, Any]):
options = {} options = {}
try: try:
options['Loader'] = yaml.FullLoader options["Loader"] = yaml.FullLoader
except AttributeError: except AttributeError:
pass pass
try: try:
with self._path.open('r', encoding = 'UTF-8') as fd: with self._path.open("r", encoding = "UTF-8") as fd:
config = yaml.load(fd, **options) config = yaml.load(fd, **options)
except FileNotFoundError: except FileNotFoundError:
@ -101,7 +101,7 @@ class RelayConfig(dict[str, Any]):
return return
for key, value in config.items(): for key, value in config.items():
if key == 'ap': if key == "ap":
for k, v in value.items(): for k, v in value.items():
if k not in self: if k not in self:
continue continue
@ -119,10 +119,10 @@ class RelayConfig(dict[str, Any]):
class RelayDatabase(dict[str, Any]): class RelayDatabase(dict[str, Any]):
def __init__(self, config: RelayConfig): def __init__(self, config: RelayConfig):
dict.__init__(self, { dict.__init__(self, {
'relay-list': {}, "relay-list": {},
'private-key': None, "private-key": None,
'follow-requests': {}, "follow-requests": {},
'version': 1 "version": 1
}) })
self.config = config self.config = config
@ -131,12 +131,12 @@ class RelayDatabase(dict[str, Any]):
@property @property
def hostnames(self) -> tuple[str]: def hostnames(self) -> tuple[str]:
return tuple(self['relay-list'].keys()) return tuple(self["relay-list"].keys())
@property @property
def inboxes(self) -> tuple[dict[str, str]]: def inboxes(self) -> tuple[dict[str, str]]:
return tuple(data['inbox'] for data in self['relay-list'].values()) return tuple(data["inbox"] for data in self["relay-list"].values())
def load(self) -> None: def load(self) -> None:
@ -144,29 +144,29 @@ class RelayDatabase(dict[str, Any]):
with self.config.db.open() as fd: with self.config.db.open() as fd:
data = json.load(fd) data = json.load(fd)
self['version'] = data.get('version', None) self["version"] = data.get("version", None)
self['private-key'] = data.get('private-key') self["private-key"] = data.get("private-key")
if self['version'] is None: if self["version"] is None:
self['version'] = 1 self["version"] = 1
if 'actorKeys' in data: if "actorKeys" in data:
self['private-key'] = data['actorKeys']['privateKey'] self["private-key"] = data["actorKeys"]["privateKey"]
for item in data.get('relay-list', []): for item in data.get("relay-list", []):
domain = urlparse(item).hostname domain = urlparse(item).hostname
self['relay-list'][domain] = { self["relay-list"][domain] = {
'domain': domain, "domain": domain,
'inbox': item, "inbox": item,
'followid': None "followid": None
} }
else: else:
self['relay-list'] = data.get('relay-list', {}) self["relay-list"] = data.get("relay-list", {})
for domain, instance in self['relay-list'].items(): for domain, instance in self["relay-list"].items():
if not instance.get('domain'): if not instance.get("domain"):
instance['domain'] = domain instance["domain"] = domain
except FileNotFoundError: except FileNotFoundError:
pass pass

View file

@ -16,7 +16,7 @@ if TYPE_CHECKING:
from typing import Self from typing import Self
if platform.system() == 'Windows': if platform.system() == "Windows":
import multiprocessing import multiprocessing
CORE_COUNT = multiprocessing.cpu_count() CORE_COUNT = multiprocessing.cpu_count()
@ -25,9 +25,9 @@ else:
DOCKER_VALUES = { DOCKER_VALUES = {
'listen': '0.0.0.0', "listen": "0.0.0.0",
'port': 8080, "port": 8080,
'sq_path': '/data/relay.sqlite3' "sq_path": "/data/relay.sqlite3"
} }
@ -37,26 +37,26 @@ class NOVALUE:
@dataclass(init = False) @dataclass(init = False)
class Config: class Config:
listen: str = '0.0.0.0' listen: str = "0.0.0.0"
port: int = 8080 port: int = 8080
domain: str = 'relay.example.com' domain: str = "relay.example.com"
workers: int = CORE_COUNT workers: int = CORE_COUNT
db_type: str = 'sqlite' db_type: str = "sqlite"
ca_type: str = 'database' ca_type: str = "database"
sq_path: str = 'relay.sqlite3' sq_path: str = "relay.sqlite3"
pg_host: str = '/var/run/postgresql' pg_host: str = "/var/run/postgresql"
pg_port: int = 5432 pg_port: int = 5432
pg_user: str = getpass.getuser() pg_user: str = getpass.getuser()
pg_pass: str | None = None pg_pass: str | None = None
pg_name: str = 'activityrelay' pg_name: str = "activityrelay"
rd_host: str = 'localhost' rd_host: str = "localhost"
rd_port: int = 6470 rd_port: int = 6470
rd_user: str | None = None rd_user: str | None = None
rd_pass: str | None = None rd_pass: str | None = None
rd_database: int = 0 rd_database: int = 0
rd_prefix: str = 'activityrelay' rd_prefix: str = "activityrelay"
def __init__(self, path: Path | None = None, load: bool = False): def __init__(self, path: Path | None = None, load: bool = False):
@ -116,17 +116,17 @@ class Config:
@property @property
def actor(self) -> str: def actor(self) -> str:
return f'https://{self.domain}/actor' return f"https://{self.domain}/actor"
@property @property
def inbox(self) -> str: def inbox(self) -> str:
return f'https://{self.domain}/inbox' return f"https://{self.domain}/inbox"
@property @property
def keyid(self) -> str: def keyid(self) -> str:
return f'{self.actor}#main-key' return f"{self.actor}#main-key"
def load(self) -> None: def load(self) -> None:
@ -134,43 +134,43 @@ class Config:
options = {} options = {}
try: try:
options['Loader'] = yaml.FullLoader options["Loader"] = yaml.FullLoader
except AttributeError: except AttributeError:
pass pass
with self.path.open('r', encoding = 'UTF-8') as fd: with self.path.open("r", encoding = "UTF-8") as fd:
config = yaml.load(fd, **options) config = yaml.load(fd, **options)
if not config: if not config:
raise ValueError('Config is empty') raise ValueError("Config is empty")
pgcfg = config.get('postgres', {}) pgcfg = config.get("postgres", {})
rdcfg = config.get('redis', {}) rdcfg = config.get("redis", {})
for key in type(self).KEYS(): for key in type(self).KEYS():
if IS_DOCKER and key in {'listen', 'port', 'sq_path'}: if IS_DOCKER and key in {"listen", "port", "sq_path"}:
self.set(key, DOCKER_VALUES[key]) self.set(key, DOCKER_VALUES[key])
continue continue
if key.startswith('pg'): if key.startswith("pg"):
self.set(key, pgcfg.get(key[3:], NOVALUE)) self.set(key, pgcfg.get(key[3:], NOVALUE))
continue continue
elif key.startswith('rd'): elif key.startswith("rd"):
self.set(key, rdcfg.get(key[3:], NOVALUE)) self.set(key, rdcfg.get(key[3:], NOVALUE))
continue continue
cfgkey = key cfgkey = key
if key == 'db_type': if key == "db_type":
cfgkey = 'database_type' cfgkey = "database_type"
elif key == 'ca_type': elif key == "ca_type":
cfgkey = 'cache_type' cfgkey = "cache_type"
elif key == 'sq_path': elif key == "sq_path":
cfgkey = 'sqlite_path' cfgkey = "sqlite_path"
self.set(key, config.get(cfgkey, NOVALUE)) self.set(key, config.get(cfgkey, NOVALUE))
@ -186,32 +186,32 @@ class Config:
data: dict[str, Any] = {} data: dict[str, Any] = {}
for key, value in asdict(self).items(): for key, value in asdict(self).items():
if key.startswith('pg_'): if key.startswith("pg_"):
if 'postgres' not in data: if "postgres" not in data:
data['postgres'] = {} data["postgres"] = {}
data['postgres'][key[3:]] = value data["postgres"][key[3:]] = value
continue continue
if key.startswith('rd_'): if key.startswith("rd_"):
if 'redis' not in data: if "redis" not in data:
data['redis'] = {} data["redis"] = {}
data['redis'][key[3:]] = value data["redis"][key[3:]] = value
continue continue
if key == 'db_type': if key == "db_type":
key = 'database_type' key = "database_type"
elif key == 'ca_type': elif key == "ca_type":
key = 'cache_type' key = "cache_type"
elif key == 'sq_path': elif key == "sq_path":
key = 'sqlite_path' key = "sqlite_path"
data[key] = value data[key] = value
with self.path.open('w', encoding = 'utf-8') as fd: with self.path.open("w", encoding = "utf-8") as fd:
yaml.dump(data, fd, sort_keys = False) yaml.dump(data, fd, sort_keys = False)

View file

@ -16,17 +16,17 @@ sqlite3.register_adapter(Date, Date.timestamp)
def get_database(config: Config, migrate: bool = True) -> Database[Connection]: def get_database(config: Config, migrate: bool = True) -> Database[Connection]:
options = { options = {
'connection_class': Connection, "connection_class": Connection,
'pool_size': 5, "pool_size": 5,
'tables': TABLES "tables": TABLES
} }
db: Database[Connection] db: Database[Connection]
if config.db_type == 'sqlite': if config.db_type == "sqlite":
db = Database.sqlite(config.sqlite_path, **options) db = Database.sqlite(config.sqlite_path, **options)
elif config.db_type == 'postgres': elif config.db_type == "postgres":
db = Database.postgresql( db = Database.postgresql(
config.pg_name, config.pg_name,
config.pg_host, config.pg_host,
@ -36,30 +36,30 @@ def get_database(config: Config, migrate: bool = True) -> Database[Connection]:
**options **options
) )
db.load_prepared_statements(File.from_resource('relay', 'data/statements.sql')) db.load_prepared_statements(File.from_resource("relay", "data/statements.sql"))
db.connect() db.connect()
if not migrate: if not migrate:
return db return db
with db.session(True) as conn: with db.session(True) as conn:
if 'config' not in conn.get_tables(): if "config" not in conn.get_tables():
logging.info("Creating database tables") logging.info("Creating database tables")
migrate_0(conn) migrate_0(conn)
return db return db
if (schema_ver := conn.get_config('schema-version')) < ConfigData.DEFAULT('schema-version'): if (schema_ver := conn.get_config("schema-version")) < ConfigData.DEFAULT("schema-version"):
logging.info("Migrating database from version '%i'", schema_ver) logging.info("Migrating database from version '%i'", schema_ver)
for ver, func in VERSIONS.items(): for ver, func in VERSIONS.items():
if schema_ver < ver: if schema_ver < ver:
func(conn) func(conn)
conn.put_config('schema-version', ver) conn.put_config("schema-version", ver)
logging.info("Updated database to %i", ver) logging.info("Updated database to %i", ver)
if (privkey := conn.get_config('private-key')): if (privkey := conn.get_config("private-key")):
conn.app.signer = privkey conn.app.signer = privkey
logging.set_level(conn.get_config('log-level')) logging.set_level(conn.get_config("log-level"))
return db return db

View file

@ -15,76 +15,76 @@ if TYPE_CHECKING:
THEMES = { THEMES = {
'default': { "default": {
'text': '#DDD', "text": "#DDD",
'background': '#222', "background": "#222",
'primary': '#D85', "primary": "#D85",
'primary-hover': '#DA8', "primary-hover": "#DA8",
'section-background': '#333', "section-background": "#333",
'table-background': '#444', "table-background": "#444",
'border': '#444', "border": "#444",
'message-text': '#DDD', "message-text": "#DDD",
'message-background': '#335', "message-background": "#335",
'message-border': '#446', "message-border": "#446",
'error-text': '#DDD', "error-text": "#DDD",
'error-background': '#533', "error-background": "#533",
'error-border': '#644' "error-border": "#644"
}, },
'pink': { "pink": {
'text': '#DDD', "text": "#DDD",
'background': '#222', "background": "#222",
'primary': '#D69', "primary": "#D69",
'primary-hover': '#D36', "primary-hover": "#D36",
'section-background': '#333', "section-background": "#333",
'table-background': '#444', "table-background": "#444",
'border': '#444', "border": "#444",
'message-text': '#DDD', "message-text": "#DDD",
'message-background': '#335', "message-background": "#335",
'message-border': '#446', "message-border": "#446",
'error-text': '#DDD', "error-text": "#DDD",
'error-background': '#533', "error-background": "#533",
'error-border': '#644' "error-border": "#644"
}, },
'blue': { "blue": {
'text': '#DDD', "text": "#DDD",
'background': '#222', "background": "#222",
'primary': '#69D', "primary": "#69D",
'primary-hover': '#36D', "primary-hover": "#36D",
'section-background': '#333', "section-background": "#333",
'table-background': '#444', "table-background": "#444",
'border': '#444', "border": "#444",
'message-text': '#DDD', "message-text": "#DDD",
'message-background': '#335', "message-background": "#335",
'message-border': '#446', "message-border": "#446",
'error-text': '#DDD', "error-text": "#DDD",
'error-background': '#533', "error-background": "#533",
'error-border': '#644' "error-border": "#644"
} }
} }
# serializer | deserializer # serializer | deserializer
CONFIG_CONVERT: dict[str, tuple[Callable[[Any], str], Callable[[str], Any]]] = { CONFIG_CONVERT: dict[str, tuple[Callable[[Any], str], Callable[[str], Any]]] = {
'str': (str, str), "str": (str, str),
'int': (str, int), "int": (str, int),
'bool': (str, convert_to_boolean), "bool": (str, convert_to_boolean),
'logging.LogLevel': (lambda x: x.name, logging.LogLevel.parse) "logging.LogLevel": (lambda x: x.name, logging.LogLevel.parse)
} }
@dataclass() @dataclass()
class ConfigData: class ConfigData:
schema_version: int = 20240625 schema_version: int = 20240625
private_key: str = '' private_key: str = ""
approval_required: bool = False approval_required: bool = False
log_level: logging.LogLevel = logging.LogLevel.INFO log_level: logging.LogLevel = logging.LogLevel.INFO
name: str = 'ActivityRelay' name: str = "ActivityRelay"
note: str = '' note: str = ""
theme: str = 'default' theme: str = "default"
whitelist_enabled: bool = False whitelist_enabled: bool = False
def __getitem__(self, key: str) -> Any: def __getitem__(self, key: str) -> Any:
if (value := getattr(self, key.replace('-', '_'), None)) is None: if (value := getattr(self, key.replace("-", "_"), None)) is None:
raise KeyError(key) raise KeyError(key)
return value return value
@ -101,7 +101,7 @@ class ConfigData:
@staticmethod @staticmethod
def SYSTEM_KEYS() -> Sequence[str]: def SYSTEM_KEYS() -> Sequence[str]:
return ('schema-version', 'schema_version', 'private-key', 'private_key') return ("schema-version", "schema_version", "private-key", "private_key")
@classmethod @classmethod
@ -111,12 +111,12 @@ class ConfigData:
@classmethod @classmethod
def DEFAULT(cls: type[Self], key: str) -> str | int | bool: def DEFAULT(cls: type[Self], key: str) -> str | int | bool:
return cls.FIELD(key.replace('-', '_')).default # type: ignore[return-value] return cls.FIELD(key.replace("-", "_")).default # type: ignore[return-value]
@classmethod @classmethod
def FIELD(cls: type[Self], key: str) -> Field[str | int | bool]: def FIELD(cls: type[Self], key: str) -> Field[str | int | bool]:
parsed_key = key.replace('-', '_') parsed_key = key.replace("-", "_")
for field in fields(cls): for field in fields(cls):
if field.name == parsed_key: if field.name == parsed_key:
@ -131,9 +131,9 @@ class ConfigData:
set_schema_version = False set_schema_version = False
for row in rows: for row in rows:
data.set(row['key'], row['value']) data.set(row["key"], row["value"])
if row['key'] == 'schema-version': if row["key"] == "schema-version":
set_schema_version = True set_schema_version = True
if not set_schema_version: if not set_schema_version:
@ -161,4 +161,4 @@ class ConfigData:
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
return {key.replace('_', '-'): value for key, value in asdict(self).items()} return {key.replace("_", "-"): value for key, value in asdict(self).items()}

View file

@ -23,16 +23,17 @@ if TYPE_CHECKING:
from ..application import Application from ..application import Application
RELAY_SOFTWARE = [ RELAY_SOFTWARE = [
'activityrelay', # https://git.pleroma.social/pleroma/relay "activityrelay", # https://git.pleroma.social/pleroma/relay
'activity-relay', # https://github.com/yukimochi/Activity-Relay "activity-relay", # https://github.com/yukimochi/Activity-Relay
'aoderelay', # https://git.asonix.dog/asonix/relay "aoderelay", # https://git.asonix.dog/asonix/relay
'feditools-relay' # https://git.ptzo.gdn/feditools/relay "feditools-relay", # https://git.ptzo.gdn/feditools/relay
"buzzrelay" # https://github.com/astro/buzzrelay
] ]
class Connection(SqlConnection): class Connection(SqlConnection):
hasher = PasswordHasher( hasher = PasswordHasher(
encoding = 'utf-8' encoding = "utf-8"
) )
@property @property
@ -63,49 +64,49 @@ class Connection(SqlConnection):
def fix_timestamps(self) -> None: def fix_timestamps(self) -> None:
for app in self.select('apps').all(schema.App): for app in self.select("apps").all(schema.App):
data = {'created': app.created.timestamp(), 'accessed': app.accessed.timestamp()} data = {"created": app.created.timestamp(), "accessed": app.accessed.timestamp()}
self.update('apps', data, client_id = app.client_id) self.update("apps", data, client_id = app.client_id)
for item in self.select('cache'): for item in self.select("cache"):
data = {'updated': Date.parse(item['updated']).timestamp()} data = {"updated": Date.parse(item["updated"]).timestamp()}
self.update('cache', data, id = item['id']) self.update("cache", data, id = item["id"])
for dban in self.select('domain_bans').all(schema.DomainBan): for dban in self.select("domain_bans").all(schema.DomainBan):
data = {'created': dban.created.timestamp()} data = {"created": dban.created.timestamp()}
self.update('domain_bans', data, domain = dban.domain) self.update("domain_bans", data, domain = dban.domain)
for instance in self.select('inboxes').all(schema.Instance): for instance in self.select("inboxes").all(schema.Instance):
data = {'created': instance.created.timestamp()} data = {"created": instance.created.timestamp()}
self.update('inboxes', data, domain = instance.domain) self.update("inboxes", data, domain = instance.domain)
for sban in self.select('software_bans').all(schema.SoftwareBan): for sban in self.select("software_bans").all(schema.SoftwareBan):
data = {'created': sban.created.timestamp()} data = {"created": sban.created.timestamp()}
self.update('software_bans', data, name = sban.name) self.update("software_bans", data, name = sban.name)
for user in self.select('users').all(schema.User): for user in self.select("users").all(schema.User):
data = {'created': user.created.timestamp()} data = {"created": user.created.timestamp()}
self.update('users', data, username = user.username) self.update("users", data, username = user.username)
for wlist in self.select('whitelist').all(schema.Whitelist): for wlist in self.select("whitelist").all(schema.Whitelist):
data = {'created': wlist.created.timestamp()} data = {"created": wlist.created.timestamp()}
self.update('whitelist', data, domain = wlist.domain) self.update("whitelist", data, domain = wlist.domain)
def get_config(self, key: str) -> Any: def get_config(self, key: str) -> Any:
key = key.replace('_', '-') key = key.replace("_", "-")
with self.run('get-config', {'key': key}) as cur: with self.run("get-config", {"key": key}) as cur:
if (row := cur.one(Row)) is None: if (row := cur.one(Row)) is None:
return ConfigData.DEFAULT(key) return ConfigData.DEFAULT(key)
data = ConfigData() data = ConfigData()
data.set(row['key'], row['value']) data.set(row["key"], row["value"])
return data.get(key) return data.get(key)
def get_config_all(self) -> ConfigData: def get_config_all(self) -> ConfigData:
rows = tuple(self.run('get-config-all', None).all(schema.Row)) rows = tuple(self.run("get-config-all", None).all(schema.Row))
return ConfigData.from_rows(rows) return ConfigData.from_rows(rows)
@ -119,7 +120,7 @@ class Connection(SqlConnection):
case "log_level": case "log_level":
value = logging.LogLevel.parse(value) value = logging.LogLevel.parse(value)
logging.set_level(value) logging.set_level(value)
self.app['workers'].set_log_level(value) self.app["workers"].set_log_level(value)
case "approval_required": case "approval_required":
value = convert_to_boolean(value) value = convert_to_boolean(value)
@ -129,25 +130,25 @@ class Connection(SqlConnection):
case "theme": case "theme":
if value not in THEMES: if value not in THEMES:
raise ValueError(f'"{value}" is not a valid theme') raise ValueError(f"\"{value}\" is not a valid theme")
data = ConfigData() data = ConfigData()
data.set(key, value) data.set(key, value)
params = { params = {
'key': key, "key": key,
'value': data.get(key, serialize = True), "value": data.get(key, serialize = True),
'type': 'LogLevel' if field.type == 'logging.LogLevel' else field.type # type: ignore "type": "LogLevel" if field.type == "logging.LogLevel" else field.type
} }
with self.run('put-config', params): with self.run("put-config", params):
pass pass
return data.get(key) return data.get(key)
def get_inbox(self, value: str) -> schema.Instance | None: def get_inbox(self, value: str) -> schema.Instance | None:
with self.run('get-inbox', {'value': value}) as cur: with self.run("get-inbox", {"value": value}) as cur:
return cur.one(schema.Instance) return cur.one(schema.Instance)
@ -165,21 +166,21 @@ class Connection(SqlConnection):
accepted: bool = True) -> schema.Instance: accepted: bool = True) -> schema.Instance:
params: dict[str, Any] = { params: dict[str, Any] = {
'inbox': inbox, "inbox": inbox,
'actor': actor, "actor": actor,
'followid': followid, "followid": followid,
'software': software, "software": software,
'accepted': accepted "accepted": accepted
} }
if self.get_inbox(domain) is None: if self.get_inbox(domain) is None:
if not inbox: if not inbox:
raise ValueError("Missing inbox") raise ValueError("Missing inbox")
params['domain'] = domain params["domain"] = domain
params['created'] = datetime.now(tz = timezone.utc) params["created"] = datetime.now(tz = timezone.utc)
with self.run('put-inbox', params) as cur: with self.run("put-inbox", params) as cur:
if (row := cur.one(schema.Instance)) is None: if (row := cur.one(schema.Instance)) is None:
raise RuntimeError(f"Failed to insert instance: {domain}") raise RuntimeError(f"Failed to insert instance: {domain}")
@ -189,7 +190,7 @@ class Connection(SqlConnection):
if value is None: if value is None:
del params[key] del params[key]
with self.update('inboxes', params, domain = domain) as cur: with self.update("inboxes", params, domain = domain) as cur:
if (row := cur.one(schema.Instance)) is None: if (row := cur.one(schema.Instance)) is None:
raise RuntimeError(f"Failed to update instance: {domain}") raise RuntimeError(f"Failed to update instance: {domain}")
@ -197,20 +198,20 @@ class Connection(SqlConnection):
def del_inbox(self, value: str) -> bool: def del_inbox(self, value: str) -> bool:
with self.run('del-inbox', {'value': value}) as cur: with self.run("del-inbox", {"value": value}) as cur:
if cur.row_count > 1: if cur.row_count > 1:
raise ValueError('More than one row was modified') raise ValueError("More than one row was modified")
return cur.row_count == 1 return cur.row_count == 1
def get_request(self, domain: str) -> schema.Instance | None: def get_request(self, domain: str) -> schema.Instance | None:
with self.run('get-request', {'domain': domain}) as cur: with self.run("get-request", {"domain": domain}) as cur:
return cur.one(schema.Instance) return cur.one(schema.Instance)
def get_requests(self) -> Iterator[schema.Instance]: def get_requests(self) -> Iterator[schema.Instance]:
return self.execute('SELECT * FROM inboxes WHERE accepted = false').all(schema.Instance) return self.execute("SELECT * FROM inboxes WHERE accepted = false").all(schema.Instance)
def put_request_response(self, domain: str, accepted: bool) -> schema.Instance: def put_request_response(self, domain: str, accepted: bool) -> schema.Instance:
@ -219,16 +220,16 @@ class Connection(SqlConnection):
if not accepted: if not accepted:
if not self.del_inbox(domain): if not self.del_inbox(domain):
raise RuntimeError(f'Failed to delete request: {domain}') raise RuntimeError(f"Failed to delete request: {domain}")
return instance return instance
params = { params = {
'domain': domain, "domain": domain,
'accepted': accepted "accepted": accepted
} }
with self.run('put-inbox-accept', params) as cur: with self.run("put-inbox-accept", params) as cur:
if (row := cur.one(schema.Instance)) is None: if (row := cur.one(schema.Instance)) is None:
raise RuntimeError(f"Failed to insert response for domain: {domain}") raise RuntimeError(f"Failed to insert response for domain: {domain}")
@ -236,12 +237,12 @@ class Connection(SqlConnection):
def get_user(self, value: str) -> schema.User | None: def get_user(self, value: str) -> schema.User | None:
with self.run('get-user', {'value': value}) as cur: with self.run("get-user", {"value": value}) as cur:
return cur.one(schema.User) return cur.one(schema.User)
def get_user_by_token(self, token: str) -> schema.User | None: def get_user_by_token(self, token: str) -> schema.User | None:
with self.run('get-user-by-token', {'token': token}) as cur: with self.run("get-user-by-token", {"token": token}) as cur:
return cur.one(schema.User) return cur.one(schema.User)
@ -254,10 +255,10 @@ class Connection(SqlConnection):
data: dict[str, str | datetime | None] = {} data: dict[str, str | datetime | None] = {}
if password: if password:
data['hash'] = self.hasher.hash(password) data["hash"] = self.hasher.hash(password)
if handle: if handle:
data['handle'] = handle data["handle"] = handle
stmt = Update("users", data) stmt = Update("users", data)
stmt.set_where("username", username) stmt.set_where("username", username)
@ -269,16 +270,16 @@ class Connection(SqlConnection):
return row return row
if password is None: if password is None:
raise ValueError('Password cannot be empty') raise ValueError("Password cannot be empty")
data = { data = {
'username': username, "username": username,
'hash': self.hasher.hash(password), "hash": self.hasher.hash(password),
'handle': handle, "handle": handle,
'created': datetime.now(tz = timezone.utc) "created": datetime.now(tz = timezone.utc)
} }
with self.run('put-user', data) as cur: with self.run("put-user", data) as cur:
if (row := cur.one(schema.User)) is None: if (row := cur.one(schema.User)) is None:
raise RuntimeError(f"Failed to insert user: {username}") raise RuntimeError(f"Failed to insert user: {username}")
@ -289,10 +290,10 @@ class Connection(SqlConnection):
if (user := self.get_user(username)) is None: if (user := self.get_user(username)) is None:
raise KeyError(username) raise KeyError(username)
with self.run('del-token-user', {'username': user.username}): with self.run("del-token-user", {"username": user.username}):
pass pass
with self.run('del-user', {'username': user.username}): with self.run("del-user", {"username": user.username}):
pass pass
@ -302,61 +303,61 @@ class Connection(SqlConnection):
token: str | None = None) -> schema.App | None: token: str | None = None) -> schema.App | None:
params = { params = {
'id': client_id, "id": client_id,
'secret': client_secret "secret": client_secret
} }
if token is not None: if token is not None:
command = 'get-app-with-token' command = "get-app-with-token"
params['token'] = token params["token"] = token
else: else:
command = 'get-app' command = "get-app"
with self.run(command, params) as cur: with self.run(command, params) as cur:
return cur.one(schema.App) return cur.one(schema.App)
def get_app_by_token(self, token: str) -> schema.App | None: def get_app_by_token(self, token: str) -> schema.App | None:
with self.run('get-app-by-token', {'token': token}) as cur: with self.run("get-app-by-token", {"token": token}) as cur:
return cur.one(schema.App) return cur.one(schema.App)
def put_app(self, name: str, redirect_uri: str, website: str | None = None) -> schema.App: def put_app(self, name: str, redirect_uri: str, website: str | None = None) -> schema.App:
params = { params = {
'name': name, "name": name,
'redirect_uri': redirect_uri, "redirect_uri": redirect_uri,
'website': website, "website": website,
'client_id': secrets.token_hex(20), "client_id": secrets.token_hex(20),
'client_secret': secrets.token_hex(20), "client_secret": secrets.token_hex(20),
'created': Date.new_utc(), "created": Date.new_utc(),
'accessed': Date.new_utc() "accessed": Date.new_utc()
} }
with self.insert('apps', params) as cur: with self.insert("apps", params) as cur:
if (row := cur.one(schema.App)) is None: if (row := cur.one(schema.App)) is None:
raise RuntimeError(f'Failed to insert app: {name}') raise RuntimeError(f"Failed to insert app: {name}")
return row return row
def put_app_login(self, user: schema.User) -> schema.App: def put_app_login(self, user: schema.User) -> schema.App:
params = { params = {
'name': 'Web', "name": "Web",
'redirect_uri': 'urn:ietf:wg:oauth:2.0:oob', "redirect_uri": "urn:ietf:wg:oauth:2.0:oob",
'website': None, "website": None,
'user': user.username, "user": user.username,
'client_id': secrets.token_hex(20), "client_id": secrets.token_hex(20),
'client_secret': secrets.token_hex(20), "client_secret": secrets.token_hex(20),
'auth_code': None, "auth_code": None,
'token': secrets.token_hex(20), "token": secrets.token_hex(20),
'created': Date.new_utc(), "created": Date.new_utc(),
'accessed': Date.new_utc() "accessed": Date.new_utc()
} }
with self.insert('apps', params) as cur: with self.insert("apps", params) as cur:
if (row := cur.one(schema.App)) is None: if (row := cur.one(schema.App)) is None:
raise RuntimeError(f'Failed to create app for "{user.username}"') raise RuntimeError(f"Failed to create app for \"{user.username}\"")
return row return row
@ -365,52 +366,52 @@ class Connection(SqlConnection):
data: dict[str, str | None] = {} data: dict[str, str | None] = {}
if user is not None: if user is not None:
data['user'] = user.username data["user"] = user.username
if set_auth: if set_auth:
data['auth_code'] = secrets.token_hex(20) data["auth_code"] = secrets.token_hex(20)
else: else:
data['token'] = secrets.token_hex(20) data["token"] = secrets.token_hex(20)
data['auth_code'] = None data["auth_code"] = None
params = { params = {
'client_id': app.client_id, "client_id": app.client_id,
'client_secret': app.client_secret "client_secret": app.client_secret
} }
with self.update('apps', data, **params) as cur: # type: ignore[arg-type] with self.update("apps", data, **params) as cur: # type: ignore[arg-type]
if (row := cur.one(schema.App)) is None: if (row := cur.one(schema.App)) is None:
raise RuntimeError('Failed to update row') raise RuntimeError("Failed to update row")
return row return row
def del_app(self, client_id: str, client_secret: str, token: str | None = None) -> bool: def del_app(self, client_id: str, client_secret: str, token: str | None = None) -> bool:
params = { params = {
'id': client_id, "id": client_id,
'secret': client_secret "secret": client_secret
} }
if token is not None: if token is not None:
command = 'del-app-with-token' command = "del-app-with-token"
params['token'] = token params["token"] = token
else: else:
command = 'del-app' command = "del-app"
with self.run(command, params) as cur: with self.run(command, params) as cur:
if cur.row_count > 1: if cur.row_count > 1:
raise RuntimeError('More than 1 row was deleted') raise RuntimeError("More than 1 row was deleted")
return cur.row_count == 0 return cur.row_count == 0
def get_domain_ban(self, domain: str) -> schema.DomainBan | None: def get_domain_ban(self, domain: str) -> schema.DomainBan | None:
if domain.startswith('http'): if domain.startswith("http"):
domain = urlparse(domain).netloc domain = urlparse(domain).netloc
with self.run('get-domain-ban', {'domain': domain}) as cur: with self.run("get-domain-ban", {"domain": domain}) as cur:
return cur.one(schema.DomainBan) return cur.one(schema.DomainBan)
@ -424,13 +425,13 @@ class Connection(SqlConnection):
note: str | None = None) -> schema.DomainBan: note: str | None = None) -> schema.DomainBan:
params = { params = {
'domain': domain, "domain": domain,
'reason': reason, "reason": reason,
'note': note, "note": note,
'created': datetime.now(tz = timezone.utc) "created": datetime.now(tz = timezone.utc)
} }
with self.run('put-domain-ban', params) as cur: with self.run("put-domain-ban", params) as cur:
if (row := cur.one(schema.DomainBan)) is None: if (row := cur.one(schema.DomainBan)) is None:
raise RuntimeError(f"Failed to insert domain ban: {domain}") raise RuntimeError(f"Failed to insert domain ban: {domain}")
@ -443,22 +444,22 @@ class Connection(SqlConnection):
note: str | None = None) -> schema.DomainBan: note: str | None = None) -> schema.DomainBan:
if not (reason or note): if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified') raise ValueError("\"reason\" and/or \"note\" must be specified")
params = {} params = {}
if reason is not None: if reason is not None:
params['reason'] = reason params["reason"] = reason
if note is not None: if note is not None:
params['note'] = note params["note"] = note
statement = Update('domain_bans', params) statement = Update("domain_bans", params)
statement.set_where("domain", domain) statement.set_where("domain", domain)
with self.query(statement) as cur: with self.query(statement) as cur:
if cur.row_count > 1: if cur.row_count > 1:
raise ValueError('More than one row was modified') raise ValueError("More than one row was modified")
if (row := cur.one(schema.DomainBan)) is None: if (row := cur.one(schema.DomainBan)) is None:
raise RuntimeError(f"Failed to update domain ban: {domain}") raise RuntimeError(f"Failed to update domain ban: {domain}")
@ -467,20 +468,20 @@ class Connection(SqlConnection):
def del_domain_ban(self, domain: str) -> bool: def del_domain_ban(self, domain: str) -> bool:
with self.run('del-domain-ban', {'domain': domain}) as cur: with self.run("del-domain-ban", {"domain": domain}) as cur:
if cur.row_count > 1: if cur.row_count > 1:
raise ValueError('More than one row was modified') raise ValueError("More than one row was modified")
return cur.row_count == 1 return cur.row_count == 1
def get_software_ban(self, name: str) -> schema.SoftwareBan | None: def get_software_ban(self, name: str) -> schema.SoftwareBan | None:
with self.run('get-software-ban', {'name': name}) as cur: with self.run("get-software-ban", {"name": name}) as cur:
return cur.one(schema.SoftwareBan) return cur.one(schema.SoftwareBan)
def get_software_bans(self) -> Iterator[schema.SoftwareBan,]: def get_software_bans(self) -> Iterator[schema.SoftwareBan,]:
return self.execute('SELECT * FROM software_bans').all(schema.SoftwareBan) return self.execute("SELECT * FROM software_bans").all(schema.SoftwareBan)
def put_software_ban(self, def put_software_ban(self,
@ -489,15 +490,15 @@ class Connection(SqlConnection):
note: str | None = None) -> schema.SoftwareBan: note: str | None = None) -> schema.SoftwareBan:
params = { params = {
'name': name, "name": name,
'reason': reason, "reason": reason,
'note': note, "note": note,
'created': datetime.now(tz = timezone.utc) "created": datetime.now(tz = timezone.utc)
} }
with self.run('put-software-ban', params) as cur: with self.run("put-software-ban", params) as cur:
if (row := cur.one(schema.SoftwareBan)) is None: if (row := cur.one(schema.SoftwareBan)) is None:
raise RuntimeError(f'Failed to insert software ban: {name}') raise RuntimeError(f"Failed to insert software ban: {name}")
return row return row
@ -508,39 +509,39 @@ class Connection(SqlConnection):
note: str | None = None) -> schema.SoftwareBan: note: str | None = None) -> schema.SoftwareBan:
if not (reason or note): if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified') raise ValueError("\"reason\" and/or \"note\" must be specified")
params = {} params = {}
if reason is not None: if reason is not None:
params['reason'] = reason params["reason"] = reason
if note is not None: if note is not None:
params['note'] = note params["note"] = note
statement = Update('software_bans', params) statement = Update("software_bans", params)
statement.set_where("name", name) statement.set_where("name", name)
with self.query(statement) as cur: with self.query(statement) as cur:
if cur.row_count > 1: if cur.row_count > 1:
raise ValueError('More than one row was modified') raise ValueError("More than one row was modified")
if (row := cur.one(schema.SoftwareBan)) is None: if (row := cur.one(schema.SoftwareBan)) is None:
raise RuntimeError(f'Failed to update software ban: {name}') raise RuntimeError(f"Failed to update software ban: {name}")
return row return row
def del_software_ban(self, name: str) -> bool: def del_software_ban(self, name: str) -> bool:
with self.run('del-software-ban', {'name': name}) as cur: with self.run("del-software-ban", {"name": name}) as cur:
if cur.row_count > 1: if cur.row_count > 1:
raise ValueError('More than one row was modified') raise ValueError("More than one row was modified")
return cur.row_count == 1 return cur.row_count == 1
def get_domain_whitelist(self, domain: str) -> schema.Whitelist | None: def get_domain_whitelist(self, domain: str) -> schema.Whitelist | None:
with self.run('get-domain-whitelist', {'domain': domain}) as cur: with self.run("get-domain-whitelist", {"domain": domain}) as cur:
return cur.one() return cur.one()
@ -550,20 +551,20 @@ class Connection(SqlConnection):
def put_domain_whitelist(self, domain: str) -> schema.Whitelist: def put_domain_whitelist(self, domain: str) -> schema.Whitelist:
params = { params = {
'domain': domain, "domain": domain,
'created': datetime.now(tz = timezone.utc) "created": datetime.now(tz = timezone.utc)
} }
with self.run('put-domain-whitelist', params) as cur: with self.run("put-domain-whitelist", params) as cur:
if (row := cur.one(schema.Whitelist)) is None: if (row := cur.one(schema.Whitelist)) is None:
raise RuntimeError(f'Failed to insert whitelisted domain: {domain}') raise RuntimeError(f"Failed to insert whitelisted domain: {domain}")
return row return row
def del_domain_whitelist(self, domain: str) -> bool: def del_domain_whitelist(self, domain: str) -> bool:
with self.run('del-domain-whitelist', {'domain': domain}) as cur: with self.run("del-domain-whitelist", {"domain": domain}) as cur:
if cur.row_count > 1: if cur.row_count > 1:
raise ValueError('More than one row was modified') raise ValueError("More than one row was modified")
return cur.row_count == 1 return cur.row_count == 1

View file

@ -32,113 +32,113 @@ def deserialize_timestamp(value: Any) -> Date:
@TABLES.add_row @TABLES.add_row
class Config(Row): class Config(Row):
key: Column[str] = Column('key', 'text', primary_key = True, unique = True, nullable = False) key: Column[str] = Column("key", "text", primary_key = True, unique = True, nullable = False)
value: Column[str] = Column('value', 'text') value: Column[str] = Column("value", "text")
type: Column[str] = Column('type', 'text', default = 'str') type: Column[str] = Column("type", "text", default = "str")
@TABLES.add_row @TABLES.add_row
class Instance(Row): class Instance(Row):
table_name: str = 'inboxes' table_name: str = "inboxes"
domain: Column[str] = Column( domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = False) "domain", "text", primary_key = True, unique = True, nullable = False)
actor: Column[str] = Column('actor', 'text', unique = True) actor: Column[str] = Column("actor", "text", unique = True)
inbox: Column[str] = Column('inbox', 'text', unique = True, nullable = False) inbox: Column[str] = Column("inbox", "text", unique = True, nullable = False)
followid: Column[str] = Column('followid', 'text') followid: Column[str] = Column("followid", "text")
software: Column[str] = Column('software', 'text') software: Column[str] = Column("software", "text")
accepted: Column[Date] = Column('accepted', 'boolean') accepted: Column[Date] = Column("accepted", "boolean")
created: Column[Date] = Column( created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp) "created", "timestamp", nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row @TABLES.add_row
class Whitelist(Row): class Whitelist(Row):
domain: Column[str] = Column( domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = True) "domain", "text", primary_key = True, unique = True, nullable = True)
created: Column[Date] = Column( created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp) "created", "timestamp", nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row @TABLES.add_row
class DomainBan(Row): class DomainBan(Row):
table_name: str = 'domain_bans' table_name: str = "domain_bans"
domain: Column[str] = Column( domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = True) "domain", "text", primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text') reason: Column[str] = Column("reason", "text")
note: Column[str] = Column('note', 'text') note: Column[str] = Column("note", "text")
created: Column[Date] = Column( created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp) "created", "timestamp", nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row @TABLES.add_row
class SoftwareBan(Row): class SoftwareBan(Row):
table_name: str = 'software_bans' table_name: str = "software_bans"
name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True) name: Column[str] = Column("name", "text", primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text') reason: Column[str] = Column("reason", "text")
note: Column[str] = Column('note', 'text') note: Column[str] = Column("note", "text")
created: Column[Date] = Column( created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp) "created", "timestamp", nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row @TABLES.add_row
class User(Row): class User(Row):
table_name: str = 'users' table_name: str = "users"
username: Column[str] = Column( username: Column[str] = Column(
'username', 'text', primary_key = True, unique = True, nullable = False) "username", "text", primary_key = True, unique = True, nullable = False)
hash: Column[str] = Column('hash', 'text', nullable = False) hash: Column[str] = Column("hash", "text", nullable = False)
handle: Column[str] = Column('handle', 'text') handle: Column[str] = Column("handle", "text")
created: Column[Date] = Column( created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp) "created", "timestamp", nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row @TABLES.add_row
class App(Row): class App(Row):
table_name: str = 'apps' table_name: str = "apps"
client_id: Column[str] = Column( client_id: Column[str] = Column(
'client_id', 'text', primary_key = True, unique = True, nullable = False) "client_id", "text", primary_key = True, unique = True, nullable = False)
client_secret: Column[str] = Column('client_secret', 'text', nullable = False) client_secret: Column[str] = Column("client_secret", "text", nullable = False)
name: Column[str] = Column('name', 'text') name: Column[str] = Column("name", "text")
website: Column[str] = Column('website', 'text') website: Column[str] = Column("website", "text")
redirect_uri: Column[str] = Column('redirect_uri', 'text', nullable = False) redirect_uri: Column[str] = Column("redirect_uri", "text", nullable = False)
token: Column[str | None] = Column('token', 'text') token: Column[str | None] = Column("token", "text")
auth_code: Column[str | None] = Column('auth_code', 'text') auth_code: Column[str | None] = Column("auth_code", "text")
user: Column[str | None] = Column('user', 'text') user: Column[str | None] = Column("user", "text")
created: Column[Date] = Column( created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp) "created", "timestamp", nullable = False, deserializer = deserialize_timestamp)
accessed: Column[Date] = Column( accessed: Column[Date] = Column(
'accessed', 'timestamp', nullable = False, deserializer = deserialize_timestamp) "accessed", "timestamp", nullable = False, deserializer = deserialize_timestamp)
def get_api_data(self, include_token: bool = False) -> dict[str, Any]: def get_api_data(self, include_token: bool = False) -> dict[str, Any]:
data = deepcopy(self) data = deepcopy(self)
data.pop('user') data.pop("user")
data.pop('auth_code') data.pop("auth_code")
if not include_token: if not include_token:
data.pop('token') data.pop("token")
return data return data
def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]: def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]:
ver = int(func.__name__.replace('migrate_', '')) ver = int(func.__name__.replace("migrate_", ""))
VERSIONS[ver] = func VERSIONS[ver] = func
return func return func
def migrate_0(conn: Connection) -> None: def migrate_0(conn: Connection) -> None:
conn.create_tables() conn.create_tables()
conn.put_config('schema-version', ConfigData.DEFAULT('schema-version')) conn.put_config("schema-version", ConfigData.DEFAULT("schema-version"))
@migration @migration
@ -148,11 +148,11 @@ def migrate_20240206(conn: Connection) -> None:
@migration @migration
def migrate_20240310(conn: Connection) -> None: def migrate_20240310(conn: Connection) -> None:
conn.execute('ALTER TABLE "inboxes" ADD COLUMN "accepted" BOOLEAN').close() conn.execute("ALTER TABLE \"inboxes\" ADD COLUMN \"accepted\" BOOLEAN").close()
conn.execute('UPDATE "inboxes" SET "accepted" = true').close() conn.execute("UPDATE \"inboxes\" SET \"accepted\" = true").close()
@migration @migration
def migrate_20240625(conn: Connection) -> None: def migrate_20240625(conn: Connection) -> None:
conn.create_tables() conn.create_tables()
conn.execute('DROP TABLE "tokens"').close() conn.execute("DROP TABLE \"tokens\"").close()

View file

@ -16,15 +16,13 @@
-if config.approval_required -if config.approval_required
%div.section.message %div.section.message
Follow requests require approval. You will need to wait for an admin to accept or deny Follow requests require approval. You will need to wait for an admin to accept or deny your request.
your request.
-elif config.whitelist_enabled -elif config.whitelist_enabled
%fieldset.section.message %fieldset.section.message
%legend << Whitelist Enabled %legend << Whitelist Enabled
The whitelist is enabled on this instance. Ask the admin to add your instance before The whitelist is enabled on this instance. Ask the admin to add your instance before joining.
joining.
%fieldset.section %fieldset.section
%legend << Instances %legend << Instances

View file

@ -499,8 +499,7 @@ function page_login() {
async function login(event) { async function login(event) {
const values = { const values = {
username: fields.username.value.trim(), username: fields.username.value.trim(),
password: fields.password.value.trim(), password: fields.password.value.trim()
redir: fields.redir.value.trim()
} }
if (values.username === "" | values.password === "") { if (values.username === "" | values.password === "") {
@ -509,14 +508,16 @@ function page_login() {
} }
try { try {
await request("POST", "v1/login", values); application = await request("POST", "v1/login", values);
} catch (error) { } catch (error) {
toast(error); toast(error);
return; return;
} }
document.location = values.redir; const max_age = 60 * 60 * 24 * 30;
document.cookie = `user-token=${application.token};Secure;SameSite=Strict;Domain=${document.location.host};MaxAge=${max_age}`;
document.location = fields.redir.value.trim();
} }

View file

@ -17,6 +17,13 @@ if TYPE_CHECKING:
from .application import Application from .application import Application
T = TypeVar("T", bound = JsonBase[Any])
HEADERS = {
"Accept": f"{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9",
"User-Agent": f"ActivityRelay/{__version__}"
}
SUPPORTS_HS2019 = { SUPPORTS_HS2019 = {
'friendica', 'friendica',
'gotosocial', 'gotosocial',
@ -32,12 +39,6 @@ SUPPORTS_HS2019 = {
'sharkey' 'sharkey'
} }
T = TypeVar('T', bound = JsonBase[Any])
HEADERS = {
'Accept': f'{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9',
'User-Agent': f'ActivityRelay/{__version__}'
}
class HttpClient: class HttpClient:
def __init__(self, limit: int = 100, timeout: int = 10): def __init__(self, limit: int = 100, timeout: int = 10):
@ -106,25 +107,25 @@ class HttpClient:
old_algo: bool) -> str | None: old_algo: bool) -> str | None:
if not self._session: if not self._session:
raise RuntimeError('Client not open') raise RuntimeError("Client not open")
url = url.split("#", 1)[0] url = url.split("#", 1)[0]
if not force: if not force:
try: try:
if not (item := self.cache.get('request', url)).older_than(48): if not (item := self.cache.get("request", url)).older_than(48):
return item.value # type: ignore [no-any-return] return item.value # type: ignore [no-any-return]
except KeyError: except KeyError:
logging.verbose('No cached data for url: %s', url) logging.verbose("No cached data for url: %s", url)
headers = {} headers = {}
if sign_headers: if sign_headers:
algo = AlgorithmType.RSASHA256 if old_algo else AlgorithmType.HS2019 algo = AlgorithmType.RSASHA256 if old_algo else AlgorithmType.HS2019
headers = self.signer.sign_headers('GET', url, algorithm = algo) headers = self.signer.sign_headers("GET", url, algorithm = algo)
logging.debug('Fetching resource: %s', url) logging.debug("Fetching resource: %s", url)
async with self._session.get(url, headers = headers) as resp: async with self._session.get(url, headers = headers) as resp:
# Not expecting a response with 202s, so just return # Not expecting a response with 202s, so just return
@ -142,7 +143,7 @@ class HttpClient:
raise HttpError(resp.status, error) raise HttpError(resp.status, error)
self.cache.set('request', url, data, 'str') self.cache.set("request", url, data, "str")
return data return data
@ -172,13 +173,13 @@ class HttpClient:
old_algo: bool = True) -> T | str | None: old_algo: bool = True) -> T | str | None:
if cls is not None and not issubclass(cls, JsonBase): if cls is not None and not issubclass(cls, JsonBase):
raise TypeError('cls must be a sub-class of "blib.JsonBase"') raise TypeError("cls must be a sub-class of \"blib.JsonBase\"")
data = await self._get(url, sign_headers, force, old_algo) data = await self._get(url, sign_headers, force, old_algo)
if cls is not None: if cls is not None:
if data is None: if data is None:
# this shouldn't actually get raised, but keeping just in case # this shouldn"t actually get raised, but keeping just in case
raise EmptyBodyError(f"GET {url}") raise EmptyBodyError(f"GET {url}")
return cls.parse(data) return cls.parse(data)
@ -188,7 +189,7 @@ class HttpClient:
async def post(self, url: str, data: Message | bytes, instance: Instance | None = None) -> None: async def post(self, url: str, data: Message | bytes, instance: Instance | None = None) -> None:
if not self._session: if not self._session:
raise RuntimeError('Client not open') raise RuntimeError("Client not open")
# akkoma and pleroma do not support HS2019 and other software still needs to be tested # akkoma and pleroma do not support HS2019 and other software still needs to be tested
if instance is not None and instance.software in SUPPORTS_HS2019: if instance is not None and instance.software in SUPPORTS_HS2019:
@ -210,14 +211,14 @@ class HttpClient:
mtype = message.type.value if isinstance(message.type, ObjectType) else message.type mtype = message.type.value if isinstance(message.type, ObjectType) else message.type
headers = self.signer.sign_headers( headers = self.signer.sign_headers(
'POST', "POST",
url, url,
body, body,
headers = {'Content-Type': 'application/activity+json'}, headers = {"Content-Type": "application/activity+json"},
algorithm = algorithm algorithm = algorithm
) )
logging.verbose('Sending "%s" to %s', mtype, url) logging.verbose("Sending \"%s\" to %s", mtype, url)
async with self._session.post(url, headers = headers, data = body) as resp: async with self._session.post(url, headers = headers, data = body) as resp:
if resp.status not in (200, 202): if resp.status not in (200, 202):
@ -231,10 +232,10 @@ class HttpClient:
async def fetch_nodeinfo(self, domain: str, force: bool = False) -> Nodeinfo: async def fetch_nodeinfo(self, domain: str, force: bool = False) -> Nodeinfo:
nodeinfo_url = None nodeinfo_url = None
wk_nodeinfo = await self.get( wk_nodeinfo = await self.get(
f'https://{domain}/.well-known/nodeinfo', False, WellKnownNodeinfo, force f"https://{domain}/.well-known/nodeinfo", False, WellKnownNodeinfo, force
) )
for version in ('20', '21'): for version in ("20", "21"):
try: try:
nodeinfo_url = wk_nodeinfo.get_url(version) nodeinfo_url = wk_nodeinfo.get_url(version)
@ -242,7 +243,7 @@ class HttpClient:
pass pass
if nodeinfo_url is None: if nodeinfo_url is None:
raise ValueError(f'Failed to fetch nodeinfo url for {domain}') raise ValueError(f"Failed to fetch nodeinfo url for {domain}")
return await self.get(nodeinfo_url, False, Nodeinfo, force) return await self.get(nodeinfo_url, False, Nodeinfo, force)

View file

@ -54,7 +54,7 @@ class LogLevel(IntEnum):
except ValueError: except ValueError:
pass pass
raise AttributeError(f'Invalid enum property for {cls.__name__}: {data}') raise AttributeError(f"Invalid enum property for {cls.__name__}: {data}")
def get_level() -> LogLevel: def get_level() -> LogLevel:
@ -80,7 +80,7 @@ critical: LoggingMethod = logging.critical
try: try:
env_log_file: Path | None = Path(os.environ['LOG_FILE']).expanduser().resolve() env_log_file: Path | None = Path(os.environ["LOG_FILE"]).expanduser().resolve()
except KeyError: except KeyError:
env_log_file = None env_log_file = None
@ -90,16 +90,16 @@ handlers: list[Any] = [logging.StreamHandler()]
if env_log_file: if env_log_file:
handlers.append(logging.FileHandler(env_log_file)) handlers.append(logging.FileHandler(env_log_file))
if os.environ.get('IS_SYSTEMD'): if os.environ.get("IS_SYSTEMD"):
logging_format = '%(levelname)s: %(message)s' logging_format = "%(levelname)s: %(message)s"
else: else:
logging_format = '[%(asctime)s] %(levelname)s: %(message)s' logging_format = "[%(asctime)s] %(levelname)s: %(message)s"
logging.addLevelName(LogLevel.VERBOSE, 'VERBOSE') logging.addLevelName(LogLevel.VERBOSE, "VERBOSE")
logging.basicConfig( logging.basicConfig(
level = LogLevel.INFO, level = LogLevel.INFO,
format = logging_format, format = logging_format,
datefmt = '%Y-%m-%d %H:%M:%S', datefmt = "%Y-%m-%d %H:%M:%S",
handlers = handlers handlers = handlers
) )

File diff suppressed because it is too large Load diff

View file

@ -16,63 +16,53 @@ if TYPE_CHECKING:
from .application import Application from .application import Application
T = TypeVar('T') T = TypeVar("T")
ResponseType = TypedDict('ResponseType', {
'status': int, IS_DOCKER = bool(os.environ.get("DOCKER_RUNNING"))
'headers': dict[str, Any] | None, IS_WINDOWS = platform.system() == "Windows"
'content_type': str,
'body': bytes | None, ResponseType = TypedDict("ResponseType", {
'text': str | None "status": int,
"headers": dict[str, Any] | None,
"content_type": str,
"body": bytes | None,
"text": str | None
}) })
IS_DOCKER = bool(os.environ.get('DOCKER_RUNNING'))
IS_WINDOWS = platform.system() == 'Windows'
MIMETYPES = { MIMETYPES = {
'activity': 'application/activity+json', "activity": "application/activity+json",
'css': 'text/css', "css": "text/css",
'html': 'text/html', "html": "text/html",
'json': 'application/json', "json": "application/json",
'text': 'text/plain', "text": "text/plain",
'webmanifest': 'application/manifest+json' "webmanifest": "application/manifest+json"
} }
ACTOR_FORMATS = { ACTOR_FORMATS = {
'mastodon': 'https://{domain}/actor', "mastodon": "https://{domain}/actor",
'akkoma': 'https://{domain}/relay', "akkoma": "https://{domain}/relay",
'pleroma': 'https://{domain}/relay' "pleroma": "https://{domain}/relay"
} }
SOFTWARE = (
'mastodon',
'akkoma',
'pleroma',
'misskey',
'friendica',
'hubzilla',
'firefish',
'gotosocial'
)
JSON_PATHS: tuple[str, ...] = ( JSON_PATHS: tuple[str, ...] = (
'/api/v1', "/api/v1",
'/actor', "/actor",
'/inbox', "/inbox",
'/outbox', "/outbox",
'/following', "/following",
'/followers', "/followers",
'/.well-known', "/.well-known",
'/nodeinfo', "/nodeinfo",
'/oauth/token', "/oauth/token",
'/oauth/revoke' "/oauth/revoke"
) )
TOKEN_PATHS: tuple[str, ...] = ( TOKEN_PATHS: tuple[str, ...] = (
'/logout', "/logout",
'/admin', "/admin",
'/api', "/api",
'/oauth/authorize', "/oauth/authorize",
'/oauth/revoke' "/oauth/revoke"
) )
@ -80,7 +70,7 @@ def get_app() -> Application:
from .application import Application from .application import Application
if not Application.DEFAULT: if not Application.DEFAULT:
raise ValueError('No default application set') raise ValueError("No default application set")
return Application.DEFAULT return Application.DEFAULT
@ -136,23 +126,23 @@ class Message(aputils.Message):
approves: bool = False) -> Self: approves: bool = False) -> Self:
return cls.new(aputils.ObjectType.APPLICATION, { return cls.new(aputils.ObjectType.APPLICATION, {
'id': f'https://{host}/actor', "id": f"https://{host}/actor",
'preferredUsername': 'relay', "preferredUsername": "relay",
'name': 'ActivityRelay', "name": "ActivityRelay",
'summary': description or 'ActivityRelay bot', "summary": description or "ActivityRelay bot",
'manuallyApprovesFollowers': approves, "manuallyApprovesFollowers": approves,
'followers': f'https://{host}/followers', "followers": f"https://{host}/followers",
'following': f'https://{host}/following', "following": f"https://{host}/following",
'inbox': f'https://{host}/inbox', "inbox": f"https://{host}/inbox",
'outbox': f'https://{host}/outbox', "outbox": f"https://{host}/outbox",
'url': f'https://{host}/', "url": f"https://{host}/",
'endpoints': { "endpoints": {
'sharedInbox': f'https://{host}/inbox' "sharedInbox": f"https://{host}/inbox"
}, },
'publicKey': { "publicKey": {
'id': f'https://{host}/actor#main-key', "id": f"https://{host}/actor#main-key",
'owner': f'https://{host}/actor', "owner": f"https://{host}/actor",
'publicKeyPem': pubkey "publicKeyPem": pubkey
} }
}) })
@ -160,44 +150,44 @@ class Message(aputils.Message):
@classmethod @classmethod
def new_announce(cls: type[Self], host: str, obj: str | dict[str, Any]) -> Self: def new_announce(cls: type[Self], host: str, obj: str | dict[str, Any]) -> Self:
return cls.new(aputils.ObjectType.ANNOUNCE, { return cls.new(aputils.ObjectType.ANNOUNCE, {
'id': f'https://{host}/activities/{uuid4()}', "id": f"https://{host}/activities/{uuid4()}",
'to': [f'https://{host}/followers'], "to": [f"https://{host}/followers"],
'actor': f'https://{host}/actor', "actor": f"https://{host}/actor",
'object': obj "object": obj
}) })
@classmethod @classmethod
def new_follow(cls: type[Self], host: str, actor: str) -> Self: def new_follow(cls: type[Self], host: str, actor: str) -> Self:
return cls.new(aputils.ObjectType.FOLLOW, { return cls.new(aputils.ObjectType.FOLLOW, {
'id': f'https://{host}/activities/{uuid4()}', "id": f"https://{host}/activities/{uuid4()}",
'to': [actor], "to": [actor],
'object': actor, "object": actor,
'actor': f'https://{host}/actor' "actor": f"https://{host}/actor"
}) })
@classmethod @classmethod
def new_unfollow(cls: type[Self], host: str, actor: str, follow: dict[str, str]) -> Self: def new_unfollow(cls: type[Self], host: str, actor: str, follow: dict[str, str]) -> Self:
return cls.new(aputils.ObjectType.UNDO, { return cls.new(aputils.ObjectType.UNDO, {
'id': f'https://{host}/activities/{uuid4()}', "id": f"https://{host}/activities/{uuid4()}",
'to': [actor], "to": [actor],
'actor': f'https://{host}/actor', "actor": f"https://{host}/actor",
'object': follow "object": follow
}) })
@classmethod @classmethod
def new_response(cls: type[Self], host: str, actor: str, followid: str, accept: bool) -> Self: def new_response(cls: type[Self], host: str, actor: str, followid: str, accept: bool) -> Self:
return cls.new(aputils.ObjectType.ACCEPT if accept else aputils.ObjectType.REJECT, { return cls.new(aputils.ObjectType.ACCEPT if accept else aputils.ObjectType.REJECT, {
'id': f'https://{host}/activities/{uuid4()}', "id": f"https://{host}/activities/{uuid4()}",
'to': [actor], "to": [actor],
'actor': f'https://{host}/actor', "actor": f"https://{host}/actor",
'object': { "object": {
'id': followid, "id": followid,
'type': 'Follow', "type": "Follow",
'object': f'https://{host}/actor', "object": f"https://{host}/actor",
'actor': actor "actor": actor
} }
}) })
@ -210,35 +200,35 @@ class Response(AiohttpResponse):
@classmethod @classmethod
def new(cls: type[Self], def new(cls: type[Self],
body: str | bytes | dict[str, Any] | Sequence[Any] = '', body: str | bytes | dict[str, Any] | Sequence[Any] = "",
status: int = 200, status: int = 200,
headers: dict[str, str] | None = None, headers: dict[str, str] | None = None,
ctype: str = 'text') -> Self: ctype: str = "text") -> Self:
kwargs: ResponseType = { kwargs: ResponseType = {
'status': status, "status": status,
'headers': headers, "headers": headers,
'content_type': MIMETYPES[ctype], "content_type": MIMETYPES[ctype],
'body': None, "body": None,
'text': None "text": None
} }
if isinstance(body, str): if isinstance(body, str):
kwargs['text'] = body kwargs["text"] = body
elif isinstance(body, bytes): elif isinstance(body, bytes):
kwargs['body'] = body kwargs["body"] = body
elif isinstance(body, (dict, Sequence)): elif isinstance(body, (dict, Sequence)):
kwargs['text'] = json.dumps(body, cls = JsonEncoder) kwargs["text"] = json.dumps(body, cls = JsonEncoder)
return cls(**kwargs) return cls(**kwargs)
@classmethod @classmethod
def new_redir(cls: type[Self], path: str, status: int = 307) -> Self: def new_redir(cls: type[Self], path: str, status: int = 307) -> Self:
body = f'Redirect to <a href="{path}">{path}</a>' body = f"Redirect to <a href=\"{path}\">{path}</a>"
return cls.new(body, status, {'Location': path}, ctype = 'html') return cls.new(body, status, {"Location": path}, ctype = "html")
@classmethod @classmethod
@ -256,9 +246,9 @@ class Response(AiohttpResponse):
@property @property
def location(self) -> str: def location(self) -> str:
return self.headers.get('Location', '') return self.headers.get("Location", "")
@location.setter @location.setter
def location(self, value: str) -> None: def location(self, value: str) -> None:
self.headers['Location'] = value self.headers["Location"] = value

View file

@ -11,12 +11,12 @@ if typing.TYPE_CHECKING:
from .views.activitypub import InboxData from .views.activitypub import InboxData
def actor_type_check(actor: Message, software: str | None) -> bool: def is_application(actor: Message, software: str | None) -> bool:
if actor.type == 'Application': if actor.type == "Application":
return True return True
# akkoma (< 3.6.0) and pleroma use Person for the actor type # akkoma (< 3.6.0) and pleroma use Person for the actor type
if software in {'akkoma', 'pleroma'} and actor.id == f'https://{actor.domain}/relay': if software in {"akkoma", "pleroma"} and actor.id == f"https://{actor.domain}/relay":
return True return True
return False return False
@ -24,38 +24,38 @@ def actor_type_check(actor: Message, software: str | None) -> bool:
async def handle_relay(app: Application, data: InboxData, conn: Connection) -> None: async def handle_relay(app: Application, data: InboxData, conn: Connection) -> None:
try: try:
app.cache.get('handle-relay', data.message.object_id) app.cache.get("handle-relay", data.message.object_id)
logging.verbose('already relayed %s', data.message.object_id) logging.verbose("already relayed %s", data.message.object_id)
return return
except KeyError: except KeyError:
pass pass
message = Message.new_announce(app.config.domain, data.message.object_id) message = Message.new_announce(app.config.domain, data.message.object_id)
logging.debug('>> relay: %s', message) logging.debug(">> relay: %s", message)
for instance in conn.distill_inboxes(data.message): for instance in conn.distill_inboxes(data.message):
app.push_message(instance.inbox, message, instance) app.push_message(instance.inbox, message, instance)
app.cache.set('handle-relay', data.message.object_id, message.id, 'str') app.cache.set("handle-relay", data.message.object_id, message.id, "str")
async def handle_forward(app: Application, data: InboxData, conn: Connection) -> None: async def handle_forward(app: Application, data: InboxData, conn: Connection) -> None:
try: try:
app.cache.get('handle-relay', data.message.id) app.cache.get("handle-relay", data.message.id)
logging.verbose('already forwarded %s', data.message.id) logging.verbose("already forwarded %s", data.message.id)
return return
except KeyError: except KeyError:
pass pass
message = Message.new_announce(app.config.domain, data.message) message = Message.new_announce(app.config.domain, data.message)
logging.debug('>> forward: %s', message) logging.debug(">> forward: %s", message)
for instance in conn.distill_inboxes(data.message): for instance in conn.distill_inboxes(data.message):
app.push_message(instance.inbox, data.message, instance) app.push_message(instance.inbox, data.message, instance)
app.cache.set('handle-relay', data.message.id, message.id, 'str') app.cache.set("handle-relay", data.message.id, message.id, "str")
async def handle_follow(app: Application, data: InboxData, conn: Connection) -> None: async def handle_follow(app: Application, data: InboxData, conn: Connection) -> None:
@ -65,10 +65,8 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) ->
# reject if software used by actor is banned # reject if software used by actor is banned
if software and conn.get_software_ban(software): if software and conn.get_software_ban(software):
logging.verbose('Rejected banned actor: %s', data.actor.id)
app.push_message( app.push_message(
data.actor.shared_inbox, data.shared_inbox,
Message.new_response( Message.new_response(
host = app.config.domain, host = app.config.domain,
actor = data.actor.id, actor = data.actor.id,
@ -79,7 +77,7 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) ->
) )
logging.verbose( logging.verbose(
'Rejected follow from actor for using specific software: actor=%s, software=%s', "Rejected follow from actor for using specific software: actor=%s, software=%s",
data.actor.id, data.actor.id,
software software
) )
@ -87,11 +85,11 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) ->
return return
# reject if the actor is not an instance actor # reject if the actor is not an instance actor
if actor_type_check(data.actor, software): if not is_application(data.actor, software):
logging.verbose('Non-application actor tried to follow: %s', data.actor.id) logging.verbose("Non-application actor tried to follow: %s", data.actor.id)
app.push_message( app.push_message(
data.actor.shared_inbox, data.shared_inbox,
Message.new_response( Message.new_response(
host = app.config.domain, host = app.config.domain,
actor = data.actor.id, actor = data.actor.id,
@ -106,12 +104,12 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) ->
if not conn.get_domain_whitelist(data.actor.domain): if not conn.get_domain_whitelist(data.actor.domain):
# add request if approval-required is enabled # add request if approval-required is enabled
if config.approval_required: if config.approval_required:
logging.verbose('New follow request fromm actor: %s', data.actor.id) logging.verbose("New follow request fromm actor: %s", data.actor.id)
with conn.transaction(): with conn.transaction():
data.instance = conn.put_inbox( data.instance = conn.put_inbox(
domain = data.actor.domain, domain = data.actor.domain,
inbox = data.actor.shared_inbox, inbox = data.shared_inbox,
actor = data.actor.id, actor = data.actor.id,
followid = data.message.id, followid = data.message.id,
software = software, software = software,
@ -120,12 +118,12 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) ->
return return
# reject if the actor isn't whitelisted while the whiltelist is enabled # reject if the actor isn"t whitelisted while the whiltelist is enabled
if config.whitelist_enabled: if config.whitelist_enabled:
logging.verbose('Rejected actor for not being in the whitelist: %s', data.actor.id) logging.verbose("Rejected actor for not being in the whitelist: %s", data.actor.id)
app.push_message( app.push_message(
data.actor.shared_inbox, data.shared_inbox,
Message.new_response( Message.new_response(
host = app.config.domain, host = app.config.domain,
actor = data.actor.id, actor = data.actor.id,
@ -140,7 +138,7 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) ->
with conn.transaction(): with conn.transaction():
data.instance = conn.put_inbox( data.instance = conn.put_inbox(
domain = data.actor.domain, domain = data.actor.domain,
inbox = data.actor.shared_inbox, inbox = data.shared_inbox,
actor = data.actor.id, actor = data.actor.id,
followid = data.message.id, followid = data.message.id,
software = software, software = software,
@ -148,7 +146,7 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) ->
) )
app.push_message( app.push_message(
data.actor.shared_inbox, data.shared_inbox,
Message.new_response( Message.new_response(
host = app.config.domain, host = app.config.domain,
actor = data.actor.id, actor = data.actor.id,
@ -160,9 +158,9 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) ->
# Are Akkoma and Pleroma the only two that expect a follow back? # Are Akkoma and Pleroma the only two that expect a follow back?
# Ignoring only Mastodon for now # Ignoring only Mastodon for now
if software != 'mastodon': if software != "mastodon":
app.push_message( app.push_message(
data.actor.shared_inbox, data.shared_inbox,
Message.new_follow( Message.new_follow(
host = app.config.domain, host = app.config.domain,
actor = data.actor.id actor = data.actor.id
@ -172,8 +170,8 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) ->
async def handle_undo(app: Application, data: InboxData, conn: Connection) -> None: async def handle_undo(app: Application, data: InboxData, conn: Connection) -> None:
if data.message.object['type'] != 'Follow': if data.message.object["type"] != "Follow":
# forwarding deletes does not work, so don't bother # forwarding deletes does not work, so don"t bother
# await handle_forward(app, data, conn) # await handle_forward(app, data, conn)
return return
@ -187,13 +185,13 @@ async def handle_undo(app: Application, data: InboxData, conn: Connection) -> No
with conn.transaction(): with conn.transaction():
if not conn.del_inbox(data.actor.id): if not conn.del_inbox(data.actor.id):
logging.verbose( logging.verbose(
'Failed to delete "%s" with follow ID "%s"', "Failed to delete \"%s\" with follow ID \"%s\"",
data.actor.id, data.actor.id,
data.message.object_id data.message.object_id
) )
app.push_message( app.push_message(
data.actor.shared_inbox, data.shared_inbox,
Message.new_unfollow( Message.new_unfollow(
host = app.config.domain, host = app.config.domain,
actor = data.actor.id, actor = data.actor.id,
@ -204,19 +202,19 @@ async def handle_undo(app: Application, data: InboxData, conn: Connection) -> No
processors = { processors = {
'Announce': handle_relay, "Announce": handle_relay,
'Create': handle_relay, "Create": handle_relay,
'Delete': handle_forward, "Delete": handle_forward,
'Follow': handle_follow, "Follow": handle_follow,
'Undo': handle_undo, "Undo": handle_undo,
'Update': handle_forward, "Update": handle_forward,
} }
async def run_processor(data: InboxData) -> None: async def run_processor(data: InboxData) -> None:
if data.message.type not in processors: if data.message.type not in processors:
logging.verbose( logging.verbose(
'Message type "%s" from actor cannot be handled: %s', "Message type \"%s\" from actor cannot be handled: %s",
data.message.type, data.message.type,
data.actor.id data.actor.id
) )
@ -242,5 +240,5 @@ async def run_processor(data: InboxData) -> None:
actor = data.actor.id actor = data.actor.id
) )
logging.verbose('New "%s" from actor: %s', data.message.type, data.actor.id) logging.verbose("New \"%s\" from actor: %s", data.message.type, data.actor.id)
await processors[data.message.type](app, data, conn) await processors[data.message.type](app, data, conn)

0
relay/py.typed Normal file
View file

View file

@ -5,7 +5,7 @@ import textwrap
from aiohttp.web import Request from aiohttp.web import Request
from blib import File from blib import File
from collections.abc import Callable from collections.abc import Callable
from hamlish_jinja import HamlishExtension from hamlish import HamlishExtension, HamlishSettings
from jinja2 import Environment, FileSystemLoader from jinja2 import Environment, FileSystemLoader
from jinja2.ext import Extension from jinja2.ext import Extension
from jinja2.nodes import CallBlock, Node from jinja2.nodes import CallBlock, Node
@ -21,6 +21,7 @@ if TYPE_CHECKING:
class Template(Environment): class Template(Environment):
render_markdown: Callable[[str], str] render_markdown: Callable[[str], str]
hamlish: HamlishSettings
def __init__(self, app: Application): def __init__(self, app: Application):
@ -33,14 +34,12 @@ class Template(Environment):
MarkdownExtension MarkdownExtension
], ],
loader = FileSystemLoader([ loader = FileSystemLoader([
File.from_resource('relay', 'frontend'), File.from_resource("relay", "frontend"),
app.config.path.parent.joinpath('template') app.config.path.parent.joinpath("template")
]) ])
) )
self.app = app self.app = app
self.hamlish_enable_div_shortcut = True
self.hamlish_mode = 'indented'
def render(self, path: str, request: Request, **context: Any) -> str: def render(self, path: str, request: Request, **context: Any) -> str:
@ -48,10 +47,10 @@ class Template(Environment):
config = conn.get_config_all() config = conn.get_config_all()
new_context = { new_context = {
'request': request, "request": request,
'domain': self.app.config.domain, "domain": self.app.config.domain,
'version': __version__, "version": __version__,
'config': config, "config": config,
**(context or {}) **(context or {})
} }
@ -59,11 +58,11 @@ class Template(Environment):
class MarkdownExtension(Extension): class MarkdownExtension(Extension):
tags = {'markdown'} tags = {"markdown"}
extensions = ( extensions = (
'attr_list', "attr_list",
'smarty', "smarty",
'tables' "tables"
) )
@ -78,14 +77,14 @@ class MarkdownExtension(Extension):
def parse(self, parser: Parser) -> Node | list[Node]: def parse(self, parser: Parser) -> Node | list[Node]:
lineno = next(parser.stream).lineno lineno = next(parser.stream).lineno
body = parser.parse_statements( body = parser.parse_statements(
('name:endmarkdown',), ("name:endmarkdown",),
drop_needle = True drop_needle = True
) )
output = CallBlock(self.call_method('_render_markdown'), [], [], body) output = CallBlock(self.call_method("_render_markdown"), [], [], body)
return output.set_lineno(lineno) return output.set_lineno(lineno)
def _render_markdown(self, caller: Callable[[], str] | str) -> str: def _render_markdown(self, caller: Callable[[], str] | str) -> str:
text = caller if isinstance(caller, str) else caller() text = caller if isinstance(caller, str) else caller()
return self._markdown.convert(textwrap.dedent(text.strip('\n'))) return self._markdown.convert(textwrap.dedent(text.strip("\n")))

View file

@ -44,66 +44,86 @@ class InboxData:
signer: Signer | None = None signer: Signer | None = None
try: try:
signature = Signature.parse(request.headers['signature']) signature = Signature.parse(request.headers["signature"])
except KeyError: except KeyError:
logging.verbose('Missing signature header') logging.verbose("Missing signature header")
raise HttpError(400, 'missing signature header') raise HttpError(400, "missing signature header")
try: try:
message = await request.json(loads = Message.parse) message = await request.json(loads = Message.parse)
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
logging.verbose('Failed to parse message from actor: %s', signature.keyid) logging.verbose("Failed to parse message from actor: %s", signature.keyid)
raise HttpError(400, 'failed to parse message') raise HttpError(400, "failed to parse message")
if message is None: if message is None:
logging.verbose('empty message') logging.verbose("empty message")
raise HttpError(400, 'missing message') raise HttpError(400, "missing message")
if 'actor' not in message: if "actor" not in message:
logging.verbose('actor not in message') logging.verbose("actor not in message")
raise HttpError(400, 'no actor in message') raise HttpError(400, "no actor in message")
actor_id: str
try: try:
actor = await app.client.get(signature.keyid, True, Message) actor_id = message.actor_id
except AttributeError:
actor_id = signature.keyid
try:
actor = await app.client.get(actor_id, True, Message)
except HttpError as e: except HttpError as e:
# ld signatures aren't handled atm, so just ignore it # ld signatures aren"t handled atm, so just ignore it
if message.type == 'Delete': if message.type == "Delete":
logging.verbose('Instance sent a delete which cannot be handled') logging.verbose("Instance sent a delete which cannot be handled")
raise HttpError(202, '') raise HttpError(202, "")
logging.verbose('Failed to fetch actor: %s', signature.keyid) logging.verbose("Failed to fetch actor: %s", signature.keyid)
logging.debug('HTTP Status %i: %s', e.status, e.message) logging.debug("HTTP Status %i: %s", e.status, e.message)
raise HttpError(400, 'failed to fetch actor') raise HttpError(400, "failed to fetch actor")
except ClientConnectorError as e: except ClientConnectorError as e:
logging.warning('Error when trying to fetch actor: %s, %s', signature.keyid, str(e)) logging.warning("Error when trying to fetch actor: %s, %s", signature.keyid, str(e))
raise HttpError(400, 'failed to fetch actor') raise HttpError(400, "failed to fetch actor")
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
raise HttpError(500, 'unexpected error when fetching actor') raise HttpError(500, "unexpected error when fetching actor")
try: try:
signer = actor.signer signer = actor.signer
except KeyError: except KeyError:
logging.verbose('Actor missing public key: %s', signature.keyid) logging.verbose("Actor missing public key: %s", signature.keyid)
raise HttpError(400, 'actor missing public key') raise HttpError(400, "actor missing public key")
try: try:
await signer.validate_request_async(request) await signer.validate_request_async(request)
except SignatureFailureError as e: except SignatureFailureError as e:
logging.verbose('signature validation failed for "%s": %s', actor.id, e) logging.verbose("signature validation failed for \"%s\": %s", actor.id, e)
raise HttpError(401, str(e)) raise HttpError(401, str(e))
return cls(signature, message, actor, signer, None) return cls(signature, message, actor, signer, None)
@property
def shared_inbox(self) -> str:
if self.actor is None:
raise AttributeError("Actor not set yet")
try:
return self.actor.shared_inbox
except KeyError:
return self.actor.inbox # type: ignore[no-any-return]
@register_route(HttpMethod.GET, "/actor", "/inbox") @register_route(HttpMethod.GET, "/actor", "/inbox")
async def handle_actor(app: Application, request: Request) -> Response: async def handle_actor(app: Application, request: Request) -> Response:
with app.database.session(False) as conn: with app.database.session(False) as conn:
@ -124,34 +144,34 @@ async def handle_inbox(app: Application, request: Request) -> Response:
data = await InboxData.parse(app, request) data = await InboxData.parse(app, request)
with app.database.session() as conn: with app.database.session() as conn:
data.instance = conn.get_inbox(data.actor.shared_inbox) data.instance = conn.get_inbox(data.shared_inbox)
# reject if actor is banned # reject if actor is banned
if conn.get_domain_ban(data.actor.domain): if conn.get_domain_ban(data.actor.domain):
logging.verbose('Ignored request from banned actor: %s', data.actor.id) logging.verbose("Ignored request from banned actor: %s", data.actor.id)
raise HttpError(403, 'access denied') raise HttpError(403, "access denied")
# reject if activity type isn't 'Follow' and the actor isn't following # reject if activity type isn"t "Follow" and the actor isn"t following
if data.message.type != 'Follow' and not data.instance: if data.message.type != "Follow" and not data.instance:
logging.verbose( logging.verbose(
'Rejected actor for trying to post while not following: %s', "Rejected actor for trying to post while not following: %s",
data.actor.id data.actor.id
) )
raise HttpError(401, 'access denied') raise HttpError(401, "access denied")
logging.debug('>> payload %s', data.message.to_json(4)) logging.debug(">> payload %s", data.message.to_json(4))
await run_processor(data) await run_processor(data)
return Response.new(status = 202) return Response.new(status = 202)
@register_route(HttpMethod.GET, '/outbox') @register_route(HttpMethod.GET, "/outbox")
async def handle_outbox(app: Application, request: Request) -> Response: async def handle_outbox(app: Application, request: Request) -> Response:
msg = aputils.Message.new( msg = aputils.Message.new(
aputils.ObjectType.ORDERED_COLLECTION, aputils.ObjectType.ORDERED_COLLECTION,
{ {
"id": f'https://{app.config.domain}/outbox', "id": f"https://{app.config.domain}/outbox",
"totalItems": 0, "totalItems": 0,
"orderedItems": [] "orderedItems": []
} }
@ -160,15 +180,15 @@ async def handle_outbox(app: Application, request: Request) -> Response:
return Response.new(msg, ctype = "activity") return Response.new(msg, ctype = "activity")
@register_route(HttpMethod.GET, '/following', '/followers') @register_route(HttpMethod.GET, "/following", "/followers")
async def handle_follow(app: Application, request: Request) -> Response: async def handle_follow(app: Application, request: Request) -> Response:
with app.database.session(False) as s: with app.database.session(False) as s:
inboxes = [row['actor'] for row in s.get_inboxes()] inboxes = [row["actor"] for row in s.get_inboxes()]
msg = aputils.Message.new( msg = aputils.Message.new(
aputils.ObjectType.COLLECTION, aputils.ObjectType.COLLECTION,
{ {
"id": f'https://{app.config.domain}{request.path}', "id": f"https://{app.config.domain}{request.path}",
"totalItems": len(inboxes), "totalItems": len(inboxes),
"items": inboxes "items": inboxes
} }
@ -177,21 +197,21 @@ async def handle_follow(app: Application, request: Request) -> Response:
return Response.new(msg, ctype = "activity") return Response.new(msg, ctype = "activity")
@register_route(HttpMethod.GET, '/.well-known/webfinger') @register_route(HttpMethod.GET, "/.well-known/webfinger")
async def get(app: Application, request: Request) -> Response: async def get(app: Application, request: Request) -> Response:
try: try:
subject = request.query['resource'] subject = request.query["resource"]
except KeyError: except KeyError:
raise HttpError(400, 'missing "resource" query key') raise HttpError(400, "missing \"resource\" query key")
if subject != f'acct:relay@{app.config.domain}': if subject != f"acct:relay@{app.config.domain}":
raise HttpError(404, 'user not found') raise HttpError(404, "user not found")
data = aputils.Webfinger.new( data = aputils.Webfinger.new(
handle = 'relay', handle = "relay",
domain = app.config.domain, domain = app.config.domain,
actor = app.config.actor actor = app.config.actor
) )
return Response.new(data, ctype = 'json') return Response.new(data, ctype = "json")

View file

@ -181,7 +181,16 @@ async def handle_login(
application = s.put_app_login(user) application = s.put_app_login(user)
return objects.Application.from_row(application) return objects.Application(
application.client_id,
application.client_secret,
application.name,
application.website,
application.redirect_uri,
application.token,
application.created,
application.accessed
)
@Route(HttpMethod.GET, "/api/v1/app", "Application", True) @Route(HttpMethod.GET, "/api/v1/app", "Application", True)
@ -343,7 +352,7 @@ async def handle_instance_add(
with app.database.session(False) as s: with app.database.session(False) as s:
if s.get_inbox(domain) is not None: if s.get_inbox(domain) is not None:
raise HttpError(404, 'Instance already in database') raise HttpError(404, "Instance already in database")
if inbox is None: if inbox is None:
try: try:
@ -396,7 +405,7 @@ async def handle_instance_update(
with app.database.session(False) as s: with app.database.session(False) as s:
if (instance := s.get_inbox(domain)) is None: if (instance := s.get_inbox(domain)) is None:
raise HttpError(404, 'Instance with domain not found') raise HttpError(404, "Instance with domain not found")
row = s.put_inbox( row = s.put_inbox(
instance.domain, instance.domain,

View file

@ -31,11 +31,11 @@ if TYPE_CHECKING:
METHODS: dict[str, Method] = {} METHODS: dict[str, Method] = {}
ROUTES: list[tuple[str, str, HandlerCallback]] = [] ROUTES: list[tuple[str, str, HandlerCallback]] = []
DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob' DEFAULT_REDIRECT: str = "urn:ietf:wg:oauth:2.0:oob"
ALLOWED_HEADERS: set[str] = { ALLOWED_HEADERS: set[str] = {
'accept', "accept",
'authorization', "authorization",
'content-type' "content-type"
} }
@ -100,14 +100,14 @@ class Method:
return_type = get_origin(return_type) return_type = get_origin(return_type)
if not issubclass(return_type, (Response, ApiObject, list)): if not issubclass(return_type, (Response, ApiObject, list)):
raise ValueError(f"Invalid return type '{return_type.__name__}' for {func.__name__}") raise ValueError(f"Invalid return type \"{return_type.__name__}\" for {func.__name__}")
args = {key: value for key, value in inspect.signature(func).parameters.items()} args = {key: value for key, value in inspect.signature(func).parameters.items()}
docstring, paramdocs = parse_docstring(func.__doc__ or "") docstring, paramdocs = parse_docstring(func.__doc__ or "")
params = [] params = []
if func.__doc__ is None: if func.__doc__ is None:
logging.warning(f"Missing docstring for '{func.__name__}'") logging.warning(f"Missing docstring for \"{func.__name__}\"")
for key, value in args.items(): for key, value in args.items():
types: list[type[Any]] = [] types: list[type[Any]] = []
@ -134,7 +134,7 @@ class Method:
)) ))
if not paramdocs.get(key): if not paramdocs.get(key):
logging.warning(f"Missing docs for '{key}' parameter in '{func.__name__}'") logging.warning(f"Missing docs for \"{key}\" parameter in \"{func.__name__}\"")
rtype = annotations.get("return") or type(None) rtype = annotations.get("return") or type(None)
return cls(func.__name__, category, docstring, method, path, rtype, tuple(params)) return cls(func.__name__, category, docstring, method, path, rtype, tuple(params))
@ -222,7 +222,7 @@ class Route:
if request.method != "OPTIONS" and self.require_token: if request.method != "OPTIONS" and self.require_token:
if (auth := request.headers.getone("Authorization", None)) is None: if (auth := request.headers.getone("Authorization", None)) is None:
raise HttpError(401, 'Missing token') raise HttpError(401, "Missing token")
try: try:
authtype, code = auth.split(" ", 1) authtype, code = auth.split(" ", 1)
@ -245,15 +245,15 @@ class Route:
request["application"] = application request["application"] = application
if request.content_type in {'application/x-www-form-urlencoded', 'multipart/form-data'}: if request.content_type in {"application/x-www-form-urlencoded", "multipart/form-data"}:
post_data = {key: value for key, value in (await request.post()).items()} post_data = {key: value for key, value in (await request.post()).items()}
elif request.content_type == 'application/json': elif request.content_type == "application/json":
try: try:
post_data = await request.json() post_data = await request.json()
except JSONDecodeError: except JSONDecodeError:
raise HttpError(400, 'Invalid JSON data') raise HttpError(400, "Invalid JSON data")
else: else:
post_data = {key: str(value) for key, value in request.query.items()} post_data = {key: str(value) for key, value in request.query.items()}
@ -262,7 +262,7 @@ class Route:
response = await self.handler(get_app(), request, **post_data) response = await self.handler(get_app(), request, **post_data)
except HttpError as error: except HttpError as error:
return Response.new({'error': error.message}, error.status, ctype = "json") return Response.new({"error": error.message}, error.status, ctype = "json")
headers = { headers = {
"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Origin": "*",

View file

@ -19,7 +19,7 @@ if TYPE_CHECKING:
async def handle_home(app: Application, request: Request) -> Response: async def handle_home(app: Application, request: Request) -> Response:
with app.database.session() as conn: with app.database.session() as conn:
context: dict[str, Any] = { context: dict[str, Any] = {
'instances': tuple(conn.get_inboxes()) "instances": tuple(conn.get_inboxes())
} }
return Response.new_template(200, "page/home.haml", request, context) return Response.new_template(200, "page/home.haml", request, context)
@ -34,28 +34,28 @@ async def handle_api_doc(app: Application, request: Request) -> Response:
return Response.new_template(200, "page/docs.haml", request, context) return Response.new_template(200, "page/docs.haml", request, context)
@register_route(HttpMethod.GET, '/login') @register_route(HttpMethod.GET, "/login")
async def handle_login(app: Application, request: Request) -> Response: async def handle_login(app: Application, request: Request) -> Response:
context = {"redir": unquote(request.query.get("redir", "/"))} context = {"redir": unquote(request.query.get("redir", "/"))}
return Response.new_template(200, "page/login.haml", request, context) return Response.new_template(200, "page/login.haml", request, context)
@register_route(HttpMethod.GET, '/logout') @register_route(HttpMethod.GET, "/logout")
async def handle_logout(app: Application, request: Request) -> Response: async def handle_logout(app: Application, request: Request) -> Response:
with app.database.session(True) as conn: with app.database.session(True) as conn:
conn.del_app(request['token'].client_id, request['token'].client_secret) conn.del_app(request["token"].client_id, request["token"].client_secret)
resp = Response.new_redir('/') resp = Response.new_redir("/")
resp.del_cookie('user-token', domain = app.config.domain, path = '/') resp.del_cookie("user-token", domain = app.config.domain, path = "/")
return resp return resp
@register_route(HttpMethod.GET, '/admin') @register_route(HttpMethod.GET, "/admin")
async def handle_admin(app: Application, request: Request) -> Response: async def handle_admin(app: Application, request: Request) -> Response:
return Response.new_redir(f'/login?redir={request.path}', 301) return Response.new_redir(f"/login?redir={request.path}", 301)
@register_route(HttpMethod.GET, '/admin/instances') @register_route(HttpMethod.GET, "/admin/instances")
async def handle_admin_instances( async def handle_admin_instances(
app: Application, app: Application,
request: Request, request: Request,
@ -64,20 +64,20 @@ async def handle_admin_instances(
with app.database.session() as conn: with app.database.session() as conn:
context: dict[str, Any] = { context: dict[str, Any] = {
'instances': tuple(conn.get_inboxes()), "instances": tuple(conn.get_inboxes()),
'requests': tuple(conn.get_requests()) "requests": tuple(conn.get_requests())
} }
if error: if error:
context['error'] = error context["error"] = error
if message: if message:
context['message'] = message context["message"] = message
return Response.new_template(200, "page/admin-instances.haml", request, context) return Response.new_template(200, "page/admin-instances.haml", request, context)
@register_route(HttpMethod.GET, '/admin/whitelist') @register_route(HttpMethod.GET, "/admin/whitelist")
async def handle_admin_whitelist( async def handle_admin_whitelist(
app: Application, app: Application,
request: Request, request: Request,
@ -86,19 +86,19 @@ async def handle_admin_whitelist(
with app.database.session() as conn: with app.database.session() as conn:
context: dict[str, Any] = { context: dict[str, Any] = {
'whitelist': tuple(conn.execute('SELECT * FROM whitelist ORDER BY domain ASC')) "whitelist": tuple(conn.execute("SELECT * FROM whitelist ORDER BY domain ASC"))
} }
if error: if error:
context['error'] = error context["error"] = error
if message: if message:
context['message'] = message context["message"] = message
return Response.new_template(200, "page/admin-whitelist.haml", request, context) return Response.new_template(200, "page/admin-whitelist.haml", request, context)
@register_route(HttpMethod.GET, '/admin/domain_bans') @register_route(HttpMethod.GET, "/admin/domain_bans")
async def handle_admin_instance_bans( async def handle_admin_instance_bans(
app: Application, app: Application,
request: Request, request: Request,
@ -107,19 +107,19 @@ async def handle_admin_instance_bans(
with app.database.session() as conn: with app.database.session() as conn:
context: dict[str, Any] = { context: dict[str, Any] = {
'bans': tuple(conn.execute('SELECT * FROM domain_bans ORDER BY domain ASC')) "bans": tuple(conn.execute("SELECT * FROM domain_bans ORDER BY domain ASC"))
} }
if error: if error:
context['error'] = error context["error"] = error
if message: if message:
context['message'] = message context["message"] = message
return Response.new_template(200, "page/admin-domain_bans.haml", request, context) return Response.new_template(200, "page/admin-domain_bans.haml", request, context)
@register_route(HttpMethod.GET, '/admin/software_bans') @register_route(HttpMethod.GET, "/admin/software_bans")
async def handle_admin_software_bans( async def handle_admin_software_bans(
app: Application, app: Application,
request: Request, request: Request,
@ -128,19 +128,19 @@ async def handle_admin_software_bans(
with app.database.session() as conn: with app.database.session() as conn:
context: dict[str, Any] = { context: dict[str, Any] = {
'bans': tuple(conn.execute('SELECT * FROM software_bans ORDER BY name ASC')) "bans": tuple(conn.execute("SELECT * FROM software_bans ORDER BY name ASC"))
} }
if error: if error:
context['error'] = error context["error"] = error
if message: if message:
context['message'] = message context["message"] = message
return Response.new_template(200, "page/admin-software_bans.haml", request, context) return Response.new_template(200, "page/admin-software_bans.haml", request, context)
@register_route(HttpMethod.GET, '/admin/users') @register_route(HttpMethod.GET, "/admin/users")
async def handle_admin_users( async def handle_admin_users(
app: Application, app: Application,
request: Request, request: Request,
@ -149,29 +149,29 @@ async def handle_admin_users(
with app.database.session() as conn: with app.database.session() as conn:
context: dict[str, Any] = { context: dict[str, Any] = {
'users': tuple(conn.execute('SELECT * FROM users ORDER BY username ASC')) "users": tuple(conn.execute("SELECT * FROM users ORDER BY username ASC"))
} }
if error: if error:
context['error'] = error context["error"] = error
if message: if message:
context['message'] = message context["message"] = message
return Response.new_template(200, "page/admin-users.haml", request, context) return Response.new_template(200, "page/admin-users.haml", request, context)
@register_route(HttpMethod.GET, '/admin/config') @register_route(HttpMethod.GET, "/admin/config")
async def handle_admin_config( async def handle_admin_config(
app: Application, app: Application,
request: Request, request: Request,
message: str | None = None) -> Response: message: str | None = None) -> Response:
context: dict[str, Any] = { context: dict[str, Any] = {
'themes': tuple(THEMES.keys()), "themes": tuple(THEMES.keys()),
'levels': tuple(level.name for level in LogLevel), "levels": tuple(level.name for level in LogLevel),
'message': message, "message": message,
'desc': { "desc": {
"name": "Name of the relay to be displayed in the header of the pages and in " + "name": "Name of the relay to be displayed in the header of the pages and in " +
"the actor endpoint.", # noqa: E131 "the actor endpoint.", # noqa: E131
"note": "Description of the relay to be displayed on the front page and as the " + "note": "Description of the relay to be displayed on the front page and as the " +
@ -187,36 +187,36 @@ async def handle_admin_config(
return Response.new_template(200, "page/admin-config.haml", request, context) return Response.new_template(200, "page/admin-config.haml", request, context)
@register_route(HttpMethod.GET, '/manifest.json') @register_route(HttpMethod.GET, "/manifest.json")
async def handle_manifest(app: Application, request: Request) -> Response: async def handle_manifest(app: Application, request: Request) -> Response:
with app.database.session(False) as conn: with app.database.session(False) as conn:
config = conn.get_config_all() config = conn.get_config_all()
theme = THEMES[config.theme] theme = THEMES[config.theme]
data = { data = {
'background_color': theme['background'], "background_color": theme["background"],
'categories': ['activitypub'], "categories": ["activitypub"],
'description': 'Message relay for the ActivityPub network', "description": "Message relay for the ActivityPub network",
'display': 'standalone', "display": "standalone",
'name': config['name'], "name": config["name"],
'orientation': 'portrait', "orientation": "portrait",
'scope': f"https://{app.config.domain}/", "scope": f"https://{app.config.domain}/",
'short_name': 'ActivityRelay', "short_name": "ActivityRelay",
'start_url': f"https://{app.config.domain}/", "start_url": f"https://{app.config.domain}/",
'theme_color': theme['primary'] "theme_color": theme["primary"]
} }
return Response.new(data, ctype = 'webmanifest') return Response.new(data, ctype = "webmanifest")
@register_route(HttpMethod.GET, '/theme/{theme}.css') # type: ignore[arg-type] @register_route(HttpMethod.GET, "/theme/{theme}.css") # type: ignore[arg-type]
async def handle_theme(app: Application, request: Request, theme: str) -> Response: async def handle_theme(app: Application, request: Request, theme: str) -> Response:
try: try:
context: dict[str, Any] = { context: dict[str, Any] = {
'theme': THEMES[theme] "theme": THEMES[theme]
} }
except KeyError: except KeyError:
return Response.new('Invalid theme', 404) return Response.new("Invalid theme", 404)
return Response.new_template(200, "variables.css", request, context, ctype = "css") return Response.new_template(200, "variables.css", request, context, ctype = "css")

View file

@ -21,16 +21,18 @@ VERSION = __version__
if File(__file__).join("../../../.git").resolve().exists: if File(__file__).join("../../../.git").resolve().exists:
try: try:
commit_label = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode('ascii') commit_label = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode("ascii")
VERSION = f'{__version__} {commit_label}' VERSION = f"{__version__} {commit_label}"
del commit_label
except Exception: except Exception:
pass pass
NODEINFO_PATHS = [ NODEINFO_PATHS = [
'/nodeinfo/{niversion:\\d.\\d}.json', "/nodeinfo/{niversion:\\d.\\d}.json",
'/nodeinfo/{niversion:\\d.\\d}' "/nodeinfo/{niversion:\\d.\\d}"
] ]
@ -40,23 +42,23 @@ async def handle_nodeinfo(app: Application, request: Request, niversion: str) ->
inboxes = conn.get_inboxes() inboxes = conn.get_inboxes()
nodeinfo = aputils.Nodeinfo.new( nodeinfo = aputils.Nodeinfo.new(
name = 'activityrelay', name = "activityrelay",
version = VERSION, version = VERSION,
protocols = ['activitypub'], protocols = ["activitypub"],
open_regs = not conn.get_config('whitelist-enabled'), open_regs = not conn.get_config("whitelist-enabled"),
users = 1, users = 1,
repo = 'https://git.pleroma.social/pleroma/relay' if niversion == '2.1' else None, repo = "https://codeberg.org/barkshark/activityrelay" if niversion == "2.1" else None,
metadata = { metadata = {
'approval_required': conn.get_config('approval-required'), "approval_required": conn.get_config("approval-required"),
'peers': [inbox['domain'] for inbox in inboxes] "peers": [inbox["domain"] for inbox in inboxes]
} }
) )
return Response.new(nodeinfo, ctype = 'json') return Response.new(nodeinfo, ctype = "json")
@register_route(HttpMethod.GET, '/.well-known/nodeinfo') @register_route(HttpMethod.GET, "/.well-known/nodeinfo")
async def handle_wk_nodeinfo(app: Application, request: Request) -> Response: async def handle_wk_nodeinfo(app: Application, request: Request) -> Response:
data = aputils.WellKnownNodeinfo.new_template(app.config.domain) data = aputils.WellKnownNodeinfo.new_template(app.config.domain)
return Response.new(data, ctype = 'json') return Response.new(data, ctype = "json")

View file

@ -96,16 +96,16 @@ class PushWorker(Process):
await self.client.post(item.inbox, item.message, item.instance) await self.client.post(item.inbox, item.message, item.instance)
except HttpError as e: except HttpError as e:
logging.error('HTTP Error when pushing to %s: %i %s', item.inbox, e.status, e.message) logging.error("HTTP Error when pushing to %s: %i %s", item.inbox, e.status, e.message)
except AsyncTimeoutError: except AsyncTimeoutError:
logging.error('Timeout when pushing to %s', item.domain) logging.error("Timeout when pushing to %s", item.domain)
except ClientConnectionError as e: except ClientConnectionError as e:
logging.error('Failed to connect to %s for message push: %s', item.domain, str(e)) logging.error("Failed to connect to %s for message push: %s", item.domain, str(e))
except ClientSSLError as e: except ClientSSLError as e:
logging.error('SSL error when pushing to %s: %s', item.domain, str(e)) logging.error("SSL error when pushing to %s: %s", item.domain, str(e))
class PushWorkers(list[PushWorker]): class PushWorkers(list[PushWorker]):