mirror of
https://git.pleroma.social/pleroma/relay.git
synced 2024-09-18 22:51:58 +00:00
Compare commits
28 commits
dec7c6a674
...
5217516c8a
Author | SHA1 | Date | |
---|---|---|---|
5217516c8a | |||
5765753b59 | |||
1d72f2a254 | |||
5e962be057 | |||
bdc7d41d7a | |||
45b0de26c7 | |||
7e08e18785 | |||
e67ebd75ed | |||
9a3e3768e7 | |||
c508257981 | |||
b308b03546 | |||
5407027af8 | |||
f49bc0ae90 | |||
cad7f47e7e | |||
058df0ac78 | |||
e825a01795 | |||
ab9b8abbd2 | |||
15882f3e49 | |||
a2b96d03dc | |||
98a975550a | |||
ed03779a11 | |||
e44108f341 | |||
a0d84b5ae5 | |||
478e21fb15 | |||
0d50215fc1 | |||
62555b3591 | |||
d55cf4d1d0 | |||
bd50baa639 |
64
dev.py
64
dev.py
|
@ -5,16 +5,17 @@ import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
import tomllib
|
||||||
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from relay import __version__, logger as logging
|
from relay import __version__, logger as logging
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from typing import Sequence
|
from typing import Any, Sequence
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from watchdog.observers import Observer
|
from watchdog.observers import Observer
|
||||||
from watchdog.events import PatternMatchingEventHandler
|
from watchdog.events import FileSystemEvent, PatternMatchingEventHandler
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
class PatternMatchingEventHandler: # type: ignore
|
class PatternMatchingEventHandler: # type: ignore
|
||||||
|
@ -29,39 +30,38 @@ IGNORE_EXT = {
|
||||||
|
|
||||||
|
|
||||||
@click.group('cli')
|
@click.group('cli')
|
||||||
def cli():
|
def cli() -> None:
|
||||||
'Useful commands for development'
|
'Useful commands for development'
|
||||||
|
|
||||||
|
|
||||||
@cli.command('install')
|
@cli.command('install')
|
||||||
def cli_install():
|
@click.option('--no-dev', '-d', is_flag = True, help = 'Do not install development dependencies')
|
||||||
cmd = [
|
def cli_install(no_dev: bool) -> None:
|
||||||
sys.executable, '-m', 'pip', 'install',
|
with open('pyproject.toml', 'rb') as fd:
|
||||||
'-r', 'requirements.txt',
|
data = tomllib.load(fd)
|
||||||
'-r', 'dev-requirements.txt'
|
|
||||||
]
|
|
||||||
|
|
||||||
subprocess.run(cmd, check = False)
|
deps = data['project']['dependencies']
|
||||||
|
|
||||||
|
if not no_dev:
|
||||||
|
deps.extend(data['project']['optional-dependencies']['dev'])
|
||||||
|
|
||||||
|
subprocess.run([sys.executable, '-m', 'pip', 'install', '-U', *deps], check = False)
|
||||||
|
|
||||||
|
|
||||||
@cli.command('lint')
|
@cli.command('lint')
|
||||||
@click.argument('path', required = False, type = Path, default = REPO.joinpath('relay'))
|
@click.argument('path', required = False, type = Path, default = REPO.joinpath('relay'))
|
||||||
@click.option('--strict', '-s', is_flag = True, help = 'Enable strict mode for mypy')
|
|
||||||
@click.option('--watch', '-w', is_flag = True,
|
@click.option('--watch', '-w', is_flag = True,
|
||||||
help = 'Automatically, re-run the linters on source change')
|
help = 'Automatically, re-run the linters on source change')
|
||||||
def cli_lint(path: Path, strict: bool, watch: bool) -> None:
|
def cli_lint(path: Path, watch: bool) -> None:
|
||||||
path = path.expanduser().resolve()
|
path = path.expanduser().resolve()
|
||||||
|
|
||||||
if watch:
|
if watch:
|
||||||
handle_run_watcher([sys.executable, "-m", "relay.dev", "lint", str(path)], wait = True)
|
handle_run_watcher([sys.executable, "dev.py", "lint", str(path)], wait = True)
|
||||||
return
|
return
|
||||||
|
|
||||||
flake8 = [sys.executable, '-m', 'flake8', "dev.py", str(path)]
|
flake8 = [sys.executable, '-m', 'flake8', "dev.py", str(path)]
|
||||||
mypy = [sys.executable, '-m', 'mypy', "dev.py", str(path)]
|
mypy = [sys.executable, '-m', 'mypy', "dev.py", str(path)]
|
||||||
|
|
||||||
if strict:
|
|
||||||
mypy.append('--strict')
|
|
||||||
|
|
||||||
click.echo('----- flake8 -----')
|
click.echo('----- flake8 -----')
|
||||||
subprocess.run(flake8)
|
subprocess.run(flake8)
|
||||||
|
|
||||||
|
@ -70,7 +70,7 @@ def cli_lint(path: Path, strict: bool, watch: bool) -> None:
|
||||||
|
|
||||||
|
|
||||||
@cli.command('clean')
|
@cli.command('clean')
|
||||||
def cli_clean():
|
def cli_clean() -> None:
|
||||||
dirs = {
|
dirs = {
|
||||||
'dist',
|
'dist',
|
||||||
'build',
|
'build',
|
||||||
|
@ -88,7 +88,7 @@ def cli_clean():
|
||||||
|
|
||||||
|
|
||||||
@cli.command('build')
|
@cli.command('build')
|
||||||
def cli_build():
|
def cli_build() -> None:
|
||||||
with TemporaryDirectory() as tmp:
|
with TemporaryDirectory() as tmp:
|
||||||
arch = 'amd64' if sys.maxsize >= 2**32 else 'i386'
|
arch = 'amd64' if sys.maxsize >= 2**32 else 'i386'
|
||||||
cmd = [
|
cmd = [
|
||||||
|
@ -118,7 +118,7 @@ def cli_build():
|
||||||
|
|
||||||
@cli.command('run')
|
@cli.command('run')
|
||||||
@click.option('--dev', '-d', is_flag = True)
|
@click.option('--dev', '-d', is_flag = True)
|
||||||
def cli_run(dev: bool):
|
def cli_run(dev: bool) -> None:
|
||||||
print('Starting process watcher')
|
print('Starting process watcher')
|
||||||
|
|
||||||
cmd = [sys.executable, '-m', 'relay', 'run']
|
cmd = [sys.executable, '-m', 'relay', 'run']
|
||||||
|
@ -126,16 +126,20 @@ def cli_run(dev: bool):
|
||||||
if dev:
|
if dev:
|
||||||
cmd.append('-d')
|
cmd.append('-d')
|
||||||
|
|
||||||
handle_run_watcher(cmd)
|
handle_run_watcher(cmd, watch_path = REPO.joinpath("relay"))
|
||||||
|
|
||||||
|
|
||||||
def handle_run_watcher(*commands: Sequence[str], wait: bool = False):
|
def handle_run_watcher(
|
||||||
|
*commands: Sequence[str],
|
||||||
|
watch_path: Path | str = REPO,
|
||||||
|
wait: bool = False) -> None:
|
||||||
|
|
||||||
handler = WatchHandler(*commands, wait = wait)
|
handler = WatchHandler(*commands, wait = wait)
|
||||||
handler.run_procs()
|
handler.run_procs()
|
||||||
|
|
||||||
watcher = Observer()
|
watcher = Observer()
|
||||||
watcher.schedule(handler, str(REPO), recursive=True)
|
watcher.schedule(handler, str(watch_path), recursive=True) # type: ignore
|
||||||
watcher.start()
|
watcher.start() # type: ignore
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
|
@ -145,7 +149,7 @@ def handle_run_watcher(*commands: Sequence[str], wait: bool = False):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
handler.kill_procs()
|
handler.kill_procs()
|
||||||
watcher.stop()
|
watcher.stop() # type: ignore
|
||||||
watcher.join()
|
watcher.join()
|
||||||
|
|
||||||
|
|
||||||
|
@ -153,16 +157,16 @@ class WatchHandler(PatternMatchingEventHandler):
|
||||||
patterns = ['*.py']
|
patterns = ['*.py']
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, *commands: Sequence[str], wait: bool = False):
|
def __init__(self, *commands: Sequence[str], wait: bool = False) -> None:
|
||||||
PatternMatchingEventHandler.__init__(self)
|
PatternMatchingEventHandler.__init__(self) # type: ignore
|
||||||
|
|
||||||
self.commands: Sequence[Sequence[str]] = commands
|
self.commands: Sequence[Sequence[str]] = commands
|
||||||
self.wait: bool = wait
|
self.wait: bool = wait
|
||||||
self.procs: list[subprocess.Popen] = []
|
self.procs: list[subprocess.Popen[Any]] = []
|
||||||
self.last_restart: datetime = datetime.now()
|
self.last_restart: datetime = datetime.now()
|
||||||
|
|
||||||
|
|
||||||
def kill_procs(self):
|
def kill_procs(self) -> None:
|
||||||
for proc in self.procs:
|
for proc in self.procs:
|
||||||
if proc.poll() is not None:
|
if proc.poll() is not None:
|
||||||
continue
|
continue
|
||||||
|
@ -183,7 +187,7 @@ class WatchHandler(PatternMatchingEventHandler):
|
||||||
logging.info('Process terminated')
|
logging.info('Process terminated')
|
||||||
|
|
||||||
|
|
||||||
def run_procs(self, restart: bool = False):
|
def run_procs(self, restart: bool = False) -> None:
|
||||||
if restart:
|
if restart:
|
||||||
if datetime.now() - timedelta(seconds = 3) < self.last_restart:
|
if datetime.now() - timedelta(seconds = 3) < self.last_restart:
|
||||||
return
|
return
|
||||||
|
@ -205,7 +209,7 @@ class WatchHandler(PatternMatchingEventHandler):
|
||||||
logging.info('Started processes with PIDs: %s', ', '.join(pids))
|
logging.info('Started processes with PIDs: %s', ', '.join(pids))
|
||||||
|
|
||||||
|
|
||||||
def on_any_event(self, event):
|
def on_any_event(self, event: FileSystemEvent) -> None:
|
||||||
if event.event_type not in ['modified', 'created', 'deleted']:
|
if event.event_type not in ['modified', 'created', 'deleted']:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -16,19 +16,21 @@ classifiers = [
|
||||||
"Programming Language :: Python :: 3.12",
|
"Programming Language :: Python :: 3.12",
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"activitypub-utils == 0.2.1",
|
"activitypub-utils >= 0.3.1, < 0.4.0",
|
||||||
"aiohttp >= 3.9.1",
|
"aiohttp >= 3.9.5",
|
||||||
"aiohttp-swagger[performance] == 1.0.16",
|
"aiohttp-swagger[performance] == 1.0.16",
|
||||||
"argon2-cffi == 23.1.0",
|
"argon2-cffi == 23.1.0",
|
||||||
"barkshark-sql == 0.1.2",
|
"barkshark-lib >= 0.1.4, < 0.2.0",
|
||||||
"click >= 8.1.2",
|
"barkshark-sql >= 0.2.0-rc1, < 0.3.0",
|
||||||
|
"click == 8.1.2",
|
||||||
"hiredis == 2.3.2",
|
"hiredis == 2.3.2",
|
||||||
|
"idna == 3.4",
|
||||||
"jinja2-haml == 0.3.5",
|
"jinja2-haml == 0.3.5",
|
||||||
"markdown == 3.5.2",
|
"markdown == 3.6",
|
||||||
"platformdirs == 4.2.0",
|
"platformdirs == 4.2.2",
|
||||||
"pyyaml >= 6.0",
|
"pyyaml == 6.0",
|
||||||
"redis == 5.0.1",
|
"redis == 5.0.5",
|
||||||
"importlib_resources == 6.1.1; python_version < '3.9'"
|
"importlib-resources == 6.4.0; python_version < '3.9'"
|
||||||
]
|
]
|
||||||
requires-python = ">=3.8"
|
requires-python = ">=3.8"
|
||||||
dynamic = ["version"]
|
dynamic = ["version"]
|
||||||
|
@ -48,10 +50,10 @@ activityrelay = "relay.manage:main"
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
dev = [
|
dev = [
|
||||||
"flake8 == 7.0.0",
|
"flake8 == 7.0.0",
|
||||||
"mypy == 1.9.0",
|
"mypy == 1.10.0",
|
||||||
"pyinstaller == 6.3.0",
|
"pyinstaller == 6.8.0",
|
||||||
"watchdog == 4.0.0",
|
"watchdog == 4.0.1",
|
||||||
"typing_extensions >= 4.10.0; python_version < '3.11.0'"
|
"typing-extensions == 4.12.2; python_version < '3.11.0'"
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.setuptools]
|
[tool.setuptools]
|
||||||
|
@ -87,4 +89,18 @@ warn_redundant_casts = true
|
||||||
warn_unreachable = true
|
warn_unreachable = true
|
||||||
warn_unused_ignores = true
|
warn_unused_ignores = true
|
||||||
ignore_missing_imports = true
|
ignore_missing_imports = true
|
||||||
|
implicit_reexport = true
|
||||||
|
strict = true
|
||||||
follow_imports = "silent"
|
follow_imports = "silent"
|
||||||
|
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = "relay.database"
|
||||||
|
implicit_reexport = true
|
||||||
|
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = "aputils"
|
||||||
|
implicit_reexport = true
|
||||||
|
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = "blib"
|
||||||
|
implicit_reexport = true
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
__version__ = '0.3.1'
|
__version__ = '0.3.3'
|
||||||
|
|
|
@ -4,40 +4,35 @@ import asyncio
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import signal
|
import signal
|
||||||
import time
|
import time
|
||||||
import traceback
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from aiohttp.web import StaticResource
|
from aiohttp.web import StaticResource
|
||||||
from aiohttp_swagger import setup_swagger
|
from aiohttp_swagger import setup_swagger
|
||||||
from aputils.signer import Signer
|
from aputils.signer import Signer
|
||||||
|
from bsql import Database
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from mimetypes import guess_type
|
from mimetypes import guess_type
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Empty
|
|
||||||
from threading import Event, Thread
|
from threading import Event, Thread
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from . import logger as logging
|
from . import logger as logging, workers
|
||||||
from .cache import get_cache
|
from .cache import Cache, get_cache
|
||||||
from .config import Config
|
from .config import Config
|
||||||
from .database import get_database
|
from .database import Connection, get_database
|
||||||
|
from .database.schema import Instance
|
||||||
from .http_client import HttpClient
|
from .http_client import HttpClient
|
||||||
from .misc import check_open_port, get_resource
|
from .misc import Message, Response, check_open_port, get_resource
|
||||||
from .template import Template
|
from .template import Template
|
||||||
from .views import VIEWS
|
from .views import VIEWS
|
||||||
from .views.api import handle_api_path
|
from .views.api import handle_api_path
|
||||||
from .views.frontend import handle_frontend_path
|
from .views.frontend import handle_frontend_path
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from collections.abc import Callable
|
|
||||||
from bsql import Database, Row
|
|
||||||
from .cache import Cache
|
|
||||||
from .misc import Message, Response
|
|
||||||
|
|
||||||
|
|
||||||
def get_csp(request: web.Request) -> str:
|
def get_csp(request: web.Request) -> str:
|
||||||
data = [
|
data = [
|
||||||
"default-src 'none'",
|
"default-src 'self'",
|
||||||
f"script-src 'nonce-{request['hash']}'",
|
f"script-src 'nonce-{request['hash']}'",
|
||||||
f"style-src 'self' 'nonce-{request['hash']}'",
|
f"style-src 'self' 'nonce-{request['hash']}'",
|
||||||
"form-action 'self'",
|
"form-action 'self'",
|
||||||
|
@ -58,9 +53,9 @@ class Application(web.Application):
|
||||||
def __init__(self, cfgpath: Path | None, dev: bool = False):
|
def __init__(self, cfgpath: Path | None, dev: bool = False):
|
||||||
web.Application.__init__(self,
|
web.Application.__init__(self,
|
||||||
middlewares = [
|
middlewares = [
|
||||||
handle_api_path,
|
handle_api_path, # type: ignore[list-item]
|
||||||
handle_frontend_path,
|
handle_frontend_path, # type: ignore[list-item]
|
||||||
handle_response_headers
|
handle_response_headers # type: ignore[list-item]
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -79,7 +74,7 @@ class Application(web.Application):
|
||||||
self['cache'].setup()
|
self['cache'].setup()
|
||||||
self['template'] = Template(self)
|
self['template'] = Template(self)
|
||||||
self['push_queue'] = multiprocessing.Queue()
|
self['push_queue'] = multiprocessing.Queue()
|
||||||
self['workers'] = []
|
self['workers'] = workers.PushWorkers(self.config.workers)
|
||||||
|
|
||||||
self.cache.setup()
|
self.cache.setup()
|
||||||
self.on_cleanup.append(handle_cleanup) # type: ignore
|
self.on_cleanup.append(handle_cleanup) # type: ignore
|
||||||
|
@ -96,27 +91,27 @@ class Application(web.Application):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cache(self) -> Cache:
|
def cache(self) -> Cache:
|
||||||
return self['cache']
|
return self['cache'] # type: ignore[no-any-return]
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self) -> HttpClient:
|
def client(self) -> HttpClient:
|
||||||
return self['client']
|
return self['client'] # type: ignore[no-any-return]
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def config(self) -> Config:
|
def config(self) -> Config:
|
||||||
return self['config']
|
return self['config'] # type: ignore[no-any-return]
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def database(self) -> Database:
|
def database(self) -> Database[Connection]:
|
||||||
return self['database']
|
return self['database'] # type: ignore[no-any-return]
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def signer(self) -> Signer:
|
def signer(self) -> Signer:
|
||||||
return self['signer']
|
return self['signer'] # type: ignore[no-any-return]
|
||||||
|
|
||||||
|
|
||||||
@signer.setter
|
@signer.setter
|
||||||
|
@ -130,7 +125,7 @@ class Application(web.Application):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def template(self) -> Template:
|
def template(self) -> Template:
|
||||||
return self['template']
|
return self['template'] # type: ignore[no-any-return]
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -143,8 +138,8 @@ class Application(web.Application):
|
||||||
return timedelta(seconds=uptime.seconds)
|
return timedelta(seconds=uptime.seconds)
|
||||||
|
|
||||||
|
|
||||||
def push_message(self, inbox: str, message: Message | bytes, instance: Row) -> None:
|
def push_message(self, inbox: str, message: Message, instance: Instance) -> None:
|
||||||
self['push_queue'].put((inbox, message, instance))
|
self['workers'].push_message(inbox, message, instance)
|
||||||
|
|
||||||
|
|
||||||
def register_static_routes(self) -> None:
|
def register_static_routes(self) -> None:
|
||||||
|
@ -185,11 +180,11 @@ class Application(web.Application):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def stop(self, *_):
|
def stop(self, *_: Any) -> None:
|
||||||
self['running'] = False
|
self['running'] = False
|
||||||
|
|
||||||
|
|
||||||
async def handle_run(self):
|
async def handle_run(self) -> None:
|
||||||
self['running'] = True
|
self['running'] = True
|
||||||
|
|
||||||
self.set_signal_handler(True)
|
self.set_signal_handler(True)
|
||||||
|
@ -199,12 +194,7 @@ class Application(web.Application):
|
||||||
self['cache'].setup()
|
self['cache'].setup()
|
||||||
self['cleanup_thread'] = CacheCleanupThread(self)
|
self['cleanup_thread'] = CacheCleanupThread(self)
|
||||||
self['cleanup_thread'].start()
|
self['cleanup_thread'].start()
|
||||||
|
self['workers'].start()
|
||||||
for _ in range(self.config.workers):
|
|
||||||
worker = PushWorker(self['push_queue'])
|
|
||||||
worker.start()
|
|
||||||
|
|
||||||
self['workers'].append(worker)
|
|
||||||
|
|
||||||
runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"')
|
runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"')
|
||||||
await runner.setup()
|
await runner.setup()
|
||||||
|
@ -224,15 +214,13 @@ class Application(web.Application):
|
||||||
|
|
||||||
await site.stop()
|
await site.stop()
|
||||||
|
|
||||||
for worker in self['workers']:
|
self['workers'].stop()
|
||||||
worker.stop()
|
|
||||||
|
|
||||||
self.set_signal_handler(False)
|
self.set_signal_handler(False)
|
||||||
|
|
||||||
self['starttime'] = None
|
self['starttime'] = None
|
||||||
self['running'] = False
|
self['running'] = False
|
||||||
self['cleanup_thread'].stop()
|
self['cleanup_thread'].stop()
|
||||||
self['workers'].clear()
|
|
||||||
self['database'].disconnect()
|
self['database'].disconnect()
|
||||||
self['cache'].close()
|
self['cache'].close()
|
||||||
|
|
||||||
|
@ -294,42 +282,11 @@ class CacheCleanupThread(Thread):
|
||||||
self.running.clear()
|
self.running.clear()
|
||||||
|
|
||||||
|
|
||||||
class PushWorker(multiprocessing.Process):
|
|
||||||
def __init__(self, queue: multiprocessing.Queue):
|
|
||||||
multiprocessing.Process.__init__(self)
|
|
||||||
self.queue = queue
|
|
||||||
self.shutdown = multiprocessing.Event()
|
|
||||||
|
|
||||||
|
|
||||||
def stop(self) -> None:
|
|
||||||
self.shutdown.set()
|
|
||||||
|
|
||||||
|
|
||||||
def run(self) -> None:
|
|
||||||
asyncio.run(self.handle_queue())
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_queue(self) -> None:
|
|
||||||
client = HttpClient()
|
|
||||||
client.open()
|
|
||||||
|
|
||||||
while not self.shutdown.is_set():
|
|
||||||
try:
|
|
||||||
inbox, message, instance = self.queue.get(block=True, timeout=0.1)
|
|
||||||
asyncio.create_task(client.post(inbox, message, instance))
|
|
||||||
|
|
||||||
except Empty:
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
|
|
||||||
# make sure an exception doesn't bring down the worker
|
|
||||||
except Exception:
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
await client.close()
|
|
||||||
|
|
||||||
|
|
||||||
@web.middleware
|
@web.middleware
|
||||||
async def handle_response_headers(request: web.Request, handler: Callable) -> Response:
|
async def handle_response_headers(
|
||||||
|
request: web.Request,
|
||||||
|
handler: Callable[[web.Request], Awaitable[Response]]) -> Response:
|
||||||
|
|
||||||
resp = await handler(request)
|
resp = await handler(request)
|
||||||
resp.headers['Server'] = 'ActivityRelay'
|
resp.headers['Server'] = 'ActivityRelay'
|
||||||
|
|
||||||
|
|
|
@ -2,28 +2,27 @@ from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import typing
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from bsql import Database, Row
|
||||||
|
from collections.abc import Callable, Iterator
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from redis import Redis
|
from redis import Redis
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from .database import get_database
|
from .database import Connection, get_database
|
||||||
from .misc import Message, boolean
|
from .misc import Message, boolean
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from blib import Database
|
|
||||||
from collections.abc import Callable, Iterator
|
|
||||||
from typing import Any
|
|
||||||
from .application import Application
|
from .application import Application
|
||||||
|
|
||||||
|
|
||||||
# todo: implement more caching backends
|
SerializerCallback = Callable[[Any], str]
|
||||||
|
DeserializerCallback = Callable[[str], Any]
|
||||||
|
|
||||||
BACKENDS: dict[str, type[Cache]] = {}
|
BACKENDS: dict[str, type[Cache]] = {}
|
||||||
CONVERTERS: dict[str, tuple[Callable, Callable]] = {
|
CONVERTERS: dict[str, tuple[SerializerCallback, DeserializerCallback]] = {
|
||||||
'str': (str, str),
|
'str': (str, str),
|
||||||
'int': (str, int),
|
'int': (str, int),
|
||||||
'bool': (str, boolean),
|
'bool': (str, boolean),
|
||||||
|
@ -61,13 +60,13 @@ class Item:
|
||||||
updated: datetime
|
updated: datetime
|
||||||
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self) -> None:
|
||||||
if isinstance(self.updated, str):
|
if isinstance(self.updated, str): # type: ignore[unreachable]
|
||||||
self.updated = datetime.fromisoformat(self.updated)
|
self.updated = datetime.fromisoformat(self.updated) # type: ignore[unreachable]
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_data(cls: type[Item], *args) -> Item:
|
def from_data(cls: type[Item], *args: Any) -> Item:
|
||||||
data = cls(*args)
|
data = cls(*args)
|
||||||
data.value = deserialize_value(data.value, data.value_type)
|
data.value = deserialize_value(data.value, data.value_type)
|
||||||
|
|
||||||
|
@ -159,10 +158,13 @@ class SqlCache(Cache):
|
||||||
|
|
||||||
def __init__(self, app: Application):
|
def __init__(self, app: Application):
|
||||||
Cache.__init__(self, app)
|
Cache.__init__(self, app)
|
||||||
self._db: Database = None
|
self._db: Database[Connection] | None = None
|
||||||
|
|
||||||
|
|
||||||
def get(self, namespace: str, key: str) -> Item:
|
def get(self, namespace: str, key: str) -> Item:
|
||||||
|
if self._db is None:
|
||||||
|
raise RuntimeError("Database has not been setup")
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'namespace': namespace,
|
'namespace': namespace,
|
||||||
'key': key
|
'key': key
|
||||||
|
@ -170,7 +172,7 @@ class SqlCache(Cache):
|
||||||
|
|
||||||
with self._db.session(False) as conn:
|
with self._db.session(False) as conn:
|
||||||
with conn.run('get-cache-item', params) as cur:
|
with conn.run('get-cache-item', params) as cur:
|
||||||
if not (row := cur.one()):
|
if not (row := cur.one(Row)):
|
||||||
raise KeyError(f'{namespace}:{key}')
|
raise KeyError(f'{namespace}:{key}')
|
||||||
|
|
||||||
row.pop('id', None)
|
row.pop('id', None)
|
||||||
|
@ -178,18 +180,27 @@ class SqlCache(Cache):
|
||||||
|
|
||||||
|
|
||||||
def get_keys(self, namespace: str) -> Iterator[str]:
|
def get_keys(self, namespace: str) -> Iterator[str]:
|
||||||
|
if self._db is None:
|
||||||
|
raise RuntimeError("Database has not been setup")
|
||||||
|
|
||||||
with self._db.session(False) as conn:
|
with self._db.session(False) as conn:
|
||||||
for row in conn.run('get-cache-keys', {'namespace': namespace}):
|
for row in conn.run('get-cache-keys', {'namespace': namespace}):
|
||||||
yield row['key']
|
yield row['key']
|
||||||
|
|
||||||
|
|
||||||
def get_namespaces(self) -> Iterator[str]:
|
def get_namespaces(self) -> Iterator[str]:
|
||||||
|
if self._db is None:
|
||||||
|
raise RuntimeError("Database has not been setup")
|
||||||
|
|
||||||
with self._db.session(False) as conn:
|
with self._db.session(False) as conn:
|
||||||
for row in conn.run('get-cache-namespaces', None):
|
for row in conn.run('get-cache-namespaces', None):
|
||||||
yield row['namespace']
|
yield row['namespace']
|
||||||
|
|
||||||
|
|
||||||
def set(self, namespace: str, key: str, value: Any, value_type: str = 'str') -> Item:
|
def set(self, namespace: str, key: str, value: Any, value_type: str = 'str') -> Item:
|
||||||
|
if self._db is None:
|
||||||
|
raise RuntimeError("Database has not been setup")
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'namespace': namespace,
|
'namespace': namespace,
|
||||||
'key': key,
|
'key': key,
|
||||||
|
@ -199,13 +210,18 @@ class SqlCache(Cache):
|
||||||
}
|
}
|
||||||
|
|
||||||
with self._db.session(True) as conn:
|
with self._db.session(True) as conn:
|
||||||
with conn.run('set-cache-item', params) as conn:
|
with conn.run('set-cache-item', params) as cur:
|
||||||
row = conn.one()
|
if (row := cur.one(Row)) is None:
|
||||||
|
raise RuntimeError("Cache item not set")
|
||||||
|
|
||||||
row.pop('id', None)
|
row.pop('id', None)
|
||||||
return Item.from_data(*tuple(row.values()))
|
return Item.from_data(*tuple(row.values()))
|
||||||
|
|
||||||
|
|
||||||
def delete(self, namespace: str, key: str) -> None:
|
def delete(self, namespace: str, key: str) -> None:
|
||||||
|
if self._db is None:
|
||||||
|
raise RuntimeError("Database has not been setup")
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'namespace': namespace,
|
'namespace': namespace,
|
||||||
'key': key
|
'key': key
|
||||||
|
@ -217,6 +233,9 @@ class SqlCache(Cache):
|
||||||
|
|
||||||
|
|
||||||
def delete_old(self, days: int = 14) -> None:
|
def delete_old(self, days: int = 14) -> None:
|
||||||
|
if self._db is None:
|
||||||
|
raise RuntimeError("Database has not been setup")
|
||||||
|
|
||||||
limit = datetime.now(tz = timezone.utc) - timedelta(days = days)
|
limit = datetime.now(tz = timezone.utc) - timedelta(days = days)
|
||||||
params = {"limit": limit.timestamp()}
|
params = {"limit": limit.timestamp()}
|
||||||
|
|
||||||
|
@ -226,6 +245,9 @@ class SqlCache(Cache):
|
||||||
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
|
if self._db is None:
|
||||||
|
raise RuntimeError("Database has not been setup")
|
||||||
|
|
||||||
with self._db.session(True) as conn:
|
with self._db.session(True) as conn:
|
||||||
with conn.execute("DELETE FROM cache"):
|
with conn.execute("DELETE FROM cache"):
|
||||||
pass
|
pass
|
||||||
|
@ -360,5 +382,5 @@ class RedisCache(Cache):
|
||||||
if not self._rd:
|
if not self._rd:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._rd.close()
|
self._rd.close() # type: ignore
|
||||||
self._rd = None # type: ignore
|
self._rd = None # type: ignore
|
||||||
|
|
|
@ -1,21 +1,16 @@
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import typing
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from .misc import boolean
|
from .misc import boolean
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
class RelayConfig(dict[str, Any]):
|
||||||
class RelayConfig(dict):
|
|
||||||
def __init__(self, path: str):
|
def __init__(self, path: str):
|
||||||
dict.__init__(self, {})
|
dict.__init__(self, {})
|
||||||
|
|
||||||
|
@ -122,7 +117,7 @@ class RelayConfig(dict):
|
||||||
self[key] = value
|
self[key] = value
|
||||||
|
|
||||||
|
|
||||||
class RelayDatabase(dict):
|
class RelayDatabase(dict[str, Any]):
|
||||||
def __init__(self, config: RelayConfig):
|
def __init__(self, config: RelayConfig):
|
||||||
dict.__init__(self, {
|
dict.__init__(self, {
|
||||||
'relay-list': {},
|
'relay-list': {},
|
||||||
|
|
|
@ -3,18 +3,16 @@ from __future__ import annotations
|
||||||
import getpass
|
import getpass
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import typing
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from dataclasses import asdict, dataclass, fields
|
from dataclasses import asdict, dataclass, fields
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from platformdirs import user_config_dir
|
from platformdirs import user_config_dir
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from .misc import IS_DOCKER
|
from .misc import IS_DOCKER
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from typing import Self
|
from typing import Self
|
||||||
|
|
||||||
|
@ -66,7 +64,7 @@ class Config:
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, path: Path | None = None, load: bool = False):
|
def __init__(self, path: Path | None = None, load: bool = False):
|
||||||
self.path = Config.get_config_dir(path)
|
self.path: Path = Config.get_config_dir(path)
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
if load:
|
if load:
|
||||||
|
|
|
@ -1,31 +1,28 @@
|
||||||
from __future__ import annotations
|
from bsql import Database
|
||||||
|
|
||||||
import bsql
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from .config import THEMES, ConfigData
|
from .config import THEMES, ConfigData
|
||||||
from .connection import RELAY_SOFTWARE, Connection
|
from .connection import RELAY_SOFTWARE, Connection
|
||||||
from .schema import TABLES, VERSIONS, migrate_0
|
from .schema import TABLES, VERSIONS, migrate_0
|
||||||
|
|
||||||
from .. import logger as logging
|
from .. import logger as logging
|
||||||
|
from ..config import Config
|
||||||
from ..misc import get_resource
|
from ..misc import get_resource
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from ..config import Config
|
|
||||||
|
|
||||||
|
def get_database(config: Config, migrate: bool = True) -> Database[Connection]:
|
||||||
def get_database(config: Config, migrate: bool = True) -> bsql.Database:
|
|
||||||
options = {
|
options = {
|
||||||
'connection_class': Connection,
|
'connection_class': Connection,
|
||||||
'pool_size': 5,
|
'pool_size': 5,
|
||||||
'tables': TABLES
|
'tables': TABLES
|
||||||
}
|
}
|
||||||
|
|
||||||
|
db: Database[Connection]
|
||||||
|
|
||||||
if config.db_type == 'sqlite':
|
if config.db_type == 'sqlite':
|
||||||
db = bsql.Database.sqlite(config.sqlite_path, **options)
|
db = Database.sqlite(config.sqlite_path, **options)
|
||||||
|
|
||||||
elif config.db_type == 'postgres':
|
elif config.db_type == 'postgres':
|
||||||
db = bsql.Database.postgresql(
|
db = Database.postgresql(
|
||||||
config.pg_name,
|
config.pg_name,
|
||||||
config.pg_host,
|
config.pg_host,
|
||||||
config.pg_port,
|
config.pg_port,
|
||||||
|
|
|
@ -1,17 +1,16 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
# removing the above line turns annotations into types instead of str objects which messes with
|
||||||
|
# `Field.type`
|
||||||
|
|
||||||
import typing
|
from bsql import Row
|
||||||
|
from collections.abc import Callable, Sequence
|
||||||
from dataclasses import Field, asdict, dataclass, fields
|
from dataclasses import Field, asdict, dataclass, fields
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from .. import logger as logging
|
from .. import logger as logging
|
||||||
from ..misc import boolean
|
from ..misc import boolean
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from bsql import Row
|
|
||||||
from collections.abc import Callable, Sequence
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from typing import Self
|
from typing import Self
|
||||||
|
|
||||||
|
@ -120,7 +119,7 @@ class ConfigData:
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def FIELD(cls: type[Self], key: str) -> Field:
|
def FIELD(cls: type[Self], key: str) -> Field[Any]:
|
||||||
for field in fields(cls):
|
for field in fields(cls):
|
||||||
if field.name == key.replace('-', '_'):
|
if field.name == key.replace('-', '_'):
|
||||||
return field
|
return field
|
||||||
|
|
|
@ -1,28 +1,24 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from argon2 import PasswordHasher
|
from argon2 import PasswordHasher
|
||||||
from bsql import Connection as SqlConnection, Update
|
from bsql import Connection as SqlConnection, Row, Update
|
||||||
|
from collections.abc import Iterator
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from . import schema
|
||||||
from .config import (
|
from .config import (
|
||||||
THEMES,
|
THEMES,
|
||||||
ConfigData
|
ConfigData
|
||||||
)
|
)
|
||||||
|
|
||||||
from .. import logger as logging
|
from .. import logger as logging
|
||||||
from ..misc import boolean, get_app
|
from ..misc import Message, boolean, get_app
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Iterator, Sequence
|
|
||||||
from bsql import Row
|
|
||||||
from typing import Any
|
|
||||||
from ..application import Application
|
from ..application import Application
|
||||||
from ..misc import Message
|
|
||||||
|
|
||||||
|
|
||||||
RELAY_SOFTWARE = [
|
RELAY_SOFTWARE = [
|
||||||
'activityrelay', # https://git.pleroma.social/pleroma/relay
|
'activityrelay', # https://git.pleroma.social/pleroma/relay
|
||||||
|
@ -42,14 +38,14 @@ class Connection(SqlConnection):
|
||||||
return get_app()
|
return get_app()
|
||||||
|
|
||||||
|
|
||||||
def distill_inboxes(self, message: Message) -> Iterator[Row]:
|
def distill_inboxes(self, message: Message) -> Iterator[schema.Instance]:
|
||||||
src_domains = {
|
src_domains = {
|
||||||
message.domain,
|
message.domain,
|
||||||
urlparse(message.object_id).netloc
|
urlparse(message.object_id).netloc
|
||||||
}
|
}
|
||||||
|
|
||||||
for instance in self.get_inboxes():
|
for instance in self.get_inboxes():
|
||||||
if instance['domain'] not in src_domains:
|
if instance.domain not in src_domains:
|
||||||
yield instance
|
yield instance
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,7 +53,7 @@ class Connection(SqlConnection):
|
||||||
key = key.replace('_', '-')
|
key = key.replace('_', '-')
|
||||||
|
|
||||||
with self.run('get-config', {'key': key}) as cur:
|
with self.run('get-config', {'key': key}) as cur:
|
||||||
if not (row := cur.one()):
|
if (row := cur.one(Row)) is None:
|
||||||
return ConfigData.DEFAULT(key)
|
return ConfigData.DEFAULT(key)
|
||||||
|
|
||||||
data = ConfigData()
|
data = ConfigData()
|
||||||
|
@ -66,8 +62,8 @@ class Connection(SqlConnection):
|
||||||
|
|
||||||
|
|
||||||
def get_config_all(self) -> ConfigData:
|
def get_config_all(self) -> ConfigData:
|
||||||
with self.run('get-config-all', None) as cur:
|
rows = tuple(self.run('get-config-all', None).all(schema.Row))
|
||||||
return ConfigData.from_rows(tuple(cur.all()))
|
return ConfigData.from_rows(rows)
|
||||||
|
|
||||||
|
|
||||||
def put_config(self, key: str, value: Any) -> Any:
|
def put_config(self, key: str, value: Any) -> Any:
|
||||||
|
@ -80,6 +76,7 @@ class Connection(SqlConnection):
|
||||||
elif key == 'log-level':
|
elif key == 'log-level':
|
||||||
value = logging.LogLevel.parse(value)
|
value = logging.LogLevel.parse(value)
|
||||||
logging.set_level(value)
|
logging.set_level(value)
|
||||||
|
self.app['workers'].set_log_level(value)
|
||||||
|
|
||||||
elif key in {'approval-required', 'whitelist-enabled'}:
|
elif key in {'approval-required', 'whitelist-enabled'}:
|
||||||
value = boolean(value)
|
value = boolean(value)
|
||||||
|
@ -94,7 +91,7 @@ class Connection(SqlConnection):
|
||||||
params = {
|
params = {
|
||||||
'key': key,
|
'key': key,
|
||||||
'value': data.get(key, serialize = True),
|
'value': data.get(key, serialize = True),
|
||||||
'type': 'LogLevel' if field.type == 'logging.LogLevel' else field.type
|
'type': 'LogLevel' if field.type == 'logging.LogLevel' else field.type # type: ignore
|
||||||
}
|
}
|
||||||
|
|
||||||
with self.run('put-config', params):
|
with self.run('put-config', params):
|
||||||
|
@ -103,14 +100,13 @@ class Connection(SqlConnection):
|
||||||
return data.get(key)
|
return data.get(key)
|
||||||
|
|
||||||
|
|
||||||
def get_inbox(self, value: str) -> Row:
|
def get_inbox(self, value: str) -> schema.Instance | None:
|
||||||
with self.run('get-inbox', {'value': value}) as cur:
|
with self.run('get-inbox', {'value': value}) as cur:
|
||||||
return cur.one() # type: ignore
|
return cur.one(schema.Instance)
|
||||||
|
|
||||||
|
|
||||||
def get_inboxes(self) -> Sequence[Row]:
|
def get_inboxes(self) -> Iterator[schema.Instance]:
|
||||||
with self.execute("SELECT * FROM inboxes WHERE accepted = 1") as cur:
|
return self.execute("SELECT * FROM inboxes WHERE accepted = 1").all(schema.Instance)
|
||||||
return tuple(cur.all())
|
|
||||||
|
|
||||||
|
|
||||||
def put_inbox(self,
|
def put_inbox(self,
|
||||||
|
@ -119,7 +115,7 @@ class Connection(SqlConnection):
|
||||||
actor: str | None = None,
|
actor: str | None = None,
|
||||||
followid: str | None = None,
|
followid: str | None = None,
|
||||||
software: str | None = None,
|
software: str | None = None,
|
||||||
accepted: bool = True) -> Row:
|
accepted: bool = True) -> schema.Instance:
|
||||||
|
|
||||||
params: dict[str, Any] = {
|
params: dict[str, Any] = {
|
||||||
'inbox': inbox,
|
'inbox': inbox,
|
||||||
|
@ -129,7 +125,7 @@ class Connection(SqlConnection):
|
||||||
'accepted': accepted
|
'accepted': accepted
|
||||||
}
|
}
|
||||||
|
|
||||||
if not self.get_inbox(domain):
|
if self.get_inbox(domain) is None:
|
||||||
if not inbox:
|
if not inbox:
|
||||||
raise ValueError("Missing inbox")
|
raise ValueError("Missing inbox")
|
||||||
|
|
||||||
|
@ -137,14 +133,20 @@ class Connection(SqlConnection):
|
||||||
params['created'] = datetime.now(tz = timezone.utc)
|
params['created'] = datetime.now(tz = timezone.utc)
|
||||||
|
|
||||||
with self.run('put-inbox', params) as cur:
|
with self.run('put-inbox', params) as cur:
|
||||||
return cur.one() # type: ignore
|
if (row := cur.one(schema.Instance)) is None:
|
||||||
|
raise RuntimeError(f"Failed to insert instance: {domain}")
|
||||||
|
|
||||||
|
return row
|
||||||
|
|
||||||
for key, value in tuple(params.items()):
|
for key, value in tuple(params.items()):
|
||||||
if value is None:
|
if value is None:
|
||||||
del params[key]
|
del params[key]
|
||||||
|
|
||||||
with self.update('inboxes', params, domain = domain) as cur:
|
with self.update('inboxes', params, domain = domain) as cur:
|
||||||
return cur.one() # type: ignore
|
if (row := cur.one(schema.Instance)) is None:
|
||||||
|
raise RuntimeError(f"Failed to update instance: {domain}")
|
||||||
|
|
||||||
|
return row
|
||||||
|
|
||||||
|
|
||||||
def del_inbox(self, value: str) -> bool:
|
def del_inbox(self, value: str) -> bool:
|
||||||
|
@ -155,24 +157,23 @@ class Connection(SqlConnection):
|
||||||
return cur.row_count == 1
|
return cur.row_count == 1
|
||||||
|
|
||||||
|
|
||||||
def get_request(self, domain: str) -> Row:
|
def get_request(self, domain: str) -> schema.Instance | None:
|
||||||
with self.run('get-request', {'domain': domain}) as cur:
|
with self.run('get-request', {'domain': domain}) as cur:
|
||||||
if not (row := cur.one()):
|
return cur.one(schema.Instance)
|
||||||
raise KeyError(domain)
|
|
||||||
|
|
||||||
return row
|
|
||||||
|
|
||||||
|
|
||||||
def get_requests(self) -> Sequence[Row]:
|
def get_requests(self) -> Iterator[schema.Instance]:
|
||||||
with self.execute('SELECT * FROM inboxes WHERE accepted = 0') as cur:
|
return self.execute('SELECT * FROM inboxes WHERE accepted = 0').all(schema.Instance)
|
||||||
return tuple(cur.all())
|
|
||||||
|
|
||||||
|
|
||||||
def put_request_response(self, domain: str, accepted: bool) -> Row:
|
def put_request_response(self, domain: str, accepted: bool) -> schema.Instance:
|
||||||
instance = self.get_request(domain)
|
if (instance := self.get_request(domain)) is None:
|
||||||
|
raise KeyError(domain)
|
||||||
|
|
||||||
if not accepted:
|
if not accepted:
|
||||||
self.del_inbox(domain)
|
if not self.del_inbox(domain):
|
||||||
|
raise RuntimeError(f'Failed to delete request: {domain}')
|
||||||
|
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
|
@ -181,21 +182,28 @@ class Connection(SqlConnection):
|
||||||
}
|
}
|
||||||
|
|
||||||
with self.run('put-inbox-accept', params) as cur:
|
with self.run('put-inbox-accept', params) as cur:
|
||||||
return cur.one() # type: ignore
|
if (row := cur.one(schema.Instance)) is None:
|
||||||
|
raise RuntimeError(f"Failed to insert response for domain: {domain}")
|
||||||
|
|
||||||
|
return row
|
||||||
|
|
||||||
|
|
||||||
def get_user(self, value: str) -> Row:
|
def get_user(self, value: str) -> schema.User | None:
|
||||||
with self.run('get-user', {'value': value}) as cur:
|
with self.run('get-user', {'value': value}) as cur:
|
||||||
return cur.one() # type: ignore
|
return cur.one(schema.User)
|
||||||
|
|
||||||
|
|
||||||
def get_user_by_token(self, code: str) -> Row:
|
def get_user_by_token(self, code: str) -> schema.User | None:
|
||||||
with self.run('get-user-by-token', {'code': code}) as cur:
|
with self.run('get-user-by-token', {'code': code}) as cur:
|
||||||
return cur.one() # type: ignore
|
return cur.one(schema.User)
|
||||||
|
|
||||||
|
|
||||||
def put_user(self, username: str, password: str | None, handle: str | None = None) -> Row:
|
def get_users(self) -> Iterator[schema.User]:
|
||||||
if self.get_user(username):
|
return self.execute("SELECT * FROM users").all(schema.User)
|
||||||
|
|
||||||
|
|
||||||
|
def put_user(self, username: str, password: str | None, handle: str | None = None) -> schema.User:
|
||||||
|
if self.get_user(username) is not None:
|
||||||
data: dict[str, str | datetime | None] = {}
|
data: dict[str, str | datetime | None] = {}
|
||||||
|
|
||||||
if password:
|
if password:
|
||||||
|
@ -208,7 +216,10 @@ class Connection(SqlConnection):
|
||||||
stmt.set_where("username", username)
|
stmt.set_where("username", username)
|
||||||
|
|
||||||
with self.query(stmt) as cur:
|
with self.query(stmt) as cur:
|
||||||
return cur.one() # type: ignore
|
if (row := cur.one(schema.User)) is None:
|
||||||
|
raise RuntimeError(f"Failed to update user: {username}")
|
||||||
|
|
||||||
|
return row
|
||||||
|
|
||||||
if password is None:
|
if password is None:
|
||||||
raise ValueError('Password cannot be empty')
|
raise ValueError('Password cannot be empty')
|
||||||
|
@ -221,25 +232,36 @@ class Connection(SqlConnection):
|
||||||
}
|
}
|
||||||
|
|
||||||
with self.run('put-user', data) as cur:
|
with self.run('put-user', data) as cur:
|
||||||
return cur.one() # type: ignore
|
if (row := cur.one(schema.User)) is None:
|
||||||
|
raise RuntimeError(f"Failed to insert user: {username}")
|
||||||
|
|
||||||
|
return row
|
||||||
|
|
||||||
|
|
||||||
def del_user(self, username: str) -> None:
|
def del_user(self, username: str) -> None:
|
||||||
user = self.get_user(username)
|
if (user := self.get_user(username)) is None:
|
||||||
|
raise KeyError(username)
|
||||||
|
|
||||||
with self.run('del-user', {'value': user['username']}):
|
with self.run('del-user', {'value': user.username}):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
with self.run('del-token-user', {'username': user['username']}):
|
with self.run('del-token-user', {'username': user.username}):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def get_token(self, code: str) -> Row:
|
def get_token(self, code: str) -> schema.Token | None:
|
||||||
with self.run('get-token', {'code': code}) as cur:
|
with self.run('get-token', {'code': code}) as cur:
|
||||||
return cur.one() # type: ignore
|
return cur.one(schema.Token)
|
||||||
|
|
||||||
|
|
||||||
def put_token(self, username: str) -> Row:
|
def get_tokens(self, username: str | None = None) -> Iterator[schema.Token]:
|
||||||
|
if username is not None:
|
||||||
|
return self.select('tokens').all(schema.Token)
|
||||||
|
|
||||||
|
return self.select('tokens', username = username).all(schema.Token)
|
||||||
|
|
||||||
|
|
||||||
|
def put_token(self, username: str) -> schema.Token:
|
||||||
data = {
|
data = {
|
||||||
'code': uuid4().hex,
|
'code': uuid4().hex,
|
||||||
'user': username,
|
'user': username,
|
||||||
|
@ -247,7 +269,10 @@ class Connection(SqlConnection):
|
||||||
}
|
}
|
||||||
|
|
||||||
with self.run('put-token', data) as cur:
|
with self.run('put-token', data) as cur:
|
||||||
return cur.one() # type: ignore
|
if (row := cur.one(schema.Token)) is None:
|
||||||
|
raise RuntimeError(f"Failed to insert token for user: {username}")
|
||||||
|
|
||||||
|
return row
|
||||||
|
|
||||||
|
|
||||||
def del_token(self, code: str) -> None:
|
def del_token(self, code: str) -> None:
|
||||||
|
@ -255,18 +280,22 @@ class Connection(SqlConnection):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def get_domain_ban(self, domain: str) -> Row:
|
def get_domain_ban(self, domain: str) -> schema.DomainBan | None:
|
||||||
if domain.startswith('http'):
|
if domain.startswith('http'):
|
||||||
domain = urlparse(domain).netloc
|
domain = urlparse(domain).netloc
|
||||||
|
|
||||||
with self.run('get-domain-ban', {'domain': domain}) as cur:
|
with self.run('get-domain-ban', {'domain': domain}) as cur:
|
||||||
return cur.one() # type: ignore
|
return cur.one(schema.DomainBan)
|
||||||
|
|
||||||
|
|
||||||
|
def get_domain_bans(self) -> Iterator[schema.DomainBan]:
|
||||||
|
return self.execute("SELECT * FROM domain_bans").all(schema.DomainBan)
|
||||||
|
|
||||||
|
|
||||||
def put_domain_ban(self,
|
def put_domain_ban(self,
|
||||||
domain: str,
|
domain: str,
|
||||||
reason: str | None = None,
|
reason: str | None = None,
|
||||||
note: str | None = None) -> Row:
|
note: str | None = None) -> schema.DomainBan:
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'domain': domain,
|
'domain': domain,
|
||||||
|
@ -276,13 +305,16 @@ class Connection(SqlConnection):
|
||||||
}
|
}
|
||||||
|
|
||||||
with self.run('put-domain-ban', params) as cur:
|
with self.run('put-domain-ban', params) as cur:
|
||||||
return cur.one() # type: ignore
|
if (row := cur.one(schema.DomainBan)) is None:
|
||||||
|
raise RuntimeError(f"Failed to insert domain ban: {domain}")
|
||||||
|
|
||||||
|
return row
|
||||||
|
|
||||||
|
|
||||||
def update_domain_ban(self,
|
def update_domain_ban(self,
|
||||||
domain: str,
|
domain: str,
|
||||||
reason: str | None = None,
|
reason: str | None = None,
|
||||||
note: str | None = None) -> Row:
|
note: str | None = None) -> schema.DomainBan:
|
||||||
|
|
||||||
if not (reason or note):
|
if not (reason or note):
|
||||||
raise ValueError('"reason" and/or "note" must be specified')
|
raise ValueError('"reason" and/or "note" must be specified')
|
||||||
|
@ -302,7 +334,10 @@ class Connection(SqlConnection):
|
||||||
if cur.row_count > 1:
|
if cur.row_count > 1:
|
||||||
raise ValueError('More than one row was modified')
|
raise ValueError('More than one row was modified')
|
||||||
|
|
||||||
return self.get_domain_ban(domain)
|
if (row := cur.one(schema.DomainBan)) is None:
|
||||||
|
raise RuntimeError(f"Failed to update domain ban: {domain}")
|
||||||
|
|
||||||
|
return row
|
||||||
|
|
||||||
|
|
||||||
def del_domain_ban(self, domain: str) -> bool:
|
def del_domain_ban(self, domain: str) -> bool:
|
||||||
|
@ -313,15 +348,19 @@ class Connection(SqlConnection):
|
||||||
return cur.row_count == 1
|
return cur.row_count == 1
|
||||||
|
|
||||||
|
|
||||||
def get_software_ban(self, name: str) -> Row:
|
def get_software_ban(self, name: str) -> schema.SoftwareBan | None:
|
||||||
with self.run('get-software-ban', {'name': name}) as cur:
|
with self.run('get-software-ban', {'name': name}) as cur:
|
||||||
return cur.one() # type: ignore
|
return cur.one(schema.SoftwareBan)
|
||||||
|
|
||||||
|
|
||||||
|
def get_software_bans(self) -> Iterator[schema.SoftwareBan,]:
|
||||||
|
return self.execute('SELECT * FROM software_bans').all(schema.SoftwareBan)
|
||||||
|
|
||||||
|
|
||||||
def put_software_ban(self,
|
def put_software_ban(self,
|
||||||
name: str,
|
name: str,
|
||||||
reason: str | None = None,
|
reason: str | None = None,
|
||||||
note: str | None = None) -> Row:
|
note: str | None = None) -> schema.SoftwareBan:
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'name': name,
|
'name': name,
|
||||||
|
@ -331,13 +370,16 @@ class Connection(SqlConnection):
|
||||||
}
|
}
|
||||||
|
|
||||||
with self.run('put-software-ban', params) as cur:
|
with self.run('put-software-ban', params) as cur:
|
||||||
return cur.one() # type: ignore
|
if (row := cur.one(schema.SoftwareBan)) is None:
|
||||||
|
raise RuntimeError(f'Failed to insert software ban: {name}')
|
||||||
|
|
||||||
|
return row
|
||||||
|
|
||||||
|
|
||||||
def update_software_ban(self,
|
def update_software_ban(self,
|
||||||
name: str,
|
name: str,
|
||||||
reason: str | None = None,
|
reason: str | None = None,
|
||||||
note: str | None = None) -> Row:
|
note: str | None = None) -> schema.SoftwareBan:
|
||||||
|
|
||||||
if not (reason or note):
|
if not (reason or note):
|
||||||
raise ValueError('"reason" and/or "note" must be specified')
|
raise ValueError('"reason" and/or "note" must be specified')
|
||||||
|
@ -357,7 +399,10 @@ class Connection(SqlConnection):
|
||||||
if cur.row_count > 1:
|
if cur.row_count > 1:
|
||||||
raise ValueError('More than one row was modified')
|
raise ValueError('More than one row was modified')
|
||||||
|
|
||||||
return self.get_software_ban(name)
|
if (row := cur.one(schema.SoftwareBan)) is None:
|
||||||
|
raise RuntimeError(f'Failed to update software ban: {name}')
|
||||||
|
|
||||||
|
return row
|
||||||
|
|
||||||
|
|
||||||
def del_software_ban(self, name: str) -> bool:
|
def del_software_ban(self, name: str) -> bool:
|
||||||
|
@ -368,19 +413,26 @@ class Connection(SqlConnection):
|
||||||
return cur.row_count == 1
|
return cur.row_count == 1
|
||||||
|
|
||||||
|
|
||||||
def get_domain_whitelist(self, domain: str) -> Row:
|
def get_domain_whitelist(self, domain: str) -> schema.Whitelist | None:
|
||||||
with self.run('get-domain-whitelist', {'domain': domain}) as cur:
|
with self.run('get-domain-whitelist', {'domain': domain}) as cur:
|
||||||
return cur.one() # type: ignore
|
return cur.one()
|
||||||
|
|
||||||
|
|
||||||
def put_domain_whitelist(self, domain: str) -> Row:
|
def get_domains_whitelist(self) -> Iterator[schema.Whitelist,]:
|
||||||
|
return self.execute("SELECT * FROM whitelist").all(schema.Whitelist)
|
||||||
|
|
||||||
|
|
||||||
|
def put_domain_whitelist(self, domain: str) -> schema.Whitelist:
|
||||||
params = {
|
params = {
|
||||||
'domain': domain,
|
'domain': domain,
|
||||||
'created': datetime.now(tz = timezone.utc)
|
'created': datetime.now(tz = timezone.utc)
|
||||||
}
|
}
|
||||||
|
|
||||||
with self.run('put-domain-whitelist', params) as cur:
|
with self.run('put-domain-whitelist', params) as cur:
|
||||||
return cur.one() # type: ignore
|
if (row := cur.one(schema.Whitelist)) is None:
|
||||||
|
raise RuntimeError(f'Failed to insert whitelisted domain: {domain}')
|
||||||
|
|
||||||
|
return row
|
||||||
|
|
||||||
|
|
||||||
def del_domain_whitelist(self, domain: str) -> bool:
|
def del_domain_whitelist(self, domain: str) -> bool:
|
||||||
|
|
|
@ -2,69 +2,90 @@ from __future__ import annotations
|
||||||
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from bsql import Column, Table, Tables
|
from bsql import Column, Row, Tables
|
||||||
|
from collections.abc import Callable
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from .config import ConfigData
|
from .config import ConfigData
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if typing.TYPE_CHECKING:
|
||||||
from collections.abc import Callable
|
|
||||||
from .connection import Connection
|
from .connection import Connection
|
||||||
|
|
||||||
|
|
||||||
VERSIONS: dict[int, Callable] = {}
|
VERSIONS: dict[int, Callable[[Connection], None]] = {}
|
||||||
TABLES: Tables = Tables(
|
TABLES = Tables()
|
||||||
Table(
|
|
||||||
'config',
|
|
||||||
Column('key', 'text', primary_key = True, unique = True, nullable = False),
|
|
||||||
Column('value', 'text'),
|
|
||||||
Column('type', 'text', default = 'str')
|
|
||||||
),
|
|
||||||
Table(
|
|
||||||
'inboxes',
|
|
||||||
Column('domain', 'text', primary_key = True, unique = True, nullable = False),
|
|
||||||
Column('actor', 'text', unique = True),
|
|
||||||
Column('inbox', 'text', unique = True, nullable = False),
|
|
||||||
Column('followid', 'text'),
|
|
||||||
Column('software', 'text'),
|
|
||||||
Column('accepted', 'boolean'),
|
|
||||||
Column('created', 'timestamp', nullable = False)
|
|
||||||
),
|
|
||||||
Table(
|
|
||||||
'whitelist',
|
|
||||||
Column('domain', 'text', primary_key = True, unique = True, nullable = True),
|
|
||||||
Column('created', 'timestamp')
|
|
||||||
),
|
|
||||||
Table(
|
|
||||||
'domain_bans',
|
|
||||||
Column('domain', 'text', primary_key = True, unique = True, nullable = True),
|
|
||||||
Column('reason', 'text'),
|
|
||||||
Column('note', 'text'),
|
|
||||||
Column('created', 'timestamp', nullable = False)
|
|
||||||
),
|
|
||||||
Table(
|
|
||||||
'software_bans',
|
|
||||||
Column('name', 'text', primary_key = True, unique = True, nullable = True),
|
|
||||||
Column('reason', 'text'),
|
|
||||||
Column('note', 'text'),
|
|
||||||
Column('created', 'timestamp', nullable = False)
|
|
||||||
),
|
|
||||||
Table(
|
|
||||||
'users',
|
|
||||||
Column('username', 'text', primary_key = True, unique = True, nullable = False),
|
|
||||||
Column('hash', 'text', nullable = False),
|
|
||||||
Column('handle', 'text'),
|
|
||||||
Column('created', 'timestamp', nullable = False)
|
|
||||||
),
|
|
||||||
Table(
|
|
||||||
'tokens',
|
|
||||||
Column('code', 'text', primary_key = True, unique = True, nullable = False),
|
|
||||||
Column('user', 'text', nullable = False),
|
|
||||||
Column('created', 'timestmap', nullable = False)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def migration(func: Callable) -> Callable:
|
@TABLES.add_row
|
||||||
|
class Config(Row):
|
||||||
|
key: Column[str] = Column('key', 'text', primary_key = True, unique = True, nullable = False)
|
||||||
|
value: Column[str] = Column('value', 'text')
|
||||||
|
type: Column[str] = Column('type', 'text', default = 'str')
|
||||||
|
|
||||||
|
|
||||||
|
@TABLES.add_row
|
||||||
|
class Instance(Row):
|
||||||
|
table_name: str = 'inboxes'
|
||||||
|
|
||||||
|
domain: Column[str] = Column(
|
||||||
|
'domain', 'text', primary_key = True, unique = True, nullable = False)
|
||||||
|
actor: Column[str] = Column('actor', 'text', unique = True)
|
||||||
|
inbox: Column[str] = Column('inbox', 'text', unique = True, nullable = False)
|
||||||
|
followid: Column[str] = Column('followid', 'text')
|
||||||
|
software: Column[str] = Column('software', 'text')
|
||||||
|
accepted: Column[datetime] = Column('accepted', 'boolean')
|
||||||
|
created: Column[datetime] = Column('created', 'timestamp', nullable = False)
|
||||||
|
|
||||||
|
|
||||||
|
@TABLES.add_row
|
||||||
|
class Whitelist(Row):
|
||||||
|
domain: Column[str] = Column(
|
||||||
|
'domain', 'text', primary_key = True, unique = True, nullable = True)
|
||||||
|
created: Column[datetime] = Column('created', 'timestamp')
|
||||||
|
|
||||||
|
|
||||||
|
@TABLES.add_row
|
||||||
|
class DomainBan(Row):
|
||||||
|
table_name: str = 'domain_bans'
|
||||||
|
|
||||||
|
domain: Column[str] = Column(
|
||||||
|
'domain', 'text', primary_key = True, unique = True, nullable = True)
|
||||||
|
reason: Column[str] = Column('reason', 'text')
|
||||||
|
note: Column[str] = Column('note', 'text')
|
||||||
|
created: Column[datetime] = Column('created', 'timestamp')
|
||||||
|
|
||||||
|
|
||||||
|
@TABLES.add_row
|
||||||
|
class SoftwareBan(Row):
|
||||||
|
table_name: str = 'software_bans'
|
||||||
|
|
||||||
|
name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True)
|
||||||
|
reason: Column[str] = Column('reason', 'text')
|
||||||
|
note: Column[str] = Column('note', 'text')
|
||||||
|
created: Column[datetime] = Column('created', 'timestamp')
|
||||||
|
|
||||||
|
|
||||||
|
@TABLES.add_row
|
||||||
|
class User(Row):
|
||||||
|
table_name: str = 'users'
|
||||||
|
|
||||||
|
username: Column[str] = Column(
|
||||||
|
'username', 'text', primary_key = True, unique = True, nullable = False)
|
||||||
|
hash: Column[str] = Column('hash', 'text', nullable = False)
|
||||||
|
handle: Column[str] = Column('handle', 'text')
|
||||||
|
created: Column[datetime] = Column('created', 'timestamp')
|
||||||
|
|
||||||
|
|
||||||
|
@TABLES.add_row
|
||||||
|
class Token(Row):
|
||||||
|
table_name: str = 'tokens'
|
||||||
|
|
||||||
|
code: Column[str] = Column('code', 'text', primary_key = True, unique = True, nullable = False)
|
||||||
|
user: Column[str] = Column('user', 'text', nullable = False)
|
||||||
|
created: Column[datetime] = Column('created', 'timestamp')
|
||||||
|
|
||||||
|
|
||||||
|
def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]:
|
||||||
ver = int(func.__name__.replace('migrate_', ''))
|
ver = int(func.__name__.replace('migrate_', ''))
|
||||||
VERSIONS[ver] = func
|
VERSIONS[ver] = func
|
||||||
return func
|
return func
|
||||||
|
|
|
@ -11,10 +11,11 @@
|
||||||
%title << {{config.name}}: {{page}}
|
%title << {{config.name}}: {{page}}
|
||||||
%meta(charset="UTF-8")
|
%meta(charset="UTF-8")
|
||||||
%meta(name="viewport" content="width=device-width, initial-scale=1")
|
%meta(name="viewport" content="width=device-width, initial-scale=1")
|
||||||
%link(rel="stylesheet" type="text/css" href="/theme/{{config.theme}}.css" nonce="{{view.request['hash']}}" class="theme")
|
%link(rel="stylesheet" type="text/css" href="/theme/{{config.theme}}.css?{{version}}" nonce="{{view.request['hash']}}" class="theme")
|
||||||
%link(rel="stylesheet" type="text/css" href="/static/style.css" nonce="{{view.request['hash']}}")
|
%link(rel="stylesheet" type="text/css" href="/static/style.css?{{version}}" nonce="{{view.request['hash']}}")
|
||||||
%link(rel="manifest" href="/manifest.json")
|
%link(rel="stylesheet" type="text/css" href="/static/bootstrap-icons.css?{{version}}" nonce="{{view.request['hash']}}")
|
||||||
%script(type="application/javascript" src="/static/api.js" nonce="{{view.request['hash']}}" defer)
|
%link(rel="manifest" href="/manifest.json?{{version}}")
|
||||||
|
%script(type="application/javascript" src="/static/functions.js?{{version}}" nonce="{{view.request['hash']}}" defer)
|
||||||
-block head
|
-block head
|
||||||
|
|
||||||
%body
|
%body
|
||||||
|
@ -41,7 +42,7 @@
|
||||||
|
|
||||||
#container
|
#container
|
||||||
#header.section
|
#header.section
|
||||||
%span#menu-open << ⁞
|
%span#menu-open -> %i(class="bi bi-list")
|
||||||
%a.title(href="/") -> =config.name
|
%a.title(href="/") -> =config.name
|
||||||
.empty
|
.empty
|
||||||
|
|
||||||
|
|
|
@ -1,29 +1,32 @@
|
||||||
-extends "base.haml"
|
-extends "base.haml"
|
||||||
-set page="Config"
|
-set page="Config"
|
||||||
|
|
||||||
-block head
|
|
||||||
%script(type="application/javascript" src="/static/config.js" nonce="{{view.request['hash']}}" defer)
|
|
||||||
|
|
||||||
-import "functions.haml" as func
|
-import "functions.haml" as func
|
||||||
|
|
||||||
-block content
|
-block content
|
||||||
%fieldset.section
|
%fieldset.section
|
||||||
%legend << Config
|
%legend << Config
|
||||||
|
|
||||||
.grid-2col
|
.grid-2col
|
||||||
%label(for="name") << Name
|
%label(for="name") << Name
|
||||||
|
%i(class="bi bi-question-circle-fill" title="{{desc.name}}")
|
||||||
%input(id = "name" placeholder="Relay Name" value="{{config.name or ''}}")
|
%input(id = "name" placeholder="Relay Name" value="{{config.name or ''}}")
|
||||||
|
|
||||||
%label(for="note") << Description
|
%label(for="note") << Description
|
||||||
|
%i(class="bi bi-question-circle-fill" title="{{desc.note}}")
|
||||||
%textarea(id="note" value="{{config.note or ''}}") << {{config.note}}
|
%textarea(id="note" value="{{config.note or ''}}") << {{config.note}}
|
||||||
|
|
||||||
%label(for="theme") << Color Theme
|
%label(for="theme") << Color Theme
|
||||||
|
%i(class="bi bi-question-circle-fill" title="{{desc.theme}}")
|
||||||
=func.new_select("theme", config.theme, themes)
|
=func.new_select("theme", config.theme, themes)
|
||||||
|
|
||||||
%label(for="log-level") << Log Level
|
%label(for="log-level") << Log Level
|
||||||
|
%i(class="bi bi-question-circle-fill" title="{{desc.log_level}}")
|
||||||
=func.new_select("log-level", config.log_level.name, levels)
|
=func.new_select("log-level", config.log_level.name, levels)
|
||||||
|
|
||||||
%label(for="whitelist-enabled") << Whitelist
|
%label(for="whitelist-enabled") << Whitelist
|
||||||
|
%i(class="bi bi-question-circle-fill" title="{{desc.whitelist_enabled}}")
|
||||||
=func.new_checkbox("whitelist-enabled", config.whitelist_enabled)
|
=func.new_checkbox("whitelist-enabled", config.whitelist_enabled)
|
||||||
|
|
||||||
%label(for="approval-required") << Approval Required
|
%label(for="approval-required") << Approval Required
|
||||||
|
%i(class="bi bi-question-circle-fill" title="{{desc.approval_required}}")
|
||||||
=func.new_checkbox("approval-required", config.approval_required)
|
=func.new_checkbox("approval-required", config.approval_required)
|
||||||
|
|
|
@ -1,9 +1,6 @@
|
||||||
-extends "base.haml"
|
-extends "base.haml"
|
||||||
-set page="Domain Bans"
|
-set page="Domain Bans"
|
||||||
|
|
||||||
-block head
|
|
||||||
%script(type="application/javascript" src="/static/domain_ban.js" nonce="{{view.request['hash']}}" defer)
|
|
||||||
|
|
||||||
-block content
|
-block content
|
||||||
%details.section
|
%details.section
|
||||||
%summary << Ban Domain
|
%summary << Ban Domain
|
||||||
|
@ -35,7 +32,7 @@
|
||||||
%tr(id="{{ban.domain}}")
|
%tr(id="{{ban.domain}}")
|
||||||
%td.domain
|
%td.domain
|
||||||
%details
|
%details
|
||||||
%summary -> =ban.domain
|
%summary -> =ban.domain.encode().decode("idna")
|
||||||
|
|
||||||
.grid-2col
|
.grid-2col
|
||||||
%label.reason(for="{{ban.domain}}-reason") << Reason
|
%label.reason(for="{{ban.domain}}-reason") << Reason
|
||||||
|
|
|
@ -1,9 +1,6 @@
|
||||||
-extends "base.haml"
|
-extends "base.haml"
|
||||||
-set page="Instances"
|
-set page="Instances"
|
||||||
|
|
||||||
-block head
|
|
||||||
%script(type="application/javascript" src="/static/instance.js" nonce="{{view.request['hash']}}" defer)
|
|
||||||
|
|
||||||
-block content
|
-block content
|
||||||
%details.section
|
%details.section
|
||||||
%summary << Add Instance
|
%summary << Add Instance
|
||||||
|
@ -39,7 +36,7 @@
|
||||||
-for request in requests
|
-for request in requests
|
||||||
%tr(id="{{request.domain}}")
|
%tr(id="{{request.domain}}")
|
||||||
%td.instance
|
%td.instance
|
||||||
%a(href="https://{{request.domain}}" target="_new") -> =request.domain
|
%a(href="https://{{request.domain}}" target="_new") -> =request.domain.encode().decode("idna")
|
||||||
|
|
||||||
%td.software
|
%td.software
|
||||||
=request.software or "n/a"
|
=request.software or "n/a"
|
||||||
|
@ -69,7 +66,7 @@
|
||||||
-for instance in instances
|
-for instance in instances
|
||||||
%tr(id="{{instance.domain}}")
|
%tr(id="{{instance.domain}}")
|
||||||
%td.instance
|
%td.instance
|
||||||
%a(href="https://{{instance.domain}}/" target="_new") -> =instance.domain
|
%a(href="https://{{instance.domain}}/" target="_new") -> =instance.domain.encode().decode("idna")
|
||||||
|
|
||||||
%td.software
|
%td.software
|
||||||
=instance.software or "n/a"
|
=instance.software or "n/a"
|
||||||
|
|
|
@ -1,9 +1,6 @@
|
||||||
-extends "base.haml"
|
-extends "base.haml"
|
||||||
-set page="Software Bans"
|
-set page="Software Bans"
|
||||||
|
|
||||||
-block head
|
|
||||||
%script(type="application/javascript" src="/static/software_ban.js" nonce="{{view.request['hash']}}" defer)
|
|
||||||
|
|
||||||
-block content
|
-block content
|
||||||
%details.section
|
%details.section
|
||||||
%summary << Ban Software
|
%summary << Ban Software
|
||||||
|
|
|
@ -1,9 +1,6 @@
|
||||||
-extends "base.haml"
|
-extends "base.haml"
|
||||||
-set page="Users"
|
-set page="Users"
|
||||||
|
|
||||||
-block head
|
|
||||||
%script(type="application/javascript" src="/static/user.js" nonce="{{view.request['hash']}}" defer)
|
|
||||||
|
|
||||||
-block content
|
-block content
|
||||||
%details.section
|
%details.section
|
||||||
%summary << Add User
|
%summary << Add User
|
||||||
|
|
|
@ -1,9 +1,6 @@
|
||||||
-extends "base.haml"
|
-extends "base.haml"
|
||||||
-set page="Whitelist"
|
-set page="Whitelist"
|
||||||
|
|
||||||
-block head
|
|
||||||
%script(type="application/javascript" src="/static/whitelist.js" nonce="{{view.request['hash']}}" defer)
|
|
||||||
|
|
||||||
-block content
|
-block content
|
||||||
%details.section
|
%details.section
|
||||||
%summary << Add Domain
|
%summary << Add Domain
|
||||||
|
@ -27,7 +24,7 @@
|
||||||
-for item in whitelist
|
-for item in whitelist
|
||||||
%tr(id="{{item.domain}}")
|
%tr(id="{{item.domain}}")
|
||||||
%td.domain
|
%td.domain
|
||||||
=item.domain
|
=item.domain.encode().decode("idna")
|
||||||
|
|
||||||
%td.date
|
%td.date
|
||||||
=item.created.strftime("%Y-%m-%d")
|
=item.created.strftime("%Y-%m-%d")
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
-extends "base.haml"
|
-extends "base.haml"
|
||||||
-set page = "Home"
|
-set page = "Home"
|
||||||
|
|
||||||
-block content
|
-block content
|
||||||
-if config.note
|
-if config.note
|
||||||
.section
|
.section
|
||||||
|
@ -41,7 +42,7 @@
|
||||||
-for instance in instances
|
-for instance in instances
|
||||||
%tr
|
%tr
|
||||||
%td.instance -> %a(href="https://{{instance.domain}}/" target="_new")
|
%td.instance -> %a(href="https://{{instance.domain}}/" target="_new")
|
||||||
=instance.domain
|
=instance.domain.encode().decode("idna")
|
||||||
|
|
||||||
%td.date
|
%td.date
|
||||||
=instance.created.strftime("%Y-%m-%d")
|
=instance.created.strftime("%Y-%m-%d")
|
||||||
|
|
|
@ -1,9 +1,6 @@
|
||||||
-extends "base.haml"
|
-extends "base.haml"
|
||||||
-set page="Login"
|
-set page="Login"
|
||||||
|
|
||||||
-block head
|
|
||||||
%script(type="application/javascript" src="/static/login.js" nonce="{{view.request['hash']}}" defer)
|
|
||||||
|
|
||||||
-block content
|
-block content
|
||||||
%fieldset.section
|
%fieldset.section
|
||||||
%legend << Login
|
%legend << Login
|
||||||
|
|
|
@ -1,135 +0,0 @@
|
||||||
// toast notifications
|
|
||||||
|
|
||||||
const notifications = document.querySelector("#notifications")
|
|
||||||
|
|
||||||
|
|
||||||
function remove_toast(toast) {
|
|
||||||
toast.classList.add("hide");
|
|
||||||
|
|
||||||
if (toast.timeoutId) {
|
|
||||||
clearTimeout(toast.timeoutId);
|
|
||||||
}
|
|
||||||
|
|
||||||
setTimeout(() => toast.remove(), 300);
|
|
||||||
}
|
|
||||||
|
|
||||||
function toast(text, type="error", timeout=5) {
|
|
||||||
const toast = document.createElement("li");
|
|
||||||
toast.className = `section ${type}`
|
|
||||||
toast.innerHTML = `<span class=".text">${text}</span><a href="#">✖</span>`
|
|
||||||
|
|
||||||
toast.querySelector("a").addEventListener("click", async (event) => {
|
|
||||||
event.preventDefault();
|
|
||||||
await remove_toast(toast);
|
|
||||||
});
|
|
||||||
|
|
||||||
notifications.appendChild(toast);
|
|
||||||
toast.timeoutId = setTimeout(() => remove_toast(toast), timeout * 1000);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// menu
|
|
||||||
|
|
||||||
const body = document.getElementById("container")
|
|
||||||
const menu = document.getElementById("menu");
|
|
||||||
const menu_open = document.getElementById("menu-open");
|
|
||||||
const menu_close = document.getElementById("menu-close");
|
|
||||||
|
|
||||||
|
|
||||||
menu_open.addEventListener("click", (event) => {
|
|
||||||
var new_value = menu.attributes.visible.nodeValue === "true" ? "false" : "true";
|
|
||||||
menu.attributes.visible.nodeValue = new_value;
|
|
||||||
});
|
|
||||||
|
|
||||||
menu_close.addEventListener("click", (event) => {
|
|
||||||
menu.attributes.visible.nodeValue = "false"
|
|
||||||
});
|
|
||||||
|
|
||||||
body.addEventListener("click", (event) => {
|
|
||||||
if (event.target === menu_open) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
menu.attributes.visible.nodeValue = "false";
|
|
||||||
});
|
|
||||||
|
|
||||||
|
|
||||||
// misc
|
|
||||||
|
|
||||||
function get_date_string(date) {
|
|
||||||
var year = date.getFullYear().toString();
|
|
||||||
var month = date.getMonth().toString();
|
|
||||||
var day = date.getDay().toString();
|
|
||||||
|
|
||||||
if (month.length === 1) {
|
|
||||||
month = "0" + month;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (day.length === 1) {
|
|
||||||
day = "0" + day
|
|
||||||
}
|
|
||||||
|
|
||||||
return `${year}-${month}-${day}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
function append_table_row(table, row_name, row) {
|
|
||||||
var table_row = table.insertRow(-1);
|
|
||||||
table_row.id = row_name;
|
|
||||||
|
|
||||||
index = 0;
|
|
||||||
|
|
||||||
for (var prop in row) {
|
|
||||||
if (Object.prototype.hasOwnProperty.call(row, prop)) {
|
|
||||||
var cell = table_row.insertCell(index);
|
|
||||||
cell.className = prop;
|
|
||||||
cell.innerHTML = row[prop];
|
|
||||||
|
|
||||||
index += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return table_row;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async function request(method, path, body = null) {
|
|
||||||
var headers = {
|
|
||||||
"Accept": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
if (body !== null) {
|
|
||||||
headers["Content-Type"] = "application/json"
|
|
||||||
body = JSON.stringify(body)
|
|
||||||
}
|
|
||||||
|
|
||||||
const response = await fetch("/api/" + path, {
|
|
||||||
method: method,
|
|
||||||
mode: "cors",
|
|
||||||
cache: "no-store",
|
|
||||||
redirect: "follow",
|
|
||||||
body: body,
|
|
||||||
headers: headers
|
|
||||||
});
|
|
||||||
|
|
||||||
const message = await response.json();
|
|
||||||
|
|
||||||
if (Object.hasOwn(message, "error")) {
|
|
||||||
throw new Error(message.error);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (Array.isArray(message)) {
|
|
||||||
message.forEach((msg) => {
|
|
||||||
if (Object.hasOwn(msg, "created")) {
|
|
||||||
msg.created = new Date(msg.created);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
} else {
|
|
||||||
if (Object.hasOwn(message, "created")) {
|
|
||||||
message.created = new Date(message.created);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return message;
|
|
||||||
}
|
|
2077
relay/frontend/static/bootstrap-icons.css
vendored
Normal file
2077
relay/frontend/static/bootstrap-icons.css
vendored
Normal file
File diff suppressed because it is too large
Load diff
BIN
relay/frontend/static/bootstrap-icons.woff2
Normal file
BIN
relay/frontend/static/bootstrap-icons.woff2
Normal file
Binary file not shown.
|
@ -1,40 +0,0 @@
|
||||||
const elems = [
|
|
||||||
document.querySelector("#name"),
|
|
||||||
document.querySelector("#note"),
|
|
||||||
document.querySelector("#theme"),
|
|
||||||
document.querySelector("#log-level"),
|
|
||||||
document.querySelector("#whitelist-enabled"),
|
|
||||||
document.querySelector("#approval-required")
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
async function handle_config_change(event) {
|
|
||||||
params = {
|
|
||||||
key: event.target.id,
|
|
||||||
value: event.target.type === "checkbox" ? event.target.checked : event.target.value
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
await request("POST", "v1/config", params);
|
|
||||||
|
|
||||||
} catch (error) {
|
|
||||||
toast(error);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params.key === "name") {
|
|
||||||
document.querySelector("#header .title").innerHTML = params.value;
|
|
||||||
document.querySelector("title").innerHTML = params.value;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params.key === "theme") {
|
|
||||||
document.querySelector("link.theme").href = `/theme/${params.value}.css`;
|
|
||||||
}
|
|
||||||
|
|
||||||
toast("Updated config", "message");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
for (const elem of elems) {
|
|
||||||
elem.addEventListener("change", handle_config_change);
|
|
||||||
}
|
|
|
@ -1,123 +0,0 @@
|
||||||
function create_ban_object(domain, reason, note) {
|
|
||||||
var text = '<details>\n';
|
|
||||||
text += `<summary>${domain}</summary>\n`;
|
|
||||||
text += '<div class="grid-2col">\n';
|
|
||||||
text += `<label for="${domain}-reason" class="reason">Reason</label>\n`;
|
|
||||||
text += `<textarea id="${domain}-reason" class="reason">${reason}</textarea>\n`;
|
|
||||||
text += `<label for="${domain}-note" class="note">Note</label>\n`;
|
|
||||||
text += `<textarea id="${domain}-note" class="note">${note}</textarea>\n`;
|
|
||||||
text += `<input class="update-ban" type="button" value="Update">`;
|
|
||||||
text += '</details>';
|
|
||||||
|
|
||||||
return text;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
function add_row_listeners(row) {
|
|
||||||
row.querySelector(".update-ban").addEventListener("click", async (event) => {
|
|
||||||
await update_ban(row.id);
|
|
||||||
});
|
|
||||||
|
|
||||||
row.querySelector(".remove a").addEventListener("click", async (event) => {
|
|
||||||
event.preventDefault();
|
|
||||||
await unban(row.id);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async function ban() {
|
|
||||||
var table = document.querySelector("table");
|
|
||||||
var elems = {
|
|
||||||
domain: document.getElementById("new-domain"),
|
|
||||||
reason: document.getElementById("new-reason"),
|
|
||||||
note: document.getElementById("new-note")
|
|
||||||
}
|
|
||||||
|
|
||||||
var values = {
|
|
||||||
domain: elems.domain.value.trim(),
|
|
||||||
reason: elems.reason.value.trim(),
|
|
||||||
note: elems.note.value.trim()
|
|
||||||
}
|
|
||||||
|
|
||||||
if (values.domain === "") {
|
|
||||||
toast("Domain is required");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
var ban = await request("POST", "v1/domain_ban", values);
|
|
||||||
|
|
||||||
} catch (err) {
|
|
||||||
toast(err);
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var row = append_table_row(document.querySelector("table"), ban.domain, {
|
|
||||||
domain: create_ban_object(ban.domain, ban.reason, ban.note),
|
|
||||||
date: get_date_string(ban.created),
|
|
||||||
remove: `<a href="#" title="Unban domain">✖</a>`
|
|
||||||
});
|
|
||||||
|
|
||||||
add_row_listeners(row);
|
|
||||||
|
|
||||||
elems.domain.value = null;
|
|
||||||
elems.reason.value = null;
|
|
||||||
elems.note.value = null;
|
|
||||||
|
|
||||||
document.querySelector("details.section").open = false;
|
|
||||||
toast("Banned domain", "message");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async function update_ban(domain) {
|
|
||||||
var row = document.getElementById(domain);
|
|
||||||
|
|
||||||
var elems = {
|
|
||||||
"reason": row.querySelector("textarea.reason"),
|
|
||||||
"note": row.querySelector("textarea.note")
|
|
||||||
}
|
|
||||||
|
|
||||||
var values = {
|
|
||||||
"domain": domain,
|
|
||||||
"reason": elems.reason.value,
|
|
||||||
"note": elems.note.value
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
await request("PATCH", "v1/domain_ban", values)
|
|
||||||
|
|
||||||
} catch (error) {
|
|
||||||
toast(error);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
row.querySelector("details").open = false;
|
|
||||||
toast("Updated baned domain", "message");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async function unban(domain) {
|
|
||||||
try {
|
|
||||||
await request("DELETE", "v1/domain_ban", {"domain": domain});
|
|
||||||
|
|
||||||
} catch (error) {
|
|
||||||
toast(error);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
document.getElementById(domain).remove();
|
|
||||||
toast("Unbanned domain", "message");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
document.querySelector("#new-ban").addEventListener("click", async (event) => {
|
|
||||||
await ban();
|
|
||||||
});
|
|
||||||
|
|
||||||
for (var row of document.querySelector("fieldset.section table").rows) {
|
|
||||||
if (!row.querySelector(".update-ban")) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
add_row_listeners(row);
|
|
||||||
}
|
|
|
@ -1,145 +0,0 @@
|
||||||
function add_instance_listeners(row) {
|
|
||||||
row.querySelector(".remove a").addEventListener("click", async (event) => {
|
|
||||||
event.preventDefault();
|
|
||||||
await del_instance(row.id);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
function add_request_listeners(row) {
|
|
||||||
row.querySelector(".approve a").addEventListener("click", async (event) => {
|
|
||||||
event.preventDefault();
|
|
||||||
await req_response(row.id, true);
|
|
||||||
});
|
|
||||||
|
|
||||||
row.querySelector(".deny a").addEventListener("click", async (event) => {
|
|
||||||
event.preventDefault();
|
|
||||||
await req_response(row.id, false);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async function add_instance() {
|
|
||||||
var elems = {
|
|
||||||
actor: document.getElementById("new-actor"),
|
|
||||||
inbox: document.getElementById("new-inbox"),
|
|
||||||
followid: document.getElementById("new-followid"),
|
|
||||||
software: document.getElementById("new-software")
|
|
||||||
}
|
|
||||||
|
|
||||||
var values = {
|
|
||||||
actor: elems.actor.value.trim(),
|
|
||||||
inbox: elems.inbox.value.trim(),
|
|
||||||
followid: elems.followid.value.trim(),
|
|
||||||
software: elems.software.value.trim()
|
|
||||||
}
|
|
||||||
|
|
||||||
if (values.actor === "") {
|
|
||||||
toast("Actor is required");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
var instance = await request("POST", "v1/instance", values);
|
|
||||||
|
|
||||||
} catch (err) {
|
|
||||||
toast(err);
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
row = append_table_row(document.getElementById("instances"), instance.domain, {
|
|
||||||
domain: `<a href="https://${instance.domain}/" target="_new">${instance.domain}</a>`,
|
|
||||||
software: instance.software,
|
|
||||||
date: get_date_string(instance.created),
|
|
||||||
remove: `<a href="#" title="Remove Instance">✖</a>`
|
|
||||||
});
|
|
||||||
|
|
||||||
add_instance_listeners(row);
|
|
||||||
|
|
||||||
elems.actor.value = null;
|
|
||||||
elems.inbox.value = null;
|
|
||||||
elems.followid.value = null;
|
|
||||||
elems.software.value = null;
|
|
||||||
|
|
||||||
document.querySelector("details.section").open = false;
|
|
||||||
toast("Added instance", "message");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async function del_instance(domain) {
|
|
||||||
try {
|
|
||||||
await request("DELETE", "v1/instance", {"domain": domain});
|
|
||||||
|
|
||||||
} catch (error) {
|
|
||||||
toast(error);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
document.getElementById(domain).remove();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async function req_response(domain, accept) {
|
|
||||||
params = {
|
|
||||||
"domain": domain,
|
|
||||||
"accept": accept
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
await request("POST", "v1/request", params);
|
|
||||||
|
|
||||||
} catch (error) {
|
|
||||||
toast(error);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
document.getElementById(domain).remove();
|
|
||||||
|
|
||||||
if (document.getElementById("requests").rows.length < 2) {
|
|
||||||
document.querySelector("fieldset.requests").remove()
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!accept) {
|
|
||||||
toast("Denied instance request", "message");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
instances = await request("GET", `v1/instance`, null);
|
|
||||||
instances.forEach((instance) => {
|
|
||||||
if (instance.domain === domain) {
|
|
||||||
row = append_table_row(document.getElementById("instances"), instance.domain, {
|
|
||||||
domain: `<a href="https://${instance.domain}/" target="_new">${instance.domain}</a>`,
|
|
||||||
software: instance.software,
|
|
||||||
date: get_date_string(instance.created),
|
|
||||||
remove: `<a href="#" title="Remove Instance">✖</a>`
|
|
||||||
});
|
|
||||||
|
|
||||||
add_instance_listeners(row);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
toast("Accepted instance request", "message");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
document.querySelector("#add-instance").addEventListener("click", async (event) => {
|
|
||||||
await add_instance();
|
|
||||||
})
|
|
||||||
|
|
||||||
for (var row of document.querySelector("#instances").rows) {
|
|
||||||
if (!row.querySelector(".remove a")) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
add_instance_listeners(row);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (document.querySelector("#requests")) {
|
|
||||||
for (var row of document.querySelector("#requests").rows) {
|
|
||||||
if (!row.querySelector(".approve a")) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
add_request_listeners(row);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,29 +0,0 @@
|
||||||
async function login(event) {
|
|
||||||
fields = {
|
|
||||||
username: document.querySelector("#username"),
|
|
||||||
password: document.querySelector("#password")
|
|
||||||
}
|
|
||||||
|
|
||||||
values = {
|
|
||||||
username: fields.username.value.trim(),
|
|
||||||
password: fields.password.value.trim()
|
|
||||||
}
|
|
||||||
|
|
||||||
if (values.username === "" | values.password === "") {
|
|
||||||
toast("Username and/or password field is blank");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
await request("POST", "v1/token", values);
|
|
||||||
|
|
||||||
} catch (error) {
|
|
||||||
toast(error);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
document.location = "/";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
document.querySelector(".submit").addEventListener("click", login);
|
|
|
@ -1,122 +0,0 @@
|
||||||
function create_ban_object(name, reason, note) {
|
|
||||||
var text = '<details>\n';
|
|
||||||
text += `<summary>${name}</summary>\n`;
|
|
||||||
text += '<div class="grid-2col">\n';
|
|
||||||
text += `<label for="${name}-reason" class="reason">Reason</label>\n`;
|
|
||||||
text += `<textarea id="${name}-reason" class="reason">${reason}</textarea>\n`;
|
|
||||||
text += `<label for="${name}-note" class="note">Note</label>\n`;
|
|
||||||
text += `<textarea id="${name}-note" class="note">${note}</textarea>\n`;
|
|
||||||
text += `<input class="update-ban" type="button" value="Update">`;
|
|
||||||
text += '</details>';
|
|
||||||
|
|
||||||
return text;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
function add_row_listeners(row) {
|
|
||||||
row.querySelector(".update-ban").addEventListener("click", async (event) => {
|
|
||||||
await update_ban(row.id);
|
|
||||||
});
|
|
||||||
|
|
||||||
row.querySelector(".remove a").addEventListener("click", async (event) => {
|
|
||||||
event.preventDefault();
|
|
||||||
await unban(row.id);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async function ban() {
|
|
||||||
var elems = {
|
|
||||||
name: document.getElementById("new-name"),
|
|
||||||
reason: document.getElementById("new-reason"),
|
|
||||||
note: document.getElementById("new-note")
|
|
||||||
}
|
|
||||||
|
|
||||||
var values = {
|
|
||||||
name: elems.name.value.trim(),
|
|
||||||
reason: elems.reason.value,
|
|
||||||
note: elems.note.value
|
|
||||||
}
|
|
||||||
|
|
||||||
if (values.name === "") {
|
|
||||||
toast("Domain is required");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
var ban = await request("POST", "v1/software_ban", values);
|
|
||||||
|
|
||||||
} catch (err) {
|
|
||||||
toast(err);
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var row = append_table_row(document.getElementById("bans"), ban.name, {
|
|
||||||
name: create_ban_object(ban.name, ban.reason, ban.note),
|
|
||||||
date: get_date_string(ban.created),
|
|
||||||
remove: `<a href="#" title="Unban software">✖</a>`
|
|
||||||
});
|
|
||||||
|
|
||||||
add_row_listeners(row);
|
|
||||||
|
|
||||||
elems.name.value = null;
|
|
||||||
elems.reason.value = null;
|
|
||||||
elems.note.value = null;
|
|
||||||
|
|
||||||
document.querySelector("details.section").open = false;
|
|
||||||
toast("Banned software", "message");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async function update_ban(name) {
|
|
||||||
var row = document.getElementById(name);
|
|
||||||
|
|
||||||
var elems = {
|
|
||||||
"reason": row.querySelector("textarea.reason"),
|
|
||||||
"note": row.querySelector("textarea.note")
|
|
||||||
}
|
|
||||||
|
|
||||||
var values = {
|
|
||||||
"name": name,
|
|
||||||
"reason": elems.reason.value,
|
|
||||||
"note": elems.note.value
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
await request("PATCH", "v1/software_ban", values)
|
|
||||||
|
|
||||||
} catch (error) {
|
|
||||||
toast(error);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
row.querySelector("details").open = false;
|
|
||||||
toast("Updated software ban", "message");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async function unban(name) {
|
|
||||||
try {
|
|
||||||
await request("DELETE", "v1/software_ban", {"name": name});
|
|
||||||
|
|
||||||
} catch (error) {
|
|
||||||
toast(error);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
document.getElementById(name).remove();
|
|
||||||
toast("Unbanned software", "message");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
document.querySelector("#new-ban").addEventListener("click", async (event) => {
|
|
||||||
await ban();
|
|
||||||
});
|
|
||||||
|
|
||||||
for (var row of document.querySelector("#bans").rows) {
|
|
||||||
if (!row.querySelector(".update-ban")) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
add_row_listeners(row);
|
|
||||||
}
|
|
|
@ -155,6 +155,7 @@ textarea {
|
||||||
z-index: 1;
|
z-index: 1;
|
||||||
font-size: 1.5em;
|
font-size: 1.5em;
|
||||||
min-width: 300px;
|
min-width: 300px;
|
||||||
|
overflow-x: auto;
|
||||||
}
|
}
|
||||||
|
|
||||||
#menu[visible="false"] {
|
#menu[visible="false"] {
|
||||||
|
@ -188,11 +189,17 @@ textarea {
|
||||||
}
|
}
|
||||||
|
|
||||||
#menu-open {
|
#menu-open {
|
||||||
color: var(--primary);
|
color: var(--background);
|
||||||
|
background: var(--primary);
|
||||||
|
font-size: 38px;
|
||||||
|
line-height: 38px;
|
||||||
|
border: 1px solid var(--primary);
|
||||||
|
border-radius: 5px;
|
||||||
}
|
}
|
||||||
|
|
||||||
#menu-open:hover {
|
#menu-open:hover {
|
||||||
color: var(--primary-hover);
|
color: var(--primary);
|
||||||
|
background: var(--background);
|
||||||
}
|
}
|
||||||
|
|
||||||
#menu-open, #menu-close {
|
#menu-open, #menu-close {
|
||||||
|
@ -290,13 +297,13 @@ textarea {
|
||||||
border: 1px solid var(--error-border) !important;
|
border: 1px solid var(--error-border) !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* create .grid base class and .2col and 3col classes */
|
||||||
.grid-2col {
|
.grid-2col {
|
||||||
display: grid;
|
display: grid;
|
||||||
grid-template-columns: max-content auto;
|
grid-template-columns: max-content auto;
|
||||||
grid-gap: var(--spacing);
|
grid-gap: var(--spacing);
|
||||||
margin-bottom: var(--spacing);
|
margin-bottom: var(--spacing);
|
||||||
align-items: center;
|
align-items: center;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.message {
|
.message {
|
||||||
|
@ -326,6 +333,10 @@ textarea {
|
||||||
justify-self: left;
|
justify-self: left;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#content.page-config .grid-2col {
|
||||||
|
grid-template-columns: max-content max-content auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@keyframes show_toast {
|
@keyframes show_toast {
|
||||||
0% {
|
0% {
|
||||||
|
|
|
@ -1,85 +0,0 @@
|
||||||
function add_row_listeners(row) {
|
|
||||||
row.querySelector(".remove a").addEventListener("click", async (event) => {
|
|
||||||
event.preventDefault();
|
|
||||||
await del_user(row.id);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async function add_user() {
|
|
||||||
var elems = {
|
|
||||||
username: document.getElementById("new-username"),
|
|
||||||
password: document.getElementById("new-password"),
|
|
||||||
password2: document.getElementById("new-password2"),
|
|
||||||
handle: document.getElementById("new-handle")
|
|
||||||
}
|
|
||||||
|
|
||||||
var values = {
|
|
||||||
username: elems.username.value.trim(),
|
|
||||||
password: elems.password.value.trim(),
|
|
||||||
password2: elems.password2.value.trim(),
|
|
||||||
handle: elems.handle.value.trim()
|
|
||||||
}
|
|
||||||
|
|
||||||
if (values.username === "" | values.password === "" | values.password2 === "") {
|
|
||||||
toast("Username, password, and password2 are required");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (values.password !== values.password2) {
|
|
||||||
toast("Passwords do not match");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
var user = await request("POST", "v1/user", values);
|
|
||||||
|
|
||||||
} catch (err) {
|
|
||||||
toast(err);
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var row = append_table_row(document.querySelector("fieldset.section table"), user.username, {
|
|
||||||
domain: user.username,
|
|
||||||
handle: user.handle ? self.handle : "n/a",
|
|
||||||
date: get_date_string(user.created),
|
|
||||||
remove: `<a href="#" title="Delete User">✖</a>`
|
|
||||||
});
|
|
||||||
|
|
||||||
add_row_listeners(row);
|
|
||||||
|
|
||||||
elems.username.value = null;
|
|
||||||
elems.password.value = null;
|
|
||||||
elems.password2.value = null;
|
|
||||||
elems.handle.value = null;
|
|
||||||
|
|
||||||
document.querySelector("details.section").open = false;
|
|
||||||
toast("Created user", "message");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async function del_user(username) {
|
|
||||||
try {
|
|
||||||
await request("DELETE", "v1/user", {"username": username});
|
|
||||||
|
|
||||||
} catch (error) {
|
|
||||||
toast(error);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
document.getElementById(username).remove();
|
|
||||||
toast("Deleted user", "message");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
document.querySelector("#new-user").addEventListener("click", async (event) => {
|
|
||||||
await add_user();
|
|
||||||
});
|
|
||||||
|
|
||||||
for (var row of document.querySelector("#users").rows) {
|
|
||||||
if (!row.querySelector(".remove a")) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
add_row_listeners(row);
|
|
||||||
}
|
|
|
@ -1,64 +0,0 @@
|
||||||
function add_row_listeners(row) {
|
|
||||||
row.querySelector(".remove a").addEventListener("click", async (event) => {
|
|
||||||
event.preventDefault();
|
|
||||||
await del_whitelist(row.id);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async function add_whitelist() {
|
|
||||||
var domain_elem = document.getElementById("new-domain");
|
|
||||||
var domain = domain_elem.value.trim();
|
|
||||||
|
|
||||||
if (domain === "") {
|
|
||||||
toast("Domain is required");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
var item = await request("POST", "v1/whitelist", {"domain": domain});
|
|
||||||
|
|
||||||
} catch (err) {
|
|
||||||
toast(err);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
var row = append_table_row(document.getElementById("whitelist"), item.domain, {
|
|
||||||
domain: item.domain,
|
|
||||||
date: get_date_string(item.created),
|
|
||||||
remove: `<a href="#" title="Remove whitelisted domain">✖</a>`
|
|
||||||
});
|
|
||||||
|
|
||||||
add_row_listeners(row);
|
|
||||||
|
|
||||||
domain_elem.value = null;
|
|
||||||
document.querySelector("details.section").open = false;
|
|
||||||
toast("Added domain", "message");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async function del_whitelist(domain) {
|
|
||||||
try {
|
|
||||||
await request("DELETE", "v1/whitelist", {"domain": domain});
|
|
||||||
|
|
||||||
} catch (error) {
|
|
||||||
toast(error);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
document.getElementById(domain).remove();
|
|
||||||
toast("Removed domain", "message");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
document.querySelector("#new-item").addEventListener("click", async (event) => {
|
|
||||||
await add_whitelist();
|
|
||||||
});
|
|
||||||
|
|
||||||
for (var row of document.querySelector("fieldset.section table").rows) {
|
|
||||||
if (!row.querySelector(".remove a")) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
add_row_listeners(row);
|
|
||||||
}
|
|
|
@ -1,29 +1,37 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import traceback
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from aiohttp import ClientSession, ClientTimeout, TCPConnector
|
from aiohttp import ClientSession, ClientTimeout, TCPConnector
|
||||||
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
|
from aputils import AlgorithmType, Nodeinfo, ObjectType, Signer, WellKnownNodeinfo
|
||||||
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
|
from blib import JsonBase
|
||||||
from aputils import AlgorithmType, JsonBase, Nodeinfo, ObjectType, WellKnownNodeinfo
|
from typing import TYPE_CHECKING, Any, TypeVar, overload
|
||||||
from json.decoder import JSONDecodeError
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
from . import __version__
|
from . import __version__, logger as logging
|
||||||
from . import logger as logging
|
from .cache import Cache
|
||||||
|
from .database.schema import Instance
|
||||||
from .misc import MIMETYPES, Message, get_app
|
from .misc import MIMETYPES, Message, get_app
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from aputils import Signer
|
|
||||||
from bsql import Row
|
|
||||||
from typing import Any
|
|
||||||
from .application import Application
|
from .application import Application
|
||||||
from .cache import Cache
|
|
||||||
|
|
||||||
|
|
||||||
T = typing.TypeVar('T', bound = JsonBase)
|
SUPPORTS_HS2019 = {
|
||||||
|
'friendica',
|
||||||
|
'gotosocial',
|
||||||
|
'hubzilla'
|
||||||
|
'mastodon',
|
||||||
|
'socialhome',
|
||||||
|
'misskey',
|
||||||
|
'catodon',
|
||||||
|
'cherrypick',
|
||||||
|
'firefish',
|
||||||
|
'foundkey',
|
||||||
|
'iceshrimp',
|
||||||
|
'sharkey'
|
||||||
|
}
|
||||||
|
|
||||||
|
T = TypeVar('T', bound = JsonBase)
|
||||||
HEADERS = {
|
HEADERS = {
|
||||||
'Accept': f'{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9',
|
'Accept': f'{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9',
|
||||||
'User-Agent': f'ActivityRelay/{__version__}'
|
'User-Agent': f'ActivityRelay/{__version__}'
|
||||||
|
@ -90,7 +98,12 @@ class HttpClient:
|
||||||
self._session = None
|
self._session = None
|
||||||
|
|
||||||
|
|
||||||
async def _get(self, url: str, sign_headers: bool, force: bool) -> dict[str, Any] | None:
|
async def _get(self,
|
||||||
|
url: str,
|
||||||
|
sign_headers: bool,
|
||||||
|
force: bool,
|
||||||
|
old_algo: bool) -> str | None:
|
||||||
|
|
||||||
if not self._session:
|
if not self._session:
|
||||||
raise RuntimeError('Client not open')
|
raise RuntimeError('Client not open')
|
||||||
|
|
||||||
|
@ -103,7 +116,7 @@ class HttpClient:
|
||||||
if not force:
|
if not force:
|
||||||
try:
|
try:
|
||||||
if not (item := self.cache.get('request', url)).older_than(48):
|
if not (item := self.cache.get('request', url)).older_than(48):
|
||||||
return json.loads(item.value)
|
return item.value # type: ignore [no-any-return]
|
||||||
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
logging.verbose('No cached data for url: %s', url)
|
logging.verbose('No cached data for url: %s', url)
|
||||||
|
@ -111,67 +124,72 @@ class HttpClient:
|
||||||
headers = {}
|
headers = {}
|
||||||
|
|
||||||
if sign_headers:
|
if sign_headers:
|
||||||
headers = self.signer.sign_headers('GET', url, algorithm = AlgorithmType.HS2019)
|
algo = AlgorithmType.RSASHA256 if old_algo else AlgorithmType.HS2019
|
||||||
|
headers = self.signer.sign_headers('GET', url, algorithm = algo)
|
||||||
|
|
||||||
try:
|
logging.debug('Fetching resource: %s', url)
|
||||||
logging.debug('Fetching resource: %s', url)
|
|
||||||
|
|
||||||
async with self._session.get(url, headers = headers) as resp:
|
async with self._session.get(url, headers = headers) as resp:
|
||||||
# Not expecting a response with 202s, so just return
|
# Not expecting a response with 202s, so just return
|
||||||
if resp.status == 202:
|
if resp.status == 202:
|
||||||
return None
|
|
||||||
|
|
||||||
data = await resp.text()
|
|
||||||
|
|
||||||
if resp.status != 200:
|
|
||||||
logging.verbose('Received error when requesting %s: %i', url, resp.status)
|
|
||||||
logging.debug(data)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
self.cache.set('request', url, data, 'str')
|
data = await resp.text()
|
||||||
logging.debug('%s >> resp %s', url, json.dumps(json.loads(data), indent = 4))
|
|
||||||
|
|
||||||
return json.loads(data)
|
if resp.status != 200:
|
||||||
|
logging.verbose('Received error when requesting %s: %i', url, resp.status)
|
||||||
except JSONDecodeError:
|
logging.debug(data)
|
||||||
logging.verbose('Failed to parse JSON')
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
except ClientSSLError as e:
|
self.cache.set('request', url, data, 'str')
|
||||||
logging.verbose('SSL error when connecting to %s', urlparse(url).netloc)
|
return data
|
||||||
logging.warning(str(e))
|
|
||||||
|
|
||||||
except (AsyncTimeoutError, ClientConnectionError) as e:
|
|
||||||
logging.verbose('Failed to connect to %s', urlparse(url).netloc)
|
|
||||||
logging.warning(str(e))
|
|
||||||
|
|
||||||
except Exception:
|
@overload
|
||||||
traceback.print_exc()
|
async def get(self, # type: ignore[overload-overlap]
|
||||||
|
url: str,
|
||||||
|
sign_headers: bool,
|
||||||
|
cls: None = None,
|
||||||
|
force: bool = False,
|
||||||
|
old_algo: bool = True) -> None: ...
|
||||||
|
|
||||||
return None
|
|
||||||
|
@overload
|
||||||
|
async def get(self,
|
||||||
|
url: str,
|
||||||
|
sign_headers: bool,
|
||||||
|
cls: type[T] = JsonBase, # type: ignore[assignment]
|
||||||
|
force: bool = False,
|
||||||
|
old_algo: bool = True) -> T: ...
|
||||||
|
|
||||||
|
|
||||||
async def get(self,
|
async def get(self,
|
||||||
url: str,
|
url: str,
|
||||||
sign_headers: bool,
|
sign_headers: bool,
|
||||||
cls: type[T],
|
cls: type[T] | None = None,
|
||||||
force: bool = False) -> T | None:
|
force: bool = False,
|
||||||
|
old_algo: bool = True) -> T | None:
|
||||||
|
|
||||||
if not issubclass(cls, JsonBase):
|
if cls is not None and not issubclass(cls, JsonBase):
|
||||||
raise TypeError('cls must be a sub-class of "aputils.JsonBase"')
|
raise TypeError('cls must be a sub-class of "blib.JsonBase"')
|
||||||
|
|
||||||
if (data := (await self._get(url, sign_headers, force))) is None:
|
data = await self._get(url, sign_headers, force, old_algo)
|
||||||
return None
|
|
||||||
|
|
||||||
return cls.parse(data)
|
if cls is not None:
|
||||||
|
if data is None:
|
||||||
|
raise ValueError("Empty response")
|
||||||
|
|
||||||
|
return cls.parse(data)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def post(self, url: str, data: Message | bytes, instance: Row | None = None) -> None:
|
async def post(self, url: str, data: Message | bytes, instance: Instance | None = None) -> None:
|
||||||
if not self._session:
|
if not self._session:
|
||||||
raise RuntimeError('Client not open')
|
raise RuntimeError('Client not open')
|
||||||
|
|
||||||
# akkoma and pleroma do not support HS2019 and other software still needs to be tested
|
# akkoma and pleroma do not support HS2019 and other software still needs to be tested
|
||||||
if instance and instance['software'] in {'mastodon'}:
|
if instance is not None and instance.software in SUPPORTS_HS2019:
|
||||||
algorithm = AlgorithmType.HS2019
|
algorithm = AlgorithmType.HS2019
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -197,35 +215,22 @@ class HttpClient:
|
||||||
algorithm = algorithm
|
algorithm = algorithm
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
logging.verbose('Sending "%s" to %s', mtype, url)
|
||||||
logging.verbose('Sending "%s" to %s', mtype, url)
|
|
||||||
|
|
||||||
async with self._session.post(url, headers = headers, data = body) as resp:
|
async with self._session.post(url, headers = headers, data = body) as resp:
|
||||||
# Not expecting a response, so just return
|
# Not expecting a response, so just return
|
||||||
if resp.status in {200, 202}:
|
if resp.status in {200, 202}:
|
||||||
logging.verbose('Successfully sent "%s" to %s', mtype, url)
|
logging.verbose('Successfully sent "%s" to %s', mtype, url)
|
||||||
return
|
|
||||||
|
|
||||||
logging.verbose('Received error when pushing to %s: %i', url, resp.status)
|
|
||||||
logging.debug(await resp.read())
|
|
||||||
logging.debug("message: %s", body.decode("utf-8"))
|
|
||||||
logging.debug("headers: %s", json.dumps(headers, indent = 4))
|
|
||||||
return
|
return
|
||||||
|
|
||||||
except ClientSSLError as e:
|
logging.error('Received error when pushing to %s: %i', url, resp.status)
|
||||||
logging.warning('SSL error when pushing to %s', urlparse(url).netloc)
|
logging.debug(await resp.read())
|
||||||
logging.warning(str(e))
|
logging.debug("message: %s", body.decode("utf-8"))
|
||||||
|
logging.debug("headers: %s", json.dumps(headers, indent = 4))
|
||||||
except (AsyncTimeoutError, ClientConnectionError) as e:
|
return
|
||||||
logging.warning('Failed to connect to %s for message push', urlparse(url).netloc)
|
|
||||||
logging.warning(str(e))
|
|
||||||
|
|
||||||
# prevent workers from being brought down
|
|
||||||
except Exception:
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
|
|
||||||
async def fetch_nodeinfo(self, domain: str) -> Nodeinfo | None:
|
async def fetch_nodeinfo(self, domain: str) -> Nodeinfo:
|
||||||
nodeinfo_url = None
|
nodeinfo_url = None
|
||||||
wk_nodeinfo = await self.get(
|
wk_nodeinfo = await self.get(
|
||||||
f'https://{domain}/.well-known/nodeinfo',
|
f'https://{domain}/.well-known/nodeinfo',
|
||||||
|
@ -233,10 +238,6 @@ class HttpClient:
|
||||||
WellKnownNodeinfo
|
WellKnownNodeinfo
|
||||||
)
|
)
|
||||||
|
|
||||||
if wk_nodeinfo is None:
|
|
||||||
logging.verbose('Failed to fetch well-known nodeinfo url for %s', domain)
|
|
||||||
return None
|
|
||||||
|
|
||||||
for version in ('20', '21'):
|
for version in ('20', '21'):
|
||||||
try:
|
try:
|
||||||
nodeinfo_url = wk_nodeinfo.get_url(version)
|
nodeinfo_url = wk_nodeinfo.get_url(version)
|
||||||
|
@ -245,8 +246,7 @@ class HttpClient:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if nodeinfo_url is None:
|
if nodeinfo_url is None:
|
||||||
logging.verbose('Failed to fetch nodeinfo url for %s', domain)
|
raise ValueError(f'Failed to fetch nodeinfo url for {domain}')
|
||||||
return None
|
|
||||||
|
|
||||||
return await self.get(nodeinfo_url, False, Nodeinfo)
|
return await self.get(nodeinfo_url, False, Nodeinfo)
|
||||||
|
|
||||||
|
|
|
@ -2,15 +2,12 @@ from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import typing
|
|
||||||
|
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Any, Protocol
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Callable
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from typing import Self
|
from typing import Self
|
||||||
|
|
||||||
|
@ -18,6 +15,10 @@ if typing.TYPE_CHECKING:
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
|
||||||
|
class LoggingMethod(Protocol):
|
||||||
|
def __call__(self, msg: Any, *args: Any, **kwargs: Any) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
class LogLevel(IntEnum):
|
class LogLevel(IntEnum):
|
||||||
DEBUG = logging.DEBUG
|
DEBUG = logging.DEBUG
|
||||||
VERBOSE = 15
|
VERBOSE = 15
|
||||||
|
@ -75,11 +76,11 @@ def verbose(message: str, *args: Any, **kwargs: Any) -> None:
|
||||||
logging.log(LogLevel.VERBOSE, message, *args, **kwargs)
|
logging.log(LogLevel.VERBOSE, message, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
debug: Callable = logging.debug
|
debug: LoggingMethod = logging.debug
|
||||||
info: Callable = logging.info
|
info: LoggingMethod = logging.info
|
||||||
warning: Callable = logging.warning
|
warning: LoggingMethod = logging.warning
|
||||||
error: Callable = logging.error
|
error: LoggingMethod = logging.error
|
||||||
critical: Callable = logging.critical
|
critical: LoggingMethod = logging.critical
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
136
relay/manage.py
136
relay/manage.py
|
@ -5,10 +5,10 @@ import asyncio
|
||||||
import click
|
import click
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import typing
|
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
|
from typing import Any
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from . import __version__
|
from . import __version__
|
||||||
|
@ -16,13 +16,9 @@ from . import http_client as http
|
||||||
from . import logger as logging
|
from . import logger as logging
|
||||||
from .application import Application
|
from .application import Application
|
||||||
from .compat import RelayConfig, RelayDatabase
|
from .compat import RelayConfig, RelayDatabase
|
||||||
from .database import RELAY_SOFTWARE, get_database
|
from .database import RELAY_SOFTWARE, get_database, schema
|
||||||
from .misc import ACTOR_FORMATS, SOFTWARE, IS_DOCKER, Message
|
from .misc import ACTOR_FORMATS, SOFTWARE, IS_DOCKER, Message
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from bsql import Row
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
def check_alphanumeric(text: str) -> str:
|
def check_alphanumeric(text: str) -> str:
|
||||||
if not text.isalnum():
|
if not text.isalnum():
|
||||||
|
@ -370,8 +366,8 @@ def cli_user_list(ctx: click.Context) -> None:
|
||||||
click.echo('Users:')
|
click.echo('Users:')
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
for user in conn.execute('SELECT * FROM users'):
|
for row in conn.get_users():
|
||||||
click.echo(f'- {user["username"]}')
|
click.echo(f'- {row.username}')
|
||||||
|
|
||||||
|
|
||||||
@cli_user.command('create')
|
@cli_user.command('create')
|
||||||
|
@ -382,7 +378,7 @@ def cli_user_create(ctx: click.Context, username: str, handle: str) -> None:
|
||||||
'Create a new local user'
|
'Create a new local user'
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
if conn.get_user(username):
|
if conn.get_user(username) is not None:
|
||||||
click.echo(f'User already exists: {username}')
|
click.echo(f'User already exists: {username}')
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -409,7 +405,7 @@ def cli_user_delete(ctx: click.Context, username: str) -> None:
|
||||||
'Delete a local user'
|
'Delete a local user'
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
if not conn.get_user(username):
|
if conn.get_user(username) is None:
|
||||||
click.echo(f'User does not exist: {username}')
|
click.echo(f'User does not exist: {username}')
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -427,8 +423,8 @@ def cli_user_list_tokens(ctx: click.Context, username: str) -> None:
|
||||||
click.echo(f'Tokens for "{username}":')
|
click.echo(f'Tokens for "{username}":')
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
for token in conn.execute('SELECT * FROM tokens WHERE user = :user', {'user': username}):
|
for row in conn.get_tokens(username):
|
||||||
click.echo(f'- {token["code"]}')
|
click.echo(f'- {row.code}')
|
||||||
|
|
||||||
|
|
||||||
@cli_user.command('create-token')
|
@cli_user.command('create-token')
|
||||||
|
@ -438,13 +434,13 @@ def cli_user_create_token(ctx: click.Context, username: str) -> None:
|
||||||
'Create a new API token for a user'
|
'Create a new API token for a user'
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
if not (user := conn.get_user(username)):
|
if (user := conn.get_user(username)) is None:
|
||||||
click.echo(f'User does not exist: {username}')
|
click.echo(f'User does not exist: {username}')
|
||||||
return
|
return
|
||||||
|
|
||||||
token = conn.put_token(user['username'])
|
token = conn.put_token(user.username)
|
||||||
|
|
||||||
click.echo(f'New token for "{username}": {token["code"]}')
|
click.echo(f'New token for "{username}": {token.code}')
|
||||||
|
|
||||||
|
|
||||||
@cli_user.command('delete-token')
|
@cli_user.command('delete-token')
|
||||||
|
@ -454,7 +450,7 @@ def cli_user_delete_token(ctx: click.Context, code: str) -> None:
|
||||||
'Delete an API token'
|
'Delete an API token'
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
if not conn.get_token(code):
|
if conn.get_token(code) is None:
|
||||||
click.echo('Token does not exist')
|
click.echo('Token does not exist')
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -476,8 +472,8 @@ def cli_inbox_list(ctx: click.Context) -> None:
|
||||||
click.echo('Connected to the following instances or relays:')
|
click.echo('Connected to the following instances or relays:')
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
for inbox in conn.get_inboxes():
|
for row in conn.get_inboxes():
|
||||||
click.echo(f'- {inbox["inbox"]}')
|
click.echo(f'- {row.inbox}')
|
||||||
|
|
||||||
|
|
||||||
@cli_inbox.command('follow')
|
@cli_inbox.command('follow')
|
||||||
|
@ -486,19 +482,21 @@ def cli_inbox_list(ctx: click.Context) -> None:
|
||||||
def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
|
def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
|
||||||
'Follow an actor (Relay must be running)'
|
'Follow an actor (Relay must be running)'
|
||||||
|
|
||||||
|
instance: schema.Instance | None = None
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
if conn.get_domain_ban(actor):
|
if conn.get_domain_ban(actor):
|
||||||
click.echo(f'Error: Refusing to follow banned actor: {actor}')
|
click.echo(f'Error: Refusing to follow banned actor: {actor}')
|
||||||
return
|
return
|
||||||
|
|
||||||
if (inbox_data := conn.get_inbox(actor)):
|
if (instance := conn.get_inbox(actor)) is not None:
|
||||||
inbox = inbox_data['inbox']
|
inbox = instance.inbox
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if not actor.startswith('http'):
|
if not actor.startswith('http'):
|
||||||
actor = f'https://{actor}/actor'
|
actor = f'https://{actor}/actor'
|
||||||
|
|
||||||
if not (actor_data := asyncio.run(http.get(actor, sign_headers = True))):
|
if (actor_data := asyncio.run(http.get(actor, sign_headers = True))) is None:
|
||||||
click.echo(f'Failed to fetch actor: {actor}')
|
click.echo(f'Failed to fetch actor: {actor}')
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -509,7 +507,7 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
|
||||||
actor = actor
|
actor = actor
|
||||||
)
|
)
|
||||||
|
|
||||||
asyncio.run(http.post(inbox, message, inbox_data))
|
asyncio.run(http.post(inbox, message, instance))
|
||||||
click.echo(f'Sent follow message to actor: {actor}')
|
click.echo(f'Sent follow message to actor: {actor}')
|
||||||
|
|
||||||
|
|
||||||
|
@ -519,19 +517,19 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
|
||||||
def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
|
def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
|
||||||
'Unfollow an actor (Relay must be running)'
|
'Unfollow an actor (Relay must be running)'
|
||||||
|
|
||||||
inbox_data: Row | None = None
|
instance: schema.Instance | None = None
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
if conn.get_domain_ban(actor):
|
if conn.get_domain_ban(actor):
|
||||||
click.echo(f'Error: Refusing to follow banned actor: {actor}')
|
click.echo(f'Error: Refusing to follow banned actor: {actor}')
|
||||||
return
|
return
|
||||||
|
|
||||||
if (inbox_data := conn.get_inbox(actor)):
|
if (instance := conn.get_inbox(actor)):
|
||||||
inbox = inbox_data['inbox']
|
inbox = instance.inbox
|
||||||
message = Message.new_unfollow(
|
message = Message.new_unfollow(
|
||||||
host = ctx.obj.config.domain,
|
host = ctx.obj.config.domain,
|
||||||
actor = actor,
|
actor = actor,
|
||||||
follow = inbox_data['followid']
|
follow = instance.followid
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -555,7 +553,7 @@ def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
asyncio.run(http.post(inbox, message, inbox_data))
|
asyncio.run(http.post(inbox, message, instance))
|
||||||
click.echo(f'Sent unfollow message to: {actor}')
|
click.echo(f'Sent unfollow message to: {actor}')
|
||||||
|
|
||||||
|
|
||||||
|
@ -635,9 +633,9 @@ def cli_request_list(ctx: click.Context) -> None:
|
||||||
click.echo('Follow requests:')
|
click.echo('Follow requests:')
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
for instance in conn.get_requests():
|
for row in conn.get_requests():
|
||||||
date = instance['created'].strftime('%Y-%m-%d')
|
date = row.created.strftime('%Y-%m-%d')
|
||||||
click.echo(f'- [{date}] {instance["domain"]}')
|
click.echo(f'- [{date}] {row.domain}')
|
||||||
|
|
||||||
|
|
||||||
@cli_request.command('accept')
|
@cli_request.command('accept')
|
||||||
|
@ -656,20 +654,20 @@ def cli_request_accept(ctx: click.Context, domain: str) -> None:
|
||||||
|
|
||||||
message = Message.new_response(
|
message = Message.new_response(
|
||||||
host = ctx.obj.config.domain,
|
host = ctx.obj.config.domain,
|
||||||
actor = instance['actor'],
|
actor = instance.actor,
|
||||||
followid = instance['followid'],
|
followid = instance.followid,
|
||||||
accept = True
|
accept = True
|
||||||
)
|
)
|
||||||
|
|
||||||
asyncio.run(http.post(instance['inbox'], message, instance))
|
asyncio.run(http.post(instance.inbox, message, instance))
|
||||||
|
|
||||||
if instance['software'] != 'mastodon':
|
if instance.software != 'mastodon':
|
||||||
message = Message.new_follow(
|
message = Message.new_follow(
|
||||||
host = ctx.obj.config.domain,
|
host = ctx.obj.config.domain,
|
||||||
actor = instance['actor']
|
actor = instance.actor
|
||||||
)
|
)
|
||||||
|
|
||||||
asyncio.run(http.post(instance['inbox'], message, instance))
|
asyncio.run(http.post(instance.inbox, message, instance))
|
||||||
|
|
||||||
|
|
||||||
@cli_request.command('deny')
|
@cli_request.command('deny')
|
||||||
|
@ -688,12 +686,12 @@ def cli_request_deny(ctx: click.Context, domain: str) -> None:
|
||||||
|
|
||||||
response = Message.new_response(
|
response = Message.new_response(
|
||||||
host = ctx.obj.config.domain,
|
host = ctx.obj.config.domain,
|
||||||
actor = instance['actor'],
|
actor = instance.actor,
|
||||||
followid = instance['followid'],
|
followid = instance.followid,
|
||||||
accept = False
|
accept = False
|
||||||
)
|
)
|
||||||
|
|
||||||
asyncio.run(http.post(instance['inbox'], response, instance))
|
asyncio.run(http.post(instance.inbox, response, instance))
|
||||||
|
|
||||||
|
|
||||||
@cli.group('instance')
|
@cli.group('instance')
|
||||||
|
@ -709,12 +707,12 @@ def cli_instance_list(ctx: click.Context) -> None:
|
||||||
click.echo('Banned domains:')
|
click.echo('Banned domains:')
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
for instance in conn.execute('SELECT * FROM domain_bans'):
|
for row in conn.get_domain_bans():
|
||||||
if instance['reason']:
|
if row.reason is not None:
|
||||||
click.echo(f'- {instance["domain"]} ({instance["reason"]})')
|
click.echo(f'- {row.domain} ({row.reason})')
|
||||||
|
|
||||||
else:
|
else:
|
||||||
click.echo(f'- {instance["domain"]}')
|
click.echo(f'- {row.domain}')
|
||||||
|
|
||||||
|
|
||||||
@cli_instance.command('ban')
|
@cli_instance.command('ban')
|
||||||
|
@ -726,7 +724,7 @@ def cli_instance_ban(ctx: click.Context, domain: str, reason: str, note: str) ->
|
||||||
'Ban an instance and remove the associated inbox if it exists'
|
'Ban an instance and remove the associated inbox if it exists'
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
if conn.get_domain_ban(domain):
|
if conn.get_domain_ban(domain) is not None:
|
||||||
click.echo(f'Domain already banned: {domain}')
|
click.echo(f'Domain already banned: {domain}')
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -742,7 +740,7 @@ def cli_instance_unban(ctx: click.Context, domain: str) -> None:
|
||||||
'Unban an instance'
|
'Unban an instance'
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
if not conn.del_domain_ban(domain):
|
if conn.del_domain_ban(domain) is None:
|
||||||
click.echo(f'Instance wasn\'t banned: {domain}')
|
click.echo(f'Instance wasn\'t banned: {domain}')
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -767,11 +765,11 @@ def cli_instance_update(ctx: click.Context, domain: str, reason: str, note: str)
|
||||||
|
|
||||||
click.echo(f'Updated domain ban: {domain}')
|
click.echo(f'Updated domain ban: {domain}')
|
||||||
|
|
||||||
if row['reason']:
|
if row.reason:
|
||||||
click.echo(f'- {row["domain"]} ({row["reason"]})')
|
click.echo(f'- {row.domain} ({row.reason})')
|
||||||
|
|
||||||
else:
|
else:
|
||||||
click.echo(f'- {row["domain"]}')
|
click.echo(f'- {row.domain}')
|
||||||
|
|
||||||
|
|
||||||
@cli.group('software')
|
@cli.group('software')
|
||||||
|
@ -787,12 +785,12 @@ def cli_software_list(ctx: click.Context) -> None:
|
||||||
click.echo('Banned software:')
|
click.echo('Banned software:')
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
for software in conn.execute('SELECT * FROM software_bans'):
|
for row in conn.get_software_bans():
|
||||||
if software['reason']:
|
if row.reason:
|
||||||
click.echo(f'- {software["name"]} ({software["reason"]})')
|
click.echo(f'- {row.name} ({row.reason})')
|
||||||
|
|
||||||
else:
|
else:
|
||||||
click.echo(f'- {software["name"]}')
|
click.echo(f'- {row.name}')
|
||||||
|
|
||||||
|
|
||||||
@cli_software.command('ban')
|
@cli_software.command('ban')
|
||||||
|
@ -814,12 +812,12 @@ def cli_software_ban(ctx: click.Context,
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
if name == 'RELAYS':
|
if name == 'RELAYS':
|
||||||
for software in RELAY_SOFTWARE:
|
for item in RELAY_SOFTWARE:
|
||||||
if conn.get_software_ban(software):
|
if conn.get_software_ban(item):
|
||||||
click.echo(f'Relay already banned: {software}')
|
click.echo(f'Relay already banned: {item}')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
conn.put_software_ban(software, reason or 'relay', note)
|
conn.put_software_ban(item, reason or 'relay', note)
|
||||||
|
|
||||||
click.echo('Banned all relay software')
|
click.echo('Banned all relay software')
|
||||||
return
|
return
|
||||||
|
@ -896,11 +894,11 @@ def cli_software_update(ctx: click.Context, name: str, reason: str, note: str) -
|
||||||
|
|
||||||
click.echo(f'Updated software ban: {name}')
|
click.echo(f'Updated software ban: {name}')
|
||||||
|
|
||||||
if row['reason']:
|
if row.reason:
|
||||||
click.echo(f'- {row["name"]} ({row["reason"]})')
|
click.echo(f'- {row.name} ({row.reason})')
|
||||||
|
|
||||||
else:
|
else:
|
||||||
click.echo(f'- {row["name"]}')
|
click.echo(f'- {row.name}')
|
||||||
|
|
||||||
|
|
||||||
@cli.group('whitelist')
|
@cli.group('whitelist')
|
||||||
|
@ -916,8 +914,8 @@ def cli_whitelist_list(ctx: click.Context) -> None:
|
||||||
click.echo('Current whitelisted domains:')
|
click.echo('Current whitelisted domains:')
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
for domain in conn.execute('SELECT * FROM whitelist'):
|
for row in conn.get_domain_whitelist():
|
||||||
click.echo(f'- {domain["domain"]}')
|
click.echo(f'- {row.domain}')
|
||||||
|
|
||||||
|
|
||||||
@cli_whitelist.command('add')
|
@cli_whitelist.command('add')
|
||||||
|
@ -956,23 +954,19 @@ def cli_whitelist_remove(ctx: click.Context, domain: str) -> None:
|
||||||
@cli_whitelist.command('import')
|
@cli_whitelist.command('import')
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def cli_whitelist_import(ctx: click.Context) -> None:
|
def cli_whitelist_import(ctx: click.Context) -> None:
|
||||||
'Add all current inboxes to the whitelist'
|
'Add all current instances to the whitelist'
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
for inbox in conn.execute('SELECT * FROM inboxes').all():
|
for row in conn.get_inboxes():
|
||||||
if conn.get_domain_whitelist(inbox['domain']):
|
if conn.get_domain_whitelist(row.domain) is not None:
|
||||||
click.echo(f'Domain already in whitelist: {inbox["domain"]}')
|
click.echo(f'Domain already in whitelist: {row.domain}')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
conn.put_domain_whitelist(inbox['domain'])
|
conn.put_domain_whitelist(row.domain)
|
||||||
|
|
||||||
click.echo('Imported whitelist from inboxes')
|
click.echo('Imported whitelist from inboxes')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
cli(prog_name='relay')
|
cli(prog_name='activityrelay')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
click.echo('Running relay.manage is depreciated. Run `activityrelay [command]` instead.')
|
|
||||||
|
|
|
@ -3,12 +3,14 @@ from __future__ import annotations
|
||||||
import aputils
|
import aputils
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import platform
|
||||||
import socket
|
import socket
|
||||||
import typing
|
|
||||||
|
|
||||||
from aiohttp.web import Response as AiohttpResponse
|
from aiohttp.web import Response as AiohttpResponse
|
||||||
|
from collections.abc import Sequence
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Any, TypedDict, TypeVar
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -17,8 +19,7 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from importlib_resources import files as pkgfiles # type: ignore
|
from importlib_resources import files as pkgfiles # type: ignore
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import Any
|
|
||||||
from .application import Application
|
from .application import Application
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -28,16 +29,17 @@ if typing.TYPE_CHECKING:
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
|
||||||
T = typing.TypeVar('T')
|
T = TypeVar('T')
|
||||||
ResponseType = typing.TypedDict('ResponseType', {
|
ResponseType = TypedDict('ResponseType', {
|
||||||
'status': int,
|
'status': int,
|
||||||
'headers': dict[str, typing.Any] | None,
|
'headers': dict[str, Any] | None,
|
||||||
'content_type': str,
|
'content_type': str,
|
||||||
'body': bytes | None,
|
'body': bytes | None,
|
||||||
'text': str | None
|
'text': str | None
|
||||||
})
|
})
|
||||||
|
|
||||||
IS_DOCKER = bool(os.environ.get('DOCKER_RUNNING'))
|
IS_DOCKER = bool(os.environ.get('DOCKER_RUNNING'))
|
||||||
|
IS_WINDOWS = platform.system() == 'Windows'
|
||||||
|
|
||||||
MIMETYPES = {
|
MIMETYPES = {
|
||||||
'activity': 'application/activity+json',
|
'activity': 'application/activity+json',
|
||||||
|
@ -126,7 +128,7 @@ class JsonEncoder(json.JSONEncoder):
|
||||||
if isinstance(o, datetime):
|
if isinstance(o, datetime):
|
||||||
return o.isoformat()
|
return o.isoformat()
|
||||||
|
|
||||||
return json.JSONEncoder.default(self, o)
|
return json.JSONEncoder.default(self, o) # type: ignore[no-any-return]
|
||||||
|
|
||||||
|
|
||||||
class Message(aputils.Message):
|
class Message(aputils.Message):
|
||||||
|
@ -146,6 +148,7 @@ class Message(aputils.Message):
|
||||||
'followers': f'https://{host}/followers',
|
'followers': f'https://{host}/followers',
|
||||||
'following': f'https://{host}/following',
|
'following': f'https://{host}/following',
|
||||||
'inbox': f'https://{host}/inbox',
|
'inbox': f'https://{host}/inbox',
|
||||||
|
'outbox': f'https://{host}/outbox',
|
||||||
'url': f'https://{host}/',
|
'url': f'https://{host}/',
|
||||||
'endpoints': {
|
'endpoints': {
|
||||||
'sharedInbox': f'https://{host}/inbox'
|
'sharedInbox': f'https://{host}/inbox'
|
||||||
|
@ -211,7 +214,7 @@ class Response(AiohttpResponse):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def new(cls: type[Self],
|
def new(cls: type[Self],
|
||||||
body: str | bytes | dict | tuple | list | set = '',
|
body: str | bytes | dict[str, Any] | Sequence[Any] = '',
|
||||||
status: int = 200,
|
status: int = 200,
|
||||||
headers: dict[str, str] | None = None,
|
headers: dict[str, str] | None = None,
|
||||||
ctype: str = 'text') -> Self:
|
ctype: str = 'text') -> Self:
|
||||||
|
@ -224,22 +227,22 @@ class Response(AiohttpResponse):
|
||||||
'text': None
|
'text': None
|
||||||
}
|
}
|
||||||
|
|
||||||
if isinstance(body, bytes):
|
if isinstance(body, str):
|
||||||
|
kwargs['text'] = body
|
||||||
|
|
||||||
|
elif isinstance(body, bytes):
|
||||||
kwargs['body'] = body
|
kwargs['body'] = body
|
||||||
|
|
||||||
elif isinstance(body, (dict, list, tuple, set)) or ctype in {'json', 'activity'}:
|
elif isinstance(body, (dict, Sequence)):
|
||||||
kwargs['text'] = json.dumps(body, cls = JsonEncoder)
|
kwargs['text'] = json.dumps(body, cls = JsonEncoder)
|
||||||
|
|
||||||
else:
|
|
||||||
kwargs['text'] = body
|
|
||||||
|
|
||||||
return cls(**kwargs)
|
return cls(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def new_error(cls: type[Self],
|
def new_error(cls: type[Self],
|
||||||
status: int,
|
status: int,
|
||||||
body: str | bytes | dict,
|
body: str | bytes | dict[str, Any],
|
||||||
ctype: str = 'text') -> Self:
|
ctype: str = 'text') -> Self:
|
||||||
|
|
||||||
if ctype == 'json':
|
if ctype == 'json':
|
||||||
|
|
|
@ -10,14 +10,12 @@ if typing.TYPE_CHECKING:
|
||||||
from .views.activitypub import ActorView
|
from .views.activitypub import ActorView
|
||||||
|
|
||||||
|
|
||||||
def person_check(actor: Message, software: str | None) -> bool:
|
def actor_type_check(actor: Message, software: str | None) -> bool:
|
||||||
# pleroma and akkoma may use Person for the actor type for some reason
|
if actor.type == 'Application':
|
||||||
# akkoma changed this in 3.6.0
|
return True
|
||||||
if software in {'akkoma', 'pleroma'} and actor.id == f'https://{actor.domain}/relay':
|
|
||||||
return False
|
|
||||||
|
|
||||||
# make sure the actor is an application
|
# akkoma (< 3.6.0) and pleroma use Person for the actor type
|
||||||
if actor.type != 'Application':
|
if software in {'akkoma', 'pleroma'} and actor.id == f'https://{actor.domain}/relay':
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
@ -36,7 +34,7 @@ async def handle_relay(view: ActorView, conn: Connection) -> None:
|
||||||
logging.debug('>> relay: %s', message)
|
logging.debug('>> relay: %s', message)
|
||||||
|
|
||||||
for instance in conn.distill_inboxes(view.message):
|
for instance in conn.distill_inboxes(view.message):
|
||||||
view.app.push_message(instance["inbox"], message, instance)
|
view.app.push_message(instance.inbox, message, instance)
|
||||||
|
|
||||||
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
|
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
|
||||||
|
|
||||||
|
@ -54,7 +52,7 @@ async def handle_forward(view: ActorView, conn: Connection) -> None:
|
||||||
logging.debug('>> forward: %s', message)
|
logging.debug('>> forward: %s', message)
|
||||||
|
|
||||||
for instance in conn.distill_inboxes(view.message):
|
for instance in conn.distill_inboxes(view.message):
|
||||||
view.app.push_message(instance["inbox"], await view.request.read(), instance)
|
view.app.push_message(instance.inbox, view.message, instance)
|
||||||
|
|
||||||
view.cache.set('handle-relay', view.message.id, message.id, 'str')
|
view.cache.set('handle-relay', view.message.id, message.id, 'str')
|
||||||
|
|
||||||
|
@ -88,7 +86,7 @@ async def handle_follow(view: ActorView, conn: Connection) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
# reject if the actor is not an instance actor
|
# reject if the actor is not an instance actor
|
||||||
if person_check(view.actor, software):
|
if actor_type_check(view.actor, software):
|
||||||
logging.verbose('Non-application actor tried to follow: %s', view.actor.id)
|
logging.verbose('Non-application actor tried to follow: %s', view.actor.id)
|
||||||
|
|
||||||
view.app.push_message(
|
view.app.push_message(
|
||||||
|
@ -179,7 +177,7 @@ async def handle_undo(view: ActorView, conn: Connection) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
# prevent past unfollows from removing an instance
|
# prevent past unfollows from removing an instance
|
||||||
if view.instance['followid'] and view.instance['followid'] != view.message.object_id:
|
if view.instance.followid and view.instance.followid != view.message.object_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
with conn.transaction():
|
with conn.transaction():
|
||||||
|
@ -223,18 +221,18 @@ async def run_processor(view: ActorView) -> None:
|
||||||
|
|
||||||
with view.database.session() as conn:
|
with view.database.session() as conn:
|
||||||
if view.instance:
|
if view.instance:
|
||||||
if not view.instance['software']:
|
if not view.instance.software:
|
||||||
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])):
|
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance.domain)):
|
||||||
with conn.transaction():
|
with conn.transaction():
|
||||||
view.instance = conn.put_inbox(
|
view.instance = conn.put_inbox(
|
||||||
domain = view.instance['domain'],
|
domain = view.instance.domain,
|
||||||
software = nodeinfo.sw_name
|
software = nodeinfo.sw_name
|
||||||
)
|
)
|
||||||
|
|
||||||
if not view.instance['actor']:
|
if not view.instance.actor:
|
||||||
with conn.transaction():
|
with conn.transaction():
|
||||||
view.instance = conn.put_inbox(
|
view.instance = conn.put_inbox(
|
||||||
domain = view.instance['domain'],
|
domain = view.instance.domain,
|
||||||
actor = view.actor.id
|
actor = view.actor.id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,25 +1,22 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import textwrap
|
import textwrap
|
||||||
import typing
|
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from hamlish_jinja import HamlishExtension
|
from hamlish_jinja import HamlishExtension
|
||||||
from jinja2 import Environment, FileSystemLoader
|
from jinja2 import Environment, FileSystemLoader
|
||||||
from jinja2.ext import Extension
|
from jinja2.ext import Extension
|
||||||
from jinja2.nodes import CallBlock
|
from jinja2.nodes import CallBlock, Node
|
||||||
|
from jinja2.parser import Parser
|
||||||
from markdown import Markdown
|
from markdown import Markdown
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from . import __version__
|
from . import __version__
|
||||||
from .misc import get_resource
|
from .misc import get_resource
|
||||||
|
from .views.base import View
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from jinja2.nodes import Node
|
|
||||||
from jinja2.parser import Parser
|
|
||||||
from typing import Any
|
|
||||||
from .application import Application
|
from .application import Application
|
||||||
from .views.base import View
|
|
||||||
|
|
||||||
|
|
||||||
class Template(Environment):
|
class Template(Environment):
|
||||||
|
|
|
@ -1,26 +1,22 @@
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import aputils
|
import aputils
|
||||||
import traceback
|
import traceback
|
||||||
import typing
|
|
||||||
|
from aiohttp.web import Request
|
||||||
|
|
||||||
from .base import View, register_route
|
from .base import View, register_route
|
||||||
|
|
||||||
from .. import logger as logging
|
from .. import logger as logging
|
||||||
|
from ..database import schema
|
||||||
from ..misc import Message, Response
|
from ..misc import Message, Response
|
||||||
from ..processors import run_processor
|
from ..processors import run_processor
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from aiohttp.web import Request
|
|
||||||
from bsql import Row
|
|
||||||
|
|
||||||
|
|
||||||
@register_route('/actor', '/inbox')
|
@register_route('/actor', '/inbox')
|
||||||
class ActorView(View):
|
class ActorView(View):
|
||||||
signature: aputils.Signature
|
signature: aputils.Signature
|
||||||
message: Message
|
message: Message
|
||||||
actor: Message
|
actor: Message
|
||||||
instancce: Row
|
instance: schema.Instance
|
||||||
signer: aputils.Signer
|
signer: aputils.Signer
|
||||||
|
|
||||||
|
|
||||||
|
@ -47,7 +43,7 @@ class ActorView(View):
|
||||||
return response
|
return response
|
||||||
|
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
self.instance = conn.get_inbox(self.actor.shared_inbox)
|
self.instance = conn.get_inbox(self.actor.shared_inbox) # type: ignore[assignment]
|
||||||
|
|
||||||
# reject if actor is banned
|
# reject if actor is banned
|
||||||
if conn.get_domain_ban(self.actor.domain):
|
if conn.get_domain_ban(self.actor.domain):
|
||||||
|
@ -95,9 +91,10 @@ class ActorView(View):
|
||||||
logging.verbose('actor not in message')
|
logging.verbose('actor not in message')
|
||||||
return Response.new_error(400, 'no actor in message', 'json')
|
return Response.new_error(400, 'no actor in message', 'json')
|
||||||
|
|
||||||
actor: Message | None = await self.client.get(self.signature.keyid, True, Message)
|
try:
|
||||||
|
self.actor = await self.client.get(self.signature.keyid, True, Message)
|
||||||
|
|
||||||
if actor is None:
|
except Exception:
|
||||||
# ld signatures aren't handled atm, so just ignore it
|
# ld signatures aren't handled atm, so just ignore it
|
||||||
if self.message.type == 'Delete':
|
if self.message.type == 'Delete':
|
||||||
logging.verbose('Instance sent a delete which cannot be handled')
|
logging.verbose('Instance sent a delete which cannot be handled')
|
||||||
|
@ -106,8 +103,6 @@ class ActorView(View):
|
||||||
logging.verbose(f'Failed to fetch actor: {self.signature.keyid}')
|
logging.verbose(f'Failed to fetch actor: {self.signature.keyid}')
|
||||||
return Response.new_error(400, 'failed to fetch actor', 'json')
|
return Response.new_error(400, 'failed to fetch actor', 'json')
|
||||||
|
|
||||||
self.actor = actor
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.signer = self.actor.signer
|
self.signer = self.actor.signer
|
||||||
|
|
||||||
|
@ -125,6 +120,39 @@ class ActorView(View):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@register_route('/outbox')
|
||||||
|
class OutboxView(View):
|
||||||
|
async def get(self, request: Request) -> Response:
|
||||||
|
msg = aputils.Message.new(
|
||||||
|
aputils.ObjectType.ORDERED_COLLECTION,
|
||||||
|
{
|
||||||
|
"id": f'https://{self.config.domain}/outbox',
|
||||||
|
"totalItems": 0,
|
||||||
|
"orderedItems": []
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return Response.new(msg, ctype = 'activity')
|
||||||
|
|
||||||
|
|
||||||
|
@register_route('/following', '/followers')
|
||||||
|
class RelationshipView(View):
|
||||||
|
async def get(self, request: Request) -> Response:
|
||||||
|
with self.database.session(False) as s:
|
||||||
|
inboxes = [row['actor'] for row in s.get_inboxes()]
|
||||||
|
|
||||||
|
msg = aputils.Message.new(
|
||||||
|
aputils.ObjectType.COLLECTION,
|
||||||
|
{
|
||||||
|
"id": f'https://{self.config.domain}{request.path}',
|
||||||
|
"totalItems": len(inboxes),
|
||||||
|
"items": inboxes
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return Response.new(msg, ctype = 'activity')
|
||||||
|
|
||||||
|
|
||||||
@register_route('/.well-known/webfinger')
|
@register_route('/.well-known/webfinger')
|
||||||
class WebfingerView(View):
|
class WebfingerView(View):
|
||||||
async def get(self, request: Request) -> Response:
|
async def get(self, request: Request) -> Response:
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
from __future__ import annotations
|
import traceback
|
||||||
|
|
||||||
import typing
|
from aiohttp.web import Request, middleware
|
||||||
|
|
||||||
from aiohttp import web
|
|
||||||
from argon2.exceptions import VerifyMismatchError
|
from argon2.exceptions import VerifyMismatchError
|
||||||
|
from collections.abc import Awaitable, Callable, Sequence
|
||||||
|
from typing import Any
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from .base import View, register_route
|
from .base import View, register_route
|
||||||
|
@ -12,11 +12,6 @@ from .. import __version__
|
||||||
from ..database import ConfigData
|
from ..database import ConfigData
|
||||||
from ..misc import Message, Response, boolean, get_app
|
from ..misc import Message, Response, boolean, get_app
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from aiohttp.web import Request
|
|
||||||
from collections.abc import Callable, Sequence
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
ALLOWED_HEADERS = {
|
ALLOWED_HEADERS = {
|
||||||
'accept',
|
'accept',
|
||||||
|
@ -26,7 +21,6 @@ ALLOWED_HEADERS = {
|
||||||
|
|
||||||
PUBLIC_API_PATHS: Sequence[tuple[str, str]] = (
|
PUBLIC_API_PATHS: Sequence[tuple[str, str]] = (
|
||||||
('GET', '/api/v1/relay'),
|
('GET', '/api/v1/relay'),
|
||||||
('GET', '/api/v1/instance'),
|
|
||||||
('POST', '/api/v1/token')
|
('POST', '/api/v1/token')
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -38,8 +32,10 @@ def check_api_path(method: str, path: str) -> bool:
|
||||||
return path.startswith('/api')
|
return path.startswith('/api')
|
||||||
|
|
||||||
|
|
||||||
@web.middleware
|
@middleware
|
||||||
async def handle_api_path(request: Request, handler: Callable) -> Response:
|
async def handle_api_path(
|
||||||
|
request: Request,
|
||||||
|
handler: Callable[[Request], Awaitable[Response]]) -> Response:
|
||||||
try:
|
try:
|
||||||
if (token := request.cookies.get('user-token')):
|
if (token := request.cookies.get('user-token')):
|
||||||
request['token'] = token
|
request['token'] = token
|
||||||
|
@ -94,10 +90,10 @@ class Login(View):
|
||||||
|
|
||||||
token = conn.put_token(data['username'])
|
token = conn.put_token(data['username'])
|
||||||
|
|
||||||
resp = Response.new({'token': token['code']}, ctype = 'json')
|
resp = Response.new({'token': token.code}, ctype = 'json')
|
||||||
resp.set_cookie(
|
resp.set_cookie(
|
||||||
'user-token',
|
'user-token',
|
||||||
token['code'],
|
token.code,
|
||||||
max_age = 60 * 60 * 24 * 365,
|
max_age = 60 * 60 * 24 * 365,
|
||||||
domain = self.config.domain,
|
domain = self.config.domain,
|
||||||
path = '/',
|
path = '/',
|
||||||
|
@ -121,7 +117,7 @@ class RelayInfo(View):
|
||||||
async def get(self, request: Request) -> Response:
|
async def get(self, request: Request) -> Response:
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
config = conn.get_config_all()
|
config = conn.get_config_all()
|
||||||
inboxes = [row['domain'] for row in conn.get_inboxes()]
|
inboxes = [row.domain for row in conn.get_inboxes()]
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
'domain': self.config.domain,
|
'domain': self.config.domain,
|
||||||
|
@ -192,7 +188,7 @@ class Config(View):
|
||||||
class Inbox(View):
|
class Inbox(View):
|
||||||
async def get(self, request: Request) -> Response:
|
async def get(self, request: Request) -> Response:
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
data = conn.get_inboxes()
|
data = tuple(conn.get_inboxes())
|
||||||
|
|
||||||
return Response.new(data, ctype = 'json')
|
return Response.new(data, ctype = 'json')
|
||||||
|
|
||||||
|
@ -206,24 +202,36 @@ class Inbox(View):
|
||||||
data['domain'] = urlparse(data["actor"]).netloc
|
data['domain'] = urlparse(data["actor"]).netloc
|
||||||
|
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
if conn.get_inbox(data['domain']):
|
if conn.get_inbox(data['domain']) is not None:
|
||||||
return Response.new_error(404, 'Instance already in database', 'json')
|
return Response.new_error(404, 'Instance already in database', 'json')
|
||||||
|
|
||||||
if not data.get('inbox'):
|
data['domain'] = data['domain'].encode('idna').decode()
|
||||||
actor_data: Message | None = await self.client.get(data['actor'], True, Message)
|
|
||||||
|
|
||||||
if actor_data is None:
|
if not data.get('inbox'):
|
||||||
|
try:
|
||||||
|
actor_data = await self.client.get(data['actor'], True, Message)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
traceback.print_exc()
|
||||||
return Response.new_error(500, 'Failed to fetch actor', 'json')
|
return Response.new_error(500, 'Failed to fetch actor', 'json')
|
||||||
|
|
||||||
data['inbox'] = actor_data.shared_inbox
|
data['inbox'] = actor_data.shared_inbox
|
||||||
|
|
||||||
if not data.get('software'):
|
if not data.get('software'):
|
||||||
nodeinfo = await self.client.fetch_nodeinfo(data['domain'])
|
try:
|
||||||
|
nodeinfo = await self.client.fetch_nodeinfo(data['domain'])
|
||||||
if nodeinfo is not None:
|
|
||||||
data['software'] = nodeinfo.sw_name
|
data['software'] = nodeinfo.sw_name
|
||||||
|
|
||||||
row = conn.put_inbox(**data)
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
row = conn.put_inbox(
|
||||||
|
domain = data['domain'],
|
||||||
|
actor = data['actor'],
|
||||||
|
inbox = data.get('inbox'),
|
||||||
|
software = data.get('software'),
|
||||||
|
followid = data.get('followid')
|
||||||
|
)
|
||||||
|
|
||||||
return Response.new(row, ctype = 'json')
|
return Response.new(row, ctype = 'json')
|
||||||
|
|
||||||
|
@ -235,10 +243,17 @@ class Inbox(View):
|
||||||
if isinstance(data, Response):
|
if isinstance(data, Response):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
if not (instance := conn.get_inbox(data['domain'])):
|
data['domain'] = data['domain'].encode('idna').decode()
|
||||||
|
|
||||||
|
if (instance := conn.get_inbox(data['domain'])) is None:
|
||||||
return Response.new_error(404, 'Instance with domain not found', 'json')
|
return Response.new_error(404, 'Instance with domain not found', 'json')
|
||||||
|
|
||||||
instance = conn.put_inbox(instance['domain'], **data)
|
instance = conn.put_inbox(
|
||||||
|
instance.domain,
|
||||||
|
actor = data.get('actor'),
|
||||||
|
software = data.get('software'),
|
||||||
|
followid = data.get('followid')
|
||||||
|
)
|
||||||
|
|
||||||
return Response.new(instance, ctype = 'json')
|
return Response.new(instance, ctype = 'json')
|
||||||
|
|
||||||
|
@ -250,6 +265,8 @@ class Inbox(View):
|
||||||
if isinstance(data, Response):
|
if isinstance(data, Response):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
data['domain'] = data['domain'].encode('idna').decode()
|
||||||
|
|
||||||
if not conn.get_inbox(data['domain']):
|
if not conn.get_inbox(data['domain']):
|
||||||
return Response.new_error(404, 'Instance with domain not found', 'json')
|
return Response.new_error(404, 'Instance with domain not found', 'json')
|
||||||
|
|
||||||
|
@ -262,14 +279,19 @@ class Inbox(View):
|
||||||
class RequestView(View):
|
class RequestView(View):
|
||||||
async def get(self, request: Request) -> Response:
|
async def get(self, request: Request) -> Response:
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
instances = conn.get_requests()
|
instances = tuple(conn.get_requests())
|
||||||
|
|
||||||
return Response.new(instances, ctype = 'json')
|
return Response.new(instances, ctype = 'json')
|
||||||
|
|
||||||
|
|
||||||
async def post(self, request: Request) -> Response:
|
async def post(self, request: Request) -> Response:
|
||||||
data: dict[str, Any] | Response = await self.get_api_data(['domain', 'accept'], [])
|
data: dict[str, Any] | Response = await self.get_api_data(['domain', 'accept'], [])
|
||||||
|
|
||||||
|
if isinstance(data, Response):
|
||||||
|
return data
|
||||||
|
|
||||||
data['accept'] = boolean(data['accept'])
|
data['accept'] = boolean(data['accept'])
|
||||||
|
data['domain'] = data['domain'].encode('idna').decode()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with self.database.session(True) as conn:
|
with self.database.session(True) as conn:
|
||||||
|
@ -280,20 +302,20 @@ class RequestView(View):
|
||||||
|
|
||||||
message = Message.new_response(
|
message = Message.new_response(
|
||||||
host = self.config.domain,
|
host = self.config.domain,
|
||||||
actor = instance['actor'],
|
actor = instance.actor,
|
||||||
followid = instance['followid'],
|
followid = instance.followid,
|
||||||
accept = data['accept']
|
accept = data['accept']
|
||||||
)
|
)
|
||||||
|
|
||||||
self.app.push_message(instance['inbox'], message, instance)
|
self.app.push_message(instance.inbox, message, instance)
|
||||||
|
|
||||||
if data['accept'] and instance['software'] != 'mastodon':
|
if data['accept'] and instance.software != 'mastodon':
|
||||||
message = Message.new_follow(
|
message = Message.new_follow(
|
||||||
host = self.config.domain,
|
host = self.config.domain,
|
||||||
actor = instance['actor']
|
actor = instance.actor
|
||||||
)
|
)
|
||||||
|
|
||||||
self.app.push_message(instance['inbox'], message, instance)
|
self.app.push_message(instance.inbox, message, instance)
|
||||||
|
|
||||||
resp_message = {'message': 'Request accepted' if data['accept'] else 'Request denied'}
|
resp_message = {'message': 'Request accepted' if data['accept'] else 'Request denied'}
|
||||||
return Response.new(resp_message, ctype = 'json')
|
return Response.new(resp_message, ctype = 'json')
|
||||||
|
@ -303,7 +325,7 @@ class RequestView(View):
|
||||||
class DomainBan(View):
|
class DomainBan(View):
|
||||||
async def get(self, request: Request) -> Response:
|
async def get(self, request: Request) -> Response:
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
bans = tuple(conn.execute('SELECT * FROM domain_bans').all())
|
bans = tuple(conn.get_domain_bans())
|
||||||
|
|
||||||
return Response.new(bans, ctype = 'json')
|
return Response.new(bans, ctype = 'json')
|
||||||
|
|
||||||
|
@ -314,11 +336,17 @@ class DomainBan(View):
|
||||||
if isinstance(data, Response):
|
if isinstance(data, Response):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
data['domain'] = data['domain'].encode('idna').decode()
|
||||||
|
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
if conn.get_domain_ban(data['domain']):
|
if conn.get_domain_ban(data['domain']) is not None:
|
||||||
return Response.new_error(400, 'Domain already banned', 'json')
|
return Response.new_error(400, 'Domain already banned', 'json')
|
||||||
|
|
||||||
ban = conn.put_domain_ban(**data)
|
ban = conn.put_domain_ban(
|
||||||
|
domain = data['domain'],
|
||||||
|
reason = data.get('reason'),
|
||||||
|
note = data.get('note')
|
||||||
|
)
|
||||||
|
|
||||||
return Response.new(ban, ctype = 'json')
|
return Response.new(ban, ctype = 'json')
|
||||||
|
|
||||||
|
@ -330,13 +358,19 @@ class DomainBan(View):
|
||||||
if isinstance(data, Response):
|
if isinstance(data, Response):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
if not conn.get_domain_ban(data['domain']):
|
|
||||||
return Response.new_error(404, 'Domain not banned', 'json')
|
|
||||||
|
|
||||||
if not any([data.get('note'), data.get('reason')]):
|
if not any([data.get('note'), data.get('reason')]):
|
||||||
return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
|
return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
|
||||||
|
|
||||||
ban = conn.update_domain_ban(**data)
|
data['domain'] = data['domain'].encode('idna').decode()
|
||||||
|
|
||||||
|
if conn.get_domain_ban(data['domain']) is None:
|
||||||
|
return Response.new_error(404, 'Domain not banned', 'json')
|
||||||
|
|
||||||
|
ban = conn.update_domain_ban(
|
||||||
|
domain = data['domain'],
|
||||||
|
reason = data.get('reason'),
|
||||||
|
note = data.get('note')
|
||||||
|
)
|
||||||
|
|
||||||
return Response.new(ban, ctype = 'json')
|
return Response.new(ban, ctype = 'json')
|
||||||
|
|
||||||
|
@ -348,7 +382,9 @@ class DomainBan(View):
|
||||||
if isinstance(data, Response):
|
if isinstance(data, Response):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
if not conn.get_domain_ban(data['domain']):
|
data['domain'] = data['domain'].encode('idna').decode()
|
||||||
|
|
||||||
|
if conn.get_domain_ban(data['domain']) is None:
|
||||||
return Response.new_error(404, 'Domain not banned', 'json')
|
return Response.new_error(404, 'Domain not banned', 'json')
|
||||||
|
|
||||||
conn.del_domain_ban(data['domain'])
|
conn.del_domain_ban(data['domain'])
|
||||||
|
@ -360,7 +396,7 @@ class DomainBan(View):
|
||||||
class SoftwareBan(View):
|
class SoftwareBan(View):
|
||||||
async def get(self, request: Request) -> Response:
|
async def get(self, request: Request) -> Response:
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
bans = tuple(conn.execute('SELECT * FROM software_bans').all())
|
bans = tuple(conn.get_software_bans())
|
||||||
|
|
||||||
return Response.new(bans, ctype = 'json')
|
return Response.new(bans, ctype = 'json')
|
||||||
|
|
||||||
|
@ -372,10 +408,14 @@ class SoftwareBan(View):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
if conn.get_software_ban(data['name']):
|
if conn.get_software_ban(data['name']) is not None:
|
||||||
return Response.new_error(400, 'Domain already banned', 'json')
|
return Response.new_error(400, 'Domain already banned', 'json')
|
||||||
|
|
||||||
ban = conn.put_software_ban(**data)
|
ban = conn.put_software_ban(
|
||||||
|
name = data['name'],
|
||||||
|
reason = data.get('reason'),
|
||||||
|
note = data.get('note')
|
||||||
|
)
|
||||||
|
|
||||||
return Response.new(ban, ctype = 'json')
|
return Response.new(ban, ctype = 'json')
|
||||||
|
|
||||||
|
@ -386,14 +426,18 @@ class SoftwareBan(View):
|
||||||
if isinstance(data, Response):
|
if isinstance(data, Response):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
if not any([data.get('note'), data.get('reason')]):
|
||||||
|
return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
|
||||||
|
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
if not conn.get_software_ban(data['name']):
|
if conn.get_software_ban(data['name']) is None:
|
||||||
return Response.new_error(404, 'Software not banned', 'json')
|
return Response.new_error(404, 'Software not banned', 'json')
|
||||||
|
|
||||||
if not any([data.get('note'), data.get('reason')]):
|
ban = conn.update_software_ban(
|
||||||
return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
|
name = data['name'],
|
||||||
|
reason = data.get('reason'),
|
||||||
ban = conn.update_software_ban(**data)
|
note = data.get('note')
|
||||||
|
)
|
||||||
|
|
||||||
return Response.new(ban, ctype = 'json')
|
return Response.new(ban, ctype = 'json')
|
||||||
|
|
||||||
|
@ -405,7 +449,7 @@ class SoftwareBan(View):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
if not conn.get_software_ban(data['name']):
|
if conn.get_software_ban(data['name']) is None:
|
||||||
return Response.new_error(404, 'Software not banned', 'json')
|
return Response.new_error(404, 'Software not banned', 'json')
|
||||||
|
|
||||||
conn.del_software_ban(data['name'])
|
conn.del_software_ban(data['name'])
|
||||||
|
@ -419,7 +463,7 @@ class User(View):
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
items = []
|
items = []
|
||||||
|
|
||||||
for row in conn.execute('SELECT * FROM users'):
|
for row in conn.get_users():
|
||||||
del row['hash']
|
del row['hash']
|
||||||
items.append(row)
|
items.append(row)
|
||||||
|
|
||||||
|
@ -433,12 +477,16 @@ class User(View):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
if conn.get_user(data['username']):
|
if conn.get_user(data['username']) is not None:
|
||||||
return Response.new_error(404, 'User already exists', 'json')
|
return Response.new_error(404, 'User already exists', 'json')
|
||||||
|
|
||||||
user = conn.put_user(**data)
|
user = conn.put_user(
|
||||||
del user['hash']
|
username = data['username'],
|
||||||
|
password = data['password'],
|
||||||
|
handle = data.get('handle')
|
||||||
|
)
|
||||||
|
|
||||||
|
del user['hash']
|
||||||
return Response.new(user, ctype = 'json')
|
return Response.new(user, ctype = 'json')
|
||||||
|
|
||||||
|
|
||||||
|
@ -449,9 +497,13 @@ class User(View):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
with self.database.session(True) as conn:
|
with self.database.session(True) as conn:
|
||||||
user = conn.put_user(**data)
|
user = conn.put_user(
|
||||||
del user['hash']
|
username = data['username'],
|
||||||
|
password = data['password'],
|
||||||
|
handle = data.get('handle')
|
||||||
|
)
|
||||||
|
|
||||||
|
del user['hash']
|
||||||
return Response.new(user, ctype = 'json')
|
return Response.new(user, ctype = 'json')
|
||||||
|
|
||||||
|
|
||||||
|
@ -462,7 +514,7 @@ class User(View):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
with self.database.session(True) as conn:
|
with self.database.session(True) as conn:
|
||||||
if not conn.get_user(data['username']):
|
if conn.get_user(data['username']) is None:
|
||||||
return Response.new_error(404, 'User does not exist', 'json')
|
return Response.new_error(404, 'User does not exist', 'json')
|
||||||
|
|
||||||
conn.del_user(data['username'])
|
conn.del_user(data['username'])
|
||||||
|
@ -474,7 +526,7 @@ class User(View):
|
||||||
class Whitelist(View):
|
class Whitelist(View):
|
||||||
async def get(self, request: Request) -> Response:
|
async def get(self, request: Request) -> Response:
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
items = tuple(conn.execute('SELECT * FROM whitelist').all())
|
items = tuple(conn.get_domains_whitelist())
|
||||||
|
|
||||||
return Response.new(items, ctype = 'json')
|
return Response.new(items, ctype = 'json')
|
||||||
|
|
||||||
|
@ -485,11 +537,13 @@ class Whitelist(View):
|
||||||
if isinstance(data, Response):
|
if isinstance(data, Response):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
domain = data['domain'].encode('idna').decode()
|
||||||
|
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
if conn.get_domain_whitelist(data['domain']):
|
if conn.get_domain_whitelist(domain) is not None:
|
||||||
return Response.new_error(400, 'Domain already added to whitelist', 'json')
|
return Response.new_error(400, 'Domain already added to whitelist', 'json')
|
||||||
|
|
||||||
item = conn.put_domain_whitelist(**data)
|
item = conn.put_domain_whitelist(domain)
|
||||||
|
|
||||||
return Response.new(item, ctype = 'json')
|
return Response.new(item, ctype = 'json')
|
||||||
|
|
||||||
|
@ -500,10 +554,12 @@ class Whitelist(View):
|
||||||
if isinstance(data, Response):
|
if isinstance(data, Response):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
domain = data['domain'].encode('idna').decode()
|
||||||
|
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
if not conn.get_domain_whitelist(data['domain']):
|
if conn.get_domain_whitelist(domain) is None:
|
||||||
return Response.new_error(404, 'Domain not in whitelist', 'json')
|
return Response.new_error(404, 'Domain not in whitelist', 'json')
|
||||||
|
|
||||||
conn.del_domain_whitelist(data['domain'])
|
conn.del_domain_whitelist(domain)
|
||||||
|
|
||||||
return Response.new({'message': 'Removed domain from whitelist'}, ctype = 'json')
|
return Response.new({'message': 'Removed domain from whitelist'}, ctype = 'json')
|
||||||
|
|
|
@ -1,26 +1,24 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from Crypto.Random import get_random_bytes
|
from Crypto.Random import get_random_bytes
|
||||||
from aiohttp.abc import AbstractView
|
from aiohttp.abc import AbstractView
|
||||||
from aiohttp.hdrs import METH_ALL as METHODS
|
from aiohttp.hdrs import METH_ALL as METHODS
|
||||||
from aiohttp.web import HTTPMethodNotAllowed
|
from aiohttp.web import HTTPMethodNotAllowed, Request
|
||||||
from base64 import b64encode
|
from base64 import b64encode
|
||||||
|
from bsql import Database
|
||||||
|
from collections.abc import Awaitable, Callable, Generator, Sequence, Mapping
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from json.decoder import JSONDecodeError
|
from json.decoder import JSONDecodeError
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from ..cache import Cache
|
||||||
|
from ..config import Config
|
||||||
|
from ..database import Connection
|
||||||
|
from ..http_client import HttpClient
|
||||||
from ..misc import Response, get_app
|
from ..misc import Response, get_app
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from aiohttp.web import Request
|
|
||||||
from collections.abc import Callable, Generator, Sequence, Mapping
|
|
||||||
from bsql import Database
|
|
||||||
from typing import Any
|
|
||||||
from ..application import Application
|
from ..application import Application
|
||||||
from ..cache import Cache
|
|
||||||
from ..config import Config
|
|
||||||
from ..http_client import HttpClient
|
|
||||||
from ..template import Template
|
from ..template import Template
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -29,6 +27,8 @@ if typing.TYPE_CHECKING:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
HandlerCallback = Callable[[Request], Awaitable[Response]]
|
||||||
|
|
||||||
|
|
||||||
VIEWS: list[tuple[str, type[View]]] = []
|
VIEWS: list[tuple[str, type[View]]] = []
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ def convert_data(data: Mapping[str, Any]) -> dict[str, str]:
|
||||||
return {key: str(value) for key, value in data.items()}
|
return {key: str(value) for key, value in data.items()}
|
||||||
|
|
||||||
|
|
||||||
def register_route(*paths: str) -> Callable:
|
def register_route(*paths: str) -> Callable[[type[View]], type[View]]:
|
||||||
def wrapper(view: type[View]) -> type[View]:
|
def wrapper(view: type[View]) -> type[View]:
|
||||||
for path in paths:
|
for path in paths:
|
||||||
VIEWS.append((path, view))
|
VIEWS.append((path, view))
|
||||||
|
@ -63,7 +63,7 @@ class View(AbstractView):
|
||||||
return await view.handlers[method](request, **kwargs)
|
return await view.handlers[method](request, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
async def _run_handler(self, handler: Callable[..., Any], **kwargs: Any) -> Response:
|
async def _run_handler(self, handler: HandlerCallback, **kwargs: Any) -> Response:
|
||||||
self.request['hash'] = b64encode(get_random_bytes(16)).decode('ascii')
|
self.request['hash'] = b64encode(get_random_bytes(16)).decode('ascii')
|
||||||
return await handler(self.request, **self.request.match_info, **kwargs)
|
return await handler(self.request, **self.request.match_info, **kwargs)
|
||||||
|
|
||||||
|
@ -78,7 +78,7 @@ class View(AbstractView):
|
||||||
|
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def handlers(self) -> dict[str, Callable[..., Any]]:
|
def handlers(self) -> dict[str, HandlerCallback]:
|
||||||
data = {}
|
data = {}
|
||||||
|
|
||||||
for method in METHODS:
|
for method in METHODS:
|
||||||
|
@ -112,13 +112,13 @@ class View(AbstractView):
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def database(self) -> Database:
|
def database(self) -> Database[Connection]:
|
||||||
return self.app.database
|
return self.app.database
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def template(self) -> Template:
|
def template(self) -> Template:
|
||||||
return self.app['template']
|
return self.app['template'] # type: ignore[no-any-return]
|
||||||
|
|
||||||
|
|
||||||
async def get_api_data(self,
|
async def get_api_data(self,
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from .base import View, register_route
|
from .base import View, register_route
|
||||||
|
|
||||||
|
@ -10,11 +8,6 @@ from ..database import THEMES
|
||||||
from ..logger import LogLevel
|
from ..logger import LogLevel
|
||||||
from ..misc import Response, get_app
|
from ..misc import Response, get_app
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from aiohttp.web import Request
|
|
||||||
from collections.abc import Callable
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
UNAUTH_ROUTES = {
|
UNAUTH_ROUTES = {
|
||||||
'/',
|
'/',
|
||||||
|
@ -23,7 +16,10 @@ UNAUTH_ROUTES = {
|
||||||
|
|
||||||
|
|
||||||
@web.middleware
|
@web.middleware
|
||||||
async def handle_frontend_path(request: web.Request, handler: Callable) -> Response:
|
async def handle_frontend_path(
|
||||||
|
request: web.Request,
|
||||||
|
handler: Callable[[web.Request], Awaitable[Response]]) -> Response:
|
||||||
|
|
||||||
app = get_app()
|
app = get_app()
|
||||||
|
|
||||||
if request.path in UNAUTH_ROUTES or request.path.startswith('/admin'):
|
if request.path in UNAUTH_ROUTES or request.path.startswith('/admin'):
|
||||||
|
@ -52,7 +48,7 @@ async def handle_frontend_path(request: web.Request, handler: Callable) -> Respo
|
||||||
|
|
||||||
@register_route('/')
|
@register_route('/')
|
||||||
class HomeView(View):
|
class HomeView(View):
|
||||||
async def get(self, request: Request) -> Response:
|
async def get(self, request: web.Request) -> Response:
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
context: dict[str, Any] = {
|
context: dict[str, Any] = {
|
||||||
'instances': tuple(conn.get_inboxes())
|
'instances': tuple(conn.get_inboxes())
|
||||||
|
@ -64,14 +60,14 @@ class HomeView(View):
|
||||||
|
|
||||||
@register_route('/login')
|
@register_route('/login')
|
||||||
class Login(View):
|
class Login(View):
|
||||||
async def get(self, request: Request) -> Response:
|
async def get(self, request: web.Request) -> Response:
|
||||||
data = self.template.render('page/login.haml', self)
|
data = self.template.render('page/login.haml', self)
|
||||||
return Response.new(data, ctype = 'html')
|
return Response.new(data, ctype = 'html')
|
||||||
|
|
||||||
|
|
||||||
@register_route('/logout')
|
@register_route('/logout')
|
||||||
class Logout(View):
|
class Logout(View):
|
||||||
async def get(self, request: Request) -> Response:
|
async def get(self, request: web.Request) -> Response:
|
||||||
with self.database.session(True) as conn:
|
with self.database.session(True) as conn:
|
||||||
conn.del_token(request['token'])
|
conn.del_token(request['token'])
|
||||||
|
|
||||||
|
@ -82,14 +78,14 @@ class Logout(View):
|
||||||
|
|
||||||
@register_route('/admin')
|
@register_route('/admin')
|
||||||
class Admin(View):
|
class Admin(View):
|
||||||
async def get(self, request: Request) -> Response:
|
async def get(self, request: web.Request) -> Response:
|
||||||
return Response.new('', 302, {'Location': '/admin/instances'})
|
return Response.new('', 302, {'Location': '/admin/instances'})
|
||||||
|
|
||||||
|
|
||||||
@register_route('/admin/instances')
|
@register_route('/admin/instances')
|
||||||
class AdminInstances(View):
|
class AdminInstances(View):
|
||||||
async def get(self,
|
async def get(self,
|
||||||
request: Request,
|
request: web.Request,
|
||||||
error: str | None = None,
|
error: str | None = None,
|
||||||
message: str | None = None) -> Response:
|
message: str | None = None) -> Response:
|
||||||
|
|
||||||
|
@ -112,7 +108,7 @@ class AdminInstances(View):
|
||||||
@register_route('/admin/whitelist')
|
@register_route('/admin/whitelist')
|
||||||
class AdminWhitelist(View):
|
class AdminWhitelist(View):
|
||||||
async def get(self,
|
async def get(self,
|
||||||
request: Request,
|
request: web.Request,
|
||||||
error: str | None = None,
|
error: str | None = None,
|
||||||
message: str | None = None) -> Response:
|
message: str | None = None) -> Response:
|
||||||
|
|
||||||
|
@ -134,7 +130,7 @@ class AdminWhitelist(View):
|
||||||
@register_route('/admin/domain_bans')
|
@register_route('/admin/domain_bans')
|
||||||
class AdminDomainBans(View):
|
class AdminDomainBans(View):
|
||||||
async def get(self,
|
async def get(self,
|
||||||
request: Request,
|
request: web.Request,
|
||||||
error: str | None = None,
|
error: str | None = None,
|
||||||
message: str | None = None) -> Response:
|
message: str | None = None) -> Response:
|
||||||
|
|
||||||
|
@ -156,7 +152,7 @@ class AdminDomainBans(View):
|
||||||
@register_route('/admin/software_bans')
|
@register_route('/admin/software_bans')
|
||||||
class AdminSoftwareBans(View):
|
class AdminSoftwareBans(View):
|
||||||
async def get(self,
|
async def get(self,
|
||||||
request: Request,
|
request: web.Request,
|
||||||
error: str | None = None,
|
error: str | None = None,
|
||||||
message: str | None = None) -> Response:
|
message: str | None = None) -> Response:
|
||||||
|
|
||||||
|
@ -178,7 +174,7 @@ class AdminSoftwareBans(View):
|
||||||
@register_route('/admin/users')
|
@register_route('/admin/users')
|
||||||
class AdminUsers(View):
|
class AdminUsers(View):
|
||||||
async def get(self,
|
async def get(self,
|
||||||
request: Request,
|
request: web.Request,
|
||||||
error: str | None = None,
|
error: str | None = None,
|
||||||
message: str | None = None) -> Response:
|
message: str | None = None) -> Response:
|
||||||
|
|
||||||
|
@ -199,11 +195,22 @@ class AdminUsers(View):
|
||||||
|
|
||||||
@register_route('/admin/config')
|
@register_route('/admin/config')
|
||||||
class AdminConfig(View):
|
class AdminConfig(View):
|
||||||
async def get(self, request: Request, message: str | None = None) -> Response:
|
async def get(self, request: web.Request, message: str | None = None) -> Response:
|
||||||
context: dict[str, Any] = {
|
context: dict[str, Any] = {
|
||||||
'themes': tuple(THEMES.keys()),
|
'themes': tuple(THEMES.keys()),
|
||||||
'levels': tuple(level.name for level in LogLevel),
|
'levels': tuple(level.name for level in LogLevel),
|
||||||
'message': message
|
'message': message,
|
||||||
|
'desc': {
|
||||||
|
"name": "Name of the relay to be displayed in the header of the pages and in " +
|
||||||
|
"the actor endpoint.", # noqa: E131
|
||||||
|
"note": "Description of the relay to be displayed on the front page and as the " +
|
||||||
|
"bio in the actor endpoint.",
|
||||||
|
"theme": "Color theme to use on the web pages.",
|
||||||
|
"log_level": "Minimum level of logging messages to print to the console.",
|
||||||
|
"whitelist_enabled": "Only allow instances in the whitelist to be able to follow.",
|
||||||
|
"approval_required": "Require instances not on the whitelist to be approved by " +
|
||||||
|
"and admin. The `whitelist-enabled` setting is ignored when this is enabled."
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
data = self.template.render('page/admin-config.haml', self, **context)
|
data = self.template.render('page/admin-config.haml', self, **context)
|
||||||
|
@ -212,7 +219,7 @@ class AdminConfig(View):
|
||||||
|
|
||||||
@register_route('/manifest.json')
|
@register_route('/manifest.json')
|
||||||
class ManifestJson(View):
|
class ManifestJson(View):
|
||||||
async def get(self, request: Request) -> Response:
|
async def get(self, request: web.Request) -> Response:
|
||||||
with self.database.session(False) as conn:
|
with self.database.session(False) as conn:
|
||||||
config = conn.get_config_all()
|
config = conn.get_config_all()
|
||||||
theme = THEMES[config.theme]
|
theme = THEMES[config.theme]
|
||||||
|
@ -235,7 +242,7 @@ class ManifestJson(View):
|
||||||
|
|
||||||
@register_route('/theme/{theme}.css')
|
@register_route('/theme/{theme}.css')
|
||||||
class ThemeCss(View):
|
class ThemeCss(View):
|
||||||
async def get(self, request: Request, theme: str) -> Response:
|
async def get(self, request: web.Request, theme: str) -> Response:
|
||||||
try:
|
try:
|
||||||
context: dict[str, Any] = {
|
context: dict[str, Any] = {
|
||||||
'theme': THEMES[theme]
|
'theme': THEMES[theme]
|
||||||
|
|
|
@ -1,9 +1,7 @@
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import aputils
|
import aputils
|
||||||
import subprocess
|
import subprocess
|
||||||
import typing
|
|
||||||
|
|
||||||
|
from aiohttp.web import Request
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from .base import View, register_route
|
from .base import View, register_route
|
||||||
|
@ -11,9 +9,6 @@ from .base import View, register_route
|
||||||
from .. import __version__
|
from .. import __version__
|
||||||
from ..misc import Response
|
from ..misc import Response
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from aiohttp.web import Request
|
|
||||||
|
|
||||||
|
|
||||||
VERSION = __version__
|
VERSION = __version__
|
||||||
|
|
||||||
|
|
150
relay/workers.py
Normal file
150
relay/workers.py
Normal file
|
@ -0,0 +1,150 @@
|
||||||
|
import asyncio
|
||||||
|
import traceback
|
||||||
|
import typing
|
||||||
|
|
||||||
|
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
|
||||||
|
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from multiprocessing import Event, Process, Queue, Value
|
||||||
|
from multiprocessing.synchronize import Event as EventType
|
||||||
|
from pathlib import Path
|
||||||
|
from queue import Empty, Queue as QueueType
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from . import application, logger as logging
|
||||||
|
from .database.schema import Instance
|
||||||
|
from .http_client import HttpClient
|
||||||
|
from .misc import IS_WINDOWS, Message, get_app
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from .multiprocessing.synchronize import Syncronized
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QueueItem:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PostItem(QueueItem):
|
||||||
|
inbox: str
|
||||||
|
message: Message
|
||||||
|
instance: Instance | None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def domain(self) -> str:
|
||||||
|
return urlparse(self.inbox).netloc
|
||||||
|
|
||||||
|
|
||||||
|
class PushWorker(Process):
|
||||||
|
client: HttpClient
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(self, queue: QueueType[QueueItem], log_level: "Syncronized[str]") -> None:
|
||||||
|
Process.__init__(self)
|
||||||
|
|
||||||
|
self.queue: QueueType[QueueItem] = queue
|
||||||
|
self.shutdown: EventType = Event()
|
||||||
|
self.path: Path = get_app().config.path
|
||||||
|
self.log_level: "Syncronized[str]" = log_level
|
||||||
|
self._log_level_changed: EventType = Event()
|
||||||
|
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
self.shutdown.set()
|
||||||
|
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
|
asyncio.run(self.handle_queue())
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_queue(self) -> None:
|
||||||
|
if IS_WINDOWS:
|
||||||
|
app = application.Application(self.path)
|
||||||
|
self.client = app.client
|
||||||
|
|
||||||
|
self.client.open()
|
||||||
|
app.database.connect()
|
||||||
|
app.cache.setup()
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.client = HttpClient()
|
||||||
|
self.client.open()
|
||||||
|
|
||||||
|
logging.verbose("[%i] Starting worker", self.pid)
|
||||||
|
|
||||||
|
while not self.shutdown.is_set():
|
||||||
|
try:
|
||||||
|
if self._log_level_changed.is_set():
|
||||||
|
logging.set_level(logging.LogLevel.parse(self.log_level.value))
|
||||||
|
self._log_level_changed.clear()
|
||||||
|
|
||||||
|
item = self.queue.get(block=True, timeout=0.1)
|
||||||
|
|
||||||
|
if isinstance(item, PostItem):
|
||||||
|
asyncio.create_task(self.handle_post(item))
|
||||||
|
|
||||||
|
except Empty:
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
if IS_WINDOWS:
|
||||||
|
app.database.disconnect()
|
||||||
|
app.cache.close()
|
||||||
|
|
||||||
|
await self.client.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_post(self, item: PostItem) -> None:
|
||||||
|
try:
|
||||||
|
await self.client.post(item.inbox, item.message, item.instance)
|
||||||
|
|
||||||
|
except AsyncTimeoutError:
|
||||||
|
logging.error('Timeout when pushing to %s', item.domain)
|
||||||
|
|
||||||
|
except ClientConnectionError as e:
|
||||||
|
logging.error('Failed to connect to %s for message push: %s', item.domain, str(e))
|
||||||
|
|
||||||
|
except ClientSSLError as e:
|
||||||
|
logging.error('SSL error when pushing to %s: %s', item.domain, str(e))
|
||||||
|
|
||||||
|
|
||||||
|
class PushWorkers(list[PushWorker]):
|
||||||
|
def __init__(self, count: int) -> None:
|
||||||
|
self.queue: QueueType[QueueItem] = Queue() # type: ignore[assignment]
|
||||||
|
self._log_level: "Syncronized[str]" = Value("i", logging.get_level())
|
||||||
|
self._count: int = count
|
||||||
|
|
||||||
|
|
||||||
|
def push_item(self, item: QueueItem) -> None:
|
||||||
|
self.queue.put(item)
|
||||||
|
|
||||||
|
|
||||||
|
def push_message(self, inbox: str, message: Message, instance: Instance) -> None:
|
||||||
|
self.queue.put(PostItem(inbox, message, instance))
|
||||||
|
|
||||||
|
|
||||||
|
def set_log_level(self, value: logging.LogLevel) -> None:
|
||||||
|
self._log_level.value = value
|
||||||
|
|
||||||
|
for worker in self:
|
||||||
|
worker._log_level_changed.set()
|
||||||
|
|
||||||
|
|
||||||
|
def start(self) -> None:
|
||||||
|
if len(self) > 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
for _ in range(self._count):
|
||||||
|
worker = PushWorker(self.queue, self._log_level)
|
||||||
|
worker.start()
|
||||||
|
self.append(worker)
|
||||||
|
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
for worker in self:
|
||||||
|
worker.stop()
|
||||||
|
|
||||||
|
self.clear()
|
Loading…
Reference in a new issue