mirror of
https://git.pleroma.social/pleroma/relay.git
synced 2024-12-26 04:41:07 +00:00
Merge branch 'dev' into 'main'
version 0.3.3 See merge request pleroma/relay!59
This commit is contained in:
commit
e0ca93ab93
49
dev.py
49
dev.py
|
@ -1,25 +1,38 @@
|
|||
#!/usr/bin/env python3
|
||||
import click
|
||||
import platform
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import tomllib
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from importlib.util import find_spec
|
||||
from pathlib import Path
|
||||
from relay import __version__, logger as logging
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any, Sequence
|
||||
|
||||
try:
|
||||
from watchdog.observers import Observer
|
||||
from watchdog.events import FileSystemEvent, PatternMatchingEventHandler
|
||||
import tomllib
|
||||
|
||||
except ImportError:
|
||||
class PatternMatchingEventHandler: # type: ignore
|
||||
pass
|
||||
if find_spec("toml") is None:
|
||||
subprocess.run([sys.executable, "-m", "pip", "install", "toml"])
|
||||
|
||||
import toml as tomllib # type: ignore[no-redef]
|
||||
|
||||
if None in [find_spec("click"), find_spec("watchdog")]:
|
||||
CMD = [sys.executable, "-m", "pip", "install", "click >= 8.1.0", "watchdog >= 4.0.0"]
|
||||
PROC = subprocess.run(CMD, check = False)
|
||||
|
||||
if PROC.returncode != 0:
|
||||
sys.exit()
|
||||
|
||||
print("Successfully installed dependencies")
|
||||
|
||||
import click
|
||||
|
||||
from watchdog.observers import Observer
|
||||
from watchdog.events import FileSystemEvent, PatternMatchingEventHandler
|
||||
|
||||
|
||||
REPO = Path(__file__).parent
|
||||
|
@ -37,13 +50,11 @@ def cli() -> None:
|
|||
@cli.command('install')
|
||||
@click.option('--no-dev', '-d', is_flag = True, help = 'Do not install development dependencies')
|
||||
def cli_install(no_dev: bool) -> None:
|
||||
with open('pyproject.toml', 'rb') as fd:
|
||||
data = tomllib.load(fd)
|
||||
with open('pyproject.toml', 'r', encoding = 'utf-8') as fd:
|
||||
data = tomllib.loads(fd.read())
|
||||
|
||||
deps = data['project']['dependencies']
|
||||
|
||||
if not no_dev:
|
||||
deps.extend(data['project']['optional-dependencies']['dev'])
|
||||
deps.extend(data['project']['optional-dependencies']['dev'])
|
||||
|
||||
subprocess.run([sys.executable, '-m', 'pip', 'install', '-U', *deps], check = False)
|
||||
|
||||
|
@ -60,7 +71,7 @@ def cli_lint(path: Path, watch: bool) -> None:
|
|||
return
|
||||
|
||||
flake8 = [sys.executable, '-m', 'flake8', "dev.py", str(path)]
|
||||
mypy = [sys.executable, '-m', 'mypy', "dev.py", str(path)]
|
||||
mypy = [sys.executable, '-m', 'mypy', '--python-version', '3.12', 'dev.py', str(path)]
|
||||
|
||||
click.echo('----- flake8 -----')
|
||||
subprocess.run(flake8)
|
||||
|
@ -89,6 +100,8 @@ def cli_clean() -> None:
|
|||
|
||||
@cli.command('build')
|
||||
def cli_build() -> None:
|
||||
from relay import __version__
|
||||
|
||||
with TemporaryDirectory() as tmp:
|
||||
arch = 'amd64' if sys.maxsize >= 2**32 else 'i386'
|
||||
cmd = [
|
||||
|
@ -171,7 +184,7 @@ class WatchHandler(PatternMatchingEventHandler):
|
|||
if proc.poll() is not None:
|
||||
continue
|
||||
|
||||
logging.info(f'Terminating process {proc.pid}')
|
||||
print(f'Terminating process {proc.pid}')
|
||||
proc.terminate()
|
||||
sec = 0.0
|
||||
|
||||
|
@ -180,11 +193,11 @@ class WatchHandler(PatternMatchingEventHandler):
|
|||
sec += 0.1
|
||||
|
||||
if sec >= 5:
|
||||
logging.error('Failed to terminate. Killing process...')
|
||||
print('Failed to terminate. Killing process...')
|
||||
proc.kill()
|
||||
break
|
||||
|
||||
logging.info('Process terminated')
|
||||
print('Process terminated')
|
||||
|
||||
|
||||
def run_procs(self, restart: bool = False) -> None:
|
||||
|
@ -200,13 +213,13 @@ class WatchHandler(PatternMatchingEventHandler):
|
|||
self.procs = []
|
||||
|
||||
for cmd in self.commands:
|
||||
logging.info('Running command: %s', ' '.join(cmd))
|
||||
print('Running command:', ' '.join(cmd))
|
||||
subprocess.run(cmd)
|
||||
|
||||
else:
|
||||
self.procs = list(subprocess.Popen(cmd) for cmd in self.commands)
|
||||
pids = (str(proc.pid) for proc in self.procs)
|
||||
logging.info('Started processes with PIDs: %s', ', '.join(pids))
|
||||
print('Started processes with PIDs:', ', '.join(pids))
|
||||
|
||||
|
||||
def on_any_event(self, event: FileSystemEvent) -> None:
|
||||
|
|
|
@ -16,7 +16,8 @@ Run the relay.
|
|||
|
||||
## Setup
|
||||
|
||||
Run the setup wizard to configure your relay.
|
||||
Run the setup wizard to configure your relay. For the PostgreSQL backend, the database has to be
|
||||
created first.
|
||||
|
||||
activityrelay setup
|
||||
|
||||
|
@ -29,6 +30,16 @@ not specified, the config will get backed up as `relay.backup.yaml` before conve
|
|||
activityrelay convert --old-config relaycfg.yaml
|
||||
|
||||
|
||||
## Switch Backend
|
||||
|
||||
Change the database backend from the current one to the other. The config will be updated after
|
||||
running the command.
|
||||
|
||||
Note: If switching to PostgreSQL, make sure the database exists first.
|
||||
|
||||
activityrelay switch-backend
|
||||
|
||||
|
||||
## Edit Config
|
||||
|
||||
Open the config file in a text editor. If an editor is not specified with `--editor`, the default
|
||||
|
|
|
@ -9,30 +9,27 @@ license = {text = "AGPLv3"}
|
|||
classifiers = [
|
||||
"Environment :: Console",
|
||||
"License :: OSI Approved :: GNU Affero General Public License v3",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.12"
|
||||
]
|
||||
dependencies = [
|
||||
"activitypub-utils == 0.3.1",
|
||||
"activitypub-utils >= 0.3.2, < 0.4",
|
||||
"aiohttp >= 3.9.5",
|
||||
"aiohttp-swagger[performance] == 1.0.16",
|
||||
"argon2-cffi == 23.1.0",
|
||||
"barkshark-lib >= 0.1.3-1",
|
||||
"barkshark-sql == 0.1.4-1",
|
||||
"click >= 8.1.2",
|
||||
"barkshark-lib >= 0.2.3, < 0.3.0",
|
||||
"barkshark-sql >= 0.2.0, < 0.3.0",
|
||||
"click == 8.1.2",
|
||||
"hiredis == 2.3.2",
|
||||
"idna == 3.4",
|
||||
"jinja2-haml == 0.3.5",
|
||||
"markdown == 3.6",
|
||||
"platformdirs == 4.2.2",
|
||||
"pyyaml >= 6.0",
|
||||
"redis == 5.0.5",
|
||||
"importlib-resources == 6.4.0; python_version < '3.9'"
|
||||
"pyyaml == 6.0.1",
|
||||
"redis == 5.0.7"
|
||||
]
|
||||
requires-python = ">=3.8"
|
||||
requires-python = ">=3.10"
|
||||
dynamic = ["version"]
|
||||
|
||||
[project.readme]
|
||||
|
@ -49,11 +46,10 @@ activityrelay = "relay.manage:main"
|
|||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"flake8 == 7.0.0",
|
||||
"mypy == 1.10.0",
|
||||
"pyinstaller == 6.8.0",
|
||||
"watchdog == 4.0.1",
|
||||
"typing-extensions >= 4.12.2; python_version < '3.11.0'"
|
||||
"flake8 == 7.1.0",
|
||||
"mypy == 1.11.1",
|
||||
"pyinstaller == 6.10.0",
|
||||
"watchdog == 4.0.2"
|
||||
]
|
||||
|
||||
[tool.setuptools]
|
||||
|
@ -104,7 +100,3 @@ implicit_reexport = true
|
|||
[[tool.mypy.overrides]]
|
||||
module = "blib"
|
||||
implicit_reexport = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "bsql"
|
||||
implicit_reexport = true
|
||||
|
|
|
@ -1 +1 @@
|
|||
__version__ = '0.3.2'
|
||||
__version__ = '0.3.3'
|
||||
|
|
|
@ -6,29 +6,33 @@ import signal
|
|||
import time
|
||||
import traceback
|
||||
|
||||
from Crypto.Random import get_random_bytes
|
||||
from aiohttp import web
|
||||
from aiohttp.web import StaticResource
|
||||
from aiohttp.web import HTTPException, StaticResource
|
||||
from aiohttp_swagger import setup_swagger
|
||||
from aputils.signer import Signer
|
||||
from bsql import Database, Row
|
||||
from base64 import b64encode
|
||||
from blib import File, HttpError, port_check
|
||||
from bsql import Database
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import datetime, timedelta
|
||||
from mimetypes import guess_type
|
||||
from pathlib import Path
|
||||
from queue import Empty
|
||||
from threading import Event, Thread
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from . import logger as logging
|
||||
from .cache import Cache, get_cache
|
||||
from .config import Config
|
||||
from .database import Connection, get_database
|
||||
from .database.schema import Instance
|
||||
from .http_client import HttpClient
|
||||
from .misc import IS_WINDOWS, Message, Response, check_open_port, get_resource
|
||||
from .misc import JSON_PATHS, TOKEN_PATHS, Message, Response
|
||||
from .template import Template
|
||||
from .views import VIEWS
|
||||
from .views.api import handle_api_path
|
||||
from .views.frontend import handle_frontend_path
|
||||
from .workers import PushWorkers
|
||||
|
||||
|
||||
def get_csp(request: web.Request) -> str:
|
||||
|
@ -54,9 +58,9 @@ class Application(web.Application):
|
|||
def __init__(self, cfgpath: Path | None, dev: bool = False):
|
||||
web.Application.__init__(self,
|
||||
middlewares = [
|
||||
handle_api_path, # type: ignore[list-item]
|
||||
handle_response_headers, # type: ignore[list-item]
|
||||
handle_frontend_path, # type: ignore[list-item]
|
||||
handle_response_headers # type: ignore[list-item]
|
||||
handle_api_path # type: ignore[list-item]
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -75,7 +79,7 @@ class Application(web.Application):
|
|||
self['cache'].setup()
|
||||
self['template'] = Template(self)
|
||||
self['push_queue'] = multiprocessing.Queue()
|
||||
self['workers'] = []
|
||||
self['workers'] = PushWorkers(self.config.workers)
|
||||
|
||||
self.cache.setup()
|
||||
self.on_cleanup.append(handle_cleanup) # type: ignore
|
||||
|
@ -86,33 +90,33 @@ class Application(web.Application):
|
|||
setup_swagger(
|
||||
self,
|
||||
ui_version = 3,
|
||||
swagger_from_file = get_resource('data/swagger.yaml')
|
||||
swagger_from_file = File.from_resource('relay', 'data/swagger.yaml')
|
||||
)
|
||||
|
||||
|
||||
@property
|
||||
def cache(self) -> Cache:
|
||||
return self['cache'] # type: ignore[no-any-return]
|
||||
return cast(Cache, self['cache'])
|
||||
|
||||
|
||||
@property
|
||||
def client(self) -> HttpClient:
|
||||
return self['client'] # type: ignore[no-any-return]
|
||||
return cast(HttpClient, self['client'])
|
||||
|
||||
|
||||
@property
|
||||
def config(self) -> Config:
|
||||
return self['config'] # type: ignore[no-any-return]
|
||||
return cast(Config, self['config'])
|
||||
|
||||
|
||||
@property
|
||||
def database(self) -> Database[Connection]:
|
||||
return self['database'] # type: ignore[no-any-return]
|
||||
return cast(Database[Connection], self['database'])
|
||||
|
||||
|
||||
@property
|
||||
def signer(self) -> Signer:
|
||||
return self['signer'] # type: ignore[no-any-return]
|
||||
return cast(Signer, self['signer'])
|
||||
|
||||
|
||||
@signer.setter
|
||||
|
@ -126,7 +130,7 @@ class Application(web.Application):
|
|||
|
||||
@property
|
||||
def template(self) -> Template:
|
||||
return self['template'] # type: ignore[no-any-return]
|
||||
return cast(Template, self['template'])
|
||||
|
||||
|
||||
@property
|
||||
|
@ -139,16 +143,23 @@ class Application(web.Application):
|
|||
return timedelta(seconds=uptime.seconds)
|
||||
|
||||
|
||||
def push_message(self, inbox: str, message: Message, instance: Row) -> None:
|
||||
self['push_queue'].put((inbox, message, instance))
|
||||
@property
|
||||
def workers(self) -> PushWorkers:
|
||||
return cast(PushWorkers, self['workers'])
|
||||
|
||||
|
||||
def push_message(self, inbox: str, message: Message, instance: Instance) -> None:
|
||||
self['workers'].push_message(inbox, message, instance)
|
||||
|
||||
|
||||
def register_static_routes(self) -> None:
|
||||
if self['dev']:
|
||||
static = StaticResource('/static', get_resource('frontend/static'))
|
||||
static = StaticResource('/static', File.from_resource('relay', 'frontend/static'))
|
||||
|
||||
else:
|
||||
static = CachedStaticResource('/static', get_resource('frontend/static'))
|
||||
static = CachedStaticResource(
|
||||
'/static', Path(File.from_resource('relay', 'frontend/static'))
|
||||
)
|
||||
|
||||
self.router.register_resource(static)
|
||||
|
||||
|
@ -161,7 +172,7 @@ class Application(web.Application):
|
|||
host = self.config.listen
|
||||
port = self.config.port
|
||||
|
||||
if not check_open_port(host, port):
|
||||
if port_check(port, '127.0.0.1' if host == '0.0.0.0' else host):
|
||||
logging.error(f'A server is already running on {host}:{port}')
|
||||
return
|
||||
|
||||
|
@ -195,12 +206,7 @@ class Application(web.Application):
|
|||
self['cache'].setup()
|
||||
self['cleanup_thread'] = CacheCleanupThread(self)
|
||||
self['cleanup_thread'].start()
|
||||
|
||||
for _ in range(self.config.workers):
|
||||
worker = PushWorker(self['push_queue'])
|
||||
worker.start()
|
||||
|
||||
self['workers'].append(worker)
|
||||
self['workers'].start()
|
||||
|
||||
runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"')
|
||||
await runner.setup()
|
||||
|
@ -220,15 +226,13 @@ class Application(web.Application):
|
|||
|
||||
await site.stop()
|
||||
|
||||
for worker in self['workers']:
|
||||
worker.stop()
|
||||
self['workers'].stop()
|
||||
|
||||
self.set_signal_handler(False)
|
||||
|
||||
self['starttime'] = None
|
||||
self['running'] = False
|
||||
self['cleanup_thread'].stop()
|
||||
self['workers'].clear()
|
||||
self['database'].disconnect()
|
||||
self['cache'].close()
|
||||
|
||||
|
@ -290,56 +294,15 @@ class CacheCleanupThread(Thread):
|
|||
self.running.clear()
|
||||
|
||||
|
||||
class PushWorker(multiprocessing.Process):
|
||||
def __init__(self, queue: multiprocessing.Queue[tuple[str, Message, Row]]) -> None:
|
||||
if Application.DEFAULT is None:
|
||||
raise RuntimeError('Application not setup yet')
|
||||
def format_error(request: web.Request, error: HttpError) -> Response:
|
||||
app: Application = request.app # type: ignore[assignment]
|
||||
|
||||
multiprocessing.Process.__init__(self)
|
||||
if request.path.startswith(JSON_PATHS) or 'json' in request.headers.get('accept', ''):
|
||||
return Response.new({'error': error.message}, error.status, ctype = 'json')
|
||||
|
||||
self.queue = queue
|
||||
self.shutdown = multiprocessing.Event()
|
||||
self.path = Application.DEFAULT.config.path
|
||||
|
||||
|
||||
def stop(self) -> None:
|
||||
self.shutdown.set()
|
||||
|
||||
|
||||
def run(self) -> None:
|
||||
asyncio.run(self.handle_queue())
|
||||
|
||||
|
||||
async def handle_queue(self) -> None:
|
||||
if IS_WINDOWS:
|
||||
app = Application(self.path)
|
||||
client = app.client
|
||||
|
||||
client.open()
|
||||
app.database.connect()
|
||||
app.cache.setup()
|
||||
|
||||
else:
|
||||
client = HttpClient()
|
||||
client.open()
|
||||
|
||||
while not self.shutdown.is_set():
|
||||
try:
|
||||
inbox, message, instance = self.queue.get(block=True, timeout=0.1)
|
||||
asyncio.create_task(client.post(inbox, message, instance))
|
||||
|
||||
except Empty:
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# make sure an exception doesn't bring down the worker
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
if IS_WINDOWS:
|
||||
app.database.disconnect()
|
||||
app.cache.close()
|
||||
|
||||
await client.close()
|
||||
else:
|
||||
body = app.template.render('page/error.haml', request, e = error)
|
||||
return Response.new(body, error.status, ctype = 'html')
|
||||
|
||||
|
||||
@web.middleware
|
||||
|
@ -347,14 +310,60 @@ async def handle_response_headers(
|
|||
request: web.Request,
|
||||
handler: Callable[[web.Request], Awaitable[Response]]) -> Response:
|
||||
|
||||
resp = await handler(request)
|
||||
request['hash'] = b64encode(get_random_bytes(16)).decode('ascii')
|
||||
request['token'] = None
|
||||
request['user'] = None
|
||||
|
||||
app: Application = request.app # type: ignore[assignment]
|
||||
|
||||
if request.path == "/" or request.path.startswith(TOKEN_PATHS):
|
||||
with app.database.session() as conn:
|
||||
tokens = (
|
||||
request.headers.get('Authorization', '').replace('Bearer', '').strip(),
|
||||
request.cookies.get('user-token')
|
||||
)
|
||||
|
||||
for token in tokens:
|
||||
if not token:
|
||||
continue
|
||||
|
||||
request['token'] = conn.get_app_by_token(token)
|
||||
|
||||
if request['token'] is not None:
|
||||
request['user'] = conn.get_user(request['token'].user)
|
||||
|
||||
break
|
||||
|
||||
try:
|
||||
resp = await handler(request)
|
||||
|
||||
except HttpError as e:
|
||||
resp = format_error(request, e)
|
||||
|
||||
except HTTPException as e:
|
||||
if e.status == 404:
|
||||
try:
|
||||
text = (e.text or "").split(":")[1].strip()
|
||||
|
||||
except IndexError:
|
||||
text = e.text or ""
|
||||
|
||||
resp = format_error(request, HttpError(e.status, text))
|
||||
|
||||
else:
|
||||
raise
|
||||
|
||||
except Exception:
|
||||
resp = format_error(request, HttpError(500, 'Internal server error'))
|
||||
traceback.print_exc()
|
||||
|
||||
resp.headers['Server'] = 'ActivityRelay'
|
||||
|
||||
# Still have to figure out how csp headers work
|
||||
if resp.content_type == 'text/html' and not request.path.startswith("/api"):
|
||||
resp.headers['Content-Security-Policy'] = get_csp(request)
|
||||
|
||||
if not request.app['dev'] and request.path.endswith(('.css', '.js')):
|
||||
if not request.app['dev'] and request.path.endswith(('.css', '.js', '.woff2')):
|
||||
# cache for 2 weeks
|
||||
resp.headers['Cache-Control'] = 'public,max-age=1209600,immutable'
|
||||
|
||||
|
|
106
relay/cache.py
106
relay/cache.py
|
@ -4,15 +4,16 @@ import json
|
|||
import os
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from bsql import Database
|
||||
from blib import Date, convert_to_boolean
|
||||
from bsql import Database, Row
|
||||
from collections.abc import Callable, Iterator
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from datetime import timedelta, timezone
|
||||
from redis import Redis
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
|
||||
from .database import Connection, get_database
|
||||
from .misc import Message, boolean
|
||||
from .misc import Message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .application import Application
|
||||
|
@ -25,12 +26,20 @@ BACKENDS: dict[str, type[Cache]] = {}
|
|||
CONVERTERS: dict[str, tuple[SerializerCallback, DeserializerCallback]] = {
|
||||
'str': (str, str),
|
||||
'int': (str, int),
|
||||
'bool': (str, boolean),
|
||||
'bool': (str, convert_to_boolean),
|
||||
'json': (json.dumps, json.loads),
|
||||
'message': (lambda x: x.to_json(), Message.parse)
|
||||
}
|
||||
|
||||
|
||||
class RedisConnectType(TypedDict):
|
||||
client_name: str
|
||||
decode_responses: bool
|
||||
username: str | None
|
||||
password: str | None
|
||||
db: int
|
||||
|
||||
|
||||
def get_cache(app: Application) -> Cache:
|
||||
return BACKENDS[app.config.ca_type](app)
|
||||
|
||||
|
@ -57,12 +66,14 @@ class Item:
|
|||
key: str
|
||||
value: Any
|
||||
value_type: str
|
||||
updated: datetime
|
||||
updated: Date
|
||||
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if isinstance(self.updated, str): # type: ignore[unreachable]
|
||||
self.updated = datetime.fromisoformat(self.updated) # type: ignore[unreachable]
|
||||
self.updated = Date.parse(self.updated)
|
||||
|
||||
if self.updated.tzinfo is None:
|
||||
self.updated = self.updated.replace(tzinfo = timezone.utc)
|
||||
|
||||
|
||||
@classmethod
|
||||
|
@ -70,15 +81,11 @@ class Item:
|
|||
data = cls(*args)
|
||||
data.value = deserialize_value(data.value, data.value_type)
|
||||
|
||||
if not isinstance(data.updated, datetime):
|
||||
data.updated = datetime.fromtimestamp(data.updated, tz = timezone.utc) # type: ignore
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def older_than(self, hours: int) -> bool:
|
||||
delta = datetime.now(tz = timezone.utc) - self.updated
|
||||
return (delta.total_seconds()) > hours * 3600
|
||||
return self.updated + timedelta(hours = hours) < Date.new_utc()
|
||||
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
|
@ -172,7 +179,7 @@ class SqlCache(Cache):
|
|||
|
||||
with self._db.session(False) as conn:
|
||||
with conn.run('get-cache-item', params) as cur:
|
||||
if not (row := cur.one()):
|
||||
if not (row := cur.one(Row)):
|
||||
raise KeyError(f'{namespace}:{key}')
|
||||
|
||||
row.pop('id', None)
|
||||
|
@ -206,14 +213,16 @@ class SqlCache(Cache):
|
|||
'key': key,
|
||||
'value': serialize_value(value, value_type),
|
||||
'type': value_type,
|
||||
'date': datetime.now(tz = timezone.utc)
|
||||
'date': Date.new_utc()
|
||||
}
|
||||
|
||||
with self._db.session(True) as conn:
|
||||
with conn.run('set-cache-item', params) as cur:
|
||||
row = cur.one()
|
||||
row.pop('id', None) # type: ignore[union-attr]
|
||||
return Item.from_data(*tuple(row.values())) # type: ignore[union-attr]
|
||||
if (row := cur.one(Row)) is None:
|
||||
raise RuntimeError("Cache item not set")
|
||||
|
||||
row.pop('id', None)
|
||||
return Item.from_data(*tuple(row.values()))
|
||||
|
||||
|
||||
def delete(self, namespace: str, key: str) -> None:
|
||||
|
@ -234,11 +243,10 @@ class SqlCache(Cache):
|
|||
if self._db is None:
|
||||
raise RuntimeError("Database has not been setup")
|
||||
|
||||
limit = datetime.now(tz = timezone.utc) - timedelta(days = days)
|
||||
params = {"limit": limit.timestamp()}
|
||||
date = Date.new_utc() - timedelta(days = days)
|
||||
|
||||
with self._db.session(True) as conn:
|
||||
with conn.execute("DELETE FROM cache WHERE updated < :limit", params):
|
||||
with conn.execute("DELETE FROM cache WHERE updated < :limit", {"limit": date}):
|
||||
pass
|
||||
|
||||
|
||||
|
@ -278,7 +286,7 @@ class RedisCache(Cache):
|
|||
|
||||
def __init__(self, app: Application):
|
||||
Cache.__init__(self, app)
|
||||
self._rd: Redis = None # type: ignore
|
||||
self._rd: Redis | None = None
|
||||
|
||||
|
||||
@property
|
||||
|
@ -291,28 +299,38 @@ class RedisCache(Cache):
|
|||
|
||||
|
||||
def get(self, namespace: str, key: str) -> Item:
|
||||
if self._rd is None:
|
||||
raise ConnectionError("Not connected")
|
||||
|
||||
key_name = self.get_key_name(namespace, key)
|
||||
|
||||
if not (raw_value := self._rd.get(key_name)):
|
||||
raise KeyError(f'{namespace}:{key}')
|
||||
|
||||
value_type, updated, value = raw_value.split(':', 2) # type: ignore
|
||||
value_type, updated, value = raw_value.split(':', 2) # type: ignore[union-attr]
|
||||
|
||||
return Item.from_data(
|
||||
namespace,
|
||||
key,
|
||||
value,
|
||||
value_type,
|
||||
datetime.fromtimestamp(float(updated), tz = timezone.utc)
|
||||
Date.parse(float(updated))
|
||||
)
|
||||
|
||||
|
||||
def get_keys(self, namespace: str) -> Iterator[str]:
|
||||
if self._rd is None:
|
||||
raise ConnectionError("Not connected")
|
||||
|
||||
for key in self._rd.scan_iter(self.get_key_name(namespace, '*')):
|
||||
*_, key_name = key.split(':', 2)
|
||||
yield key_name
|
||||
|
||||
|
||||
def get_namespaces(self) -> Iterator[str]:
|
||||
if self._rd is None:
|
||||
raise ConnectionError("Not connected")
|
||||
|
||||
namespaces = []
|
||||
|
||||
for key in self._rd.scan_iter(f'{self.prefix}:*'):
|
||||
|
@ -324,7 +342,10 @@ class RedisCache(Cache):
|
|||
|
||||
|
||||
def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> Item:
|
||||
date = datetime.now(tz = timezone.utc).timestamp()
|
||||
if self._rd is None:
|
||||
raise ConnectionError("Not connected")
|
||||
|
||||
date = Date.new_utc().timestamp()
|
||||
value = serialize_value(value, value_type)
|
||||
|
||||
self._rd.set(
|
||||
|
@ -336,11 +357,17 @@ class RedisCache(Cache):
|
|||
|
||||
|
||||
def delete(self, namespace: str, key: str) -> None:
|
||||
if self._rd is None:
|
||||
raise ConnectionError("Not connected")
|
||||
|
||||
self._rd.delete(self.get_key_name(namespace, key))
|
||||
|
||||
|
||||
def delete_old(self, days: int = 14) -> None:
|
||||
limit = datetime.now(tz = timezone.utc) - timedelta(days = days)
|
||||
if self._rd is None:
|
||||
raise ConnectionError("Not connected")
|
||||
|
||||
limit = Date.new_utc() - timedelta(days = days)
|
||||
|
||||
for full_key in self._rd.scan_iter(f'{self.prefix}:*'):
|
||||
_, namespace, key = full_key.split(':', 2)
|
||||
|
@ -351,14 +378,17 @@ class RedisCache(Cache):
|
|||
|
||||
|
||||
def clear(self) -> None:
|
||||
if self._rd is None:
|
||||
raise ConnectionError("Not connected")
|
||||
|
||||
self._rd.delete(f"{self.prefix}:*")
|
||||
|
||||
|
||||
def setup(self) -> None:
|
||||
if self._rd:
|
||||
if self._rd is not None:
|
||||
return
|
||||
|
||||
options = {
|
||||
options: RedisConnectType = {
|
||||
'client_name': f'ActivityRelay_{self.app.config.domain}',
|
||||
'decode_responses': True,
|
||||
'username': self.app.config.rd_user,
|
||||
|
@ -367,18 +397,22 @@ class RedisCache(Cache):
|
|||
}
|
||||
|
||||
if os.path.exists(self.app.config.rd_host):
|
||||
options['unix_socket_path'] = self.app.config.rd_host
|
||||
self._rd = Redis(
|
||||
unix_socket_path = self.app.config.rd_host,
|
||||
**options
|
||||
)
|
||||
return
|
||||
|
||||
else:
|
||||
options['host'] = self.app.config.rd_host
|
||||
options['port'] = self.app.config.rd_port
|
||||
|
||||
self._rd = Redis(**options) # type: ignore
|
||||
self._rd = Redis(
|
||||
host = self.app.config.rd_host,
|
||||
port = self.app.config.rd_port,
|
||||
**options
|
||||
)
|
||||
|
||||
|
||||
def close(self) -> None:
|
||||
if not self._rd:
|
||||
return
|
||||
|
||||
self._rd.close() # type: ignore
|
||||
self._rd = None # type: ignore
|
||||
self._rd.close() # type: ignore[no-untyped-call]
|
||||
self._rd = None
|
||||
|
|
|
@ -2,13 +2,12 @@ import json
|
|||
import os
|
||||
import yaml
|
||||
|
||||
from blib import convert_to_boolean
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from .misc import boolean
|
||||
|
||||
|
||||
class RelayConfig(dict[str, Any]):
|
||||
def __init__(self, path: str):
|
||||
|
@ -31,7 +30,7 @@ class RelayConfig(dict[str, Any]):
|
|||
|
||||
elif key == 'whitelist_enabled':
|
||||
if not isinstance(value, bool):
|
||||
value = boolean(value)
|
||||
value = convert_to_boolean(value)
|
||||
|
||||
super().__setitem__(key, value)
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import getpass
|
||||
import os
|
||||
import platform
|
||||
|
@ -6,16 +8,13 @@ import yaml
|
|||
from dataclasses import asdict, dataclass, fields
|
||||
from pathlib import Path
|
||||
from platformdirs import user_config_dir
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from .misc import IS_DOCKER
|
||||
|
||||
try:
|
||||
if TYPE_CHECKING:
|
||||
from typing import Self
|
||||
|
||||
except ImportError:
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
if platform.system() == 'Windows':
|
||||
import multiprocessing
|
||||
|
@ -61,7 +60,7 @@ class Config:
|
|||
|
||||
|
||||
def __init__(self, path: Path | None = None, load: bool = False):
|
||||
self.path = Config.get_config_dir(path)
|
||||
self.path: Path = Config.get_config_dir(path)
|
||||
self.reset()
|
||||
|
||||
if load:
|
||||
|
@ -81,7 +80,7 @@ class Config:
|
|||
def DEFAULT(cls: type[Self], key: str) -> str | int | None:
|
||||
for field in fields(cls):
|
||||
if field.name == key:
|
||||
return field.default # type: ignore
|
||||
return field.default # type: ignore[return-value]
|
||||
|
||||
raise KeyError(key)
|
||||
|
||||
|
@ -146,7 +145,7 @@ class Config:
|
|||
if not config:
|
||||
raise ValueError('Config is empty')
|
||||
|
||||
pgcfg = config.get('postgresql', {})
|
||||
pgcfg = config.get('postgres', {})
|
||||
rdcfg = config.get('redis', {})
|
||||
|
||||
for key in type(self).KEYS():
|
||||
|
|
|
@ -40,7 +40,7 @@ WHERE domain = :value or inbox = :value or actor = :value;
|
|||
|
||||
|
||||
-- name: get-request
|
||||
SELECT * FROM inboxes WHERE accepted = 0 and domain = :domain;
|
||||
SELECT * FROM inboxes WHERE accepted = false and domain = :domain;
|
||||
|
||||
|
||||
-- name: get-user
|
||||
|
@ -51,8 +51,8 @@ WHERE username = :value or handle = :value;
|
|||
-- name: get-user-by-token
|
||||
SELECT * FROM users
|
||||
WHERE username = (
|
||||
SELECT user FROM tokens
|
||||
WHERE code = :code
|
||||
SELECT user FROM apps
|
||||
WHERE token = :token
|
||||
);
|
||||
|
||||
|
||||
|
@ -64,28 +64,35 @@ RETURNING *;
|
|||
|
||||
-- name: del-user
|
||||
DELETE FROM users
|
||||
WHERE username = :value or handle = :value;
|
||||
WHERE username = :username or handle = :username;
|
||||
|
||||
|
||||
-- name: get-token
|
||||
SELECT * FROM tokens
|
||||
WHERE code = :code;
|
||||
-- name: get-app
|
||||
SELECT * FROM apps
|
||||
WHERE client_id = :id and client_secret = :secret;
|
||||
|
||||
|
||||
-- name: put-token
|
||||
INSERT INTO tokens (code, user, created)
|
||||
VALUES (:code, :user, :created)
|
||||
RETURNING *;
|
||||
-- name: get-app-with-token
|
||||
SELECT * FROM apps
|
||||
WHERE client_id = :id and client_secret = :secret and token = :token;
|
||||
|
||||
|
||||
-- name: del-token
|
||||
DELETE FROM tokens
|
||||
WHERE code = :code;
|
||||
-- name: get-app-by-token
|
||||
SELECT * FROM apps
|
||||
WHERE token = :token;
|
||||
|
||||
-- name: del-app
|
||||
DELETE FROM apps
|
||||
WHERE client_id = :id and client_secret = :secret;
|
||||
|
||||
|
||||
-- name: del-app-with-token
|
||||
DELETE FROM apps
|
||||
WHERE client_id = :id and client_secret = :secret and token = :token;
|
||||
|
||||
|
||||
-- name: del-token-user
|
||||
DELETE FROM tokens
|
||||
WHERE user = :username;
|
||||
DELETE FROM apps WHERE "user" = :username;
|
||||
|
||||
|
||||
-- name: get-software-ban
|
||||
|
|
|
@ -18,10 +18,12 @@ securityDefinitions:
|
|||
in: cookie
|
||||
name: user-token
|
||||
Bearer:
|
||||
type: apiKey
|
||||
type: oauth2
|
||||
name: Authorization
|
||||
in: header
|
||||
description: "Enter the token with the `Bearer ` prefix"
|
||||
flow: accessCode
|
||||
authorizationUrl: /oauth/authorize
|
||||
tokenUrl: /oauth/token
|
||||
|
||||
paths:
|
||||
/:
|
||||
|
@ -35,6 +37,161 @@ paths:
|
|||
schema:
|
||||
$ref: "#/definitions/Error"
|
||||
|
||||
/oauth/authorize:
|
||||
get:
|
||||
tags:
|
||||
- OAuth
|
||||
description: Get an authorization code
|
||||
parameters:
|
||||
- in: query
|
||||
name: response-type
|
||||
required: true
|
||||
type: string
|
||||
- in: query
|
||||
name: client_id
|
||||
required: true
|
||||
type: string
|
||||
- in: query
|
||||
name: redirect_uri
|
||||
required: true
|
||||
type: string
|
||||
|
||||
/oauth/token:
|
||||
post:
|
||||
tags:
|
||||
- OAuth
|
||||
description: Get a token for an authorized app
|
||||
parameters:
|
||||
- in: formData
|
||||
name: grant_type
|
||||
required: true
|
||||
type: string
|
||||
- in: formData
|
||||
name: code
|
||||
required: true
|
||||
type: string
|
||||
- in: formData
|
||||
name: client_id
|
||||
required: true
|
||||
type: string
|
||||
- in: formData
|
||||
name: client_secret
|
||||
required: true
|
||||
type: string
|
||||
- in: formData
|
||||
name: redirect_uri
|
||||
required: true
|
||||
type: string
|
||||
consumes:
|
||||
- application/x-www-form-urlencoded
|
||||
- application/json
|
||||
- multipart/form-data
|
||||
produces:
|
||||
- application/json
|
||||
responses:
|
||||
"200":
|
||||
description: Application
|
||||
schema:
|
||||
$ref: "#/definitions/Application"
|
||||
|
||||
/oauth/revoke:
|
||||
post:
|
||||
tags:
|
||||
- OAuth
|
||||
description: Get a token for an authorized app
|
||||
parameters:
|
||||
- in: formData
|
||||
name: client_id
|
||||
required: true
|
||||
type: string
|
||||
- in: formData
|
||||
name: client_secret
|
||||
required: true
|
||||
type: string
|
||||
- in: formData
|
||||
name: token
|
||||
required: true
|
||||
type: string
|
||||
consumes:
|
||||
- application/json
|
||||
- multipart/form-data
|
||||
- application/x-www-form-urlencoded
|
||||
produces:
|
||||
- application/json
|
||||
responses:
|
||||
"200":
|
||||
description: Message confirming application deletion
|
||||
schema:
|
||||
$ref: "#/definitions/Message"
|
||||
|
||||
/v1/app:
|
||||
get:
|
||||
tags:
|
||||
- Applications
|
||||
description: Verify the token is valid
|
||||
produces:
|
||||
- application/json
|
||||
responses:
|
||||
"200":
|
||||
description: Application with the associated token
|
||||
schema:
|
||||
$ref: "#/definitions/Application"
|
||||
|
||||
post:
|
||||
tags:
|
||||
- Applications
|
||||
description: Create a new application
|
||||
parameters:
|
||||
- in: query
|
||||
name: name
|
||||
required: true
|
||||
type: string
|
||||
- in: query
|
||||
name: redirect_uri
|
||||
required: true
|
||||
type: string
|
||||
- in: query
|
||||
name: website
|
||||
required: false
|
||||
type: string
|
||||
format: url
|
||||
consumes:
|
||||
- application/json
|
||||
- multipart/form-data
|
||||
- application/x-www-form-urlencoded
|
||||
produces:
|
||||
- application/json
|
||||
responses:
|
||||
"200":
|
||||
description: Newly created application
|
||||
schema:
|
||||
$ref: "#/definitions/Application"
|
||||
|
||||
delete:
|
||||
tags:
|
||||
- Applications
|
||||
description: Deletes an application
|
||||
parameters:
|
||||
- in: formData
|
||||
name: client_id
|
||||
required: true
|
||||
type: string
|
||||
- in: formData
|
||||
name: client_secret
|
||||
required: true
|
||||
type: string
|
||||
consumes:
|
||||
- application/json
|
||||
- multipart/form-data
|
||||
- application/x-www-form-urlencoded
|
||||
produces:
|
||||
- application/json
|
||||
responses:
|
||||
"200":
|
||||
description: Confirmation of application deletion
|
||||
schema:
|
||||
$ref: "#/definitions/Message"
|
||||
|
||||
/v1/relay:
|
||||
get:
|
||||
tags:
|
||||
|
@ -48,23 +205,11 @@ paths:
|
|||
schema:
|
||||
$ref: "#/definitions/Info"
|
||||
|
||||
/v1/token:
|
||||
get:
|
||||
tags:
|
||||
- Token
|
||||
description: Verify API token
|
||||
produces:
|
||||
- application/json
|
||||
responses:
|
||||
"200":
|
||||
description: Valid token
|
||||
schema:
|
||||
$ref: "#/definitions/Message"
|
||||
|
||||
/v1/login:
|
||||
post:
|
||||
tags:
|
||||
- Token
|
||||
description: Get a new token
|
||||
- Login
|
||||
description: Login with a username and password
|
||||
parameters:
|
||||
- in: formData
|
||||
name: username
|
||||
|
@ -74,7 +219,6 @@ paths:
|
|||
name: password
|
||||
required: true
|
||||
type: string
|
||||
format: password
|
||||
consumes:
|
||||
- application/json
|
||||
- multipart/form-data
|
||||
|
@ -83,22 +227,9 @@ paths:
|
|||
- application/json
|
||||
responses:
|
||||
"200":
|
||||
description: Created token
|
||||
description: A new Application
|
||||
schema:
|
||||
$ref: "#/definitions/Token"
|
||||
|
||||
|
||||
delete:
|
||||
tags:
|
||||
- Token
|
||||
description: Revoke a token
|
||||
produces:
|
||||
- application/json
|
||||
responses:
|
||||
"200":
|
||||
description: Revoked token
|
||||
schema:
|
||||
$ref: "#/definitions/Message"
|
||||
$ref: "#/definitions/Application"
|
||||
|
||||
/v1/config:
|
||||
get:
|
||||
|
@ -731,9 +862,43 @@ definitions:
|
|||
description: Human-readable message text
|
||||
type: string
|
||||
|
||||
Application:
|
||||
type: object
|
||||
properties:
|
||||
client_id:
|
||||
description: Identifier for the application
|
||||
type: string
|
||||
client_secret:
|
||||
description: Secret string for the application
|
||||
type: string
|
||||
name:
|
||||
description: Human-readable name of the application
|
||||
type: string
|
||||
website:
|
||||
description: Website for the application
|
||||
type: string
|
||||
format: url
|
||||
redirect_uri:
|
||||
description: URL to redirect to when authorizing an app
|
||||
type: string
|
||||
token:
|
||||
description: String to use in the Authorization header for client requests
|
||||
type: string
|
||||
created:
|
||||
description: Date the application was created
|
||||
type: string
|
||||
format: date-time
|
||||
accessed:
|
||||
description: Date the application was last used
|
||||
type: string
|
||||
format: date-time
|
||||
|
||||
Config:
|
||||
type: object
|
||||
properties:
|
||||
approval-required:
|
||||
description: Require instances to be approved when following
|
||||
type: bool
|
||||
log-level:
|
||||
description: Maximum level of log messages to print to the console
|
||||
type: string
|
||||
|
@ -743,6 +908,9 @@ definitions:
|
|||
note:
|
||||
description: Blurb to display on the home page
|
||||
type: string
|
||||
theme:
|
||||
description: Name of the color scheme to use for the frontend
|
||||
type: string
|
||||
whitelist-enabled:
|
||||
description: Only allow specific instances to join the relay when enabled
|
||||
type: boolean
|
||||
|
@ -843,13 +1011,6 @@ definitions:
|
|||
type: string
|
||||
format: date-time
|
||||
|
||||
Token:
|
||||
type: object
|
||||
properties:
|
||||
token:
|
||||
description: Character string used for authenticating with the api
|
||||
type: string
|
||||
|
||||
User:
|
||||
type: object
|
||||
properties:
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
import sqlite3
|
||||
|
||||
from blib import Date, File
|
||||
from bsql import Database
|
||||
|
||||
from .config import THEMES, ConfigData
|
||||
|
@ -6,7 +9,9 @@ from .schema import TABLES, VERSIONS, migrate_0
|
|||
|
||||
from .. import logger as logging
|
||||
from ..config import Config
|
||||
from ..misc import get_resource
|
||||
|
||||
|
||||
sqlite3.register_adapter(Date, Date.timestamp)
|
||||
|
||||
|
||||
def get_database(config: Config, migrate: bool = True) -> Database[Connection]:
|
||||
|
@ -16,6 +21,8 @@ def get_database(config: Config, migrate: bool = True) -> Database[Connection]:
|
|||
'tables': TABLES
|
||||
}
|
||||
|
||||
db: Database[Connection]
|
||||
|
||||
if config.db_type == 'sqlite':
|
||||
db = Database.sqlite(config.sqlite_path, **options)
|
||||
|
||||
|
@ -29,7 +36,7 @@ def get_database(config: Config, migrate: bool = True) -> Database[Connection]:
|
|||
**options
|
||||
)
|
||||
|
||||
db.load_prepared_statements(get_resource('data/statements.sql'))
|
||||
db.load_prepared_statements(File.from_resource('relay', 'data/statements.sql'))
|
||||
db.connect()
|
||||
|
||||
if not migrate:
|
||||
|
|
|
@ -2,20 +2,17 @@ from __future__ import annotations
|
|||
# removing the above line turns annotations into types instead of str objects which messes with
|
||||
# `Field.type`
|
||||
|
||||
from blib import convert_to_boolean
|
||||
from bsql import Row
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import Field, asdict, dataclass, fields
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from .. import logger as logging
|
||||
from ..misc import boolean
|
||||
|
||||
try:
|
||||
if TYPE_CHECKING:
|
||||
from typing import Self
|
||||
|
||||
except ImportError:
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
THEMES = {
|
||||
'default': {
|
||||
|
@ -69,14 +66,14 @@ THEMES = {
|
|||
CONFIG_CONVERT: dict[str, tuple[Callable[[Any], str], Callable[[str], Any]]] = {
|
||||
'str': (str, str),
|
||||
'int': (str, int),
|
||||
'bool': (str, boolean),
|
||||
'bool': (str, convert_to_boolean),
|
||||
'logging.LogLevel': (lambda x: x.name, logging.LogLevel.parse)
|
||||
}
|
||||
|
||||
|
||||
@dataclass()
|
||||
class ConfigData:
|
||||
schema_version: int = 20240310
|
||||
schema_version: int = 20240625
|
||||
private_key: str = ''
|
||||
approval_required: bool = False
|
||||
log_level: logging.LogLevel = logging.LogLevel.INFO
|
||||
|
@ -114,11 +111,11 @@ class ConfigData:
|
|||
|
||||
@classmethod
|
||||
def DEFAULT(cls: type[Self], key: str) -> str | int | bool:
|
||||
return cls.FIELD(key.replace('-', '_')).default # type: ignore
|
||||
return cls.FIELD(key.replace('-', '_')).default # type: ignore[return-value]
|
||||
|
||||
|
||||
@classmethod
|
||||
def FIELD(cls: type[Self], key: str) -> Field[Any]:
|
||||
def FIELD(cls: type[Self], key: str) -> Field[str | int | bool]:
|
||||
for field in fields(cls):
|
||||
if field.name == key.replace('-', '_'):
|
||||
return field
|
||||
|
|
|
@ -1,20 +1,23 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
|
||||
from argon2 import PasswordHasher
|
||||
from bsql import Connection as SqlConnection, Row, Update
|
||||
from collections.abc import Iterator, Sequence
|
||||
from blib import Date, convert_to_boolean
|
||||
from bsql import BackendType, Connection as SqlConnection, Row, Update
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
from . import schema
|
||||
from .config import (
|
||||
THEMES,
|
||||
ConfigData
|
||||
)
|
||||
|
||||
from .. import logger as logging
|
||||
from ..misc import Message, boolean, get_app
|
||||
from ..misc import Message, get_app
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..application import Application
|
||||
|
@ -37,22 +40,63 @@ class Connection(SqlConnection):
|
|||
return get_app()
|
||||
|
||||
|
||||
def distill_inboxes(self, message: Message) -> Iterator[Row]:
|
||||
def distill_inboxes(self, message: Message) -> Iterator[schema.Instance]:
|
||||
src_domains = {
|
||||
message.domain,
|
||||
urlparse(message.object_id).netloc
|
||||
}
|
||||
|
||||
for instance in self.get_inboxes():
|
||||
if instance['domain'] not in src_domains:
|
||||
if instance.domain not in src_domains:
|
||||
yield instance
|
||||
|
||||
|
||||
def drop_tables(self) -> None:
|
||||
with self.cursor() as cur:
|
||||
for table in self.get_tables():
|
||||
query = f"DROP TABLE IF EXISTS {table}"
|
||||
|
||||
if self.database.backend.backend_type == BackendType.POSTGRESQL:
|
||||
query += " CASCADE"
|
||||
|
||||
cur.execute(query)
|
||||
|
||||
|
||||
def fix_timestamps(self) -> None:
|
||||
for app in self.select('apps').all(schema.App):
|
||||
data = {'created': app.created.timestamp(), 'accessed': app.accessed.timestamp()}
|
||||
self.update('apps', data, client_id = app.client_id)
|
||||
|
||||
for item in self.select('cache'):
|
||||
data = {'updated': Date.parse(item['updated']).timestamp()}
|
||||
self.update('cache', data, id = item['id'])
|
||||
|
||||
for dban in self.select('domain_bans').all(schema.DomainBan):
|
||||
data = {'created': dban.created.timestamp()}
|
||||
self.update('domain_bans', data, domain = dban.domain)
|
||||
|
||||
for instance in self.select('inboxes').all(schema.Instance):
|
||||
data = {'created': instance.created.timestamp()}
|
||||
self.update('inboxes', data, domain = instance.domain)
|
||||
|
||||
for sban in self.select('software_bans').all(schema.SoftwareBan):
|
||||
data = {'created': sban.created.timestamp()}
|
||||
self.update('software_bans', data, name = sban.name)
|
||||
|
||||
for user in self.select('users').all(schema.User):
|
||||
data = {'created': user.created.timestamp()}
|
||||
self.update('users', data, username = user.username)
|
||||
|
||||
for wlist in self.select('whitelist').all(schema.Whitelist):
|
||||
data = {'created': wlist.created.timestamp()}
|
||||
self.update('whitelist', data, domain = wlist.domain)
|
||||
|
||||
|
||||
def get_config(self, key: str) -> Any:
|
||||
key = key.replace('_', '-')
|
||||
|
||||
with self.run('get-config', {'key': key}) as cur:
|
||||
if not (row := cur.one()):
|
||||
if (row := cur.one(Row)) is None:
|
||||
return ConfigData.DEFAULT(key)
|
||||
|
||||
data = ConfigData()
|
||||
|
@ -61,8 +105,8 @@ class Connection(SqlConnection):
|
|||
|
||||
|
||||
def get_config_all(self) -> ConfigData:
|
||||
with self.run('get-config-all', None) as cur:
|
||||
return ConfigData.from_rows(tuple(cur.all()))
|
||||
rows = tuple(self.run('get-config-all', None).all(schema.Row))
|
||||
return ConfigData.from_rows(rows)
|
||||
|
||||
|
||||
def put_config(self, key: str, value: Any) -> Any:
|
||||
|
@ -75,9 +119,10 @@ class Connection(SqlConnection):
|
|||
elif key == 'log-level':
|
||||
value = logging.LogLevel.parse(value)
|
||||
logging.set_level(value)
|
||||
self.app['workers'].set_log_level(value)
|
||||
|
||||
elif key in {'approval-required', 'whitelist-enabled'}:
|
||||
value = boolean(value)
|
||||
value = convert_to_boolean(value)
|
||||
|
||||
elif key == 'theme':
|
||||
if value not in THEMES:
|
||||
|
@ -98,23 +143,23 @@ class Connection(SqlConnection):
|
|||
return data.get(key)
|
||||
|
||||
|
||||
def get_inbox(self, value: str) -> Row:
|
||||
def get_inbox(self, value: str) -> schema.Instance | None:
|
||||
with self.run('get-inbox', {'value': value}) as cur:
|
||||
return cur.one() # type: ignore
|
||||
return cur.one(schema.Instance)
|
||||
|
||||
|
||||
def get_inboxes(self) -> Sequence[Row]:
|
||||
with self.execute("SELECT * FROM inboxes WHERE accepted = 1") as cur:
|
||||
return tuple(cur.all())
|
||||
def get_inboxes(self) -> Iterator[schema.Instance]:
|
||||
return self.execute("SELECT * FROM inboxes WHERE accepted = true").all(schema.Instance)
|
||||
|
||||
|
||||
def put_inbox(self,
|
||||
# todo: check if software is different than stored row
|
||||
def put_inbox(self, # noqa: E301
|
||||
domain: str,
|
||||
inbox: str | None = None,
|
||||
actor: str | None = None,
|
||||
followid: str | None = None,
|
||||
software: str | None = None,
|
||||
accepted: bool = True) -> Row:
|
||||
accepted: bool = True) -> schema.Instance:
|
||||
|
||||
params: dict[str, Any] = {
|
||||
'inbox': inbox,
|
||||
|
@ -124,7 +169,7 @@ class Connection(SqlConnection):
|
|||
'accepted': accepted
|
||||
}
|
||||
|
||||
if not self.get_inbox(domain):
|
||||
if self.get_inbox(domain) is None:
|
||||
if not inbox:
|
||||
raise ValueError("Missing inbox")
|
||||
|
||||
|
@ -132,14 +177,20 @@ class Connection(SqlConnection):
|
|||
params['created'] = datetime.now(tz = timezone.utc)
|
||||
|
||||
with self.run('put-inbox', params) as cur:
|
||||
return cur.one() # type: ignore
|
||||
if (row := cur.one(schema.Instance)) is None:
|
||||
raise RuntimeError(f"Failed to insert instance: {domain}")
|
||||
|
||||
return row
|
||||
|
||||
for key, value in tuple(params.items()):
|
||||
if value is None:
|
||||
del params[key]
|
||||
|
||||
with self.update('inboxes', params, domain = domain) as cur:
|
||||
return cur.one() # type: ignore
|
||||
if (row := cur.one(schema.Instance)) is None:
|
||||
raise RuntimeError(f"Failed to update instance: {domain}")
|
||||
|
||||
return row
|
||||
|
||||
|
||||
def del_inbox(self, value: str) -> bool:
|
||||
|
@ -150,24 +201,23 @@ class Connection(SqlConnection):
|
|||
return cur.row_count == 1
|
||||
|
||||
|
||||
def get_request(self, domain: str) -> Row:
|
||||
def get_request(self, domain: str) -> schema.Instance | None:
|
||||
with self.run('get-request', {'domain': domain}) as cur:
|
||||
if not (row := cur.one()):
|
||||
raise KeyError(domain)
|
||||
|
||||
return row
|
||||
return cur.one(schema.Instance)
|
||||
|
||||
|
||||
def get_requests(self) -> Sequence[Row]:
|
||||
with self.execute('SELECT * FROM inboxes WHERE accepted = 0') as cur:
|
||||
return tuple(cur.all())
|
||||
def get_requests(self) -> Iterator[schema.Instance]:
|
||||
return self.execute('SELECT * FROM inboxes WHERE accepted = false').all(schema.Instance)
|
||||
|
||||
|
||||
def put_request_response(self, domain: str, accepted: bool) -> Row:
|
||||
instance = self.get_request(domain)
|
||||
def put_request_response(self, domain: str, accepted: bool) -> schema.Instance:
|
||||
if (instance := self.get_request(domain)) is None:
|
||||
raise KeyError(domain)
|
||||
|
||||
if not accepted:
|
||||
self.del_inbox(domain)
|
||||
if not self.del_inbox(domain):
|
||||
raise RuntimeError(f'Failed to delete request: {domain}')
|
||||
|
||||
return instance
|
||||
|
||||
params = {
|
||||
|
@ -176,21 +226,28 @@ class Connection(SqlConnection):
|
|||
}
|
||||
|
||||
with self.run('put-inbox-accept', params) as cur:
|
||||
return cur.one() # type: ignore
|
||||
if (row := cur.one(schema.Instance)) is None:
|
||||
raise RuntimeError(f"Failed to insert response for domain: {domain}")
|
||||
|
||||
return row
|
||||
|
||||
|
||||
def get_user(self, value: str) -> Row:
|
||||
def get_user(self, value: str) -> schema.User | None:
|
||||
with self.run('get-user', {'value': value}) as cur:
|
||||
return cur.one() # type: ignore
|
||||
return cur.one(schema.User)
|
||||
|
||||
|
||||
def get_user_by_token(self, code: str) -> Row:
|
||||
with self.run('get-user-by-token', {'code': code}) as cur:
|
||||
return cur.one() # type: ignore
|
||||
def get_user_by_token(self, token: str) -> schema.User | None:
|
||||
with self.run('get-user-by-token', {'token': token}) as cur:
|
||||
return cur.one(schema.User)
|
||||
|
||||
|
||||
def put_user(self, username: str, password: str | None, handle: str | None = None) -> Row:
|
||||
if self.get_user(username):
|
||||
def get_users(self) -> Iterator[schema.User]:
|
||||
return self.execute("SELECT * FROM users").all(schema.User)
|
||||
|
||||
|
||||
def put_user(self, username: str, password: str | None, handle: str | None = None) -> schema.User:
|
||||
if self.get_user(username) is not None:
|
||||
data: dict[str, str | datetime | None] = {}
|
||||
|
||||
if password:
|
||||
|
@ -203,7 +260,10 @@ class Connection(SqlConnection):
|
|||
stmt.set_where("username", username)
|
||||
|
||||
with self.query(stmt) as cur:
|
||||
return cur.one() # type: ignore
|
||||
if (row := cur.one(schema.User)) is None:
|
||||
raise RuntimeError(f"Failed to update user: {username}")
|
||||
|
||||
return row
|
||||
|
||||
if password is None:
|
||||
raise ValueError('Password cannot be empty')
|
||||
|
@ -216,52 +276,149 @@ class Connection(SqlConnection):
|
|||
}
|
||||
|
||||
with self.run('put-user', data) as cur:
|
||||
return cur.one() # type: ignore
|
||||
if (row := cur.one(schema.User)) is None:
|
||||
raise RuntimeError(f"Failed to insert user: {username}")
|
||||
|
||||
return row
|
||||
|
||||
|
||||
def del_user(self, username: str) -> None:
|
||||
user = self.get_user(username)
|
||||
if (user := self.get_user(username)) is None:
|
||||
raise KeyError(username)
|
||||
|
||||
with self.run('del-user', {'value': user['username']}):
|
||||
with self.run('del-token-user', {'username': user.username}):
|
||||
pass
|
||||
|
||||
with self.run('del-token-user', {'username': user['username']}):
|
||||
with self.run('del-user', {'username': user.username}):
|
||||
pass
|
||||
|
||||
|
||||
def get_token(self, code: str) -> Row:
|
||||
with self.run('get-token', {'code': code}) as cur:
|
||||
return cur.one() # type: ignore
|
||||
def get_app(self,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
token: str | None = None) -> schema.App | None:
|
||||
|
||||
|
||||
def put_token(self, username: str) -> Row:
|
||||
data = {
|
||||
'code': uuid4().hex,
|
||||
'user': username,
|
||||
'created': datetime.now(tz = timezone.utc)
|
||||
params = {
|
||||
'id': client_id,
|
||||
'secret': client_secret
|
||||
}
|
||||
|
||||
with self.run('put-token', data) as cur:
|
||||
return cur.one() # type: ignore
|
||||
if token is not None:
|
||||
command = 'get-app-with-token'
|
||||
params['token'] = token
|
||||
|
||||
else:
|
||||
command = 'get-app'
|
||||
|
||||
with self.run(command, params) as cur:
|
||||
return cur.one(schema.App)
|
||||
|
||||
|
||||
def del_token(self, code: str) -> None:
|
||||
with self.run('del-token', {'code': code}):
|
||||
pass
|
||||
def get_app_by_token(self, token: str) -> schema.App | None:
|
||||
with self.run('get-app-by-token', {'token': token}) as cur:
|
||||
return cur.one(schema.App)
|
||||
|
||||
|
||||
def get_domain_ban(self, domain: str) -> Row:
|
||||
def put_app(self, name: str, redirect_uri: str, website: str | None = None) -> schema.App:
|
||||
params = {
|
||||
'name': name,
|
||||
'redirect_uri': redirect_uri,
|
||||
'website': website,
|
||||
'client_id': secrets.token_hex(20),
|
||||
'client_secret': secrets.token_hex(20),
|
||||
'created': Date.new_utc(),
|
||||
'accessed': Date.new_utc()
|
||||
}
|
||||
|
||||
with self.insert('apps', params) as cur:
|
||||
if (row := cur.one(schema.App)) is None:
|
||||
raise RuntimeError(f'Failed to insert app: {name}')
|
||||
|
||||
return row
|
||||
|
||||
|
||||
def put_app_login(self, user: schema.User) -> schema.App:
|
||||
params = {
|
||||
'name': 'Web',
|
||||
'redirect_uri': 'urn:ietf:wg:oauth:2.0:oob',
|
||||
'website': None,
|
||||
'user': user.username,
|
||||
'client_id': secrets.token_hex(20),
|
||||
'client_secret': secrets.token_hex(20),
|
||||
'auth_code': None,
|
||||
'token': secrets.token_hex(20),
|
||||
'created': Date.new_utc(),
|
||||
'accessed': Date.new_utc()
|
||||
}
|
||||
|
||||
with self.insert('apps', params) as cur:
|
||||
if (row := cur.one(schema.App)) is None:
|
||||
raise RuntimeError(f'Failed to create app for "{user.username}"')
|
||||
|
||||
return row
|
||||
|
||||
|
||||
def update_app(self, app: schema.App, user: schema.User | None, set_auth: bool) -> schema.App:
|
||||
data: dict[str, str | None] = {}
|
||||
|
||||
if user is not None:
|
||||
data['user'] = user.username
|
||||
|
||||
if set_auth:
|
||||
data['auth_code'] = secrets.token_hex(20)
|
||||
|
||||
else:
|
||||
data['token'] = secrets.token_hex(20)
|
||||
data['auth_code'] = None
|
||||
|
||||
params = {
|
||||
'client_id': app.client_id,
|
||||
'client_secret': app.client_secret
|
||||
}
|
||||
|
||||
with self.update('apps', data, **params) as cur: # type: ignore[arg-type]
|
||||
if (row := cur.one(schema.App)) is None:
|
||||
raise RuntimeError('Failed to update row')
|
||||
|
||||
return row
|
||||
|
||||
|
||||
def del_app(self, client_id: str, client_secret: str, token: str | None = None) -> bool:
|
||||
params = {
|
||||
'id': client_id,
|
||||
'secret': client_secret
|
||||
}
|
||||
|
||||
if token is not None:
|
||||
command = 'del-app-with-token'
|
||||
params['token'] = token
|
||||
|
||||
else:
|
||||
command = 'del-app'
|
||||
|
||||
with self.run(command, params) as cur:
|
||||
if cur.row_count > 1:
|
||||
raise RuntimeError('More than 1 row was deleted')
|
||||
|
||||
return cur.row_count == 0
|
||||
|
||||
|
||||
def get_domain_ban(self, domain: str) -> schema.DomainBan | None:
|
||||
if domain.startswith('http'):
|
||||
domain = urlparse(domain).netloc
|
||||
|
||||
with self.run('get-domain-ban', {'domain': domain}) as cur:
|
||||
return cur.one() # type: ignore
|
||||
return cur.one(schema.DomainBan)
|
||||
|
||||
|
||||
def get_domain_bans(self) -> Iterator[schema.DomainBan]:
|
||||
return self.execute("SELECT * FROM domain_bans").all(schema.DomainBan)
|
||||
|
||||
|
||||
def put_domain_ban(self,
|
||||
domain: str,
|
||||
reason: str | None = None,
|
||||
note: str | None = None) -> Row:
|
||||
note: str | None = None) -> schema.DomainBan:
|
||||
|
||||
params = {
|
||||
'domain': domain,
|
||||
|
@ -271,13 +428,16 @@ class Connection(SqlConnection):
|
|||
}
|
||||
|
||||
with self.run('put-domain-ban', params) as cur:
|
||||
return cur.one() # type: ignore
|
||||
if (row := cur.one(schema.DomainBan)) is None:
|
||||
raise RuntimeError(f"Failed to insert domain ban: {domain}")
|
||||
|
||||
return row
|
||||
|
||||
|
||||
def update_domain_ban(self,
|
||||
domain: str,
|
||||
reason: str | None = None,
|
||||
note: str | None = None) -> Row:
|
||||
note: str | None = None) -> schema.DomainBan:
|
||||
|
||||
if not (reason or note):
|
||||
raise ValueError('"reason" and/or "note" must be specified')
|
||||
|
@ -297,7 +457,10 @@ class Connection(SqlConnection):
|
|||
if cur.row_count > 1:
|
||||
raise ValueError('More than one row was modified')
|
||||
|
||||
return self.get_domain_ban(domain)
|
||||
if (row := cur.one(schema.DomainBan)) is None:
|
||||
raise RuntimeError(f"Failed to update domain ban: {domain}")
|
||||
|
||||
return row
|
||||
|
||||
|
||||
def del_domain_ban(self, domain: str) -> bool:
|
||||
|
@ -308,15 +471,19 @@ class Connection(SqlConnection):
|
|||
return cur.row_count == 1
|
||||
|
||||
|
||||
def get_software_ban(self, name: str) -> Row:
|
||||
def get_software_ban(self, name: str) -> schema.SoftwareBan | None:
|
||||
with self.run('get-software-ban', {'name': name}) as cur:
|
||||
return cur.one() # type: ignore
|
||||
return cur.one(schema.SoftwareBan)
|
||||
|
||||
|
||||
def get_software_bans(self) -> Iterator[schema.SoftwareBan,]:
|
||||
return self.execute('SELECT * FROM software_bans').all(schema.SoftwareBan)
|
||||
|
||||
|
||||
def put_software_ban(self,
|
||||
name: str,
|
||||
reason: str | None = None,
|
||||
note: str | None = None) -> Row:
|
||||
note: str | None = None) -> schema.SoftwareBan:
|
||||
|
||||
params = {
|
||||
'name': name,
|
||||
|
@ -326,13 +493,16 @@ class Connection(SqlConnection):
|
|||
}
|
||||
|
||||
with self.run('put-software-ban', params) as cur:
|
||||
return cur.one() # type: ignore
|
||||
if (row := cur.one(schema.SoftwareBan)) is None:
|
||||
raise RuntimeError(f'Failed to insert software ban: {name}')
|
||||
|
||||
return row
|
||||
|
||||
|
||||
def update_software_ban(self,
|
||||
name: str,
|
||||
reason: str | None = None,
|
||||
note: str | None = None) -> Row:
|
||||
note: str | None = None) -> schema.SoftwareBan:
|
||||
|
||||
if not (reason or note):
|
||||
raise ValueError('"reason" and/or "note" must be specified')
|
||||
|
@ -352,7 +522,10 @@ class Connection(SqlConnection):
|
|||
if cur.row_count > 1:
|
||||
raise ValueError('More than one row was modified')
|
||||
|
||||
return self.get_software_ban(name)
|
||||
if (row := cur.one(schema.SoftwareBan)) is None:
|
||||
raise RuntimeError(f'Failed to update software ban: {name}')
|
||||
|
||||
return row
|
||||
|
||||
|
||||
def del_software_ban(self, name: str) -> bool:
|
||||
|
@ -363,19 +536,26 @@ class Connection(SqlConnection):
|
|||
return cur.row_count == 1
|
||||
|
||||
|
||||
def get_domain_whitelist(self, domain: str) -> Row:
|
||||
def get_domain_whitelist(self, domain: str) -> schema.Whitelist | None:
|
||||
with self.run('get-domain-whitelist', {'domain': domain}) as cur:
|
||||
return cur.one() # type: ignore
|
||||
return cur.one()
|
||||
|
||||
|
||||
def put_domain_whitelist(self, domain: str) -> Row:
|
||||
def get_domains_whitelist(self) -> Iterator[schema.Whitelist,]:
|
||||
return self.execute("SELECT * FROM whitelist").all(schema.Whitelist)
|
||||
|
||||
|
||||
def put_domain_whitelist(self, domain: str) -> schema.Whitelist:
|
||||
params = {
|
||||
'domain': domain,
|
||||
'created': datetime.now(tz = timezone.utc)
|
||||
}
|
||||
|
||||
with self.run('put-domain-whitelist', params) as cur:
|
||||
return cur.one() # type: ignore
|
||||
if (row := cur.one(schema.Whitelist)) is None:
|
||||
raise RuntimeError(f'Failed to insert whitelisted domain: {domain}')
|
||||
|
||||
return row
|
||||
|
||||
|
||||
def del_domain_whitelist(self, domain: str) -> bool:
|
||||
|
|
|
@ -1,61 +1,133 @@
|
|||
from bsql import Column, Table, Tables
|
||||
from __future__ import annotations
|
||||
|
||||
from blib import Date
|
||||
from bsql import Column, Row, Tables
|
||||
from collections.abc import Callable
|
||||
from copy import deepcopy
|
||||
from datetime import timezone
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from .config import ConfigData
|
||||
from .connection import Connection
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .connection import Connection
|
||||
|
||||
|
||||
VERSIONS: dict[int, Callable[[Connection], None]] = {}
|
||||
TABLES: Tables = Tables(
|
||||
Table(
|
||||
'config',
|
||||
Column('key', 'text', primary_key = True, unique = True, nullable = False),
|
||||
Column('value', 'text'),
|
||||
Column('type', 'text', default = 'str')
|
||||
),
|
||||
Table(
|
||||
'inboxes',
|
||||
Column('domain', 'text', primary_key = True, unique = True, nullable = False),
|
||||
Column('actor', 'text', unique = True),
|
||||
Column('inbox', 'text', unique = True, nullable = False),
|
||||
Column('followid', 'text'),
|
||||
Column('software', 'text'),
|
||||
Column('accepted', 'boolean'),
|
||||
Column('created', 'timestamp', nullable = False)
|
||||
),
|
||||
Table(
|
||||
'whitelist',
|
||||
Column('domain', 'text', primary_key = True, unique = True, nullable = True),
|
||||
Column('created', 'timestamp')
|
||||
),
|
||||
Table(
|
||||
'domain_bans',
|
||||
Column('domain', 'text', primary_key = True, unique = True, nullable = True),
|
||||
Column('reason', 'text'),
|
||||
Column('note', 'text'),
|
||||
Column('created', 'timestamp', nullable = False)
|
||||
),
|
||||
Table(
|
||||
'software_bans',
|
||||
Column('name', 'text', primary_key = True, unique = True, nullable = True),
|
||||
Column('reason', 'text'),
|
||||
Column('note', 'text'),
|
||||
Column('created', 'timestamp', nullable = False)
|
||||
),
|
||||
Table(
|
||||
'users',
|
||||
Column('username', 'text', primary_key = True, unique = True, nullable = False),
|
||||
Column('hash', 'text', nullable = False),
|
||||
Column('handle', 'text'),
|
||||
Column('created', 'timestamp', nullable = False)
|
||||
),
|
||||
Table(
|
||||
'tokens',
|
||||
Column('code', 'text', primary_key = True, unique = True, nullable = False),
|
||||
Column('user', 'text', nullable = False),
|
||||
Column('created', 'timestmap', nullable = False)
|
||||
)
|
||||
)
|
||||
TABLES = Tables()
|
||||
|
||||
|
||||
def deserialize_timestamp(value: Any) -> Date:
|
||||
try:
|
||||
date = Date.parse(value)
|
||||
|
||||
except ValueError:
|
||||
date = Date.fromisoformat(value)
|
||||
|
||||
if date.tzinfo is None:
|
||||
date = date.replace(tzinfo = timezone.utc)
|
||||
|
||||
return date
|
||||
|
||||
|
||||
@TABLES.add_row
|
||||
class Config(Row):
|
||||
key: Column[str] = Column('key', 'text', primary_key = True, unique = True, nullable = False)
|
||||
value: Column[str] = Column('value', 'text')
|
||||
type: Column[str] = Column('type', 'text', default = 'str')
|
||||
|
||||
|
||||
@TABLES.add_row
|
||||
class Instance(Row):
|
||||
table_name: str = 'inboxes'
|
||||
|
||||
|
||||
domain: Column[str] = Column(
|
||||
'domain', 'text', primary_key = True, unique = True, nullable = False)
|
||||
actor: Column[str] = Column('actor', 'text', unique = True)
|
||||
inbox: Column[str] = Column('inbox', 'text', unique = True, nullable = False)
|
||||
followid: Column[str] = Column('followid', 'text')
|
||||
software: Column[str] = Column('software', 'text')
|
||||
accepted: Column[Date] = Column('accepted', 'boolean')
|
||||
created: Column[Date] = Column(
|
||||
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
|
||||
|
||||
|
||||
@TABLES.add_row
|
||||
class Whitelist(Row):
|
||||
domain: Column[str] = Column(
|
||||
'domain', 'text', primary_key = True, unique = True, nullable = True)
|
||||
created: Column[Date] = Column(
|
||||
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
|
||||
|
||||
|
||||
@TABLES.add_row
|
||||
class DomainBan(Row):
|
||||
table_name: str = 'domain_bans'
|
||||
|
||||
|
||||
domain: Column[str] = Column(
|
||||
'domain', 'text', primary_key = True, unique = True, nullable = True)
|
||||
reason: Column[str] = Column('reason', 'text')
|
||||
note: Column[str] = Column('note', 'text')
|
||||
created: Column[Date] = Column(
|
||||
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
|
||||
|
||||
|
||||
@TABLES.add_row
|
||||
class SoftwareBan(Row):
|
||||
table_name: str = 'software_bans'
|
||||
|
||||
|
||||
name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True)
|
||||
reason: Column[str] = Column('reason', 'text')
|
||||
note: Column[str] = Column('note', 'text')
|
||||
created: Column[Date] = Column(
|
||||
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
|
||||
|
||||
|
||||
@TABLES.add_row
|
||||
class User(Row):
|
||||
table_name: str = 'users'
|
||||
|
||||
|
||||
username: Column[str] = Column(
|
||||
'username', 'text', primary_key = True, unique = True, nullable = False)
|
||||
hash: Column[str] = Column('hash', 'text', nullable = False)
|
||||
handle: Column[str] = Column('handle', 'text')
|
||||
created: Column[Date] = Column(
|
||||
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
|
||||
|
||||
|
||||
@TABLES.add_row
|
||||
class App(Row):
|
||||
table_name: str = 'apps'
|
||||
|
||||
|
||||
client_id: Column[str] = Column(
|
||||
'client_id', 'text', primary_key = True, unique = True, nullable = False)
|
||||
client_secret: Column[str] = Column('client_secret', 'text', nullable = False)
|
||||
name: Column[str] = Column('name', 'text')
|
||||
website: Column[str] = Column('website', 'text')
|
||||
redirect_uri: Column[str] = Column('redirect_uri', 'text', nullable = False)
|
||||
token: Column[str | None] = Column('token', 'text')
|
||||
auth_code: Column[str | None] = Column('auth_code', 'text')
|
||||
user: Column[str | None] = Column('user', 'text')
|
||||
created: Column[Date] = Column(
|
||||
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
|
||||
accessed: Column[Date] = Column(
|
||||
'accessed', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
|
||||
|
||||
|
||||
def get_api_data(self, include_token: bool = False) -> dict[str, Any]:
|
||||
data = deepcopy(self)
|
||||
data.pop('user')
|
||||
data.pop('auth_code')
|
||||
|
||||
if not include_token:
|
||||
data.pop('token')
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]:
|
||||
|
@ -76,5 +148,11 @@ def migrate_20240206(conn: Connection) -> None:
|
|||
|
||||
@migration
|
||||
def migrate_20240310(conn: Connection) -> None:
|
||||
conn.execute("ALTER TABLE inboxes ADD COLUMN accepted BOOLEAN")
|
||||
conn.execute("UPDATE inboxes SET accepted = 1")
|
||||
conn.execute('ALTER TABLE "inboxes" ADD COLUMN "accepted" BOOLEAN').close()
|
||||
conn.execute('UPDATE "inboxes" SET "accepted" = true').close()
|
||||
|
||||
|
||||
@migration
|
||||
def migrate_20240625(conn: Connection) -> None:
|
||||
conn.create_tables()
|
||||
conn.execute('DROP TABLE "tokens"').close()
|
||||
|
|
2
relay/errors.py
Normal file
2
relay/errors.py
Normal file
|
@ -0,0 +1,2 @@
|
|||
class EmptyBodyError(Exception):
|
||||
pass
|
|
@ -1,5 +1,5 @@
|
|||
-macro menu_item(name, path)
|
||||
-if view.request.path == path or (path != "/" and view.request.path.startswith(path))
|
||||
-if request.path == path or (path != "/" and request.path.startswith(path))
|
||||
%a.button(href="{{path}}" active="true") -> =name
|
||||
|
||||
-else
|
||||
|
@ -11,11 +11,11 @@
|
|||
%title << {{config.name}}: {{page}}
|
||||
%meta(charset="UTF-8")
|
||||
%meta(name="viewport" content="width=device-width, initial-scale=1")
|
||||
%link(rel="stylesheet" type="text/css" href="/theme/{{config.theme}}.css" nonce="{{view.request['hash']}}" class="theme")
|
||||
%link(rel="stylesheet" type="text/css" href="/static/style.css" nonce="{{view.request['hash']}}")
|
||||
%link(rel="stylesheet" type="text/css" href="/static/bootstrap-icons.css" nonce="{{view.request['hash']}}")
|
||||
%link(rel="manifest" href="/manifest.json")
|
||||
%script(type="application/javascript" src="/static/api.js" nonce="{{view.request['hash']}}" defer)
|
||||
%link(rel="stylesheet" type="text/css" href="/theme/{{config.theme}}.css?{{version}}" nonce="{{request['hash']}}" class="theme")
|
||||
%link(rel="stylesheet" type="text/css" href="/static/style.css?{{version}}" nonce="{{request['hash']}}")
|
||||
%link(rel="stylesheet" type="text/css" href="/static/bootstrap-icons.css?{{version}}" nonce="{{request['hash']}}")
|
||||
%link(rel="manifest" href="/manifest.json?{{version}}")
|
||||
%script(type="application/javascript" src="/static/functions.js?{{version}}" nonce="{{request['hash']}}" defer)
|
||||
-block head
|
||||
|
||||
%body
|
||||
|
@ -26,7 +26,7 @@
|
|||
|
||||
{{menu_item("Home", "/")}}
|
||||
|
||||
-if view.request["user"]
|
||||
-if request["user"]
|
||||
{{menu_item("Instances", "/admin/instances")}}
|
||||
{{menu_item("Whitelist", "/admin/whitelist")}}
|
||||
{{menu_item("Domain Bans", "/admin/domain_bans")}}
|
||||
|
@ -61,11 +61,11 @@
|
|||
|
||||
#footer.section
|
||||
.col1
|
||||
-if not view.request["user"]
|
||||
-if not request["user"]
|
||||
%a(href="/login") << Login
|
||||
|
||||
-else
|
||||
=view.request["user"]["username"]
|
||||
=request["user"]["username"]
|
||||
(
|
||||
%a(href="/logout") << Logout
|
||||
)
|
||||
|
|
|
@ -1,29 +1,32 @@
|
|||
-extends "base.haml"
|
||||
-set page="Config"
|
||||
|
||||
-block head
|
||||
%script(type="application/javascript" src="/static/config.js" nonce="{{view.request['hash']}}" defer)
|
||||
|
||||
-import "functions.haml" as func
|
||||
|
||||
-block content
|
||||
%fieldset.section
|
||||
%legend << Config
|
||||
|
||||
.grid-2col
|
||||
%label(for="name") << Name
|
||||
%i(class="bi bi-question-circle-fill" title="{{desc.name}}")
|
||||
%input(id = "name" placeholder="Relay Name" value="{{config.name or ''}}")
|
||||
|
||||
%label(for="note") << Description
|
||||
%i(class="bi bi-question-circle-fill" title="{{desc.note}}")
|
||||
%textarea(id="note" value="{{config.note or ''}}") << {{config.note}}
|
||||
|
||||
%label(for="theme") << Color Theme
|
||||
%i(class="bi bi-question-circle-fill" title="{{desc.theme}}")
|
||||
=func.new_select("theme", config.theme, themes)
|
||||
|
||||
%label(for="log-level") << Log Level
|
||||
%i(class="bi bi-question-circle-fill" title="{{desc.log_level}}")
|
||||
=func.new_select("log-level", config.log_level.name, levels)
|
||||
|
||||
%label(for="whitelist-enabled") << Whitelist
|
||||
%i(class="bi bi-question-circle-fill" title="{{desc.whitelist_enabled}}")
|
||||
=func.new_checkbox("whitelist-enabled", config.whitelist_enabled)
|
||||
|
||||
%label(for="approval-required") << Approval Required
|
||||
%i(class="bi bi-question-circle-fill" title="{{desc.approval_required}}")
|
||||
=func.new_checkbox("approval-required", config.approval_required)
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
-extends "base.haml"
|
||||
-set page="Domain Bans"
|
||||
|
||||
-block head
|
||||
%script(type="application/javascript" src="/static/domain_ban.js" nonce="{{view.request['hash']}}" defer)
|
||||
|
||||
-block content
|
||||
%details.section
|
||||
%summary << Ban Domain
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
-extends "base.haml"
|
||||
-set page="Instances"
|
||||
|
||||
-block head
|
||||
%script(type="application/javascript" src="/static/instance.js" nonce="{{view.request['hash']}}" defer)
|
||||
|
||||
-block content
|
||||
%details.section
|
||||
%summary << Add Instance
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
-extends "base.haml"
|
||||
-set page="Software Bans"
|
||||
|
||||
-block head
|
||||
%script(type="application/javascript" src="/static/software_ban.js" nonce="{{view.request['hash']}}" defer)
|
||||
|
||||
-block content
|
||||
%details.section
|
||||
%summary << Ban Software
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
-extends "base.haml"
|
||||
-set page="Users"
|
||||
|
||||
-block head
|
||||
%script(type="application/javascript" src="/static/user.js" nonce="{{view.request['hash']}}" defer)
|
||||
|
||||
-block content
|
||||
%details.section
|
||||
%summary << Add User
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
-extends "base.haml"
|
||||
-set page="Whitelist"
|
||||
|
||||
-block head
|
||||
%script(type="application/javascript" src="/static/whitelist.js" nonce="{{view.request['hash']}}" defer)
|
||||
|
||||
-block content
|
||||
%details.section
|
||||
%summary << Add Domain
|
||||
|
|
31
relay/frontend/page/authorize_new.haml
Normal file
31
relay/frontend/page/authorize_new.haml
Normal file
|
@ -0,0 +1,31 @@
|
|||
-extends "base.haml"
|
||||
-set page="App Authorization"
|
||||
|
||||
-block content
|
||||
%fieldset.section
|
||||
%legend << App Authorization
|
||||
|
||||
-if application.website
|
||||
#title << Application "<a href="{{application.website}}" target="_new">{{application.name}}</a>" wants full API access
|
||||
|
||||
-else
|
||||
#title << Application "{{application.name}}" wants full API access
|
||||
|
||||
#buttons
|
||||
.spacer
|
||||
|
||||
%form(action="/oauth/authorize" method="POST")
|
||||
%input(type="hidden" name="client_id" value="{{application.client_id}}")
|
||||
%input(type="hidden" name="client_secret" value="{{application.client_secret}}")
|
||||
%input(type="hidden" name="redirect_uri" value="{{application.redirect_uri}}")
|
||||
%input(type="hidden" name="response" value="true")
|
||||
%input.button(type="submit" value="Allow")
|
||||
|
||||
%form(action="/oauth/authorize" method="POST")
|
||||
%input(type="hidden" name="client_id" value="{{application.client_id}}")
|
||||
%input(type="hidden" name="client_secret" value="{{application.client_secret}}")
|
||||
%input(type="hidden" name="redirect_uri" value="{{application.redirect_uri}}")
|
||||
%input(type="hidden" name="response" value="false")
|
||||
%input.button(type="submit" value="Deny")
|
||||
|
||||
.spacer
|
18
relay/frontend/page/authorize_show.haml
Normal file
18
relay/frontend/page/authorize_show.haml
Normal file
|
@ -0,0 +1,18 @@
|
|||
-extends "base.haml"
|
||||
-set page="App Authorization"
|
||||
|
||||
-block content
|
||||
%fieldset.section
|
||||
%legend << App Authorization Code
|
||||
|
||||
-if application.website
|
||||
%p
|
||||
Copy the following code into
|
||||
%a(href="{{application.website}}" target="_main") -> %code -> =application.name
|
||||
|
||||
-else
|
||||
%p
|
||||
Copy the following code info
|
||||
%code -> =application.name
|
||||
|
||||
%pre#code -> =application.auth_code
|
7
relay/frontend/page/error.haml
Normal file
7
relay/frontend/page/error.haml
Normal file
|
@ -0,0 +1,7 @@
|
|||
-extends "base.haml"
|
||||
-set page="Error"
|
||||
|
||||
-block content
|
||||
.section.error
|
||||
.title << HTTP Error {{e.status}}
|
||||
.body -> =e.message
|
|
@ -1,5 +1,6 @@
|
|||
-extends "base.haml"
|
||||
-set page = "Home"
|
||||
|
||||
-block content
|
||||
-if config.note
|
||||
.section
|
||||
|
@ -14,9 +15,7 @@
|
|||
%a(href="https://{{domain}}/actor") << https://{{domain}}/actor</a>
|
||||
|
||||
-if config.approval_required
|
||||
%fieldset.section.message
|
||||
%legend << Require Approval
|
||||
|
||||
%div.section.message
|
||||
Follow requests require approval. You will need to wait for an admin to accept or deny
|
||||
your request.
|
||||
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
-extends "base.haml"
|
||||
-set page="Login"
|
||||
|
||||
-block head
|
||||
%script(type="application/javascript" src="/static/login.js" nonce="{{view.request['hash']}}" defer)
|
||||
|
||||
-block content
|
||||
%fieldset.section
|
||||
%legend << Login
|
||||
|
@ -15,4 +12,6 @@
|
|||
%label(for="password") << Password
|
||||
%input(id="password" name="password" placeholder="Password" type="password")
|
||||
|
||||
|
||||
%input#redir(type="hidden" name="redir" value="{{redir}}")
|
||||
%input.submit(type="button" value="Login")
|
||||
|
|
|
@ -1,132 +0,0 @@
|
|||
// toast notifications
|
||||
|
||||
const notifications = document.querySelector("#notifications")
|
||||
|
||||
|
||||
function remove_toast(toast) {
|
||||
toast.classList.add("hide");
|
||||
|
||||
if (toast.timeoutId) {
|
||||
clearTimeout(toast.timeoutId);
|
||||
}
|
||||
|
||||
setTimeout(() => toast.remove(), 300);
|
||||
}
|
||||
|
||||
function toast(text, type="error", timeout=5) {
|
||||
const toast = document.createElement("li");
|
||||
toast.className = `section ${type}`
|
||||
toast.innerHTML = `<span class=".text">${text}</span><a href="#">✖</span>`
|
||||
|
||||
toast.querySelector("a").addEventListener("click", async (event) => {
|
||||
event.preventDefault();
|
||||
await remove_toast(toast);
|
||||
});
|
||||
|
||||
notifications.appendChild(toast);
|
||||
toast.timeoutId = setTimeout(() => remove_toast(toast), timeout * 1000);
|
||||
}
|
||||
|
||||
|
||||
// menu
|
||||
|
||||
const body = document.getElementById("container")
|
||||
const menu = document.getElementById("menu");
|
||||
const menu_open = document.querySelector("#menu-open i");
|
||||
const menu_close = document.getElementById("menu-close");
|
||||
|
||||
|
||||
function toggle_menu() {
|
||||
let new_value = menu.attributes.visible.nodeValue === "true" ? "false" : "true";
|
||||
menu.attributes.visible.nodeValue = new_value;
|
||||
}
|
||||
|
||||
|
||||
menu_open.addEventListener("click", toggle_menu);
|
||||
menu_close.addEventListener("click", toggle_menu);
|
||||
|
||||
body.addEventListener("click", (event) => {
|
||||
if (event.target === menu_open) {
|
||||
return;
|
||||
}
|
||||
|
||||
menu.attributes.visible.nodeValue = "false";
|
||||
});
|
||||
|
||||
for (const elem of document.querySelectorAll("#menu-open div")) {
|
||||
elem.addEventListener("click", toggle_menu);
|
||||
}
|
||||
|
||||
|
||||
// misc
|
||||
|
||||
function get_date_string(date) {
|
||||
var year = date.getUTCFullYear().toString();
|
||||
var month = (date.getUTCMonth() + 1).toString().padStart(2, "0");
|
||||
var day = date.getUTCDate().toString().padStart(2, "0");
|
||||
|
||||
return `${year}-${month}-${day}`;
|
||||
}
|
||||
|
||||
|
||||
function append_table_row(table, row_name, row) {
|
||||
var table_row = table.insertRow(-1);
|
||||
table_row.id = row_name;
|
||||
|
||||
index = 0;
|
||||
|
||||
for (var prop in row) {
|
||||
if (Object.prototype.hasOwnProperty.call(row, prop)) {
|
||||
var cell = table_row.insertCell(index);
|
||||
cell.className = prop;
|
||||
cell.innerHTML = row[prop];
|
||||
|
||||
index += 1;
|
||||
}
|
||||
}
|
||||
|
||||
return table_row;
|
||||
}
|
||||
|
||||
|
||||
async function request(method, path, body = null) {
|
||||
var headers = {
|
||||
"Accept": "application/json"
|
||||
}
|
||||
|
||||
if (body !== null) {
|
||||
headers["Content-Type"] = "application/json"
|
||||
body = JSON.stringify(body)
|
||||
}
|
||||
|
||||
const response = await fetch("/api/" + path, {
|
||||
method: method,
|
||||
mode: "cors",
|
||||
cache: "no-store",
|
||||
redirect: "follow",
|
||||
body: body,
|
||||
headers: headers
|
||||
});
|
||||
|
||||
const message = await response.json();
|
||||
|
||||
if (Object.hasOwn(message, "error")) {
|
||||
throw new Error(message.error);
|
||||
}
|
||||
|
||||
if (Array.isArray(message)) {
|
||||
message.forEach((msg) => {
|
||||
if (Object.hasOwn(msg, "created")) {
|
||||
msg.created = new Date(msg.created);
|
||||
}
|
||||
});
|
||||
|
||||
} else {
|
||||
if (Object.hasOwn(message, "created")) {
|
||||
console.log(message.created)
|
||||
message.created = new Date(message.created);
|
||||
}
|
||||
}
|
||||
|
||||
return message;
|
||||
}
|
|
@ -1,40 +0,0 @@
|
|||
const elems = [
|
||||
document.querySelector("#name"),
|
||||
document.querySelector("#note"),
|
||||
document.querySelector("#theme"),
|
||||
document.querySelector("#log-level"),
|
||||
document.querySelector("#whitelist-enabled"),
|
||||
document.querySelector("#approval-required")
|
||||
]
|
||||
|
||||
|
||||
async function handle_config_change(event) {
|
||||
params = {
|
||||
key: event.target.id,
|
||||
value: event.target.type === "checkbox" ? event.target.checked : event.target.value
|
||||
}
|
||||
|
||||
try {
|
||||
await request("POST", "v1/config", params);
|
||||
|
||||
} catch (error) {
|
||||
toast(error);
|
||||
return;
|
||||
}
|
||||
|
||||
if (params.key === "name") {
|
||||
document.querySelector("#header .title").innerHTML = params.value;
|
||||
document.querySelector("title").innerHTML = params.value;
|
||||
}
|
||||
|
||||
if (params.key === "theme") {
|
||||
document.querySelector("link.theme").href = `/theme/${params.value}.css`;
|
||||
}
|
||||
|
||||
toast("Updated config", "message");
|
||||
}
|
||||
|
||||
|
||||
for (const elem of elems) {
|
||||
elem.addEventListener("change", handle_config_change);
|
||||
}
|
|
@ -1,123 +0,0 @@
|
|||
function create_ban_object(domain, reason, note) {
|
||||
var text = '<details>\n';
|
||||
text += `<summary>${domain}</summary>\n`;
|
||||
text += '<div class="grid-2col">\n';
|
||||
text += `<label for="${domain}-reason" class="reason">Reason</label>\n`;
|
||||
text += `<textarea id="${domain}-reason" class="reason">${reason}</textarea>\n`;
|
||||
text += `<label for="${domain}-note" class="note">Note</label>\n`;
|
||||
text += `<textarea id="${domain}-note" class="note">${note}</textarea>\n`;
|
||||
text += `<input class="update-ban" type="button" value="Update">`;
|
||||
text += '</details>';
|
||||
|
||||
return text;
|
||||
}
|
||||
|
||||
|
||||
function add_row_listeners(row) {
|
||||
row.querySelector(".update-ban").addEventListener("click", async (event) => {
|
||||
await update_ban(row.id);
|
||||
});
|
||||
|
||||
row.querySelector(".remove a").addEventListener("click", async (event) => {
|
||||
event.preventDefault();
|
||||
await unban(row.id);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
async function ban() {
|
||||
var table = document.querySelector("table");
|
||||
var elems = {
|
||||
domain: document.getElementById("new-domain"),
|
||||
reason: document.getElementById("new-reason"),
|
||||
note: document.getElementById("new-note")
|
||||
}
|
||||
|
||||
var values = {
|
||||
domain: elems.domain.value.trim(),
|
||||
reason: elems.reason.value.trim(),
|
||||
note: elems.note.value.trim()
|
||||
}
|
||||
|
||||
if (values.domain === "") {
|
||||
toast("Domain is required");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
var ban = await request("POST", "v1/domain_ban", values);
|
||||
|
||||
} catch (err) {
|
||||
toast(err);
|
||||
return
|
||||
}
|
||||
|
||||
var row = append_table_row(document.querySelector("table"), ban.domain, {
|
||||
domain: create_ban_object(ban.domain, ban.reason, ban.note),
|
||||
date: get_date_string(ban.created),
|
||||
remove: `<a href="#" title="Unban domain">✖</a>`
|
||||
});
|
||||
|
||||
add_row_listeners(row);
|
||||
|
||||
elems.domain.value = null;
|
||||
elems.reason.value = null;
|
||||
elems.note.value = null;
|
||||
|
||||
document.querySelector("details.section").open = false;
|
||||
toast("Banned domain", "message");
|
||||
}
|
||||
|
||||
|
||||
async function update_ban(domain) {
|
||||
var row = document.getElementById(domain);
|
||||
|
||||
var elems = {
|
||||
"reason": row.querySelector("textarea.reason"),
|
||||
"note": row.querySelector("textarea.note")
|
||||
}
|
||||
|
||||
var values = {
|
||||
"domain": domain,
|
||||
"reason": elems.reason.value,
|
||||
"note": elems.note.value
|
||||
}
|
||||
|
||||
try {
|
||||
await request("PATCH", "v1/domain_ban", values)
|
||||
|
||||
} catch (error) {
|
||||
toast(error);
|
||||
return;
|
||||
}
|
||||
|
||||
row.querySelector("details").open = false;
|
||||
toast("Updated baned domain", "message");
|
||||
}
|
||||
|
||||
|
||||
async function unban(domain) {
|
||||
try {
|
||||
await request("DELETE", "v1/domain_ban", {"domain": domain});
|
||||
|
||||
} catch (error) {
|
||||
toast(error);
|
||||
return;
|
||||
}
|
||||
|
||||
document.getElementById(domain).remove();
|
||||
toast("Unbanned domain", "message");
|
||||
}
|
||||
|
||||
|
||||
document.querySelector("#new-ban").addEventListener("click", async (event) => {
|
||||
await ban();
|
||||
});
|
||||
|
||||
for (var row of document.querySelector("fieldset.section table").rows) {
|
||||
if (!row.querySelector(".update-ban")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
add_row_listeners(row);
|
||||
}
|
864
relay/frontend/static/functions.js
Normal file
864
relay/frontend/static/functions.js
Normal file
|
@ -0,0 +1,864 @@
|
|||
// toast notifications
|
||||
|
||||
const notifications = document.querySelector("#notifications")
|
||||
|
||||
|
||||
function remove_toast(toast) {
|
||||
toast.classList.add("hide");
|
||||
|
||||
if (toast.timeoutId) {
|
||||
clearTimeout(toast.timeoutId);
|
||||
}
|
||||
|
||||
setTimeout(() => toast.remove(), 300);
|
||||
}
|
||||
|
||||
function toast(text, type="error", timeout=5) {
|
||||
const toast = document.createElement("li");
|
||||
toast.className = `section ${type}`
|
||||
toast.innerHTML = `<span class=".text">${text}</span><a href="#">✖</span>`
|
||||
|
||||
toast.querySelector("a").addEventListener("click", async (event) => {
|
||||
event.preventDefault();
|
||||
await remove_toast(toast);
|
||||
});
|
||||
|
||||
notifications.appendChild(toast);
|
||||
toast.timeoutId = setTimeout(() => remove_toast(toast), timeout * 1000);
|
||||
}
|
||||
|
||||
|
||||
// menu
|
||||
|
||||
const body = document.getElementById("container")
|
||||
const menu = document.getElementById("menu");
|
||||
const menu_open = document.querySelector("#menu-open i");
|
||||
const menu_close = document.getElementById("menu-close");
|
||||
|
||||
|
||||
function toggle_menu() {
|
||||
let new_value = menu.attributes.visible.nodeValue === "true" ? "false" : "true";
|
||||
menu.attributes.visible.nodeValue = new_value;
|
||||
}
|
||||
|
||||
|
||||
menu_open.addEventListener("click", toggle_menu);
|
||||
menu_close.addEventListener("click", toggle_menu);
|
||||
|
||||
body.addEventListener("click", (event) => {
|
||||
if (event.target === menu_open) {
|
||||
return;
|
||||
}
|
||||
|
||||
menu.attributes.visible.nodeValue = "false";
|
||||
});
|
||||
|
||||
for (const elem of document.querySelectorAll("#menu-open div")) {
|
||||
elem.addEventListener("click", toggle_menu);
|
||||
}
|
||||
|
||||
|
||||
// misc
|
||||
|
||||
function get_date_string(date) {
|
||||
var year = date.getUTCFullYear().toString();
|
||||
var month = (date.getUTCMonth() + 1).toString().padStart(2, "0");
|
||||
var day = date.getUTCDate().toString().padStart(2, "0");
|
||||
|
||||
return `${year}-${month}-${day}`;
|
||||
}
|
||||
|
||||
|
||||
function append_table_row(table, row_name, row) {
|
||||
var table_row = table.insertRow(-1);
|
||||
table_row.id = row_name;
|
||||
|
||||
index = 0;
|
||||
|
||||
for (var prop in row) {
|
||||
if (Object.prototype.hasOwnProperty.call(row, prop)) {
|
||||
var cell = table_row.insertCell(index);
|
||||
cell.className = prop;
|
||||
cell.innerHTML = row[prop];
|
||||
|
||||
index += 1;
|
||||
}
|
||||
}
|
||||
|
||||
return table_row;
|
||||
}
|
||||
|
||||
|
||||
async function request(method, path, body = null) {
|
||||
var headers = {
|
||||
"Accept": "application/json"
|
||||
}
|
||||
|
||||
if (body !== null) {
|
||||
headers["Content-Type"] = "application/json"
|
||||
body = JSON.stringify(body)
|
||||
}
|
||||
|
||||
const response = await fetch("/api/" + path, {
|
||||
method: method,
|
||||
mode: "cors",
|
||||
cache: "no-store",
|
||||
redirect: "follow",
|
||||
body: body,
|
||||
headers: headers
|
||||
});
|
||||
|
||||
const message = await response.json();
|
||||
|
||||
if (Object.hasOwn(message, "error")) {
|
||||
throw new Error(message.error);
|
||||
}
|
||||
|
||||
if (Array.isArray(message)) {
|
||||
message.forEach((msg) => {
|
||||
if (Object.hasOwn(msg, "created")) {
|
||||
msg.created = new Date(msg.created);
|
||||
}
|
||||
});
|
||||
|
||||
} else {
|
||||
if (Object.hasOwn(message, "created")) {
|
||||
message.created = new Date(message.created);
|
||||
}
|
||||
}
|
||||
|
||||
return message;
|
||||
}
|
||||
|
||||
// page functions
|
||||
|
||||
function page_config() {
|
||||
const elems = [
|
||||
document.querySelector("#name"),
|
||||
document.querySelector("#note"),
|
||||
document.querySelector("#theme"),
|
||||
document.querySelector("#log-level"),
|
||||
document.querySelector("#whitelist-enabled"),
|
||||
document.querySelector("#approval-required")
|
||||
]
|
||||
|
||||
|
||||
async function handle_config_change(event) {
|
||||
params = {
|
||||
key: event.target.id,
|
||||
value: event.target.type === "checkbox" ? event.target.checked : event.target.value
|
||||
}
|
||||
|
||||
try {
|
||||
await request("POST", "v1/config", params);
|
||||
|
||||
} catch (error) {
|
||||
toast(error);
|
||||
return;
|
||||
}
|
||||
|
||||
if (params.key === "name") {
|
||||
document.querySelector("#header .title").innerHTML = params.value;
|
||||
document.querySelector("title").innerHTML = params.value;
|
||||
}
|
||||
|
||||
if (params.key === "theme") {
|
||||
document.querySelector("link.theme").href = `/theme/${params.value}.css`;
|
||||
}
|
||||
|
||||
toast("Updated config", "message");
|
||||
}
|
||||
|
||||
|
||||
document.querySelector("#name").addEventListener("keydown", async (event) => {
|
||||
if (event.which === 13) {
|
||||
await handle_config_change(event);
|
||||
}
|
||||
});
|
||||
|
||||
document.querySelector("#note").addEventListener("keydown", async (event) => {
|
||||
if (event.which === 13 && event.ctrlKey) {
|
||||
await handle_config_change(event);
|
||||
}
|
||||
});
|
||||
|
||||
for (const elem of elems) {
|
||||
elem.addEventListener("change", handle_config_change);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
function page_domain_ban() {
|
||||
function create_ban_object(domain, reason, note) {
|
||||
var text = '<details>\n';
|
||||
text += `<summary>${domain}</summary>\n`;
|
||||
text += '<div class="grid-2col">\n';
|
||||
text += `<label for="${domain}-reason" class="reason">Reason</label>\n`;
|
||||
text += `<textarea id="${domain}-reason" class="reason">${reason}</textarea>\n`;
|
||||
text += `<label for="${domain}-note" class="note">Note</label>\n`;
|
||||
text += `<textarea id="${domain}-note" class="note">${note}</textarea>\n`;
|
||||
text += `<input class="update-ban" type="button" value="Update">`;
|
||||
text += '</details>';
|
||||
|
||||
return text;
|
||||
}
|
||||
|
||||
|
||||
function add_row_listeners(row) {
|
||||
row.querySelector(".update-ban").addEventListener("click", async (event) => {
|
||||
await update_ban(row.id);
|
||||
});
|
||||
|
||||
row.querySelector(".remove a").addEventListener("click", async (event) => {
|
||||
event.preventDefault();
|
||||
await unban(row.id);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
async function ban() {
|
||||
var table = document.querySelector("table");
|
||||
var elems = {
|
||||
domain: document.getElementById("new-domain"),
|
||||
reason: document.getElementById("new-reason"),
|
||||
note: document.getElementById("new-note")
|
||||
}
|
||||
|
||||
var values = {
|
||||
domain: elems.domain.value.trim(),
|
||||
reason: elems.reason.value.trim(),
|
||||
note: elems.note.value.trim()
|
||||
}
|
||||
|
||||
if (values.domain === "") {
|
||||
toast("Domain is required");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
var ban = await request("POST", "v1/domain_ban", values);
|
||||
|
||||
} catch (err) {
|
||||
toast(err);
|
||||
return
|
||||
}
|
||||
|
||||
var row = append_table_row(document.querySelector("table"), ban.domain, {
|
||||
domain: create_ban_object(ban.domain, ban.reason, ban.note),
|
||||
date: get_date_string(ban.created),
|
||||
remove: `<a href="#" title="Unban domain">✖</a>`
|
||||
});
|
||||
|
||||
add_row_listeners(row);
|
||||
|
||||
elems.domain.value = null;
|
||||
elems.reason.value = null;
|
||||
elems.note.value = null;
|
||||
|
||||
document.querySelector("details.section").open = false;
|
||||
toast("Banned domain", "message");
|
||||
}
|
||||
|
||||
|
||||
async function update_ban(domain) {
|
||||
var row = document.getElementById(domain);
|
||||
|
||||
var elems = {
|
||||
"reason": row.querySelector("textarea.reason"),
|
||||
"note": row.querySelector("textarea.note")
|
||||
}
|
||||
|
||||
var values = {
|
||||
"domain": domain,
|
||||
"reason": elems.reason.value,
|
||||
"note": elems.note.value
|
||||
}
|
||||
|
||||
try {
|
||||
await request("PATCH", "v1/domain_ban", values)
|
||||
|
||||
} catch (error) {
|
||||
toast(error);
|
||||
return;
|
||||
}
|
||||
|
||||
row.querySelector("details").open = false;
|
||||
toast("Updated baned domain", "message");
|
||||
}
|
||||
|
||||
|
||||
async function unban(domain) {
|
||||
try {
|
||||
await request("DELETE", "v1/domain_ban", {"domain": domain});
|
||||
|
||||
} catch (error) {
|
||||
toast(error);
|
||||
return;
|
||||
}
|
||||
|
||||
document.getElementById(domain).remove();
|
||||
toast("Unbanned domain", "message");
|
||||
}
|
||||
|
||||
|
||||
document.querySelector("#new-ban").addEventListener("click", async (event) => {
|
||||
await ban();
|
||||
});
|
||||
|
||||
for (var elem of document.querySelectorAll("#add-item input")) {
|
||||
elem.addEventListener("keydown", async (event) => {
|
||||
if (event.which === 13) {
|
||||
await ban();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
for (var row of document.querySelector("fieldset.section table").rows) {
|
||||
if (!row.querySelector(".update-ban")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
add_row_listeners(row);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
function page_instance() {
|
||||
function add_instance_listeners(row) {
|
||||
row.querySelector(".remove a").addEventListener("click", async (event) => {
|
||||
event.preventDefault();
|
||||
await del_instance(row.id);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
function add_request_listeners(row) {
|
||||
row.querySelector(".approve a").addEventListener("click", async (event) => {
|
||||
event.preventDefault();
|
||||
await req_response(row.id, true);
|
||||
});
|
||||
|
||||
row.querySelector(".deny a").addEventListener("click", async (event) => {
|
||||
event.preventDefault();
|
||||
await req_response(row.id, false);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
async function add_instance() {
|
||||
var elems = {
|
||||
actor: document.getElementById("new-actor"),
|
||||
inbox: document.getElementById("new-inbox"),
|
||||
followid: document.getElementById("new-followid"),
|
||||
software: document.getElementById("new-software")
|
||||
}
|
||||
|
||||
var values = {
|
||||
actor: elems.actor.value.trim(),
|
||||
inbox: elems.inbox.value.trim(),
|
||||
followid: elems.followid.value.trim(),
|
||||
software: elems.software.value.trim()
|
||||
}
|
||||
|
||||
if (values.actor === "") {
|
||||
toast("Actor is required");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
var instance = await request("POST", "v1/instance", values);
|
||||
|
||||
} catch (err) {
|
||||
toast(err);
|
||||
return
|
||||
}
|
||||
|
||||
row = append_table_row(document.getElementById("instances"), instance.domain, {
|
||||
domain: `<a href="https://${instance.domain}/" target="_new">${instance.domain}</a>`,
|
||||
software: instance.software,
|
||||
date: get_date_string(instance.created),
|
||||
remove: `<a href="#" title="Remove Instance">✖</a>`
|
||||
});
|
||||
|
||||
add_instance_listeners(row);
|
||||
|
||||
elems.actor.value = null;
|
||||
elems.inbox.value = null;
|
||||
elems.followid.value = null;
|
||||
elems.software.value = null;
|
||||
|
||||
document.querySelector("details.section").open = false;
|
||||
toast("Added instance", "message");
|
||||
}
|
||||
|
||||
|
||||
async function del_instance(domain) {
|
||||
try {
|
||||
await request("DELETE", "v1/instance", {"domain": domain});
|
||||
|
||||
} catch (error) {
|
||||
toast(error);
|
||||
return;
|
||||
}
|
||||
|
||||
document.getElementById(domain).remove();
|
||||
}
|
||||
|
||||
|
||||
async function req_response(domain, accept) {
|
||||
params = {
|
||||
"domain": domain,
|
||||
"accept": accept
|
||||
}
|
||||
|
||||
try {
|
||||
await request("POST", "v1/request", params);
|
||||
|
||||
} catch (error) {
|
||||
toast(error);
|
||||
return;
|
||||
}
|
||||
|
||||
document.getElementById(domain).remove();
|
||||
|
||||
if (document.getElementById("requests").rows.length < 2) {
|
||||
document.querySelector("fieldset.requests").remove()
|
||||
}
|
||||
|
||||
if (!accept) {
|
||||
toast("Denied instance request", "message");
|
||||
return;
|
||||
}
|
||||
|
||||
instances = await request("GET", `v1/instance`, null);
|
||||
instances.forEach((instance) => {
|
||||
if (instance.domain === domain) {
|
||||
row = append_table_row(document.getElementById("instances"), instance.domain, {
|
||||
domain: `<a href="https://${instance.domain}/" target="_new">${instance.domain}</a>`,
|
||||
software: instance.software,
|
||||
date: get_date_string(instance.created),
|
||||
remove: `<a href="#" title="Remove Instance">✖</a>`
|
||||
});
|
||||
|
||||
add_instance_listeners(row);
|
||||
}
|
||||
});
|
||||
|
||||
toast("Accepted instance request", "message");
|
||||
}
|
||||
|
||||
|
||||
document.querySelector("#add-instance").addEventListener("click", async (event) => {
|
||||
await add_instance();
|
||||
})
|
||||
|
||||
for (var elem of document.querySelectorAll("#add-item input")) {
|
||||
elem.addEventListener("keydown", async (event) => {
|
||||
if (event.which === 13) {
|
||||
await add_instance();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
for (var row of document.querySelector("#instances").rows) {
|
||||
if (!row.querySelector(".remove a")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
add_instance_listeners(row);
|
||||
}
|
||||
|
||||
if (document.querySelector("#requests")) {
|
||||
for (var row of document.querySelector("#requests").rows) {
|
||||
if (!row.querySelector(".approve a")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
add_request_listeners(row);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
function page_login() {
|
||||
const fields = {
|
||||
username: document.querySelector("#username"),
|
||||
password: document.querySelector("#password"),
|
||||
redir: document.querySelector("#redir")
|
||||
};
|
||||
|
||||
async function login(event) {
|
||||
const values = {
|
||||
username: fields.username.value.trim(),
|
||||
password: fields.password.value.trim(),
|
||||
redir: fields.redir.value.trim()
|
||||
}
|
||||
|
||||
if (values.username === "" | values.password === "") {
|
||||
toast("Username and/or password field is blank");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
await request("POST", "v1/login", values);
|
||||
|
||||
} catch (error) {
|
||||
toast(error);
|
||||
return;
|
||||
}
|
||||
|
||||
document.location = values.redir;
|
||||
}
|
||||
|
||||
|
||||
document.querySelector("#username").addEventListener("keydown", async (event) => {
|
||||
if (event.which === 13) {
|
||||
fields.password.focus();
|
||||
fields.password.select();
|
||||
}
|
||||
});
|
||||
|
||||
document.querySelector("#password").addEventListener("keydown", async (event) => {
|
||||
if (event.which === 13) {
|
||||
await login(event);
|
||||
}
|
||||
});
|
||||
|
||||
document.querySelector(".submit").addEventListener("click", login);
|
||||
}
|
||||
|
||||
|
||||
function page_software_ban() {
|
||||
function create_ban_object(name, reason, note) {
|
||||
var text = '<details>\n';
|
||||
text += `<summary>${name}</summary>\n`;
|
||||
text += '<div class="grid-2col">\n';
|
||||
text += `<label for="${name}-reason" class="reason">Reason</label>\n`;
|
||||
text += `<textarea id="${name}-reason" class="reason">${reason}</textarea>\n`;
|
||||
text += `<label for="${name}-note" class="note">Note</label>\n`;
|
||||
text += `<textarea id="${name}-note" class="note">${note}</textarea>\n`;
|
||||
text += `<input class="update-ban" type="button" value="Update">`;
|
||||
text += '</details>';
|
||||
|
||||
return text;
|
||||
}
|
||||
|
||||
|
||||
function add_row_listeners(row) {
|
||||
row.querySelector(".update-ban").addEventListener("click", async (event) => {
|
||||
await update_ban(row.id);
|
||||
});
|
||||
|
||||
row.querySelector(".remove a").addEventListener("click", async (event) => {
|
||||
event.preventDefault();
|
||||
await unban(row.id);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
async function ban() {
|
||||
var elems = {
|
||||
name: document.getElementById("new-name"),
|
||||
reason: document.getElementById("new-reason"),
|
||||
note: document.getElementById("new-note")
|
||||
}
|
||||
|
||||
var values = {
|
||||
name: elems.name.value.trim(),
|
||||
reason: elems.reason.value,
|
||||
note: elems.note.value
|
||||
}
|
||||
|
||||
if (values.name === "") {
|
||||
toast("Domain is required");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
var ban = await request("POST", "v1/software_ban", values);
|
||||
|
||||
} catch (err) {
|
||||
toast(err);
|
||||
return
|
||||
}
|
||||
|
||||
var row = append_table_row(document.getElementById("bans"), ban.name, {
|
||||
name: create_ban_object(ban.name, ban.reason, ban.note),
|
||||
date: get_date_string(ban.created),
|
||||
remove: `<a href="#" title="Unban software">✖</a>`
|
||||
});
|
||||
|
||||
add_row_listeners(row);
|
||||
|
||||
elems.name.value = null;
|
||||
elems.reason.value = null;
|
||||
elems.note.value = null;
|
||||
|
||||
document.querySelector("details.section").open = false;
|
||||
toast("Banned software", "message");
|
||||
}
|
||||
|
||||
|
||||
async function update_ban(name) {
|
||||
var row = document.getElementById(name);
|
||||
|
||||
var elems = {
|
||||
"reason": row.querySelector("textarea.reason"),
|
||||
"note": row.querySelector("textarea.note")
|
||||
}
|
||||
|
||||
var values = {
|
||||
"name": name,
|
||||
"reason": elems.reason.value,
|
||||
"note": elems.note.value
|
||||
}
|
||||
|
||||
try {
|
||||
await request("PATCH", "v1/software_ban", values)
|
||||
|
||||
} catch (error) {
|
||||
toast(error);
|
||||
return;
|
||||
}
|
||||
|
||||
row.querySelector("details").open = false;
|
||||
toast("Updated software ban", "message");
|
||||
}
|
||||
|
||||
|
||||
async function unban(name) {
|
||||
try {
|
||||
await request("DELETE", "v1/software_ban", {"name": name});
|
||||
|
||||
} catch (error) {
|
||||
toast(error);
|
||||
return;
|
||||
}
|
||||
|
||||
document.getElementById(name).remove();
|
||||
toast("Unbanned software", "message");
|
||||
}
|
||||
|
||||
|
||||
document.querySelector("#new-ban").addEventListener("click", async (event) => {
|
||||
await ban();
|
||||
});
|
||||
|
||||
for (var elem of document.querySelectorAll("#add-item input")) {
|
||||
elem.addEventListener("keydown", async (event) => {
|
||||
if (event.which === 13) {
|
||||
await ban();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
for (var elem of document.querySelectorAll("#add-item textarea")) {
|
||||
elem.addEventListener("keydown", async (event) => {
|
||||
if (event.which === 13 && event.ctrlKey) {
|
||||
await ban();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
for (var row of document.querySelector("#bans").rows) {
|
||||
if (!row.querySelector(".update-ban")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
add_row_listeners(row);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
function page_user() {
|
||||
function add_row_listeners(row) {
|
||||
row.querySelector(".remove a").addEventListener("click", async (event) => {
|
||||
event.preventDefault();
|
||||
await del_user(row.id);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
async function add_user() {
|
||||
var elems = {
|
||||
username: document.getElementById("new-username"),
|
||||
password: document.getElementById("new-password"),
|
||||
password2: document.getElementById("new-password2"),
|
||||
handle: document.getElementById("new-handle")
|
||||
}
|
||||
|
||||
var values = {
|
||||
username: elems.username.value.trim(),
|
||||
password: elems.password.value.trim(),
|
||||
password2: elems.password2.value.trim(),
|
||||
handle: elems.handle.value.trim()
|
||||
}
|
||||
|
||||
if (values.username === "" | values.password === "" | values.password2 === "") {
|
||||
toast("Username, password, and password2 are required");
|
||||
return;
|
||||
}
|
||||
|
||||
if (values.password !== values.password2) {
|
||||
toast("Passwords do not match");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
var user = await request("POST", "v1/user", values);
|
||||
|
||||
} catch (err) {
|
||||
toast(err);
|
||||
return
|
||||
}
|
||||
|
||||
var row = append_table_row(document.querySelector("fieldset.section table"), user.username, {
|
||||
domain: user.username,
|
||||
handle: user.handle ? self.handle : "n/a",
|
||||
date: get_date_string(user.created),
|
||||
remove: `<a href="#" title="Delete User">✖</a>`
|
||||
});
|
||||
|
||||
add_row_listeners(row);
|
||||
|
||||
elems.username.value = null;
|
||||
elems.password.value = null;
|
||||
elems.password2.value = null;
|
||||
elems.handle.value = null;
|
||||
|
||||
document.querySelector("details.section").open = false;
|
||||
toast("Created user", "message");
|
||||
}
|
||||
|
||||
|
||||
async function del_user(username) {
|
||||
try {
|
||||
await request("DELETE", "v1/user", {"username": username});
|
||||
|
||||
} catch (error) {
|
||||
toast(error);
|
||||
return;
|
||||
}
|
||||
|
||||
document.getElementById(username).remove();
|
||||
toast("Deleted user", "message");
|
||||
}
|
||||
|
||||
|
||||
document.querySelector("#new-user").addEventListener("click", async (event) => {
|
||||
await add_user();
|
||||
});
|
||||
|
||||
for (var elem of document.querySelectorAll("#add-item input")) {
|
||||
elem.addEventListener("keydown", async (event) => {
|
||||
if (event.which === 13) {
|
||||
await add_user();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
for (var row of document.querySelector("#users").rows) {
|
||||
if (!row.querySelector(".remove a")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
add_row_listeners(row);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
function page_whitelist() {
|
||||
function add_row_listeners(row) {
|
||||
row.querySelector(".remove a").addEventListener("click", async (event) => {
|
||||
event.preventDefault();
|
||||
await del_whitelist(row.id);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
async function add_whitelist() {
|
||||
var domain_elem = document.getElementById("new-domain");
|
||||
var domain = domain_elem.value.trim();
|
||||
|
||||
if (domain === "") {
|
||||
toast("Domain is required");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
var item = await request("POST", "v1/whitelist", {"domain": domain});
|
||||
|
||||
} catch (err) {
|
||||
toast(err);
|
||||
return;
|
||||
}
|
||||
|
||||
var row = append_table_row(document.getElementById("whitelist"), item.domain, {
|
||||
domain: item.domain,
|
||||
date: get_date_string(item.created),
|
||||
remove: `<a href="#" title="Remove whitelisted domain">✖</a>`
|
||||
});
|
||||
|
||||
add_row_listeners(row);
|
||||
|
||||
domain_elem.value = null;
|
||||
document.querySelector("details.section").open = false;
|
||||
toast("Added domain", "message");
|
||||
}
|
||||
|
||||
|
||||
async function del_whitelist(domain) {
|
||||
try {
|
||||
await request("DELETE", "v1/whitelist", {"domain": domain});
|
||||
|
||||
} catch (error) {
|
||||
toast(error);
|
||||
return;
|
||||
}
|
||||
|
||||
document.getElementById(domain).remove();
|
||||
toast("Removed domain", "message");
|
||||
}
|
||||
|
||||
|
||||
document.querySelector("#new-item").addEventListener("click", async (event) => {
|
||||
await add_whitelist();
|
||||
});
|
||||
|
||||
document.querySelector("#add-item").addEventListener("keydown", async (event) => {
|
||||
if (event.which === 13) {
|
||||
await add_whitelist();
|
||||
}
|
||||
});
|
||||
|
||||
for (var row of document.querySelector("fieldset.section table").rows) {
|
||||
if (!row.querySelector(".remove a")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
add_row_listeners(row);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if (location.pathname.startsWith("/admin/config")) {
|
||||
page_config();
|
||||
|
||||
} else if (location.pathname.startsWith("/admin/domain_bans")) {
|
||||
page_domain_ban();
|
||||
|
||||
} else if (location.pathname.startsWith("/admin/instances")) {
|
||||
page_instance();
|
||||
|
||||
} else if (location.pathname.startsWith("/admin/software_bans")) {
|
||||
page_software_ban();
|
||||
|
||||
} else if (location.pathname.startsWith("/admin/users")) {
|
||||
page_user();
|
||||
|
||||
} else if (location.pathname.startsWith("/admin/whitelist")) {
|
||||
page_whitelist();
|
||||
|
||||
} else if (location.pathname.startsWith("/login")) {
|
||||
page_login();
|
||||
}
|
|
@ -1,145 +0,0 @@
|
|||
function add_instance_listeners(row) {
|
||||
row.querySelector(".remove a").addEventListener("click", async (event) => {
|
||||
event.preventDefault();
|
||||
await del_instance(row.id);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
function add_request_listeners(row) {
|
||||
row.querySelector(".approve a").addEventListener("click", async (event) => {
|
||||
event.preventDefault();
|
||||
await req_response(row.id, true);
|
||||
});
|
||||
|
||||
row.querySelector(".deny a").addEventListener("click", async (event) => {
|
||||
event.preventDefault();
|
||||
await req_response(row.id, false);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
async function add_instance() {
|
||||
var elems = {
|
||||
actor: document.getElementById("new-actor"),
|
||||
inbox: document.getElementById("new-inbox"),
|
||||
followid: document.getElementById("new-followid"),
|
||||
software: document.getElementById("new-software")
|
||||
}
|
||||
|
||||
var values = {
|
||||
actor: elems.actor.value.trim(),
|
||||
inbox: elems.inbox.value.trim(),
|
||||
followid: elems.followid.value.trim(),
|
||||
software: elems.software.value.trim()
|
||||
}
|
||||
|
||||
if (values.actor === "") {
|
||||
toast("Actor is required");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
var instance = await request("POST", "v1/instance", values);
|
||||
|
||||
} catch (err) {
|
||||
toast(err);
|
||||
return
|
||||
}
|
||||
|
||||
row = append_table_row(document.getElementById("instances"), instance.domain, {
|
||||
domain: `<a href="https://${instance.domain}/" target="_new">${instance.domain}</a>`,
|
||||
software: instance.software,
|
||||
date: get_date_string(instance.created),
|
||||
remove: `<a href="#" title="Remove Instance">✖</a>`
|
||||
});
|
||||
|
||||
add_instance_listeners(row);
|
||||
|
||||
elems.actor.value = null;
|
||||
elems.inbox.value = null;
|
||||
elems.followid.value = null;
|
||||
elems.software.value = null;
|
||||
|
||||
document.querySelector("details.section").open = false;
|
||||
toast("Added instance", "message");
|
||||
}
|
||||
|
||||
|
||||
async function del_instance(domain) {
|
||||
try {
|
||||
await request("DELETE", "v1/instance", {"domain": domain});
|
||||
|
||||
} catch (error) {
|
||||
toast(error);
|
||||
return;
|
||||
}
|
||||
|
||||
document.getElementById(domain).remove();
|
||||
}
|
||||
|
||||
|
||||
async function req_response(domain, accept) {
|
||||
params = {
|
||||
"domain": domain,
|
||||
"accept": accept
|
||||
}
|
||||
|
||||
try {
|
||||
await request("POST", "v1/request", params);
|
||||
|
||||
} catch (error) {
|
||||
toast(error);
|
||||
return;
|
||||
}
|
||||
|
||||
document.getElementById(domain).remove();
|
||||
|
||||
if (document.getElementById("requests").rows.length < 2) {
|
||||
document.querySelector("fieldset.requests").remove()
|
||||
}
|
||||
|
||||
if (!accept) {
|
||||
toast("Denied instance request", "message");
|
||||
return;
|
||||
}
|
||||
|
||||
instances = await request("GET", `v1/instance`, null);
|
||||
instances.forEach((instance) => {
|
||||
if (instance.domain === domain) {
|
||||
row = append_table_row(document.getElementById("instances"), instance.domain, {
|
||||
domain: `<a href="https://${instance.domain}/" target="_new">${instance.domain}</a>`,
|
||||
software: instance.software,
|
||||
date: get_date_string(instance.created),
|
||||
remove: `<a href="#" title="Remove Instance">✖</a>`
|
||||
});
|
||||
|
||||
add_instance_listeners(row);
|
||||
}
|
||||
});
|
||||
|
||||
toast("Accepted instance request", "message");
|
||||
}
|
||||
|
||||
|
||||
document.querySelector("#add-instance").addEventListener("click", async (event) => {
|
||||
await add_instance();
|
||||
})
|
||||
|
||||
for (var row of document.querySelector("#instances").rows) {
|
||||
if (!row.querySelector(".remove a")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
add_instance_listeners(row);
|
||||
}
|
||||
|
||||
if (document.querySelector("#requests")) {
|
||||
for (var row of document.querySelector("#requests").rows) {
|
||||
if (!row.querySelector(".approve a")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
add_request_listeners(row);
|
||||
}
|
||||
}
|
|
@ -1,29 +0,0 @@
|
|||
async function login(event) {
|
||||
fields = {
|
||||
username: document.querySelector("#username"),
|
||||
password: document.querySelector("#password")
|
||||
}
|
||||
|
||||
values = {
|
||||
username: fields.username.value.trim(),
|
||||
password: fields.password.value.trim()
|
||||
}
|
||||
|
||||
if (values.username === "" | values.password === "") {
|
||||
toast("Username and/or password field is blank");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
await request("POST", "v1/token", values);
|
||||
|
||||
} catch (error) {
|
||||
toast(error);
|
||||
return;
|
||||
}
|
||||
|
||||
document.location = "/";
|
||||
}
|
||||
|
||||
|
||||
document.querySelector(".submit").addEventListener("click", login);
|
|
@ -1,122 +0,0 @@
|
|||
function create_ban_object(name, reason, note) {
|
||||
var text = '<details>\n';
|
||||
text += `<summary>${name}</summary>\n`;
|
||||
text += '<div class="grid-2col">\n';
|
||||
text += `<label for="${name}-reason" class="reason">Reason</label>\n`;
|
||||
text += `<textarea id="${name}-reason" class="reason">${reason}</textarea>\n`;
|
||||
text += `<label for="${name}-note" class="note">Note</label>\n`;
|
||||
text += `<textarea id="${name}-note" class="note">${note}</textarea>\n`;
|
||||
text += `<input class="update-ban" type="button" value="Update">`;
|
||||
text += '</details>';
|
||||
|
||||
return text;
|
||||
}
|
||||
|
||||
|
||||
function add_row_listeners(row) {
|
||||
row.querySelector(".update-ban").addEventListener("click", async (event) => {
|
||||
await update_ban(row.id);
|
||||
});
|
||||
|
||||
row.querySelector(".remove a").addEventListener("click", async (event) => {
|
||||
event.preventDefault();
|
||||
await unban(row.id);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
async function ban() {
|
||||
var elems = {
|
||||
name: document.getElementById("new-name"),
|
||||
reason: document.getElementById("new-reason"),
|
||||
note: document.getElementById("new-note")
|
||||
}
|
||||
|
||||
var values = {
|
||||
name: elems.name.value.trim(),
|
||||
reason: elems.reason.value,
|
||||
note: elems.note.value
|
||||
}
|
||||
|
||||
if (values.name === "") {
|
||||
toast("Domain is required");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
var ban = await request("POST", "v1/software_ban", values);
|
||||
|
||||
} catch (err) {
|
||||
toast(err);
|
||||
return
|
||||
}
|
||||
|
||||
var row = append_table_row(document.getElementById("bans"), ban.name, {
|
||||
name: create_ban_object(ban.name, ban.reason, ban.note),
|
||||
date: get_date_string(ban.created),
|
||||
remove: `<a href="#" title="Unban software">✖</a>`
|
||||
});
|
||||
|
||||
add_row_listeners(row);
|
||||
|
||||
elems.name.value = null;
|
||||
elems.reason.value = null;
|
||||
elems.note.value = null;
|
||||
|
||||
document.querySelector("details.section").open = false;
|
||||
toast("Banned software", "message");
|
||||
}
|
||||
|
||||
|
||||
async function update_ban(name) {
|
||||
var row = document.getElementById(name);
|
||||
|
||||
var elems = {
|
||||
"reason": row.querySelector("textarea.reason"),
|
||||
"note": row.querySelector("textarea.note")
|
||||
}
|
||||
|
||||
var values = {
|
||||
"name": name,
|
||||
"reason": elems.reason.value,
|
||||
"note": elems.note.value
|
||||
}
|
||||
|
||||
try {
|
||||
await request("PATCH", "v1/software_ban", values)
|
||||
|
||||
} catch (error) {
|
||||
toast(error);
|
||||
return;
|
||||
}
|
||||
|
||||
row.querySelector("details").open = false;
|
||||
toast("Updated software ban", "message");
|
||||
}
|
||||
|
||||
|
||||
async function unban(name) {
|
||||
try {
|
||||
await request("DELETE", "v1/software_ban", {"name": name});
|
||||
|
||||
} catch (error) {
|
||||
toast(error);
|
||||
return;
|
||||
}
|
||||
|
||||
document.getElementById(name).remove();
|
||||
toast("Unbanned software", "message");
|
||||
}
|
||||
|
||||
|
||||
document.querySelector("#new-ban").addEventListener("click", async (event) => {
|
||||
await ban();
|
||||
});
|
||||
|
||||
for (var row of document.querySelector("#bans").rows) {
|
||||
if (!row.querySelector(".update-ban")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
add_row_listeners(row);
|
||||
}
|
|
@ -12,7 +12,7 @@ body {
|
|||
color: var(--text);
|
||||
background-color: #222;
|
||||
margin: var(--spacing);
|
||||
font-family: sans serif;
|
||||
font-family: sans-serif;
|
||||
}
|
||||
|
||||
details *:nth-child(2) {
|
||||
|
@ -88,6 +88,7 @@ tbody tr:last-child td:last-child {
|
|||
|
||||
table td {
|
||||
padding: 5px;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
table thead td {
|
||||
|
@ -282,8 +283,11 @@ textarea {
|
|||
width: 100%;
|
||||
}
|
||||
|
||||
.data-table .date {
|
||||
.data-table td:not(:first-child) {
|
||||
width: max-content;
|
||||
}
|
||||
|
||||
.data-table .date {
|
||||
text-align: right;
|
||||
}
|
||||
|
||||
|
@ -297,13 +301,13 @@ textarea {
|
|||
border: 1px solid var(--error-border) !important;
|
||||
}
|
||||
|
||||
/* create .grid base class and .2col and 3col classes */
|
||||
.grid-2col {
|
||||
display: grid;
|
||||
grid-template-columns: max-content auto;
|
||||
grid-gap: var(--spacing);
|
||||
margin-bottom: var(--spacing);
|
||||
align-items: center;
|
||||
|
||||
}
|
||||
|
||||
.message {
|
||||
|
@ -333,6 +337,48 @@ textarea {
|
|||
justify-self: left;
|
||||
}
|
||||
|
||||
#content.page-config .grid-2col {
|
||||
grid-template-columns: max-content max-content auto;
|
||||
}
|
||||
|
||||
|
||||
/* error */
|
||||
#content.page-error {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
#content.page-error .title {
|
||||
font-size: 24px;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
|
||||
/* auth */
|
||||
#content.page-app_authorization {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
#content.page-app_authorization #code {
|
||||
background: var(--background);
|
||||
border: 1px solid var(--border);
|
||||
font-size: 18px;
|
||||
margin: 0 auto;
|
||||
width: max-content;
|
||||
padding: 5px;
|
||||
}
|
||||
|
||||
#content.page-app_authorization #title {
|
||||
font-size: 24px;
|
||||
}
|
||||
|
||||
#content.page-app_authorization #buttons {
|
||||
display: grid;
|
||||
grid-template-columns: auto max-content max-content auto;
|
||||
grid-gap: var(--spacing);
|
||||
justify-items: center;
|
||||
margin: var(--spacing) 0;
|
||||
}
|
||||
|
||||
|
||||
@keyframes show_toast {
|
||||
0% {
|
||||
|
|
|
@ -1,85 +0,0 @@
|
|||
function add_row_listeners(row) {
|
||||
row.querySelector(".remove a").addEventListener("click", async (event) => {
|
||||
event.preventDefault();
|
||||
await del_user(row.id);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
async function add_user() {
|
||||
var elems = {
|
||||
username: document.getElementById("new-username"),
|
||||
password: document.getElementById("new-password"),
|
||||
password2: document.getElementById("new-password2"),
|
||||
handle: document.getElementById("new-handle")
|
||||
}
|
||||
|
||||
var values = {
|
||||
username: elems.username.value.trim(),
|
||||
password: elems.password.value.trim(),
|
||||
password2: elems.password2.value.trim(),
|
||||
handle: elems.handle.value.trim()
|
||||
}
|
||||
|
||||
if (values.username === "" | values.password === "" | values.password2 === "") {
|
||||
toast("Username, password, and password2 are required");
|
||||
return;
|
||||
}
|
||||
|
||||
if (values.password !== values.password2) {
|
||||
toast("Passwords do not match");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
var user = await request("POST", "v1/user", values);
|
||||
|
||||
} catch (err) {
|
||||
toast(err);
|
||||
return
|
||||
}
|
||||
|
||||
var row = append_table_row(document.querySelector("fieldset.section table"), user.username, {
|
||||
domain: user.username,
|
||||
handle: user.handle ? self.handle : "n/a",
|
||||
date: get_date_string(user.created),
|
||||
remove: `<a href="#" title="Delete User">✖</a>`
|
||||
});
|
||||
|
||||
add_row_listeners(row);
|
||||
|
||||
elems.username.value = null;
|
||||
elems.password.value = null;
|
||||
elems.password2.value = null;
|
||||
elems.handle.value = null;
|
||||
|
||||
document.querySelector("details.section").open = false;
|
||||
toast("Created user", "message");
|
||||
}
|
||||
|
||||
|
||||
async function del_user(username) {
|
||||
try {
|
||||
await request("DELETE", "v1/user", {"username": username});
|
||||
|
||||
} catch (error) {
|
||||
toast(error);
|
||||
return;
|
||||
}
|
||||
|
||||
document.getElementById(username).remove();
|
||||
toast("Deleted user", "message");
|
||||
}
|
||||
|
||||
|
||||
document.querySelector("#new-user").addEventListener("click", async (event) => {
|
||||
await add_user();
|
||||
});
|
||||
|
||||
for (var row of document.querySelector("#users").rows) {
|
||||
if (!row.querySelector(".remove a")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
add_row_listeners(row);
|
||||
}
|
|
@ -1,64 +0,0 @@
|
|||
function add_row_listeners(row) {
|
||||
row.querySelector(".remove a").addEventListener("click", async (event) => {
|
||||
event.preventDefault();
|
||||
await del_whitelist(row.id);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
async function add_whitelist() {
|
||||
var domain_elem = document.getElementById("new-domain");
|
||||
var domain = domain_elem.value.trim();
|
||||
|
||||
if (domain === "") {
|
||||
toast("Domain is required");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
var item = await request("POST", "v1/whitelist", {"domain": domain});
|
||||
|
||||
} catch (err) {
|
||||
toast(err);
|
||||
return;
|
||||
}
|
||||
|
||||
var row = append_table_row(document.getElementById("whitelist"), item.domain, {
|
||||
domain: item.domain,
|
||||
date: get_date_string(item.created),
|
||||
remove: `<a href="#" title="Remove whitelisted domain">✖</a>`
|
||||
});
|
||||
|
||||
add_row_listeners(row);
|
||||
|
||||
domain_elem.value = null;
|
||||
document.querySelector("details.section").open = false;
|
||||
toast("Added domain", "message");
|
||||
}
|
||||
|
||||
|
||||
async function del_whitelist(domain) {
|
||||
try {
|
||||
await request("DELETE", "v1/whitelist", {"domain": domain});
|
||||
|
||||
} catch (error) {
|
||||
toast(error);
|
||||
return;
|
||||
}
|
||||
|
||||
document.getElementById(domain).remove();
|
||||
toast("Removed domain", "message");
|
||||
}
|
||||
|
||||
|
||||
document.querySelector("#new-item").addEventListener("click", async (event) => {
|
||||
await add_whitelist();
|
||||
});
|
||||
|
||||
for (var row of document.querySelector("fieldset.section table").rows) {
|
||||
if (!row.querySelector(".remove a")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
add_row_listeners(row);
|
||||
}
|
|
@ -1,20 +1,16 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import traceback
|
||||
|
||||
from aiohttp import ClientSession, ClientTimeout, TCPConnector
|
||||
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
|
||||
from aputils import AlgorithmType, Nodeinfo, ObjectType, Signer, WellKnownNodeinfo
|
||||
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
|
||||
from blib import JsonBase
|
||||
from bsql import Row
|
||||
from json.decoder import JSONDecodeError
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
from urllib.parse import urlparse
|
||||
from blib import HttpError, JsonBase
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, overload
|
||||
|
||||
from . import __version__, logger as logging
|
||||
from .cache import Cache
|
||||
from .database.schema import Instance
|
||||
from .errors import EmptyBodyError
|
||||
from .misc import MIMETYPES, Message, get_app
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -36,7 +32,7 @@ SUPPORTS_HS2019 = {
|
|||
'sharkey'
|
||||
}
|
||||
|
||||
T = TypeVar('T', bound = JsonBase)
|
||||
T = TypeVar('T', bound = JsonBase[Any])
|
||||
HEADERS = {
|
||||
'Accept': f'{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9',
|
||||
'User-Agent': f'ActivityRelay/{__version__}'
|
||||
|
@ -107,21 +103,17 @@ class HttpClient:
|
|||
url: str,
|
||||
sign_headers: bool,
|
||||
force: bool,
|
||||
old_algo: bool) -> dict[str, Any] | None:
|
||||
old_algo: bool) -> str | None:
|
||||
|
||||
if not self._session:
|
||||
raise RuntimeError('Client not open')
|
||||
|
||||
try:
|
||||
url, _ = url.split('#', 1)
|
||||
|
||||
except ValueError:
|
||||
pass
|
||||
url = url.split("#", 1)[0]
|
||||
|
||||
if not force:
|
||||
try:
|
||||
if not (item := self.cache.get('request', url)).older_than(48):
|
||||
return json.loads(item.value) # type: ignore[no-any-return]
|
||||
return item.value # type: ignore [no-any-return]
|
||||
|
||||
except KeyError:
|
||||
logging.verbose('No cached data for url: %s', url)
|
||||
|
@ -132,67 +124,74 @@ class HttpClient:
|
|||
algo = AlgorithmType.RSASHA256 if old_algo else AlgorithmType.HS2019
|
||||
headers = self.signer.sign_headers('GET', url, algorithm = algo)
|
||||
|
||||
try:
|
||||
logging.debug('Fetching resource: %s', url)
|
||||
logging.debug('Fetching resource: %s', url)
|
||||
|
||||
async with self._session.get(url, headers = headers) as resp:
|
||||
# Not expecting a response with 202s, so just return
|
||||
if resp.status == 202:
|
||||
return None
|
||||
|
||||
data = await resp.text()
|
||||
|
||||
if resp.status != 200:
|
||||
logging.verbose('Received error when requesting %s: %i', url, resp.status)
|
||||
logging.debug(data)
|
||||
async with self._session.get(url, headers = headers) as resp:
|
||||
# Not expecting a response with 202s, so just return
|
||||
if resp.status == 202:
|
||||
return None
|
||||
|
||||
self.cache.set('request', url, data, 'str')
|
||||
logging.debug('%s >> resp %s', url, json.dumps(json.loads(data), indent = 4))
|
||||
data = await resp.text()
|
||||
|
||||
return json.loads(data) # type: ignore [no-any-return]
|
||||
if resp.status not in (200, 202):
|
||||
try:
|
||||
error = json.loads(data)["error"]
|
||||
|
||||
except JSONDecodeError:
|
||||
logging.verbose('Failed to parse JSON')
|
||||
logging.debug(data)
|
||||
return None
|
||||
except Exception:
|
||||
error = data
|
||||
|
||||
except ClientSSLError as e:
|
||||
logging.verbose('SSL error when connecting to %s', urlparse(url).netloc)
|
||||
logging.warning(str(e))
|
||||
raise HttpError(resp.status, error)
|
||||
|
||||
except (AsyncTimeoutError, ClientConnectionError) as e:
|
||||
logging.verbose('Failed to connect to %s', urlparse(url).netloc)
|
||||
logging.warning(str(e))
|
||||
self.cache.set('request', url, data, 'str')
|
||||
return data
|
||||
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
return None
|
||||
@overload
|
||||
async def get(self,
|
||||
url: str,
|
||||
sign_headers: bool,
|
||||
cls: None = None,
|
||||
force: bool = False,
|
||||
old_algo: bool = True) -> str | None: ...
|
||||
|
||||
|
||||
@overload
|
||||
async def get(self,
|
||||
url: str,
|
||||
sign_headers: bool,
|
||||
cls: type[T] = JsonBase, # type: ignore[assignment]
|
||||
force: bool = False,
|
||||
old_algo: bool = True) -> T: ...
|
||||
|
||||
|
||||
async def get(self,
|
||||
url: str,
|
||||
sign_headers: bool,
|
||||
cls: type[T],
|
||||
cls: type[T] | None = None,
|
||||
force: bool = False,
|
||||
old_algo: bool = True) -> T | None:
|
||||
old_algo: bool = True) -> T | str | None:
|
||||
|
||||
if not issubclass(cls, JsonBase):
|
||||
if cls is not None and not issubclass(cls, JsonBase):
|
||||
raise TypeError('cls must be a sub-class of "blib.JsonBase"')
|
||||
|
||||
if (data := (await self._get(url, sign_headers, force, old_algo))) is None:
|
||||
return None
|
||||
data = await self._get(url, sign_headers, force, old_algo)
|
||||
|
||||
return cls.parse(data)
|
||||
if cls is not None:
|
||||
if data is None:
|
||||
# this shouldn't actually get raised, but keeping just in case
|
||||
raise EmptyBodyError(f"GET {url}")
|
||||
|
||||
return cls.parse(data)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
async def post(self, url: str, data: Message | bytes, instance: Row | None = None) -> None:
|
||||
async def post(self, url: str, data: Message | bytes, instance: Instance | None = None) -> None:
|
||||
if not self._session:
|
||||
raise RuntimeError('Client not open')
|
||||
|
||||
# akkoma and pleroma do not support HS2019 and other software still needs to be tested
|
||||
if instance and instance['software'] in SUPPORTS_HS2019:
|
||||
if instance is not None and instance.software in SUPPORTS_HS2019:
|
||||
algorithm = AlgorithmType.HS2019
|
||||
|
||||
else:
|
||||
|
@ -218,46 +217,23 @@ class HttpClient:
|
|||
algorithm = algorithm
|
||||
)
|
||||
|
||||
try:
|
||||
logging.verbose('Sending "%s" to %s', mtype, url)
|
||||
logging.verbose('Sending "%s" to %s', mtype, url)
|
||||
|
||||
async with self._session.post(url, headers = headers, data = body) as resp:
|
||||
# Not expecting a response, so just return
|
||||
if resp.status in {200, 202}:
|
||||
logging.verbose('Successfully sent "%s" to %s', mtype, url)
|
||||
return
|
||||
|
||||
logging.verbose('Received error when pushing to %s: %i', url, resp.status)
|
||||
logging.debug(await resp.read())
|
||||
logging.debug("message: %s", body.decode("utf-8"))
|
||||
logging.debug("headers: %s", json.dumps(headers, indent = 4))
|
||||
return
|
||||
|
||||
except ClientSSLError as e:
|
||||
logging.warning('SSL error when pushing to %s', urlparse(url).netloc)
|
||||
logging.warning(str(e))
|
||||
|
||||
except (AsyncTimeoutError, ClientConnectionError) as e:
|
||||
logging.warning('Failed to connect to %s for message push', urlparse(url).netloc)
|
||||
logging.warning(str(e))
|
||||
|
||||
# prevent workers from being brought down
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
async with self._session.post(url, headers = headers, data = body) as resp:
|
||||
if resp.status not in (200, 202):
|
||||
raise HttpError(
|
||||
resp.status,
|
||||
await resp.text(),
|
||||
headers = {k: v for k, v in resp.headers.items()}
|
||||
)
|
||||
|
||||
|
||||
async def fetch_nodeinfo(self, domain: str) -> Nodeinfo | None:
|
||||
async def fetch_nodeinfo(self, domain: str, force: bool = False) -> Nodeinfo:
|
||||
nodeinfo_url = None
|
||||
wk_nodeinfo = await self.get(
|
||||
f'https://{domain}/.well-known/nodeinfo',
|
||||
False,
|
||||
WellKnownNodeinfo
|
||||
f'https://{domain}/.well-known/nodeinfo', False, WellKnownNodeinfo, force
|
||||
)
|
||||
|
||||
if wk_nodeinfo is None:
|
||||
logging.verbose('Failed to fetch well-known nodeinfo url for %s', domain)
|
||||
return None
|
||||
|
||||
for version in ('20', '21'):
|
||||
try:
|
||||
nodeinfo_url = wk_nodeinfo.get_url(version)
|
||||
|
@ -266,10 +242,9 @@ class HttpClient:
|
|||
pass
|
||||
|
||||
if nodeinfo_url is None:
|
||||
logging.verbose('Failed to fetch nodeinfo url for %s', domain)
|
||||
return None
|
||||
raise ValueError(f'Failed to fetch nodeinfo url for {domain}')
|
||||
|
||||
return await self.get(nodeinfo_url, False, Nodeinfo)
|
||||
return await self.get(nodeinfo_url, False, Nodeinfo, force)
|
||||
|
||||
|
||||
async def get(*args: Any, **kwargs: Any) -> Any:
|
||||
|
|
|
@ -1,16 +1,15 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
|
||||
try:
|
||||
if TYPE_CHECKING:
|
||||
from typing import Self
|
||||
|
||||
except ImportError:
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class LoggingMethod(Protocol):
|
||||
def __call__(self, msg: Any, *args: Any, **kwargs: Any) -> None: ...
|
||||
|
|
268
relay/manage.py
268
relay/manage.py
|
@ -6,7 +6,6 @@ import click
|
|||
import json
|
||||
import os
|
||||
|
||||
from bsql import Row
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any
|
||||
|
@ -17,7 +16,8 @@ from . import http_client as http
|
|||
from . import logger as logging
|
||||
from .application import Application
|
||||
from .compat import RelayConfig, RelayDatabase
|
||||
from .database import RELAY_SOFTWARE, get_database
|
||||
from .config import Config
|
||||
from .database import RELAY_SOFTWARE, get_database, schema
|
||||
from .misc import ACTOR_FORMATS, SOFTWARE, IS_DOCKER, Message
|
||||
|
||||
|
||||
|
@ -102,34 +102,7 @@ def cli_setup(ctx: click.Context, skip_questions: bool) -> None:
|
|||
)
|
||||
|
||||
elif ctx.obj.config.db_type == 'postgres':
|
||||
ctx.obj.config.pg_name = click.prompt(
|
||||
'What is the name of the database?',
|
||||
default = ctx.obj.config.pg_name
|
||||
)
|
||||
|
||||
ctx.obj.config.pg_host = click.prompt(
|
||||
'What IP address, hostname, or unix socket does the server listen on?',
|
||||
default = ctx.obj.config.pg_host,
|
||||
type = int
|
||||
)
|
||||
|
||||
ctx.obj.config.pg_port = click.prompt(
|
||||
'What port does the server listen on?',
|
||||
default = ctx.obj.config.pg_port,
|
||||
type = int
|
||||
)
|
||||
|
||||
ctx.obj.config.pg_user = click.prompt(
|
||||
'Which user will authenticate with the server?',
|
||||
default = ctx.obj.config.pg_user
|
||||
)
|
||||
|
||||
ctx.obj.config.pg_pass = click.prompt(
|
||||
'User password',
|
||||
hide_input = True,
|
||||
show_default = False,
|
||||
default = ctx.obj.config.pg_pass or ""
|
||||
) or None
|
||||
config_postgresql(ctx.obj.config)
|
||||
|
||||
ctx.obj.config.ca_type = click.prompt(
|
||||
'Which caching backend?',
|
||||
|
@ -213,6 +186,21 @@ def cli_run(ctx: click.Context, dev: bool = False) -> None:
|
|||
os._exit(0)
|
||||
|
||||
|
||||
@cli.command('db-maintenance')
|
||||
@click.pass_context
|
||||
def cli_db_maintenance(ctx: click.Context) -> None:
|
||||
'Perform maintenance tasks on the database'
|
||||
|
||||
if ctx.obj.config.db_type == "postgres":
|
||||
return
|
||||
|
||||
with ctx.obj.database.session(False) as s:
|
||||
with s.transaction():
|
||||
s.fix_timestamps()
|
||||
|
||||
with s.execute("VACUUM"):
|
||||
pass
|
||||
|
||||
|
||||
@cli.command('convert')
|
||||
@click.option('--old-config', '-o', help = 'Path to the config file to convert from')
|
||||
|
@ -240,18 +228,18 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
|
|||
ctx.obj.config.set('domain', config['host'])
|
||||
ctx.obj.config.save()
|
||||
|
||||
# fix: mypy complains about the types returned by click.progressbar when updating click to 8.1.7
|
||||
with get_database(ctx.obj.config) as db:
|
||||
with db.session(True) as conn:
|
||||
conn.put_config('private-key', database['private-key'])
|
||||
conn.put_config('note', config['note'])
|
||||
conn.put_config('whitelist-enabled', config['whitelist_enabled'])
|
||||
|
||||
with click.progressbar( # type: ignore
|
||||
with click.progressbar(
|
||||
database['relay-list'].values(),
|
||||
label = 'Inboxes'.ljust(15),
|
||||
width = 0
|
||||
) as inboxes:
|
||||
|
||||
for inbox in inboxes:
|
||||
if inbox['software'] in {'akkoma', 'pleroma'}:
|
||||
actor = f'https://{inbox["domain"]}/relay'
|
||||
|
@ -270,7 +258,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
|
|||
software = inbox['software']
|
||||
)
|
||||
|
||||
with click.progressbar( # type: ignore
|
||||
with click.progressbar(
|
||||
config['blocked_software'],
|
||||
label = 'Banned software'.ljust(15),
|
||||
width = 0
|
||||
|
@ -282,7 +270,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
|
|||
reason = 'relay' if software in RELAY_SOFTWARE else None
|
||||
)
|
||||
|
||||
with click.progressbar( # type: ignore
|
||||
with click.progressbar(
|
||||
config['blocked_instances'],
|
||||
label = 'Banned domains'.ljust(15),
|
||||
width = 0
|
||||
|
@ -291,7 +279,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
|
|||
for domain in banned_software:
|
||||
conn.put_domain_ban(domain)
|
||||
|
||||
with click.progressbar( # type: ignore
|
||||
with click.progressbar(
|
||||
config['whitelist'],
|
||||
label = 'Whitelist'.ljust(15),
|
||||
width = 0
|
||||
|
@ -315,6 +303,48 @@ def cli_editconfig(ctx: click.Context, editor: str) -> None:
|
|||
)
|
||||
|
||||
|
||||
@cli.command('switch-backend')
|
||||
@click.pass_context
|
||||
def cli_switchbackend(ctx: click.Context) -> None:
|
||||
"""
|
||||
Copy the database from one backend to the other
|
||||
|
||||
Be sure to set the database type to the backend you want to convert from. For instance, set
|
||||
the database type to `sqlite`, fill out the connection details for postgresql, and the
|
||||
data from the sqlite database will be copied to the postgresql database. This only works if
|
||||
the database in postgresql already exists.
|
||||
"""
|
||||
|
||||
config = Config(ctx.obj.config.path, load = True)
|
||||
config.db_type = "sqlite" if config.db_type == "postgres" else "postgres"
|
||||
|
||||
if config.db_type == "postgres":
|
||||
if click.confirm("Setup PostgreSQL configuration?"):
|
||||
config_postgresql(config)
|
||||
|
||||
order = ("SQLite", "PostgreSQL")
|
||||
click.pause("Make sure the database and user already exist before continuing")
|
||||
|
||||
else:
|
||||
order = ("PostgreSQL", "SQLite")
|
||||
|
||||
click.echo(f"About to convert from {order[0]} to {order[1]}...")
|
||||
database = get_database(config, migrate = False)
|
||||
|
||||
with database.session(True) as new, ctx.obj.database.session(False) as old:
|
||||
if click.confirm("All tables in the destination database will be dropped. Continue?"):
|
||||
new.drop_tables()
|
||||
|
||||
new.create_tables()
|
||||
|
||||
for table in schema.TABLES.keys():
|
||||
for row in old.execute(f"SELECT * FROM {table}"):
|
||||
new.insert(table, row).close()
|
||||
|
||||
config.save()
|
||||
click.echo("Done!")
|
||||
|
||||
|
||||
@cli.group('config')
|
||||
def cli_config() -> None:
|
||||
'Manage the relay settings stored in the database'
|
||||
|
@ -348,10 +378,15 @@ def cli_config_list(ctx: click.Context) -> None:
|
|||
def cli_config_set(ctx: click.Context, key: str, value: Any) -> None:
|
||||
'Set a config value'
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
new_value = conn.put_config(key, value)
|
||||
try:
|
||||
with ctx.obj.database.session() as conn:
|
||||
new_value = conn.put_config(key, value)
|
||||
|
||||
print(f'{key}: {repr(new_value)}')
|
||||
except Exception:
|
||||
click.echo(f'Invalid config name: {key}')
|
||||
return
|
||||
|
||||
click.echo(f'{key}: {repr(new_value)}')
|
||||
|
||||
|
||||
@cli.group('user')
|
||||
|
@ -367,8 +402,8 @@ def cli_user_list(ctx: click.Context) -> None:
|
|||
click.echo('Users:')
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
for user in conn.execute('SELECT * FROM users'):
|
||||
click.echo(f'- {user["username"]}')
|
||||
for row in conn.get_users():
|
||||
click.echo(f'- {row.username}')
|
||||
|
||||
|
||||
@cli_user.command('create')
|
||||
|
@ -379,7 +414,7 @@ def cli_user_create(ctx: click.Context, username: str, handle: str) -> None:
|
|||
'Create a new local user'
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
if conn.get_user(username):
|
||||
if conn.get_user(username) is not None:
|
||||
click.echo(f'User already exists: {username}')
|
||||
return
|
||||
|
||||
|
@ -406,7 +441,7 @@ def cli_user_delete(ctx: click.Context, username: str) -> None:
|
|||
'Delete a local user'
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
if not conn.get_user(username):
|
||||
if conn.get_user(username) is None:
|
||||
click.echo(f'User does not exist: {username}')
|
||||
return
|
||||
|
||||
|
@ -424,8 +459,8 @@ def cli_user_list_tokens(ctx: click.Context, username: str) -> None:
|
|||
click.echo(f'Tokens for "{username}":')
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
for token in conn.execute('SELECT * FROM tokens WHERE user = :user', {'user': username}):
|
||||
click.echo(f'- {token["code"]}')
|
||||
for row in conn.get_tokens(username):
|
||||
click.echo(f'- {row.code}')
|
||||
|
||||
|
||||
@cli_user.command('create-token')
|
||||
|
@ -435,13 +470,13 @@ def cli_user_create_token(ctx: click.Context, username: str) -> None:
|
|||
'Create a new API token for a user'
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
if not (user := conn.get_user(username)):
|
||||
if (user := conn.get_user(username)) is None:
|
||||
click.echo(f'User does not exist: {username}')
|
||||
return
|
||||
|
||||
token = conn.put_token(user['username'])
|
||||
token = conn.put_token(user.username)
|
||||
|
||||
click.echo(f'New token for "{username}": {token["code"]}')
|
||||
click.echo(f'New token for "{username}": {token.code}')
|
||||
|
||||
|
||||
@cli_user.command('delete-token')
|
||||
|
@ -451,7 +486,7 @@ def cli_user_delete_token(ctx: click.Context, code: str) -> None:
|
|||
'Delete an API token'
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
if not conn.get_token(code):
|
||||
if conn.get_token(code) is None:
|
||||
click.echo('Token does not exist')
|
||||
return
|
||||
|
||||
|
@ -473,8 +508,8 @@ def cli_inbox_list(ctx: click.Context) -> None:
|
|||
click.echo('Connected to the following instances or relays:')
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
for inbox in conn.get_inboxes():
|
||||
click.echo(f'- {inbox["inbox"]}')
|
||||
for row in conn.get_inboxes():
|
||||
click.echo(f'- {row.inbox}')
|
||||
|
||||
|
||||
@cli_inbox.command('follow')
|
||||
|
@ -483,19 +518,21 @@ def cli_inbox_list(ctx: click.Context) -> None:
|
|||
def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
|
||||
'Follow an actor (Relay must be running)'
|
||||
|
||||
instance: schema.Instance | None = None
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
if conn.get_domain_ban(actor):
|
||||
click.echo(f'Error: Refusing to follow banned actor: {actor}')
|
||||
return
|
||||
|
||||
if (inbox_data := conn.get_inbox(actor)):
|
||||
inbox = inbox_data['inbox']
|
||||
if (instance := conn.get_inbox(actor)) is not None:
|
||||
inbox = instance.inbox
|
||||
|
||||
else:
|
||||
if not actor.startswith('http'):
|
||||
actor = f'https://{actor}/actor'
|
||||
|
||||
if not (actor_data := asyncio.run(http.get(actor, sign_headers = True))):
|
||||
if (actor_data := asyncio.run(http.get(actor, sign_headers = True))) is None:
|
||||
click.echo(f'Failed to fetch actor: {actor}')
|
||||
return
|
||||
|
||||
|
@ -506,7 +543,7 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
|
|||
actor = actor
|
||||
)
|
||||
|
||||
asyncio.run(http.post(inbox, message, inbox_data))
|
||||
asyncio.run(http.post(inbox, message, instance))
|
||||
click.echo(f'Sent follow message to actor: {actor}')
|
||||
|
||||
|
||||
|
@ -516,19 +553,19 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
|
|||
def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
|
||||
'Unfollow an actor (Relay must be running)'
|
||||
|
||||
inbox_data: Row | None = None
|
||||
instance: schema.Instance | None = None
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
if conn.get_domain_ban(actor):
|
||||
click.echo(f'Error: Refusing to follow banned actor: {actor}')
|
||||
return
|
||||
|
||||
if (inbox_data := conn.get_inbox(actor)):
|
||||
inbox = inbox_data['inbox']
|
||||
if (instance := conn.get_inbox(actor)):
|
||||
inbox = instance.inbox
|
||||
message = Message.new_unfollow(
|
||||
host = ctx.obj.config.domain,
|
||||
actor = actor,
|
||||
follow = inbox_data['followid']
|
||||
follow = instance.followid
|
||||
)
|
||||
|
||||
else:
|
||||
|
@ -552,7 +589,7 @@ def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
|
|||
}
|
||||
)
|
||||
|
||||
asyncio.run(http.post(inbox, message, inbox_data))
|
||||
asyncio.run(http.post(inbox, message, instance))
|
||||
click.echo(f'Sent unfollow message to: {actor}')
|
||||
|
||||
|
||||
|
@ -632,9 +669,9 @@ def cli_request_list(ctx: click.Context) -> None:
|
|||
click.echo('Follow requests:')
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
for instance in conn.get_requests():
|
||||
date = instance['created'].strftime('%Y-%m-%d')
|
||||
click.echo(f'- [{date}] {instance["domain"]}')
|
||||
for row in conn.get_requests():
|
||||
date = row.created.strftime('%Y-%m-%d')
|
||||
click.echo(f'- [{date}] {row.domain}')
|
||||
|
||||
|
||||
@cli_request.command('accept')
|
||||
|
@ -653,20 +690,20 @@ def cli_request_accept(ctx: click.Context, domain: str) -> None:
|
|||
|
||||
message = Message.new_response(
|
||||
host = ctx.obj.config.domain,
|
||||
actor = instance['actor'],
|
||||
followid = instance['followid'],
|
||||
actor = instance.actor,
|
||||
followid = instance.followid,
|
||||
accept = True
|
||||
)
|
||||
|
||||
asyncio.run(http.post(instance['inbox'], message, instance))
|
||||
asyncio.run(http.post(instance.inbox, message, instance))
|
||||
|
||||
if instance['software'] != 'mastodon':
|
||||
if instance.software != 'mastodon':
|
||||
message = Message.new_follow(
|
||||
host = ctx.obj.config.domain,
|
||||
actor = instance['actor']
|
||||
actor = instance.actor
|
||||
)
|
||||
|
||||
asyncio.run(http.post(instance['inbox'], message, instance))
|
||||
asyncio.run(http.post(instance.inbox, message, instance))
|
||||
|
||||
|
||||
@cli_request.command('deny')
|
||||
|
@ -685,12 +722,12 @@ def cli_request_deny(ctx: click.Context, domain: str) -> None:
|
|||
|
||||
response = Message.new_response(
|
||||
host = ctx.obj.config.domain,
|
||||
actor = instance['actor'],
|
||||
followid = instance['followid'],
|
||||
actor = instance.actor,
|
||||
followid = instance.followid,
|
||||
accept = False
|
||||
)
|
||||
|
||||
asyncio.run(http.post(instance['inbox'], response, instance))
|
||||
asyncio.run(http.post(instance.inbox, response, instance))
|
||||
|
||||
|
||||
@cli.group('instance')
|
||||
|
@ -706,12 +743,12 @@ def cli_instance_list(ctx: click.Context) -> None:
|
|||
click.echo('Banned domains:')
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
for instance in conn.execute('SELECT * FROM domain_bans'):
|
||||
if instance['reason']:
|
||||
click.echo(f'- {instance["domain"]} ({instance["reason"]})')
|
||||
for row in conn.get_domain_bans():
|
||||
if row.reason is not None:
|
||||
click.echo(f'- {row.domain} ({row.reason})')
|
||||
|
||||
else:
|
||||
click.echo(f'- {instance["domain"]}')
|
||||
click.echo(f'- {row.domain}')
|
||||
|
||||
|
||||
@cli_instance.command('ban')
|
||||
|
@ -723,7 +760,7 @@ def cli_instance_ban(ctx: click.Context, domain: str, reason: str, note: str) ->
|
|||
'Ban an instance and remove the associated inbox if it exists'
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
if conn.get_domain_ban(domain):
|
||||
if conn.get_domain_ban(domain) is not None:
|
||||
click.echo(f'Domain already banned: {domain}')
|
||||
return
|
||||
|
||||
|
@ -739,7 +776,7 @@ def cli_instance_unban(ctx: click.Context, domain: str) -> None:
|
|||
'Unban an instance'
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
if not conn.del_domain_ban(domain):
|
||||
if conn.del_domain_ban(domain) is None:
|
||||
click.echo(f'Instance wasn\'t banned: {domain}')
|
||||
return
|
||||
|
||||
|
@ -764,11 +801,11 @@ def cli_instance_update(ctx: click.Context, domain: str, reason: str, note: str)
|
|||
|
||||
click.echo(f'Updated domain ban: {domain}')
|
||||
|
||||
if row['reason']:
|
||||
click.echo(f'- {row["domain"]} ({row["reason"]})')
|
||||
if row.reason:
|
||||
click.echo(f'- {row.domain} ({row.reason})')
|
||||
|
||||
else:
|
||||
click.echo(f'- {row["domain"]}')
|
||||
click.echo(f'- {row.domain}')
|
||||
|
||||
|
||||
@cli.group('software')
|
||||
|
@ -784,12 +821,12 @@ def cli_software_list(ctx: click.Context) -> None:
|
|||
click.echo('Banned software:')
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
for software in conn.execute('SELECT * FROM software_bans'):
|
||||
if software['reason']:
|
||||
click.echo(f'- {software["name"]} ({software["reason"]})')
|
||||
for row in conn.get_software_bans():
|
||||
if row.reason:
|
||||
click.echo(f'- {row.name} ({row.reason})')
|
||||
|
||||
else:
|
||||
click.echo(f'- {software["name"]}')
|
||||
click.echo(f'- {row.name}')
|
||||
|
||||
|
||||
@cli_software.command('ban')
|
||||
|
@ -811,12 +848,12 @@ def cli_software_ban(ctx: click.Context,
|
|||
|
||||
with ctx.obj.database.session() as conn:
|
||||
if name == 'RELAYS':
|
||||
for software in RELAY_SOFTWARE:
|
||||
if conn.get_software_ban(software):
|
||||
click.echo(f'Relay already banned: {software}')
|
||||
for item in RELAY_SOFTWARE:
|
||||
if conn.get_software_ban(item):
|
||||
click.echo(f'Relay already banned: {item}')
|
||||
continue
|
||||
|
||||
conn.put_software_ban(software, reason or 'relay', note)
|
||||
conn.put_software_ban(item, reason or 'relay', note)
|
||||
|
||||
click.echo('Banned all relay software')
|
||||
return
|
||||
|
@ -893,11 +930,11 @@ def cli_software_update(ctx: click.Context, name: str, reason: str, note: str) -
|
|||
|
||||
click.echo(f'Updated software ban: {name}')
|
||||
|
||||
if row['reason']:
|
||||
click.echo(f'- {row["name"]} ({row["reason"]})')
|
||||
if row.reason:
|
||||
click.echo(f'- {row.name} ({row.reason})')
|
||||
|
||||
else:
|
||||
click.echo(f'- {row["name"]}')
|
||||
click.echo(f'- {row.name}')
|
||||
|
||||
|
||||
@cli.group('whitelist')
|
||||
|
@ -913,8 +950,8 @@ def cli_whitelist_list(ctx: click.Context) -> None:
|
|||
click.echo('Current whitelisted domains:')
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
for domain in conn.execute('SELECT * FROM whitelist'):
|
||||
click.echo(f'- {domain["domain"]}')
|
||||
for row in conn.get_domain_whitelist():
|
||||
click.echo(f'- {row.domain}')
|
||||
|
||||
|
||||
@cli_whitelist.command('add')
|
||||
|
@ -953,23 +990,48 @@ def cli_whitelist_remove(ctx: click.Context, domain: str) -> None:
|
|||
@cli_whitelist.command('import')
|
||||
@click.pass_context
|
||||
def cli_whitelist_import(ctx: click.Context) -> None:
|
||||
'Add all current inboxes to the whitelist'
|
||||
'Add all current instances to the whitelist'
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
for inbox in conn.execute('SELECT * FROM inboxes').all():
|
||||
if conn.get_domain_whitelist(inbox['domain']):
|
||||
click.echo(f'Domain already in whitelist: {inbox["domain"]}')
|
||||
for row in conn.get_inboxes():
|
||||
if conn.get_domain_whitelist(row.domain) is not None:
|
||||
click.echo(f'Domain already in whitelist: {row.domain}')
|
||||
continue
|
||||
|
||||
conn.put_domain_whitelist(inbox['domain'])
|
||||
conn.put_domain_whitelist(row.domain)
|
||||
|
||||
click.echo('Imported whitelist from inboxes')
|
||||
|
||||
|
||||
def config_postgresql(config: Config) -> None:
|
||||
config.pg_name = click.prompt(
|
||||
'What is the name of the database?',
|
||||
default = config.pg_name
|
||||
)
|
||||
|
||||
config.pg_host = click.prompt(
|
||||
'What IP address, hostname, or unix socket does the server listen on?',
|
||||
default = config.pg_host,
|
||||
)
|
||||
|
||||
config.pg_port = click.prompt(
|
||||
'What port does the server listen on?',
|
||||
default = config.pg_port,
|
||||
type = int
|
||||
)
|
||||
|
||||
config.pg_user = click.prompt(
|
||||
'Which user will authenticate with the server?',
|
||||
default = config.pg_user
|
||||
)
|
||||
|
||||
config.pg_pass = click.prompt(
|
||||
'User password',
|
||||
hide_input = True,
|
||||
show_default = False,
|
||||
default = config.pg_pass or ""
|
||||
) or None
|
||||
|
||||
|
||||
def main() -> None:
|
||||
cli(prog_name='relay')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
click.echo('Running relay.manage is depreciated. Run `activityrelay [command]` instead.')
|
||||
cli(prog_name='activityrelay')
|
||||
|
|
|
@ -4,28 +4,15 @@ import aputils
|
|||
import json
|
||||
import os
|
||||
import platform
|
||||
import socket
|
||||
|
||||
from aiohttp.web import Response as AiohttpResponse
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, TypedDict, TypeVar
|
||||
from uuid import uuid4
|
||||
|
||||
try:
|
||||
from importlib.resources import files as pkgfiles
|
||||
|
||||
except ImportError:
|
||||
from importlib_resources import files as pkgfiles # type: ignore
|
||||
|
||||
try:
|
||||
from typing import Self
|
||||
|
||||
except ImportError:
|
||||
from typing_extensions import Self
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Self
|
||||
from .application import Application
|
||||
|
||||
|
||||
|
@ -50,11 +37,6 @@ MIMETYPES = {
|
|||
'webmanifest': 'application/manifest+json'
|
||||
}
|
||||
|
||||
NODEINFO_NS = {
|
||||
'20': 'http://nodeinfo.diaspora.software/ns/schema/2.0',
|
||||
'21': 'http://nodeinfo.diaspora.software/ns/schema/2.1'
|
||||
}
|
||||
|
||||
ACTOR_FORMATS = {
|
||||
'mastodon': 'https://{domain}/actor',
|
||||
'akkoma': 'https://{domain}/relay',
|
||||
|
@ -72,42 +54,26 @@ SOFTWARE = (
|
|||
'gotosocial'
|
||||
)
|
||||
|
||||
JSON_PATHS: tuple[str, ...] = (
|
||||
'/api/v1',
|
||||
'/actor',
|
||||
'/inbox',
|
||||
'/outbox',
|
||||
'/following',
|
||||
'/followers',
|
||||
'/.well-known',
|
||||
'/nodeinfo',
|
||||
'/oauth/token',
|
||||
'/oauth/revoke'
|
||||
)
|
||||
|
||||
def boolean(value: Any) -> bool:
|
||||
if isinstance(value, str):
|
||||
if value.lower() in {'on', 'y', 'yes', 'true', 'enable', 'enabled', '1'}:
|
||||
return True
|
||||
|
||||
if value.lower() in {'off', 'n', 'no', 'false', 'disable', 'disabled', '0'}:
|
||||
return False
|
||||
|
||||
raise TypeError(f'Cannot parse string "{value}" as a boolean')
|
||||
|
||||
if isinstance(value, int):
|
||||
if value == 1:
|
||||
return True
|
||||
|
||||
if value == 0:
|
||||
return False
|
||||
|
||||
raise ValueError('Integer value must be 1 or 0')
|
||||
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
return bool(value)
|
||||
|
||||
|
||||
def check_open_port(host: str, port: int) -> bool:
|
||||
if host == '0.0.0.0':
|
||||
host = '127.0.0.1'
|
||||
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
try:
|
||||
return s.connect_ex((host, port)) != 0
|
||||
|
||||
except socket.error:
|
||||
return False
|
||||
TOKEN_PATHS: tuple[str, ...] = (
|
||||
'/logout',
|
||||
'/admin',
|
||||
'/api',
|
||||
'/oauth/authorize',
|
||||
'/oauth/revoke'
|
||||
)
|
||||
|
||||
|
||||
def get_app() -> Application:
|
||||
|
@ -119,10 +85,6 @@ def get_app() -> Application:
|
|||
return Application.DEFAULT
|
||||
|
||||
|
||||
def get_resource(path: str) -> Path:
|
||||
return Path(str(pkgfiles('relay'))).joinpath(path)
|
||||
|
||||
|
||||
class JsonEncoder(json.JSONEncoder):
|
||||
def default(self, o: Any) -> str:
|
||||
if isinstance(o, datetime):
|
||||
|
@ -240,21 +202,9 @@ class Response(AiohttpResponse):
|
|||
|
||||
|
||||
@classmethod
|
||||
def new_error(cls: type[Self],
|
||||
status: int,
|
||||
body: str | bytes | dict[str, Any],
|
||||
ctype: str = 'text') -> Self:
|
||||
|
||||
if ctype == 'json':
|
||||
body = {'error': body}
|
||||
|
||||
return cls.new(body=body, status=status, ctype=ctype)
|
||||
|
||||
|
||||
@classmethod
|
||||
def new_redir(cls: type[Self], path: str) -> Self:
|
||||
def new_redir(cls: type[Self], path: str, status: int = 307) -> Self:
|
||||
body = f'Redirect to <a href="{path}">{path}</a>'
|
||||
return cls.new(body, 302, {'Location': path})
|
||||
return cls.new(body, status, {'Location': path}, ctype = 'html')
|
||||
|
||||
|
||||
@property
|
||||
|
|
|
@ -34,7 +34,7 @@ async def handle_relay(view: ActorView, conn: Connection) -> None:
|
|||
logging.debug('>> relay: %s', message)
|
||||
|
||||
for instance in conn.distill_inboxes(view.message):
|
||||
view.app.push_message(instance["inbox"], message, instance)
|
||||
view.app.push_message(instance.inbox, message, instance)
|
||||
|
||||
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
|
||||
|
||||
|
@ -52,13 +52,13 @@ async def handle_forward(view: ActorView, conn: Connection) -> None:
|
|||
logging.debug('>> forward: %s', message)
|
||||
|
||||
for instance in conn.distill_inboxes(view.message):
|
||||
view.app.push_message(instance["inbox"], view.message, instance)
|
||||
view.app.push_message(instance.inbox, view.message, instance)
|
||||
|
||||
view.cache.set('handle-relay', view.message.id, message.id, 'str')
|
||||
|
||||
|
||||
async def handle_follow(view: ActorView, conn: Connection) -> None:
|
||||
nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain)
|
||||
nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain, force = True)
|
||||
software = nodeinfo.sw_name if nodeinfo else None
|
||||
config = conn.get_config_all()
|
||||
|
||||
|
@ -171,13 +171,13 @@ async def handle_follow(view: ActorView, conn: Connection) -> None:
|
|||
|
||||
|
||||
async def handle_undo(view: ActorView, conn: Connection) -> None:
|
||||
# If the object is not a Follow, forward it
|
||||
if view.message.object['type'] != 'Follow':
|
||||
await handle_forward(view, conn)
|
||||
# forwarding deletes does not work, so don't bother
|
||||
# await handle_forward(view, conn)
|
||||
return
|
||||
|
||||
# prevent past unfollows from removing an instance
|
||||
if view.instance['followid'] and view.instance['followid'] != view.message.object_id:
|
||||
if view.instance.followid and view.instance.followid != view.message.object_id:
|
||||
return
|
||||
|
||||
with conn.transaction():
|
||||
|
@ -221,18 +221,18 @@ async def run_processor(view: ActorView) -> None:
|
|||
|
||||
with view.database.session() as conn:
|
||||
if view.instance:
|
||||
if not view.instance['software']:
|
||||
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])):
|
||||
if not view.instance.software:
|
||||
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance.domain)):
|
||||
with conn.transaction():
|
||||
view.instance = conn.put_inbox(
|
||||
domain = view.instance['domain'],
|
||||
domain = view.instance.domain,
|
||||
software = nodeinfo.sw_name
|
||||
)
|
||||
|
||||
if not view.instance['actor']:
|
||||
if not view.instance.actor:
|
||||
with conn.transaction():
|
||||
view.instance = conn.put_inbox(
|
||||
domain = view.instance['domain'],
|
||||
domain = view.instance.domain,
|
||||
actor = view.actor.id
|
||||
)
|
||||
|
||||
|
|
|
@ -2,6 +2,8 @@ from __future__ import annotations
|
|||
|
||||
import textwrap
|
||||
|
||||
from aiohttp.web import Request
|
||||
from blib import File
|
||||
from collections.abc import Callable
|
||||
from hamlish_jinja import HamlishExtension
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
|
@ -12,14 +14,15 @@ from markdown import Markdown
|
|||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from . import __version__
|
||||
from .misc import get_resource
|
||||
from .views.base import View
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .application import Application
|
||||
|
||||
|
||||
class Template(Environment):
|
||||
_render_markdown: Callable[[str], str]
|
||||
|
||||
|
||||
def __init__(self, app: Application):
|
||||
Environment.__init__(self,
|
||||
autoescape = True,
|
||||
|
@ -30,7 +33,7 @@ class Template(Environment):
|
|||
MarkdownExtension
|
||||
],
|
||||
loader = FileSystemLoader([
|
||||
get_resource('frontend'),
|
||||
File.from_resource('relay', 'frontend'),
|
||||
app.config.path.parent.joinpath('template')
|
||||
])
|
||||
)
|
||||
|
@ -40,12 +43,12 @@ class Template(Environment):
|
|||
self.hamlish_mode = 'indented'
|
||||
|
||||
|
||||
def render(self, path: str, view: View | None = None, **context: Any) -> str:
|
||||
def render(self, path: str, request: Request, **context: Any) -> str:
|
||||
with self.app.database.session(False) as conn:
|
||||
config = conn.get_config_all()
|
||||
|
||||
new_context = {
|
||||
'view': view,
|
||||
'request': request,
|
||||
'domain': self.app.config.domain,
|
||||
'version': __version__,
|
||||
'config': config,
|
||||
|
@ -56,7 +59,7 @@ class Template(Environment):
|
|||
|
||||
|
||||
def render_markdown(self, text: str) -> str:
|
||||
return self._render_markdown(text) # type: ignore
|
||||
return self._render_markdown(text)
|
||||
|
||||
|
||||
class MarkdownExtension(Extension):
|
||||
|
|
|
@ -1,26 +1,24 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import aputils
|
||||
import traceback
|
||||
import typing
|
||||
|
||||
from aiohttp import ClientConnectorError
|
||||
from aiohttp.web import Request
|
||||
from blib import HttpError
|
||||
|
||||
from .base import View, register_route
|
||||
|
||||
from .. import logger as logging
|
||||
from ..database import schema
|
||||
from ..misc import Message, Response
|
||||
from ..processors import run_processor
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from aiohttp.web import Request
|
||||
from bsql import Row
|
||||
|
||||
|
||||
@register_route('/actor', '/inbox')
|
||||
class ActorView(View):
|
||||
signature: aputils.Signature
|
||||
message: Message
|
||||
actor: Message
|
||||
instancce: Row
|
||||
instance: schema.Instance
|
||||
signer: aputils.Signer
|
||||
|
||||
|
||||
|
@ -43,16 +41,15 @@ class ActorView(View):
|
|||
|
||||
|
||||
async def post(self, request: Request) -> Response:
|
||||
if response := await self.get_post_data():
|
||||
return response
|
||||
await self.get_post_data()
|
||||
|
||||
with self.database.session() as conn:
|
||||
self.instance = conn.get_inbox(self.actor.shared_inbox)
|
||||
self.instance = conn.get_inbox(self.actor.shared_inbox) # type: ignore[assignment]
|
||||
|
||||
# reject if actor is banned
|
||||
if conn.get_domain_ban(self.actor.domain):
|
||||
logging.verbose('Ignored request from banned actor: %s', self.actor.id)
|
||||
return Response.new_error(403, 'access denied', 'json')
|
||||
raise HttpError(403, 'access denied')
|
||||
|
||||
# reject if activity type isn't 'Follow' and the actor isn't following
|
||||
if self.message.type != 'Follow' and not self.instance:
|
||||
|
@ -61,7 +58,7 @@ class ActorView(View):
|
|||
self.actor.id
|
||||
)
|
||||
|
||||
return Response.new_error(401, 'access denied', 'json')
|
||||
raise HttpError(401, 'access denied')
|
||||
|
||||
logging.debug('>> payload %s', self.message.to_json(4))
|
||||
|
||||
|
@ -69,60 +66,66 @@ class ActorView(View):
|
|||
return Response.new(status = 202)
|
||||
|
||||
|
||||
async def get_post_data(self) -> Response | None:
|
||||
async def get_post_data(self) -> None:
|
||||
try:
|
||||
self.signature = aputils.Signature.parse(self.request.headers['signature'])
|
||||
|
||||
except KeyError:
|
||||
logging.verbose('Missing signature header')
|
||||
return Response.new_error(400, 'missing signature header', 'json')
|
||||
raise HttpError(400, 'missing signature header')
|
||||
|
||||
try:
|
||||
message: Message | None = await self.request.json(loads = Message.parse)
|
||||
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
logging.verbose('Failed to parse inbox message')
|
||||
return Response.new_error(400, 'failed to parse message', 'json')
|
||||
logging.verbose('Failed to parse message from actor: %s', self.signature.keyid)
|
||||
raise HttpError(400, 'failed to parse message')
|
||||
|
||||
if message is None:
|
||||
logging.verbose('empty message')
|
||||
return Response.new_error(400, 'missing message', 'json')
|
||||
raise HttpError(400, 'missing message')
|
||||
|
||||
self.message = message
|
||||
|
||||
if 'actor' not in self.message:
|
||||
logging.verbose('actor not in message')
|
||||
return Response.new_error(400, 'no actor in message', 'json')
|
||||
raise HttpError(400, 'no actor in message')
|
||||
|
||||
actor: Message | None = await self.client.get(self.signature.keyid, True, Message)
|
||||
try:
|
||||
self.actor = await self.client.get(self.signature.keyid, True, Message)
|
||||
|
||||
if actor is None:
|
||||
except HttpError as e:
|
||||
# ld signatures aren't handled atm, so just ignore it
|
||||
if self.message.type == 'Delete':
|
||||
logging.verbose('Instance sent a delete which cannot be handled')
|
||||
return Response.new(status=202)
|
||||
raise HttpError(202, '')
|
||||
|
||||
logging.verbose(f'Failed to fetch actor: {self.signature.keyid}')
|
||||
return Response.new_error(400, 'failed to fetch actor', 'json')
|
||||
logging.verbose('Failed to fetch actor: %s', self.signature.keyid)
|
||||
logging.debug('HTTP Status %i: %s', e.status, e.message)
|
||||
raise HttpError(400, 'failed to fetch actor')
|
||||
|
||||
self.actor = actor
|
||||
except ClientConnectorError as e:
|
||||
logging.warning('Error when trying to fetch actor: %s, %s', self.signature.keyid, str(e))
|
||||
raise HttpError(400, 'failed to fetch actor')
|
||||
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
raise HttpError(500, 'unexpected error when fetching actor')
|
||||
|
||||
try:
|
||||
self.signer = self.actor.signer
|
||||
|
||||
except KeyError:
|
||||
logging.verbose('Actor missing public key: %s', self.signature.keyid)
|
||||
return Response.new_error(400, 'actor missing public key', 'json')
|
||||
raise HttpError(400, 'actor missing public key')
|
||||
|
||||
try:
|
||||
await self.signer.validate_request_async(self.request)
|
||||
|
||||
except aputils.SignatureFailureError as e:
|
||||
logging.verbose('signature validation failed for "%s": %s', self.actor.id, e)
|
||||
return Response.new_error(401, str(e), 'json')
|
||||
|
||||
return None
|
||||
raise HttpError(401, str(e))
|
||||
|
||||
|
||||
@register_route('/outbox')
|
||||
|
@ -165,10 +168,10 @@ class WebfingerView(View):
|
|||
subject = request.query['resource']
|
||||
|
||||
except KeyError:
|
||||
return Response.new_error(400, 'missing "resource" query key', 'json')
|
||||
raise HttpError(400, 'missing "resource" query key')
|
||||
|
||||
if subject != f'acct:relay@{self.config.domain}':
|
||||
return Response.new_error(404, 'user not found', 'json')
|
||||
raise HttpError(404, 'user not found')
|
||||
|
||||
data = aputils.Webfinger.new(
|
||||
handle = 'relay',
|
||||
|
|
|
@ -1,17 +1,20 @@
|
|||
import traceback
|
||||
|
||||
from aiohttp.web import Request, middleware
|
||||
from argon2.exceptions import VerifyMismatchError
|
||||
from blib import HttpError, convert_to_boolean
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from .base import View, register_route
|
||||
|
||||
from .. import __version__
|
||||
from ..database import ConfigData
|
||||
from ..misc import Message, Response, boolean, get_app
|
||||
from ..database import ConfigData, schema
|
||||
from ..misc import Message, Response
|
||||
|
||||
|
||||
ALLOWED_HEADERS = {
|
||||
DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob'
|
||||
ALLOWED_HEADERS: set[str] = {
|
||||
'accept',
|
||||
'authorization',
|
||||
'content-type'
|
||||
|
@ -19,6 +22,8 @@ ALLOWED_HEADERS = {
|
|||
|
||||
PUBLIC_API_PATHS: Sequence[tuple[str, str]] = (
|
||||
('GET', '/api/v1/relay'),
|
||||
('POST', '/api/v1/app'),
|
||||
('POST', '/api/v1/login'),
|
||||
('POST', '/api/v1/token')
|
||||
)
|
||||
|
||||
|
@ -34,64 +39,184 @@ def check_api_path(method: str, path: str) -> bool:
|
|||
async def handle_api_path(
|
||||
request: Request,
|
||||
handler: Callable[[Request], Awaitable[Response]]) -> Response:
|
||||
try:
|
||||
if (token := request.cookies.get('user-token')):
|
||||
request['token'] = token
|
||||
|
||||
else:
|
||||
request['token'] = request.headers['Authorization'].replace('Bearer', '').strip()
|
||||
|
||||
with get_app().database.session() as conn:
|
||||
request['user'] = conn.get_user_by_token(request['token'])
|
||||
|
||||
except (KeyError, ValueError):
|
||||
request['token'] = None
|
||||
request['user'] = None
|
||||
if not request.path.startswith('/api') or request.path == '/api/doc':
|
||||
return await handler(request)
|
||||
|
||||
if request.method != "OPTIONS" and check_api_path(request.method, request.path):
|
||||
if not request['token']:
|
||||
return Response.new_error(401, 'Missing token', 'json')
|
||||
if request['token'] is None:
|
||||
raise HttpError(401, 'Missing token')
|
||||
|
||||
if not request['user']:
|
||||
return Response.new_error(401, 'Invalid token', 'json')
|
||||
if request['user'] is None:
|
||||
raise HttpError(401, 'Invalid token')
|
||||
|
||||
response = await handler(request)
|
||||
|
||||
if request.path.startswith('/api'):
|
||||
response.headers['Access-Control-Allow-Origin'] = '*'
|
||||
response.headers['Access-Control-Allow-Headers'] = ', '.join(ALLOWED_HEADERS)
|
||||
response.headers['Access-Control-Allow-Origin'] = '*'
|
||||
response.headers['Access-Control-Allow-Headers'] = ', '.join(ALLOWED_HEADERS)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@register_route('/api/v1/token')
|
||||
class Login(View):
|
||||
@register_route('/oauth/authorize')
|
||||
@register_route('/api/oauth/authorize')
|
||||
class OauthAuthorize(View):
|
||||
async def get(self, request: Request) -> Response:
|
||||
return Response.new({'message': 'Token valid'}, ctype = 'json')
|
||||
data = await self.get_api_data(['response_type', 'client_id', 'redirect_uri'], [])
|
||||
|
||||
if data['response_type'] != 'code':
|
||||
raise HttpError(400, 'Response type is not "code"')
|
||||
|
||||
with self.database.session(True) as conn:
|
||||
with conn.select('apps', client_id = data['client_id']) as cur:
|
||||
if (app := cur.one(schema.App)) is None:
|
||||
raise HttpError(404, 'Could not find app')
|
||||
|
||||
if app.token is not None:
|
||||
raise HttpError(400, 'Application has already been authorized')
|
||||
|
||||
if app.auth_code is not None:
|
||||
context = {'application': app}
|
||||
html = self.template.render(
|
||||
'page/authorize_show.haml', self.request, **context
|
||||
)
|
||||
|
||||
return Response.new(html, ctype = 'html')
|
||||
|
||||
if data['redirect_uri'] != app.redirect_uri:
|
||||
raise HttpError(400, 'redirect_uri does not match application')
|
||||
|
||||
context = {'application': app}
|
||||
html = self.template.render('page/authorize_new.haml', self.request, **context)
|
||||
return Response.new(html, ctype = 'html')
|
||||
|
||||
|
||||
async def post(self, request: Request) -> Response:
|
||||
data = await self.get_api_data(['username', 'password'], [])
|
||||
data = await self.get_api_data(
|
||||
['client_id', 'client_secret', 'redirect_uri', 'response'], []
|
||||
)
|
||||
|
||||
if isinstance(data, Response):
|
||||
return data
|
||||
with self.database.session(True) as conn:
|
||||
if (app := conn.get_app(data['client_id'], data['client_secret'])) is None:
|
||||
raise HttpError(404, 'Could not find app')
|
||||
|
||||
if convert_to_boolean(data['response']):
|
||||
if app.token is not None:
|
||||
raise HttpError(400, 'Application has already been authorized')
|
||||
|
||||
if app.auth_code is None:
|
||||
app = conn.update_app(app, request['user'], True)
|
||||
|
||||
if app.redirect_uri == DEFAULT_REDIRECT:
|
||||
context = {'application': app}
|
||||
html = self.template.render(
|
||||
'page/authorize_show.haml', self.request, **context
|
||||
)
|
||||
|
||||
return Response.new(html, ctype = 'html')
|
||||
|
||||
return Response.new_redir(f'{app.redirect_uri}?code={app.auth_code}')
|
||||
|
||||
if not conn.del_app(app.client_id, app.client_secret):
|
||||
raise HttpError(404, 'App not found')
|
||||
|
||||
return Response.new_redir('/')
|
||||
|
||||
|
||||
@register_route('/oauth/token')
|
||||
@register_route('/api/oauth/token')
|
||||
class OauthToken(View):
|
||||
async def post(self, request: Request) -> Response:
|
||||
data = await self.get_api_data(
|
||||
['grant_type', 'code', 'client_id', 'client_secret', 'redirect_uri'], []
|
||||
)
|
||||
|
||||
if data['grant_type'] != 'authorization_code':
|
||||
raise HttpError(400, 'Invalid grant type')
|
||||
|
||||
with self.database.session(True) as conn:
|
||||
if (app := conn.get_app(data['client_id'], data['client_secret'])) is None:
|
||||
raise HttpError(404, 'Application not found')
|
||||
|
||||
if app.auth_code != data['code']:
|
||||
raise HttpError(400, 'Invalid authentication code')
|
||||
|
||||
if app.redirect_uri != data['redirect_uri']:
|
||||
raise HttpError(400, 'Invalid redirect uri')
|
||||
|
||||
app = conn.update_app(app, request['user'], False)
|
||||
|
||||
return Response.new(app.get_api_data(True), ctype = 'json')
|
||||
|
||||
|
||||
@register_route('/oauth/revoke')
|
||||
@register_route('/api/oauth/revoke')
|
||||
class OauthRevoke(View):
|
||||
async def post(self, request: Request) -> Response:
|
||||
data = await self.get_api_data(['client_id', 'client_secret', 'token'], [])
|
||||
|
||||
with self.database.session(True) as conn:
|
||||
if (app := conn.get_app(**data)) is None:
|
||||
raise HttpError(404, 'Could not find token')
|
||||
|
||||
if app.user != request['token'].username:
|
||||
raise HttpError(403, 'Invalid token')
|
||||
|
||||
if not conn.del_app(**data):
|
||||
raise HttpError(400, 'Token not removed')
|
||||
|
||||
return Response.new({'msg': 'Token deleted'}, ctype = 'json')
|
||||
|
||||
|
||||
@register_route('/api/v1/app')
|
||||
class App(View):
|
||||
async def get(self, request: Request) -> Response:
|
||||
return Response.new(request['token'].get_api_data(), ctype = 'json')
|
||||
|
||||
|
||||
async def post(self, request: Request) -> Response:
|
||||
data = await self.get_api_data(['name', 'redirect_uri'], ['website'])
|
||||
|
||||
with self.database.session(True) as conn:
|
||||
app = conn.put_app(
|
||||
name = data['name'],
|
||||
redirect_uri = data['redirect_uri'],
|
||||
website = data.get('website')
|
||||
)
|
||||
|
||||
return Response.new(app.get_api_data(), ctype = 'json')
|
||||
|
||||
|
||||
async def delete(self, request: Request) -> Response:
|
||||
data = await self.get_api_data(['client_id', 'client_secret'], [])
|
||||
|
||||
with self.database.session(True) as conn:
|
||||
if not conn.del_app(data['client_id'], data['client_secret'], request['token'].code):
|
||||
raise HttpError(400, 'Token not removed')
|
||||
|
||||
return Response.new({'msg': 'Token deleted'}, ctype = 'json')
|
||||
|
||||
|
||||
@register_route('/api/v1/login')
|
||||
class Login(View):
|
||||
async def post(self, request: Request) -> Response:
|
||||
data = await self.get_api_data(['username', 'password'], [])
|
||||
|
||||
with self.database.session(True) as conn:
|
||||
if not (user := conn.get_user(data['username'])):
|
||||
return Response.new_error(401, 'User not found', 'json')
|
||||
raise HttpError(401, 'User not found')
|
||||
|
||||
try:
|
||||
conn.hasher.verify(user['hash'], data['password'])
|
||||
|
||||
except VerifyMismatchError:
|
||||
return Response.new_error(401, 'Invalid password', 'json')
|
||||
raise HttpError(401, 'Invalid password')
|
||||
|
||||
token = conn.put_token(data['username'])
|
||||
app = conn.put_app_login(user)
|
||||
|
||||
resp = Response.new({'token': token['code']}, ctype = 'json')
|
||||
resp = Response.new(app.get_api_data(True), ctype = 'json')
|
||||
resp.set_cookie(
|
||||
'user-token',
|
||||
token['code'],
|
||||
app.token, # type: ignore[arg-type]
|
||||
max_age = 60 * 60 * 24 * 365,
|
||||
domain = self.config.domain,
|
||||
path = '/',
|
||||
|
@ -103,19 +228,12 @@ class Login(View):
|
|||
return resp
|
||||
|
||||
|
||||
async def delete(self, request: Request) -> Response:
|
||||
with self.database.session() as conn:
|
||||
conn.del_token(request['token'])
|
||||
|
||||
return Response.new({'message': 'Token revoked'}, ctype = 'json')
|
||||
|
||||
|
||||
@register_route('/api/v1/relay')
|
||||
class RelayInfo(View):
|
||||
async def get(self, request: Request) -> Response:
|
||||
with self.database.session() as conn:
|
||||
config = conn.get_config_all()
|
||||
inboxes = [row['domain'] for row in conn.get_inboxes()]
|
||||
inboxes = [row.domain for row in conn.get_inboxes()]
|
||||
|
||||
data = {
|
||||
'domain': self.config.domain,
|
||||
|
@ -152,17 +270,16 @@ class Config(View):
|
|||
|
||||
async def post(self, request: Request) -> Response:
|
||||
data = await self.get_api_data(['key', 'value'], [])
|
||||
|
||||
if isinstance(data, Response):
|
||||
return data
|
||||
|
||||
data['key'] = data['key'].replace('-', '_')
|
||||
|
||||
if data['key'] not in ConfigData.USER_KEYS():
|
||||
return Response.new_error(400, 'Invalid key', 'json')
|
||||
raise HttpError(400, 'Invalid key')
|
||||
|
||||
with self.database.session() as conn:
|
||||
conn.put_config(data['key'], data['value'])
|
||||
value = conn.put_config(data['key'], data['value'])
|
||||
|
||||
if data['key'] == 'log-level':
|
||||
self.app.workers.set_log_level(value)
|
||||
|
||||
return Response.new({'message': 'Updated config'}, ctype = 'json')
|
||||
|
||||
|
@ -170,14 +287,14 @@ class Config(View):
|
|||
async def delete(self, request: Request) -> Response:
|
||||
data = await self.get_api_data(['key'], [])
|
||||
|
||||
if isinstance(data, Response):
|
||||
return data
|
||||
|
||||
if data['key'] not in ConfigData.USER_KEYS():
|
||||
return Response.new_error(400, 'Invalid key', 'json')
|
||||
raise HttpError(400, 'Invalid key')
|
||||
|
||||
with self.database.session() as conn:
|
||||
conn.put_config(data['key'], ConfigData.DEFAULT(data['key']))
|
||||
value = conn.put_config(data['key'], ConfigData.DEFAULT(data['key']))
|
||||
|
||||
if data['key'] == 'log-level':
|
||||
self.app.workers.set_log_level(value)
|
||||
|
||||
return Response.new({'message': 'Updated config'}, ctype = 'json')
|
||||
|
||||
|
@ -186,40 +303,46 @@ class Config(View):
|
|||
class Inbox(View):
|
||||
async def get(self, request: Request) -> Response:
|
||||
with self.database.session() as conn:
|
||||
data = conn.get_inboxes()
|
||||
data = tuple(conn.get_inboxes())
|
||||
|
||||
return Response.new(data, ctype = 'json')
|
||||
|
||||
|
||||
async def post(self, request: Request) -> Response:
|
||||
data = await self.get_api_data(['actor'], ['inbox', 'software', 'followid'])
|
||||
|
||||
if isinstance(data, Response):
|
||||
return data
|
||||
|
||||
data['domain'] = urlparse(data["actor"]).netloc
|
||||
|
||||
with self.database.session() as conn:
|
||||
if conn.get_inbox(data['domain']):
|
||||
return Response.new_error(404, 'Instance already in database', 'json')
|
||||
if conn.get_inbox(data['domain']) is not None:
|
||||
raise HttpError(404, 'Instance already in database')
|
||||
|
||||
data['domain'] = data['domain'].encode('idna').decode()
|
||||
|
||||
if not data.get('inbox'):
|
||||
actor_data: Message | None = await self.client.get(data['actor'], True, Message)
|
||||
try:
|
||||
actor_data = await self.client.get(data['actor'], True, Message)
|
||||
|
||||
if actor_data is None:
|
||||
return Response.new_error(500, 'Failed to fetch actor', 'json')
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
raise HttpError(500, 'Failed to fetch actor') from None
|
||||
|
||||
data['inbox'] = actor_data.shared_inbox
|
||||
|
||||
if not data.get('software'):
|
||||
nodeinfo = await self.client.fetch_nodeinfo(data['domain'])
|
||||
|
||||
if nodeinfo is not None:
|
||||
try:
|
||||
nodeinfo = await self.client.fetch_nodeinfo(data['domain'])
|
||||
data['software'] = nodeinfo.sw_name
|
||||
|
||||
row = conn.put_inbox(**data) # type: ignore[arg-type]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
row = conn.put_inbox(
|
||||
domain = data['domain'],
|
||||
actor = data['actor'],
|
||||
inbox = data.get('inbox'),
|
||||
software = data.get('software'),
|
||||
followid = data.get('followid')
|
||||
)
|
||||
|
||||
return Response.new(row, ctype = 'json')
|
||||
|
||||
|
@ -227,16 +350,17 @@ class Inbox(View):
|
|||
async def patch(self, request: Request) -> Response:
|
||||
with self.database.session() as conn:
|
||||
data = await self.get_api_data(['domain'], ['actor', 'software', 'followid'])
|
||||
|
||||
if isinstance(data, Response):
|
||||
return data
|
||||
|
||||
data['domain'] = data['domain'].encode('idna').decode()
|
||||
|
||||
if not (instance := conn.get_inbox(data['domain'])):
|
||||
return Response.new_error(404, 'Instance with domain not found', 'json')
|
||||
if (instance := conn.get_inbox(data['domain'])) is None:
|
||||
raise HttpError(404, 'Instance with domain not found')
|
||||
|
||||
instance = conn.put_inbox(instance['domain'], **data) # type: ignore[arg-type]
|
||||
instance = conn.put_inbox(
|
||||
instance.domain,
|
||||
actor = data.get('actor'),
|
||||
software = data.get('software'),
|
||||
followid = data.get('followid')
|
||||
)
|
||||
|
||||
return Response.new(instance, ctype = 'json')
|
||||
|
||||
|
@ -244,14 +368,10 @@ class Inbox(View):
|
|||
async def delete(self, request: Request) -> Response:
|
||||
with self.database.session() as conn:
|
||||
data = await self.get_api_data(['domain'], [])
|
||||
|
||||
if isinstance(data, Response):
|
||||
return data
|
||||
|
||||
data['domain'] = data['domain'].encode('idna').decode()
|
||||
|
||||
if not conn.get_inbox(data['domain']):
|
||||
return Response.new_error(404, 'Instance with domain not found', 'json')
|
||||
raise HttpError(404, 'Instance with domain not found')
|
||||
|
||||
conn.del_inbox(data['domain'])
|
||||
|
||||
|
@ -262,43 +382,41 @@ class Inbox(View):
|
|||
class RequestView(View):
|
||||
async def get(self, request: Request) -> Response:
|
||||
with self.database.session() as conn:
|
||||
instances = conn.get_requests()
|
||||
instances = tuple(conn.get_requests())
|
||||
|
||||
return Response.new(instances, ctype = 'json')
|
||||
|
||||
|
||||
async def post(self, request: Request) -> Response:
|
||||
data: dict[str, Any] | Response = await self.get_api_data(['domain', 'accept'], [])
|
||||
|
||||
if isinstance(data, Response):
|
||||
return data
|
||||
|
||||
data['accept'] = boolean(data['accept'])
|
||||
data = await self.get_api_data(['domain', 'accept'], [])
|
||||
data['domain'] = data['domain'].encode('idna').decode()
|
||||
|
||||
try:
|
||||
with self.database.session(True) as conn:
|
||||
instance = conn.put_request_response(data['domain'], data['accept'])
|
||||
instance = conn.put_request_response(
|
||||
data['domain'],
|
||||
convert_to_boolean(data['accept'])
|
||||
)
|
||||
|
||||
except KeyError:
|
||||
return Response.new_error(404, 'Request not found', 'json')
|
||||
raise HttpError(404, 'Request not found') from None
|
||||
|
||||
message = Message.new_response(
|
||||
host = self.config.domain,
|
||||
actor = instance['actor'],
|
||||
followid = instance['followid'],
|
||||
accept = data['accept']
|
||||
actor = instance.actor,
|
||||
followid = instance.followid,
|
||||
accept = convert_to_boolean(data['accept'])
|
||||
)
|
||||
|
||||
self.app.push_message(instance['inbox'], message, instance)
|
||||
self.app.push_message(instance.inbox, message, instance)
|
||||
|
||||
if data['accept'] and instance['software'] != 'mastodon':
|
||||
if data['accept'] and instance.software != 'mastodon':
|
||||
message = Message.new_follow(
|
||||
host = self.config.domain,
|
||||
actor = instance['actor']
|
||||
actor = instance.actor
|
||||
)
|
||||
|
||||
self.app.push_message(instance['inbox'], message, instance)
|
||||
self.app.push_message(instance.inbox, message, instance)
|
||||
|
||||
resp_message = {'message': 'Request accepted' if data['accept'] else 'Request denied'}
|
||||
return Response.new(resp_message, ctype = 'json')
|
||||
|
@ -308,24 +426,24 @@ class RequestView(View):
|
|||
class DomainBan(View):
|
||||
async def get(self, request: Request) -> Response:
|
||||
with self.database.session() as conn:
|
||||
bans = tuple(conn.execute('SELECT * FROM domain_bans').all())
|
||||
bans = tuple(conn.get_domain_bans())
|
||||
|
||||
return Response.new(bans, ctype = 'json')
|
||||
|
||||
|
||||
async def post(self, request: Request) -> Response:
|
||||
data = await self.get_api_data(['domain'], ['note', 'reason'])
|
||||
|
||||
if isinstance(data, Response):
|
||||
return data
|
||||
|
||||
data['domain'] = data['domain'].encode('idna').decode()
|
||||
|
||||
with self.database.session() as conn:
|
||||
if conn.get_domain_ban(data['domain']):
|
||||
return Response.new_error(400, 'Domain already banned', 'json')
|
||||
if conn.get_domain_ban(data['domain']) is not None:
|
||||
raise HttpError(400, 'Domain already banned')
|
||||
|
||||
ban = conn.put_domain_ban(**data)
|
||||
ban = conn.put_domain_ban(
|
||||
domain = data['domain'],
|
||||
reason = data.get('reason'),
|
||||
note = data.get('note')
|
||||
)
|
||||
|
||||
return Response.new(ban, ctype = 'json')
|
||||
|
||||
|
@ -334,18 +452,19 @@ class DomainBan(View):
|
|||
with self.database.session() as conn:
|
||||
data = await self.get_api_data(['domain'], ['note', 'reason'])
|
||||
|
||||
if isinstance(data, Response):
|
||||
return data
|
||||
if not any([data.get('note'), data.get('reason')]):
|
||||
raise HttpError(400, 'Must include note and/or reason parameters')
|
||||
|
||||
data['domain'] = data['domain'].encode('idna').decode()
|
||||
|
||||
if not conn.get_domain_ban(data['domain']):
|
||||
return Response.new_error(404, 'Domain not banned', 'json')
|
||||
if conn.get_domain_ban(data['domain']) is None:
|
||||
raise HttpError(404, 'Domain not banned')
|
||||
|
||||
if not any([data.get('note'), data.get('reason')]):
|
||||
return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
|
||||
|
||||
ban = conn.update_domain_ban(**data)
|
||||
ban = conn.update_domain_ban(
|
||||
domain = data['domain'],
|
||||
reason = data.get('reason'),
|
||||
note = data.get('note')
|
||||
)
|
||||
|
||||
return Response.new(ban, ctype = 'json')
|
||||
|
||||
|
@ -353,14 +472,10 @@ class DomainBan(View):
|
|||
async def delete(self, request: Request) -> Response:
|
||||
with self.database.session() as conn:
|
||||
data = await self.get_api_data(['domain'], [])
|
||||
|
||||
if isinstance(data, Response):
|
||||
return data
|
||||
|
||||
data['domain'] = data['domain'].encode('idna').decode()
|
||||
|
||||
if not conn.get_domain_ban(data['domain']):
|
||||
return Response.new_error(404, 'Domain not banned', 'json')
|
||||
if conn.get_domain_ban(data['domain']) is None:
|
||||
raise HttpError(404, 'Domain not banned')
|
||||
|
||||
conn.del_domain_ban(data['domain'])
|
||||
|
||||
|
@ -371,7 +486,7 @@ class DomainBan(View):
|
|||
class SoftwareBan(View):
|
||||
async def get(self, request: Request) -> Response:
|
||||
with self.database.session() as conn:
|
||||
bans = tuple(conn.execute('SELECT * FROM software_bans').all())
|
||||
bans = tuple(conn.get_software_bans())
|
||||
|
||||
return Response.new(bans, ctype = 'json')
|
||||
|
||||
|
@ -379,14 +494,15 @@ class SoftwareBan(View):
|
|||
async def post(self, request: Request) -> Response:
|
||||
data = await self.get_api_data(['name'], ['note', 'reason'])
|
||||
|
||||
if isinstance(data, Response):
|
||||
return data
|
||||
|
||||
with self.database.session() as conn:
|
||||
if conn.get_software_ban(data['name']):
|
||||
return Response.new_error(400, 'Domain already banned', 'json')
|
||||
if conn.get_software_ban(data['name']) is not None:
|
||||
raise HttpError(400, 'Domain already banned')
|
||||
|
||||
ban = conn.put_software_ban(**data)
|
||||
ban = conn.put_software_ban(
|
||||
name = data['name'],
|
||||
reason = data.get('reason'),
|
||||
note = data.get('note')
|
||||
)
|
||||
|
||||
return Response.new(ban, ctype = 'json')
|
||||
|
||||
|
@ -394,17 +510,18 @@ class SoftwareBan(View):
|
|||
async def patch(self, request: Request) -> Response:
|
||||
data = await self.get_api_data(['name'], ['note', 'reason'])
|
||||
|
||||
if isinstance(data, Response):
|
||||
return data
|
||||
if not any([data.get('note'), data.get('reason')]):
|
||||
raise HttpError(400, 'Must include note and/or reason parameters')
|
||||
|
||||
with self.database.session() as conn:
|
||||
if not conn.get_software_ban(data['name']):
|
||||
return Response.new_error(404, 'Software not banned', 'json')
|
||||
if conn.get_software_ban(data['name']) is None:
|
||||
raise HttpError(404, 'Software not banned')
|
||||
|
||||
if not any([data.get('note'), data.get('reason')]):
|
||||
return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
|
||||
|
||||
ban = conn.update_software_ban(**data)
|
||||
ban = conn.update_software_ban(
|
||||
name = data['name'],
|
||||
reason = data.get('reason'),
|
||||
note = data.get('note')
|
||||
)
|
||||
|
||||
return Response.new(ban, ctype = 'json')
|
||||
|
||||
|
@ -412,12 +529,9 @@ class SoftwareBan(View):
|
|||
async def delete(self, request: Request) -> Response:
|
||||
data = await self.get_api_data(['name'], [])
|
||||
|
||||
if isinstance(data, Response):
|
||||
return data
|
||||
|
||||
with self.database.session() as conn:
|
||||
if not conn.get_software_ban(data['name']):
|
||||
return Response.new_error(404, 'Software not banned', 'json')
|
||||
if conn.get_software_ban(data['name']) is None:
|
||||
raise HttpError(404, 'Software not banned')
|
||||
|
||||
conn.del_software_ban(data['name'])
|
||||
|
||||
|
@ -430,7 +544,7 @@ class User(View):
|
|||
with self.database.session() as conn:
|
||||
items = []
|
||||
|
||||
for row in conn.execute('SELECT * FROM users'):
|
||||
for row in conn.get_users():
|
||||
del row['hash']
|
||||
items.append(row)
|
||||
|
||||
|
@ -440,41 +554,40 @@ class User(View):
|
|||
async def post(self, request: Request) -> Response:
|
||||
data = await self.get_api_data(['username', 'password'], ['handle'])
|
||||
|
||||
if isinstance(data, Response):
|
||||
return data
|
||||
|
||||
with self.database.session() as conn:
|
||||
if conn.get_user(data['username']):
|
||||
return Response.new_error(404, 'User already exists', 'json')
|
||||
if conn.get_user(data['username']) is not None:
|
||||
raise HttpError(404, 'User already exists')
|
||||
|
||||
user = conn.put_user(**data)
|
||||
del user['hash']
|
||||
user = conn.put_user(
|
||||
username = data['username'],
|
||||
password = data['password'],
|
||||
handle = data.get('handle')
|
||||
)
|
||||
|
||||
del user['hash']
|
||||
return Response.new(user, ctype = 'json')
|
||||
|
||||
|
||||
async def patch(self, request: Request) -> Response:
|
||||
data = await self.get_api_data(['username'], ['password', 'handle'])
|
||||
|
||||
if isinstance(data, Response):
|
||||
return data
|
||||
|
||||
with self.database.session(True) as conn:
|
||||
user = conn.put_user(**data)
|
||||
del user['hash']
|
||||
user = conn.put_user(
|
||||
username = data['username'],
|
||||
password = data['password'],
|
||||
handle = data.get('handle')
|
||||
)
|
||||
|
||||
del user['hash']
|
||||
return Response.new(user, ctype = 'json')
|
||||
|
||||
|
||||
async def delete(self, request: Request) -> Response:
|
||||
data = await self.get_api_data(['username'], [])
|
||||
|
||||
if isinstance(data, Response):
|
||||
return data
|
||||
|
||||
with self.database.session(True) as conn:
|
||||
if not conn.get_user(data['username']):
|
||||
return Response.new_error(404, 'User does not exist', 'json')
|
||||
if conn.get_user(data['username']) is None:
|
||||
raise HttpError(404, 'User does not exist')
|
||||
|
||||
conn.del_user(data['username'])
|
||||
|
||||
|
@ -485,7 +598,7 @@ class User(View):
|
|||
class Whitelist(View):
|
||||
async def get(self, request: Request) -> Response:
|
||||
with self.database.session() as conn:
|
||||
items = tuple(conn.execute('SELECT * FROM whitelist').all())
|
||||
items = tuple(conn.get_domains_whitelist())
|
||||
|
||||
return Response.new(items, ctype = 'json')
|
||||
|
||||
|
@ -493,16 +606,13 @@ class Whitelist(View):
|
|||
async def post(self, request: Request) -> Response:
|
||||
data = await self.get_api_data(['domain'], [])
|
||||
|
||||
if isinstance(data, Response):
|
||||
return data
|
||||
|
||||
data['domain'] = data['domain'].encode('idna').decode()
|
||||
domain = data['domain'].encode('idna').decode()
|
||||
|
||||
with self.database.session() as conn:
|
||||
if conn.get_domain_whitelist(data['domain']):
|
||||
return Response.new_error(400, 'Domain already added to whitelist', 'json')
|
||||
if conn.get_domain_whitelist(domain) is not None:
|
||||
raise HttpError(400, 'Domain already added to whitelist')
|
||||
|
||||
item = conn.put_domain_whitelist(**data)
|
||||
item = conn.put_domain_whitelist(domain)
|
||||
|
||||
return Response.new(item, ctype = 'json')
|
||||
|
||||
|
@ -510,15 +620,12 @@ class Whitelist(View):
|
|||
async def delete(self, request: Request) -> Response:
|
||||
data = await self.get_api_data(['domain'], [])
|
||||
|
||||
if isinstance(data, Response):
|
||||
return data
|
||||
|
||||
data['domain'] = data['domain'].encode('idna').decode()
|
||||
domain = data['domain'].encode('idna').decode()
|
||||
|
||||
with self.database.session() as conn:
|
||||
if not conn.get_domain_whitelist(data['domain']):
|
||||
return Response.new_error(404, 'Domain not in whitelist', 'json')
|
||||
if conn.get_domain_whitelist(domain) is None:
|
||||
raise HttpError(404, 'Domain not in whitelist')
|
||||
|
||||
conn.del_domain_whitelist(data['domain'])
|
||||
conn.del_domain_whitelist(domain)
|
||||
|
||||
return Response.new({'message': 'Removed domain from whitelist'}, ctype = 'json')
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from Crypto.Random import get_random_bytes
|
||||
from aiohttp.abc import AbstractView
|
||||
from aiohttp.hdrs import METH_ALL as METHODS
|
||||
from aiohttp.web import HTTPMethodNotAllowed, Request
|
||||
from base64 import b64encode
|
||||
from aiohttp.web import Request
|
||||
from blib import HttpError
|
||||
from bsql import Database
|
||||
from collections.abc import Awaitable, Callable, Generator, Sequence, Mapping
|
||||
from functools import cached_property
|
||||
|
@ -18,18 +17,12 @@ from ..http_client import HttpClient
|
|||
from ..misc import Response, get_app
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Self
|
||||
from ..application import Application
|
||||
from ..template import Template
|
||||
|
||||
try:
|
||||
from typing import Self
|
||||
|
||||
except ImportError:
|
||||
from typing_extensions import Self
|
||||
|
||||
HandlerCallback = Callable[[Request], Awaitable[Response]]
|
||||
|
||||
|
||||
VIEWS: list[tuple[str, type[View]]] = []
|
||||
|
||||
|
||||
|
@ -49,10 +42,10 @@ def register_route(*paths: str) -> Callable[[type[View]], type[View]]:
|
|||
class View(AbstractView):
|
||||
def __await__(self) -> Generator[Any, None, Response]:
|
||||
if self.request.method not in METHODS:
|
||||
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
|
||||
raise HttpError(405, f'"{self.request.method}" method not allowed')
|
||||
|
||||
if not (handler := self.handlers.get(self.request.method)):
|
||||
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
|
||||
raise HttpError(405, f'"{self.request.method}" method not allowed')
|
||||
|
||||
return self._run_handler(handler).__await__()
|
||||
|
||||
|
@ -64,7 +57,6 @@ class View(AbstractView):
|
|||
|
||||
|
||||
async def _run_handler(self, handler: HandlerCallback, **kwargs: Any) -> Response:
|
||||
self.request['hash'] = b64encode(get_random_bytes(16)).decode('ascii')
|
||||
return await handler(self.request, **self.request.match_info, **kwargs)
|
||||
|
||||
|
||||
|
@ -123,17 +115,18 @@ class View(AbstractView):
|
|||
|
||||
async def get_api_data(self,
|
||||
required: list[str],
|
||||
optional: list[str]) -> dict[str, str] | Response:
|
||||
optional: list[str]) -> dict[str, str]:
|
||||
|
||||
if self.request.content_type in {'x-www-form-urlencoded', 'multipart/form-data'}:
|
||||
if self.request.content_type in {'application/x-www-form-urlencoded', 'multipart/form-data'}:
|
||||
post_data = convert_data(await self.request.post())
|
||||
# post_data = {key: value for key, value in parse_qsl(await self.request.text())}
|
||||
|
||||
elif self.request.content_type == 'application/json':
|
||||
try:
|
||||
post_data = convert_data(await self.request.json())
|
||||
|
||||
except JSONDecodeError:
|
||||
return Response.new_error(400, 'Invalid JSON data', 'json')
|
||||
raise HttpError(400, 'Invalid JSON data')
|
||||
|
||||
else:
|
||||
post_data = convert_data(self.request.query)
|
||||
|
@ -145,9 +138,9 @@ class View(AbstractView):
|
|||
data[key] = post_data[key]
|
||||
|
||||
except KeyError as e:
|
||||
return Response.new_error(400, f'Missing {str(e)} pararmeter', 'json')
|
||||
raise HttpError(400, f'Missing {str(e)} pararmeter') from None
|
||||
|
||||
for key in optional:
|
||||
data[key] = post_data.get(key, '')
|
||||
data[key] = post_data.get(key) # type: ignore[assignment]
|
||||
|
||||
return data
|
||||
|
|
|
@ -1,18 +1,13 @@
|
|||
from aiohttp import web
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
from urllib.parse import unquote
|
||||
|
||||
from .base import View, register_route
|
||||
|
||||
from ..database import THEMES
|
||||
from ..logger import LogLevel
|
||||
from ..misc import Response, get_app
|
||||
|
||||
|
||||
UNAUTH_ROUTES = {
|
||||
'/',
|
||||
'/login'
|
||||
}
|
||||
from ..misc import TOKEN_PATHS, Response
|
||||
|
||||
|
||||
@web.middleware
|
||||
|
@ -20,28 +15,25 @@ async def handle_frontend_path(
|
|||
request: web.Request,
|
||||
handler: Callable[[web.Request], Awaitable[Response]]) -> Response:
|
||||
|
||||
app = get_app()
|
||||
if request['user'] is not None and request.path == '/login':
|
||||
return Response.new_redir('/')
|
||||
|
||||
if request.path in UNAUTH_ROUTES or request.path.startswith('/admin'):
|
||||
request['token'] = request.cookies.get('user-token')
|
||||
request['user'] = None
|
||||
if request.path.startswith(TOKEN_PATHS[:2]) and request['user'] is None:
|
||||
if request.path == '/logout':
|
||||
return Response.new_redir('/')
|
||||
|
||||
if request['token']:
|
||||
with app.database.session(False) as conn:
|
||||
request['user'] = conn.get_user_by_token(request['token'])
|
||||
response = Response.new_redir(f'/login?redir={request.path}')
|
||||
|
||||
if request['user'] and request.path == '/login':
|
||||
return Response.new('', 302, {'Location': '/'})
|
||||
|
||||
if not request['user'] and request.path.startswith('/admin'):
|
||||
response = Response.new('', 302, {'Location': f'/login?redir={request.path}'})
|
||||
if request['token'] is not None:
|
||||
response.del_cookie('user-token')
|
||||
return response
|
||||
|
||||
return response
|
||||
|
||||
response = await handler(request)
|
||||
|
||||
if not request.path.startswith('/api') and not request['user'] and request['token']:
|
||||
response.del_cookie('user-token')
|
||||
if not request.path.startswith('/api'):
|
||||
if request['user'] is None and request['token'] is not None:
|
||||
response.del_cookie('user-token')
|
||||
|
||||
return response
|
||||
|
||||
|
@ -54,14 +46,15 @@ class HomeView(View):
|
|||
'instances': tuple(conn.get_inboxes())
|
||||
}
|
||||
|
||||
data = self.template.render('page/home.haml', self, **context)
|
||||
data = self.template.render('page/home.haml', self.request, **context)
|
||||
return Response.new(data, ctype='html')
|
||||
|
||||
|
||||
@register_route('/login')
|
||||
class Login(View):
|
||||
async def get(self, request: web.Request) -> Response:
|
||||
data = self.template.render('page/login.haml', self)
|
||||
redir = unquote(request.query.get('redir', '/'))
|
||||
data = self.template.render('page/login.haml', self.request, redir = redir)
|
||||
return Response.new(data, ctype = 'html')
|
||||
|
||||
|
||||
|
@ -69,7 +62,7 @@ class Login(View):
|
|||
class Logout(View):
|
||||
async def get(self, request: web.Request) -> Response:
|
||||
with self.database.session(True) as conn:
|
||||
conn.del_token(request['token'])
|
||||
conn.del_app(request['token'].client_id, request['token'].client_secret)
|
||||
|
||||
resp = Response.new_redir('/')
|
||||
resp.del_cookie('user-token', domain = self.config.domain, path = '/')
|
||||
|
@ -79,7 +72,7 @@ class Logout(View):
|
|||
@register_route('/admin')
|
||||
class Admin(View):
|
||||
async def get(self, request: web.Request) -> Response:
|
||||
return Response.new('', 302, {'Location': '/admin/instances'})
|
||||
return Response.new_redir(f'/login?redir={request.path}', 301)
|
||||
|
||||
|
||||
@register_route('/admin/instances')
|
||||
|
@ -101,7 +94,7 @@ class AdminInstances(View):
|
|||
if message:
|
||||
context['message'] = message
|
||||
|
||||
data = self.template.render('page/admin-instances.haml', self, **context)
|
||||
data = self.template.render('page/admin-instances.haml', self.request, **context)
|
||||
return Response.new(data, ctype = 'html')
|
||||
|
||||
|
||||
|
@ -123,7 +116,7 @@ class AdminWhitelist(View):
|
|||
if message:
|
||||
context['message'] = message
|
||||
|
||||
data = self.template.render('page/admin-whitelist.haml', self, **context)
|
||||
data = self.template.render('page/admin-whitelist.haml', self.request, **context)
|
||||
return Response.new(data, ctype = 'html')
|
||||
|
||||
|
||||
|
@ -145,7 +138,7 @@ class AdminDomainBans(View):
|
|||
if message:
|
||||
context['message'] = message
|
||||
|
||||
data = self.template.render('page/admin-domain_bans.haml', self, **context)
|
||||
data = self.template.render('page/admin-domain_bans.haml', self.request, **context)
|
||||
return Response.new(data, ctype = 'html')
|
||||
|
||||
|
||||
|
@ -167,7 +160,7 @@ class AdminSoftwareBans(View):
|
|||
if message:
|
||||
context['message'] = message
|
||||
|
||||
data = self.template.render('page/admin-software_bans.haml', self, **context)
|
||||
data = self.template.render('page/admin-software_bans.haml', self.request, **context)
|
||||
return Response.new(data, ctype = 'html')
|
||||
|
||||
|
||||
|
@ -189,7 +182,7 @@ class AdminUsers(View):
|
|||
if message:
|
||||
context['message'] = message
|
||||
|
||||
data = self.template.render('page/admin-users.haml', self, **context)
|
||||
data = self.template.render('page/admin-users.haml', self.request, **context)
|
||||
return Response.new(data, ctype = 'html')
|
||||
|
||||
|
||||
|
@ -199,10 +192,21 @@ class AdminConfig(View):
|
|||
context: dict[str, Any] = {
|
||||
'themes': tuple(THEMES.keys()),
|
||||
'levels': tuple(level.name for level in LogLevel),
|
||||
'message': message
|
||||
'message': message,
|
||||
'desc': {
|
||||
"name": "Name of the relay to be displayed in the header of the pages and in " +
|
||||
"the actor endpoint.", # noqa: E131
|
||||
"note": "Description of the relay to be displayed on the front page and as the " +
|
||||
"bio in the actor endpoint.",
|
||||
"theme": "Color theme to use on the web pages.",
|
||||
"log_level": "Minimum level of logging messages to print to the console.",
|
||||
"whitelist_enabled": "Only allow instances in the whitelist to be able to follow.",
|
||||
"approval_required": "Require instances not on the whitelist to be approved by " +
|
||||
"and admin. The `whitelist-enabled` setting is ignored when this is enabled."
|
||||
}
|
||||
}
|
||||
|
||||
data = self.template.render('page/admin-config.haml', self, **context)
|
||||
data = self.template.render('page/admin-config.haml', self.request, **context)
|
||||
return Response.new(data, ctype = 'html')
|
||||
|
||||
|
||||
|
@ -240,5 +244,5 @@ class ThemeCss(View):
|
|||
except KeyError:
|
||||
return Response.new('Invalid theme', 404)
|
||||
|
||||
data = self.template.render('variables.css', self, **context)
|
||||
data = self.template.render('variables.css', self.request, **context)
|
||||
return Response.new(data, ctype = 'css')
|
||||
|
|
143
relay/workers.py
Normal file
143
relay/workers.py
Normal file
|
@ -0,0 +1,143 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
|
||||
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
|
||||
from blib import HttpError
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import Event, Process, Queue, Value
|
||||
from multiprocessing.queues import Queue as QueueType
|
||||
from multiprocessing.sharedctypes import Synchronized
|
||||
from multiprocessing.synchronize import Event as EventType
|
||||
from pathlib import Path
|
||||
from queue import Empty
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from . import application, logger as logging
|
||||
from .database.schema import Instance
|
||||
from .http_client import HttpClient
|
||||
from .misc import IS_WINDOWS, Message, get_app
|
||||
|
||||
|
||||
@dataclass
|
||||
class PostItem:
|
||||
inbox: str
|
||||
message: Message
|
||||
instance: Instance | None
|
||||
|
||||
@property
|
||||
def domain(self) -> str:
|
||||
return urlparse(self.inbox).netloc
|
||||
|
||||
|
||||
class PushWorker(Process):
|
||||
client: HttpClient
|
||||
|
||||
|
||||
def __init__(self, queue: QueueType[PostItem], log_level: Synchronized[int]) -> None:
|
||||
Process.__init__(self)
|
||||
|
||||
self.queue: QueueType[PostItem] = queue
|
||||
self.shutdown: EventType = Event()
|
||||
self.path: Path = get_app().config.path
|
||||
self.log_level: Synchronized[int] = log_level
|
||||
self._log_level_changed: EventType = Event()
|
||||
|
||||
|
||||
def stop(self) -> None:
|
||||
self.shutdown.set()
|
||||
|
||||
|
||||
def run(self) -> None:
|
||||
asyncio.run(self.handle_queue())
|
||||
|
||||
|
||||
async def handle_queue(self) -> None:
|
||||
if IS_WINDOWS:
|
||||
app = application.Application(self.path)
|
||||
self.client = app.client
|
||||
|
||||
self.client.open()
|
||||
app.database.connect()
|
||||
app.cache.setup()
|
||||
|
||||
else:
|
||||
self.client = HttpClient()
|
||||
self.client.open()
|
||||
|
||||
logging.verbose("[%i] Starting worker", self.pid)
|
||||
|
||||
while not self.shutdown.is_set():
|
||||
try:
|
||||
if self._log_level_changed.is_set():
|
||||
logging.set_level(logging.LogLevel.parse(self.log_level.value))
|
||||
self._log_level_changed.clear()
|
||||
|
||||
item = self.queue.get(block=True, timeout=0.1)
|
||||
asyncio.create_task(self.handle_post(item))
|
||||
|
||||
except Empty:
|
||||
await asyncio.sleep(0)
|
||||
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
if IS_WINDOWS:
|
||||
app.database.disconnect()
|
||||
app.cache.close()
|
||||
|
||||
await self.client.close()
|
||||
|
||||
|
||||
async def handle_post(self, item: PostItem) -> None:
|
||||
try:
|
||||
await self.client.post(item.inbox, item.message, item.instance)
|
||||
|
||||
except HttpError as e:
|
||||
logging.error('HTTP Error when pushing to %s: %i %s', item.inbox, e.status, e.message)
|
||||
|
||||
except AsyncTimeoutError:
|
||||
logging.error('Timeout when pushing to %s', item.domain)
|
||||
|
||||
except ClientConnectionError as e:
|
||||
logging.error('Failed to connect to %s for message push: %s', item.domain, str(e))
|
||||
|
||||
except ClientSSLError as e:
|
||||
logging.error('SSL error when pushing to %s: %s', item.domain, str(e))
|
||||
|
||||
|
||||
class PushWorkers(list[PushWorker]):
|
||||
def __init__(self, count: int) -> None:
|
||||
self.queue: QueueType[PostItem] = Queue()
|
||||
self._log_level: Synchronized[int] = Value("i", logging.get_level())
|
||||
self._count: int = count
|
||||
|
||||
|
||||
def push_message(self, inbox: str, message: Message, instance: Instance) -> None:
|
||||
self.queue.put(PostItem(inbox, message, instance))
|
||||
|
||||
|
||||
def set_log_level(self, value: logging.LogLevel) -> None:
|
||||
self._log_level.value = value
|
||||
|
||||
for worker in self:
|
||||
worker._log_level_changed.set()
|
||||
|
||||
|
||||
def start(self) -> None:
|
||||
if len(self) > 0:
|
||||
return
|
||||
|
||||
for _ in range(self._count):
|
||||
worker = PushWorker(self.queue, self._log_level)
|
||||
worker.start()
|
||||
self.append(worker)
|
||||
|
||||
|
||||
def stop(self) -> None:
|
||||
for worker in self:
|
||||
worker.stop()
|
||||
|
||||
self.clear()
|
Loading…
Reference in a new issue