\n';
- text += `
\n`;
- text += `
\n`;
- text += `
\n`;
- text += `
\n`;
- text += `
`;
- text += '';
-
- return text;
-}
-
-
-function add_row_listeners(row) {
- row.querySelector(".update-ban").addEventListener("click", async (event) => {
- await update_ban(row.id);
- });
-
- row.querySelector(".remove a").addEventListener("click", async (event) => {
- event.preventDefault();
- await unban(row.id);
- });
-}
-
-
-async function ban() {
- var elems = {
- name: document.getElementById("new-name"),
- reason: document.getElementById("new-reason"),
- note: document.getElementById("new-note")
- }
-
- var values = {
- name: elems.name.value.trim(),
- reason: elems.reason.value,
- note: elems.note.value
- }
-
- if (values.name === "") {
- toast("Domain is required");
- return;
- }
-
- try {
- var ban = await request("POST", "v1/software_ban", values);
-
- } catch (err) {
- toast(err);
- return
- }
-
- var row = append_table_row(document.getElementById("bans"), ban.name, {
- name: create_ban_object(ban.name, ban.reason, ban.note),
- date: get_date_string(ban.created),
- remove: `
✖`
- });
-
- add_row_listeners(row);
-
- elems.name.value = null;
- elems.reason.value = null;
- elems.note.value = null;
-
- document.querySelector("details.section").open = false;
- toast("Banned software", "message");
-}
-
-
-async function update_ban(name) {
- var row = document.getElementById(name);
-
- var elems = {
- "reason": row.querySelector("textarea.reason"),
- "note": row.querySelector("textarea.note")
- }
-
- var values = {
- "name": name,
- "reason": elems.reason.value,
- "note": elems.note.value
- }
-
- try {
- await request("PATCH", "v1/software_ban", values)
-
- } catch (error) {
- toast(error);
- return;
- }
-
- row.querySelector("details").open = false;
- toast("Updated software ban", "message");
-}
-
-
-async function unban(name) {
- try {
- await request("DELETE", "v1/software_ban", {"name": name});
-
- } catch (error) {
- toast(error);
- return;
- }
-
- document.getElementById(name).remove();
- toast("Unbanned software", "message");
-}
-
-
-document.querySelector("#new-ban").addEventListener("click", async (event) => {
- await ban();
-});
-
-for (var row of document.querySelector("#bans").rows) {
- if (!row.querySelector(".update-ban")) {
- continue;
- }
-
- add_row_listeners(row);
-}
diff --git a/relay/frontend/static/style.css b/relay/frontend/static/style.css
index 6ec9316..ac4eaf5 100644
--- a/relay/frontend/static/style.css
+++ b/relay/frontend/static/style.css
@@ -12,7 +12,7 @@ body {
color: var(--text);
background-color: #222;
margin: var(--spacing);
- font-family: sans serif;
+ font-family: sans-serif;
}
details *:nth-child(2) {
@@ -88,6 +88,7 @@ tbody tr:last-child td:last-child {
table td {
padding: 5px;
+ white-space: nowrap;
}
table thead td {
@@ -282,8 +283,11 @@ textarea {
width: 100%;
}
-.data-table .date {
+.data-table td:not(:first-child) {
width: max-content;
+}
+
+.data-table .date {
text-align: right;
}
@@ -297,13 +301,13 @@ textarea {
border: 1px solid var(--error-border) !important;
}
+/* create .grid base class and .2col and 3col classes */
.grid-2col {
display: grid;
grid-template-columns: max-content auto;
grid-gap: var(--spacing);
margin-bottom: var(--spacing);
align-items: center;
-
}
.message {
@@ -333,6 +337,48 @@ textarea {
justify-self: left;
}
+#content.page-config .grid-2col {
+ grid-template-columns: max-content max-content auto;
+}
+
+
+/* error */
+#content.page-error {
+ text-align: center;
+}
+
+#content.page-error .title {
+ font-size: 24px;
+ font-weight: bold;
+}
+
+
+/* auth */
+#content.page-app_authorization {
+ text-align: center;
+}
+
+#content.page-app_authorization #code {
+ background: var(--background);
+ border: 1px solid var(--border);
+ font-size: 18px;
+ margin: 0 auto;
+ width: max-content;
+ padding: 5px;
+}
+
+#content.page-app_authorization #title {
+ font-size: 24px;
+}
+
+#content.page-app_authorization #buttons {
+ display: grid;
+ grid-template-columns: auto max-content max-content auto;
+ grid-gap: var(--spacing);
+ justify-items: center;
+ margin: var(--spacing) 0;
+}
+
@keyframes show_toast {
0% {
diff --git a/relay/frontend/static/user.js b/relay/frontend/static/user.js
deleted file mode 100644
index 9c74359..0000000
--- a/relay/frontend/static/user.js
+++ /dev/null
@@ -1,85 +0,0 @@
-function add_row_listeners(row) {
- row.querySelector(".remove a").addEventListener("click", async (event) => {
- event.preventDefault();
- await del_user(row.id);
- });
-}
-
-
-async function add_user() {
- var elems = {
- username: document.getElementById("new-username"),
- password: document.getElementById("new-password"),
- password2: document.getElementById("new-password2"),
- handle: document.getElementById("new-handle")
- }
-
- var values = {
- username: elems.username.value.trim(),
- password: elems.password.value.trim(),
- password2: elems.password2.value.trim(),
- handle: elems.handle.value.trim()
- }
-
- if (values.username === "" | values.password === "" | values.password2 === "") {
- toast("Username, password, and password2 are required");
- return;
- }
-
- if (values.password !== values.password2) {
- toast("Passwords do not match");
- return;
- }
-
- try {
- var user = await request("POST", "v1/user", values);
-
- } catch (err) {
- toast(err);
- return
- }
-
- var row = append_table_row(document.querySelector("fieldset.section table"), user.username, {
- domain: user.username,
- handle: user.handle ? self.handle : "n/a",
- date: get_date_string(user.created),
- remove: `
✖`
- });
-
- add_row_listeners(row);
-
- elems.username.value = null;
- elems.password.value = null;
- elems.password2.value = null;
- elems.handle.value = null;
-
- document.querySelector("details.section").open = false;
- toast("Created user", "message");
-}
-
-
-async function del_user(username) {
- try {
- await request("DELETE", "v1/user", {"username": username});
-
- } catch (error) {
- toast(error);
- return;
- }
-
- document.getElementById(username).remove();
- toast("Deleted user", "message");
-}
-
-
-document.querySelector("#new-user").addEventListener("click", async (event) => {
- await add_user();
-});
-
-for (var row of document.querySelector("#users").rows) {
- if (!row.querySelector(".remove a")) {
- continue;
- }
-
- add_row_listeners(row);
-}
diff --git a/relay/frontend/static/whitelist.js b/relay/frontend/static/whitelist.js
deleted file mode 100644
index 70d4db1..0000000
--- a/relay/frontend/static/whitelist.js
+++ /dev/null
@@ -1,64 +0,0 @@
-function add_row_listeners(row) {
- row.querySelector(".remove a").addEventListener("click", async (event) => {
- event.preventDefault();
- await del_whitelist(row.id);
- });
-}
-
-
-async function add_whitelist() {
- var domain_elem = document.getElementById("new-domain");
- var domain = domain_elem.value.trim();
-
- if (domain === "") {
- toast("Domain is required");
- return;
- }
-
- try {
- var item = await request("POST", "v1/whitelist", {"domain": domain});
-
- } catch (err) {
- toast(err);
- return;
- }
-
- var row = append_table_row(document.getElementById("whitelist"), item.domain, {
- domain: item.domain,
- date: get_date_string(item.created),
- remove: `
✖`
- });
-
- add_row_listeners(row);
-
- domain_elem.value = null;
- document.querySelector("details.section").open = false;
- toast("Added domain", "message");
-}
-
-
-async function del_whitelist(domain) {
- try {
- await request("DELETE", "v1/whitelist", {"domain": domain});
-
- } catch (error) {
- toast(error);
- return;
- }
-
- document.getElementById(domain).remove();
- toast("Removed domain", "message");
-}
-
-
-document.querySelector("#new-item").addEventListener("click", async (event) => {
- await add_whitelist();
-});
-
-for (var row of document.querySelector("fieldset.section table").rows) {
- if (!row.querySelector(".remove a")) {
- continue;
- }
-
- add_row_listeners(row);
-}
diff --git a/relay/http_client.py b/relay/http_client.py
index 54cea3c..ef25881 100644
--- a/relay/http_client.py
+++ b/relay/http_client.py
@@ -1,20 +1,16 @@
from __future__ import annotations
import json
-import traceback
from aiohttp import ClientSession, ClientTimeout, TCPConnector
-from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
from aputils import AlgorithmType, Nodeinfo, ObjectType, Signer, WellKnownNodeinfo
-from asyncio.exceptions import TimeoutError as AsyncTimeoutError
-from blib import JsonBase
-from bsql import Row
-from json.decoder import JSONDecodeError
-from typing import TYPE_CHECKING, Any, TypeVar
-from urllib.parse import urlparse
+from blib import HttpError, JsonBase
+from typing import TYPE_CHECKING, Any, TypeVar, overload
from . import __version__, logger as logging
from .cache import Cache
+from .database.schema import Instance
+from .errors import EmptyBodyError
from .misc import MIMETYPES, Message, get_app
if TYPE_CHECKING:
@@ -36,7 +32,7 @@ SUPPORTS_HS2019 = {
'sharkey'
}
-T = TypeVar('T', bound = JsonBase)
+T = TypeVar('T', bound = JsonBase[Any])
HEADERS = {
'Accept': f'{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9',
'User-Agent': f'ActivityRelay/{__version__}'
@@ -107,21 +103,17 @@ class HttpClient:
url: str,
sign_headers: bool,
force: bool,
- old_algo: bool) -> dict[str, Any] | None:
+ old_algo: bool) -> str | None:
if not self._session:
raise RuntimeError('Client not open')
- try:
- url, _ = url.split('#', 1)
-
- except ValueError:
- pass
+ url = url.split("#", 1)[0]
if not force:
try:
if not (item := self.cache.get('request', url)).older_than(48):
- return json.loads(item.value) # type: ignore[no-any-return]
+ return item.value # type: ignore [no-any-return]
except KeyError:
logging.verbose('No cached data for url: %s', url)
@@ -132,67 +124,77 @@ class HttpClient:
algo = AlgorithmType.RSASHA256 if old_algo else AlgorithmType.HS2019
headers = self.signer.sign_headers('GET', url, algorithm = algo)
- try:
- logging.debug('Fetching resource: %s', url)
+ logging.debug('Fetching resource: %s', url)
- async with self._session.get(url, headers = headers) as resp:
- # Not expecting a response with 202s, so just return
- if resp.status == 202:
- return None
-
- data = await resp.text()
-
- if resp.status != 200:
- logging.verbose('Received error when requesting %s: %i', url, resp.status)
- logging.debug(data)
+ async with self._session.get(url, headers = headers) as resp:
+ # Not expecting a response with 202s, so just return
+ if resp.status == 202:
return None
- self.cache.set('request', url, data, 'str')
- logging.debug('%s >> resp %s', url, json.dumps(json.loads(data), indent = 4))
+ data = await resp.text()
- return json.loads(data) # type: ignore [no-any-return]
-
- except JSONDecodeError:
- logging.verbose('Failed to parse JSON')
+ if resp.status not in (200, 202):
+ logging.verbose('Received error when requesting %s: %i', url, resp.status)
logging.debug(data)
- return None
- except ClientSSLError as e:
- logging.verbose('SSL error when connecting to %s', urlparse(url).netloc)
- logging.warning(str(e))
+ try:
+ error = json.loads(data)["error"]
- except (AsyncTimeoutError, ClientConnectionError) as e:
- logging.verbose('Failed to connect to %s', urlparse(url).netloc)
- logging.warning(str(e))
+ except Exception:
+ error = data
- except Exception:
- traceback.print_exc()
+ raise HttpError(resp.status, error)
- return None
+ self.cache.set('request', url, data, 'str')
+ return data
+
+
+ @overload
+ async def get(self,
+ url: str,
+ sign_headers: bool,
+ cls: None = None,
+ force: bool = False,
+ old_algo: bool = True) -> str | None: ...
+
+
+ @overload
+ async def get(self,
+ url: str,
+ sign_headers: bool,
+ cls: type[T] = JsonBase, # type: ignore[assignment]
+ force: bool = False,
+ old_algo: bool = True) -> T: ...
async def get(self,
url: str,
sign_headers: bool,
- cls: type[T],
+ cls: type[T] | None = None,
force: bool = False,
- old_algo: bool = True) -> T | None:
+ old_algo: bool = True) -> T | str | None:
- if not issubclass(cls, JsonBase):
+ if cls is not None and not issubclass(cls, JsonBase):
raise TypeError('cls must be a sub-class of "blib.JsonBase"')
- if (data := (await self._get(url, sign_headers, force, old_algo))) is None:
- return None
+ data = await self._get(url, sign_headers, force, old_algo)
- return cls.parse(data)
+ if cls is not None:
+ if data is None:
+ # this shouldn't actually get raised, but keeping just in case
+ raise EmptyBodyError(f"GET {url}")
+
+ return cls.parse(data)
+
+ return data
- async def post(self, url: str, data: Message | bytes, instance: Row | None = None) -> None:
+ async def post(self, url: str, data: Message | bytes, instance: Instance | None = None) -> None:
if not self._session:
raise RuntimeError('Client not open')
# akkoma and pleroma do not support HS2019 and other software still needs to be tested
- if instance and instance['software'] in SUPPORTS_HS2019:
+ if instance is not None and instance.software in SUPPORTS_HS2019:
algorithm = AlgorithmType.HS2019
else:
@@ -218,46 +220,27 @@ class HttpClient:
algorithm = algorithm
)
- try:
- logging.verbose('Sending "%s" to %s', mtype, url)
+ logging.verbose('Sending "%s" to %s', mtype, url)
- async with self._session.post(url, headers = headers, data = body) as resp:
- # Not expecting a response, so just return
- if resp.status in {200, 202}:
- logging.verbose('Successfully sent "%s" to %s', mtype, url)
- return
-
- logging.verbose('Received error when pushing to %s: %i', url, resp.status)
- logging.debug(await resp.read())
- logging.debug("message: %s", body.decode("utf-8"))
- logging.debug("headers: %s", json.dumps(headers, indent = 4))
+ async with self._session.post(url, headers = headers, data = body) as resp:
+ # Not expecting a response, so just return
+ if resp.status in {200, 202}:
+ logging.verbose('Successfully sent "%s" to %s', mtype, url)
return
- except ClientSSLError as e:
- logging.warning('SSL error when pushing to %s', urlparse(url).netloc)
- logging.warning(str(e))
-
- except (AsyncTimeoutError, ClientConnectionError) as e:
- logging.warning('Failed to connect to %s for message push', urlparse(url).netloc)
- logging.warning(str(e))
-
- # prevent workers from being brought down
- except Exception:
- traceback.print_exc()
+ logging.error('Received error when pushing to %s: %i', url, resp.status)
+ logging.debug(await resp.read())
+ logging.debug("message: %s", body.decode("utf-8"))
+ logging.debug("headers: %s", json.dumps(headers, indent = 4))
+ return
- async def fetch_nodeinfo(self, domain: str) -> Nodeinfo | None:
+ async def fetch_nodeinfo(self, domain: str, force: bool = False) -> Nodeinfo:
nodeinfo_url = None
wk_nodeinfo = await self.get(
- f'https://{domain}/.well-known/nodeinfo',
- False,
- WellKnownNodeinfo
+ f'https://{domain}/.well-known/nodeinfo', False, WellKnownNodeinfo, force
)
- if wk_nodeinfo is None:
- logging.verbose('Failed to fetch well-known nodeinfo url for %s', domain)
- return None
-
for version in ('20', '21'):
try:
nodeinfo_url = wk_nodeinfo.get_url(version)
@@ -266,10 +249,9 @@ class HttpClient:
pass
if nodeinfo_url is None:
- logging.verbose('Failed to fetch nodeinfo url for %s', domain)
- return None
+ raise ValueError(f'Failed to fetch nodeinfo url for {domain}')
- return await self.get(nodeinfo_url, False, Nodeinfo)
+ return await self.get(nodeinfo_url, False, Nodeinfo, force)
async def get(*args: Any, **kwargs: Any) -> Any:
diff --git a/relay/logger.py b/relay/logger.py
index f1a1bd7..7caac9f 100644
--- a/relay/logger.py
+++ b/relay/logger.py
@@ -1,16 +1,15 @@
+from __future__ import annotations
+
import logging
import os
from enum import IntEnum
from pathlib import Path
-from typing import Any, Protocol
+from typing import TYPE_CHECKING, Any, Protocol
-try:
+if TYPE_CHECKING:
from typing import Self
-except ImportError:
- from typing_extensions import Self
-
class LoggingMethod(Protocol):
def __call__(self, msg: Any, *args: Any, **kwargs: Any) -> None: ...
diff --git a/relay/manage.py b/relay/manage.py
index cb2b099..b76443d 100644
--- a/relay/manage.py
+++ b/relay/manage.py
@@ -6,7 +6,6 @@ import click
import json
import os
-from bsql import Row
from pathlib import Path
from shutil import copyfile
from typing import Any
@@ -17,7 +16,8 @@ from . import http_client as http
from . import logger as logging
from .application import Application
from .compat import RelayConfig, RelayDatabase
-from .database import RELAY_SOFTWARE, get_database
+from .config import Config
+from .database import RELAY_SOFTWARE, get_database, schema
from .misc import ACTOR_FORMATS, SOFTWARE, IS_DOCKER, Message
@@ -213,6 +213,24 @@ def cli_run(ctx: click.Context, dev: bool = False) -> None:
os._exit(0)
+@cli.command('db-maintenance')
+@click.option('--fix-timestamps', '-t', is_flag = True,
+ help = 'Make sure timestamps in the database are float values')
+@click.pass_context
+def cli_db_maintenance(ctx: click.Context, fix_timestamps: bool) -> None:
+ 'Perform maintenance tasks on the database'
+
+ if fix_timestamps:
+ with ctx.obj.database.session(True) as conn:
+ conn.fix_timestamps()
+
+ if ctx.obj.config.db_type == "postgres":
+ return
+
+ with ctx.obj.database.session(False) as conn:
+ with conn.execute("VACUUM"):
+ pass
+
@cli.command('convert')
@click.option('--old-config', '-o', help = 'Path to the config file to convert from')
@@ -240,18 +258,18 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
ctx.obj.config.set('domain', config['host'])
ctx.obj.config.save()
+ # fix: mypy complains about the types returned by click.progressbar when updating click to 8.1.7
with get_database(ctx.obj.config) as db:
with db.session(True) as conn:
conn.put_config('private-key', database['private-key'])
conn.put_config('note', config['note'])
conn.put_config('whitelist-enabled', config['whitelist_enabled'])
- with click.progressbar( # type: ignore
+ with click.progressbar(
database['relay-list'].values(),
label = 'Inboxes'.ljust(15),
width = 0
) as inboxes:
-
for inbox in inboxes:
if inbox['software'] in {'akkoma', 'pleroma'}:
actor = f'https://{inbox["domain"]}/relay'
@@ -270,7 +288,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
software = inbox['software']
)
- with click.progressbar( # type: ignore
+ with click.progressbar(
config['blocked_software'],
label = 'Banned software'.ljust(15),
width = 0
@@ -282,7 +300,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
reason = 'relay' if software in RELAY_SOFTWARE else None
)
- with click.progressbar( # type: ignore
+ with click.progressbar(
config['blocked_instances'],
label = 'Banned domains'.ljust(15),
width = 0
@@ -291,7 +309,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
for domain in banned_software:
conn.put_domain_ban(domain)
- with click.progressbar( # type: ignore
+ with click.progressbar(
config['whitelist'],
label = 'Whitelist'.ljust(15),
width = 0
@@ -315,6 +333,33 @@ def cli_editconfig(ctx: click.Context, editor: str) -> None:
)
+@cli.command('switch-backend')
+@click.pass_context
+def cli_switchbackend(ctx: click.Context) -> None:
+ """
+ Copy the database from one backend to the other
+
+ Be sure to set the database type to the backend you want to convert from. For instance, set
+ the database type to `sqlite`, fill out the connection details for postgresql, and the
+ data from the sqlite database will be copied to the postgresql database. This only works if
+ the database in postgresql already exists.
+ """
+
+ config = Config(ctx.obj.config.path, load = True)
+ config.db_type = "sqlite" if config.db_type == "postgres" else "postgres"
+ database = get_database(config, migrate = False)
+
+ with database.session(True) as new, ctx.obj.database.session(False) as old:
+ new.create_tables()
+
+ for table in schema.TABLES.keys():
+ for row in old.execute(f"SELECT * FROM {table}"):
+ new.insert(table, row).close()
+
+ config.save()
+ click.echo(f"Converted database to {repr(config.db_type)}")
+
+
@cli.group('config')
def cli_config() -> None:
'Manage the relay settings stored in the database'
@@ -348,10 +393,15 @@ def cli_config_list(ctx: click.Context) -> None:
def cli_config_set(ctx: click.Context, key: str, value: Any) -> None:
'Set a config value'
- with ctx.obj.database.session() as conn:
- new_value = conn.put_config(key, value)
+ try:
+ with ctx.obj.database.session() as conn:
+ new_value = conn.put_config(key, value)
- print(f'{key}: {repr(new_value)}')
+ except Exception:
+ click.echo(f'Invalid config name: {key}')
+ return
+
+ click.echo(f'{key}: {repr(new_value)}')
@cli.group('user')
@@ -367,8 +417,8 @@ def cli_user_list(ctx: click.Context) -> None:
click.echo('Users:')
with ctx.obj.database.session() as conn:
- for user in conn.execute('SELECT * FROM users'):
- click.echo(f'- {user["username"]}')
+ for row in conn.get_users():
+ click.echo(f'- {row.username}')
@cli_user.command('create')
@@ -379,7 +429,7 @@ def cli_user_create(ctx: click.Context, username: str, handle: str) -> None:
'Create a new local user'
with ctx.obj.database.session() as conn:
- if conn.get_user(username):
+ if conn.get_user(username) is not None:
click.echo(f'User already exists: {username}')
return
@@ -406,7 +456,7 @@ def cli_user_delete(ctx: click.Context, username: str) -> None:
'Delete a local user'
with ctx.obj.database.session() as conn:
- if not conn.get_user(username):
+ if conn.get_user(username) is None:
click.echo(f'User does not exist: {username}')
return
@@ -424,8 +474,8 @@ def cli_user_list_tokens(ctx: click.Context, username: str) -> None:
click.echo(f'Tokens for "{username}":')
with ctx.obj.database.session() as conn:
- for token in conn.execute('SELECT * FROM tokens WHERE user = :user', {'user': username}):
- click.echo(f'- {token["code"]}')
+ for row in conn.get_tokens(username):
+ click.echo(f'- {row.code}')
@cli_user.command('create-token')
@@ -435,13 +485,13 @@ def cli_user_create_token(ctx: click.Context, username: str) -> None:
'Create a new API token for a user'
with ctx.obj.database.session() as conn:
- if not (user := conn.get_user(username)):
+ if (user := conn.get_user(username)) is None:
click.echo(f'User does not exist: {username}')
return
- token = conn.put_token(user['username'])
+ token = conn.put_token(user.username)
- click.echo(f'New token for "{username}": {token["code"]}')
+ click.echo(f'New token for "{username}": {token.code}')
@cli_user.command('delete-token')
@@ -451,7 +501,7 @@ def cli_user_delete_token(ctx: click.Context, code: str) -> None:
'Delete an API token'
with ctx.obj.database.session() as conn:
- if not conn.get_token(code):
+ if conn.get_token(code) is None:
click.echo('Token does not exist')
return
@@ -473,8 +523,8 @@ def cli_inbox_list(ctx: click.Context) -> None:
click.echo('Connected to the following instances or relays:')
with ctx.obj.database.session() as conn:
- for inbox in conn.get_inboxes():
- click.echo(f'- {inbox["inbox"]}')
+ for row in conn.get_inboxes():
+ click.echo(f'- {row.inbox}')
@cli_inbox.command('follow')
@@ -483,19 +533,21 @@ def cli_inbox_list(ctx: click.Context) -> None:
def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
'Follow an actor (Relay must be running)'
+ instance: schema.Instance | None = None
+
with ctx.obj.database.session() as conn:
if conn.get_domain_ban(actor):
click.echo(f'Error: Refusing to follow banned actor: {actor}')
return
- if (inbox_data := conn.get_inbox(actor)):
- inbox = inbox_data['inbox']
+ if (instance := conn.get_inbox(actor)) is not None:
+ inbox = instance.inbox
else:
if not actor.startswith('http'):
actor = f'https://{actor}/actor'
- if not (actor_data := asyncio.run(http.get(actor, sign_headers = True))):
+ if (actor_data := asyncio.run(http.get(actor, sign_headers = True))) is None:
click.echo(f'Failed to fetch actor: {actor}')
return
@@ -506,7 +558,7 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
actor = actor
)
- asyncio.run(http.post(inbox, message, inbox_data))
+ asyncio.run(http.post(inbox, message, instance))
click.echo(f'Sent follow message to actor: {actor}')
@@ -516,19 +568,19 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
'Unfollow an actor (Relay must be running)'
- inbox_data: Row | None = None
+ instance: schema.Instance | None = None
with ctx.obj.database.session() as conn:
if conn.get_domain_ban(actor):
click.echo(f'Error: Refusing to follow banned actor: {actor}')
return
- if (inbox_data := conn.get_inbox(actor)):
- inbox = inbox_data['inbox']
+ if (instance := conn.get_inbox(actor)):
+ inbox = instance.inbox
message = Message.new_unfollow(
host = ctx.obj.config.domain,
actor = actor,
- follow = inbox_data['followid']
+ follow = instance.followid
)
else:
@@ -552,7 +604,7 @@ def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
}
)
- asyncio.run(http.post(inbox, message, inbox_data))
+ asyncio.run(http.post(inbox, message, instance))
click.echo(f'Sent unfollow message to: {actor}')
@@ -632,9 +684,9 @@ def cli_request_list(ctx: click.Context) -> None:
click.echo('Follow requests:')
with ctx.obj.database.session() as conn:
- for instance in conn.get_requests():
- date = instance['created'].strftime('%Y-%m-%d')
- click.echo(f'- [{date}] {instance["domain"]}')
+ for row in conn.get_requests():
+ date = row.created.strftime('%Y-%m-%d')
+ click.echo(f'- [{date}] {row.domain}')
@cli_request.command('accept')
@@ -653,20 +705,20 @@ def cli_request_accept(ctx: click.Context, domain: str) -> None:
message = Message.new_response(
host = ctx.obj.config.domain,
- actor = instance['actor'],
- followid = instance['followid'],
+ actor = instance.actor,
+ followid = instance.followid,
accept = True
)
- asyncio.run(http.post(instance['inbox'], message, instance))
+ asyncio.run(http.post(instance.inbox, message, instance))
- if instance['software'] != 'mastodon':
+ if instance.software != 'mastodon':
message = Message.new_follow(
host = ctx.obj.config.domain,
- actor = instance['actor']
+ actor = instance.actor
)
- asyncio.run(http.post(instance['inbox'], message, instance))
+ asyncio.run(http.post(instance.inbox, message, instance))
@cli_request.command('deny')
@@ -685,12 +737,12 @@ def cli_request_deny(ctx: click.Context, domain: str) -> None:
response = Message.new_response(
host = ctx.obj.config.domain,
- actor = instance['actor'],
- followid = instance['followid'],
+ actor = instance.actor,
+ followid = instance.followid,
accept = False
)
- asyncio.run(http.post(instance['inbox'], response, instance))
+ asyncio.run(http.post(instance.inbox, response, instance))
@cli.group('instance')
@@ -706,12 +758,12 @@ def cli_instance_list(ctx: click.Context) -> None:
click.echo('Banned domains:')
with ctx.obj.database.session() as conn:
- for instance in conn.execute('SELECT * FROM domain_bans'):
- if instance['reason']:
- click.echo(f'- {instance["domain"]} ({instance["reason"]})')
+ for row in conn.get_domain_bans():
+ if row.reason is not None:
+ click.echo(f'- {row.domain} ({row.reason})')
else:
- click.echo(f'- {instance["domain"]}')
+ click.echo(f'- {row.domain}')
@cli_instance.command('ban')
@@ -723,7 +775,7 @@ def cli_instance_ban(ctx: click.Context, domain: str, reason: str, note: str) ->
'Ban an instance and remove the associated inbox if it exists'
with ctx.obj.database.session() as conn:
- if conn.get_domain_ban(domain):
+ if conn.get_domain_ban(domain) is not None:
click.echo(f'Domain already banned: {domain}')
return
@@ -739,7 +791,7 @@ def cli_instance_unban(ctx: click.Context, domain: str) -> None:
'Unban an instance'
with ctx.obj.database.session() as conn:
- if not conn.del_domain_ban(domain):
+ if conn.del_domain_ban(domain) is None:
click.echo(f'Instance wasn\'t banned: {domain}')
return
@@ -764,11 +816,11 @@ def cli_instance_update(ctx: click.Context, domain: str, reason: str, note: str)
click.echo(f'Updated domain ban: {domain}')
- if row['reason']:
- click.echo(f'- {row["domain"]} ({row["reason"]})')
+ if row.reason:
+ click.echo(f'- {row.domain} ({row.reason})')
else:
- click.echo(f'- {row["domain"]}')
+ click.echo(f'- {row.domain}')
@cli.group('software')
@@ -784,12 +836,12 @@ def cli_software_list(ctx: click.Context) -> None:
click.echo('Banned software:')
with ctx.obj.database.session() as conn:
- for software in conn.execute('SELECT * FROM software_bans'):
- if software['reason']:
- click.echo(f'- {software["name"]} ({software["reason"]})')
+ for row in conn.get_software_bans():
+ if row.reason:
+ click.echo(f'- {row.name} ({row.reason})')
else:
- click.echo(f'- {software["name"]}')
+ click.echo(f'- {row.name}')
@cli_software.command('ban')
@@ -811,12 +863,12 @@ def cli_software_ban(ctx: click.Context,
with ctx.obj.database.session() as conn:
if name == 'RELAYS':
- for software in RELAY_SOFTWARE:
- if conn.get_software_ban(software):
- click.echo(f'Relay already banned: {software}')
+ for item in RELAY_SOFTWARE:
+ if conn.get_software_ban(item):
+ click.echo(f'Relay already banned: {item}')
continue
- conn.put_software_ban(software, reason or 'relay', note)
+ conn.put_software_ban(item, reason or 'relay', note)
click.echo('Banned all relay software')
return
@@ -893,11 +945,11 @@ def cli_software_update(ctx: click.Context, name: str, reason: str, note: str) -
click.echo(f'Updated software ban: {name}')
- if row['reason']:
- click.echo(f'- {row["name"]} ({row["reason"]})')
+ if row.reason:
+ click.echo(f'- {row.name} ({row.reason})')
else:
- click.echo(f'- {row["name"]}')
+ click.echo(f'- {row.name}')
@cli.group('whitelist')
@@ -913,8 +965,8 @@ def cli_whitelist_list(ctx: click.Context) -> None:
click.echo('Current whitelisted domains:')
with ctx.obj.database.session() as conn:
- for domain in conn.execute('SELECT * FROM whitelist'):
- click.echo(f'- {domain["domain"]}')
+ for row in conn.get_domain_whitelist():
+ click.echo(f'- {row.domain}')
@cli_whitelist.command('add')
@@ -953,23 +1005,19 @@ def cli_whitelist_remove(ctx: click.Context, domain: str) -> None:
@cli_whitelist.command('import')
@click.pass_context
def cli_whitelist_import(ctx: click.Context) -> None:
- 'Add all current inboxes to the whitelist'
+ 'Add all current instances to the whitelist'
with ctx.obj.database.session() as conn:
- for inbox in conn.execute('SELECT * FROM inboxes').all():
- if conn.get_domain_whitelist(inbox['domain']):
- click.echo(f'Domain already in whitelist: {inbox["domain"]}')
+ for row in conn.get_inboxes():
+ if conn.get_domain_whitelist(row.domain) is not None:
+ click.echo(f'Domain already in whitelist: {row.domain}')
continue
- conn.put_domain_whitelist(inbox['domain'])
+ conn.put_domain_whitelist(row.domain)
click.echo('Imported whitelist from inboxes')
def main() -> None:
- cli(prog_name='relay')
-
-
-if __name__ == '__main__':
- click.echo('Running relay.manage is depreciated. Run `activityrelay [command]` instead.')
+ cli(prog_name='activityrelay')
diff --git a/relay/misc.py b/relay/misc.py
index 9e8f035..aa44956 100644
--- a/relay/misc.py
+++ b/relay/misc.py
@@ -9,23 +9,13 @@ import socket
from aiohttp.web import Response as AiohttpResponse
from collections.abc import Sequence
from datetime import datetime
+from importlib.resources import files as pkgfiles
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypedDict, TypeVar
from uuid import uuid4
-try:
- from importlib.resources import files as pkgfiles
-
-except ImportError:
- from importlib_resources import files as pkgfiles # type: ignore
-
-try:
- from typing import Self
-
-except ImportError:
- from typing_extensions import Self
-
if TYPE_CHECKING:
+ from typing import Self
from .application import Application
@@ -72,6 +62,27 @@ SOFTWARE = (
'gotosocial'
)
+JSON_PATHS: tuple[str, ...] = (
+ '/api/v1',
+ '/actor',
+ '/inbox',
+ '/outbox',
+ '/following',
+ '/followers',
+ '/.well-known',
+ '/nodeinfo',
+ '/oauth/token',
+ '/oauth/revoke'
+)
+
+TOKEN_PATHS: tuple[str, ...] = (
+ '/logout',
+ '/admin',
+ '/api',
+ '/oauth/authorize',
+ '/oauth/revoke'
+)
+
def boolean(value: Any) -> bool:
if isinstance(value, str):
@@ -252,9 +263,9 @@ class Response(AiohttpResponse):
@classmethod
- def new_redir(cls: type[Self], path: str) -> Self:
+ def new_redir(cls: type[Self], path: str, status: int = 307) -> Self:
body = f'Redirect to
{path}'
- return cls.new(body, 302, {'Location': path})
+ return cls.new(body, status, {'Location': path}, ctype = 'html')
@property
diff --git a/relay/processors.py b/relay/processors.py
index cd742ec..57e9222 100644
--- a/relay/processors.py
+++ b/relay/processors.py
@@ -34,7 +34,7 @@ async def handle_relay(view: ActorView, conn: Connection) -> None:
logging.debug('>> relay: %s', message)
for instance in conn.distill_inboxes(view.message):
- view.app.push_message(instance["inbox"], message, instance)
+ view.app.push_message(instance.inbox, message, instance)
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
@@ -52,13 +52,13 @@ async def handle_forward(view: ActorView, conn: Connection) -> None:
logging.debug('>> forward: %s', message)
for instance in conn.distill_inboxes(view.message):
- view.app.push_message(instance["inbox"], view.message, instance)
+ view.app.push_message(instance.inbox, view.message, instance)
view.cache.set('handle-relay', view.message.id, message.id, 'str')
async def handle_follow(view: ActorView, conn: Connection) -> None:
- nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain)
+ nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain, force = True)
software = nodeinfo.sw_name if nodeinfo else None
config = conn.get_config_all()
@@ -177,7 +177,7 @@ async def handle_undo(view: ActorView, conn: Connection) -> None:
return
# prevent past unfollows from removing an instance
- if view.instance['followid'] and view.instance['followid'] != view.message.object_id:
+ if view.instance.followid and view.instance.followid != view.message.object_id:
return
with conn.transaction():
@@ -221,18 +221,18 @@ async def run_processor(view: ActorView) -> None:
with view.database.session() as conn:
if view.instance:
- if not view.instance['software']:
- if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])):
+ if not view.instance.software:
+ if (nodeinfo := await view.client.fetch_nodeinfo(view.instance.domain)):
with conn.transaction():
view.instance = conn.put_inbox(
- domain = view.instance['domain'],
+ domain = view.instance.domain,
software = nodeinfo.sw_name
)
- if not view.instance['actor']:
+ if not view.instance.actor:
with conn.transaction():
view.instance = conn.put_inbox(
- domain = view.instance['domain'],
+ domain = view.instance.domain,
actor = view.actor.id
)
diff --git a/relay/template.py b/relay/template.py
index ef25f92..3ee2855 100644
--- a/relay/template.py
+++ b/relay/template.py
@@ -2,6 +2,7 @@ from __future__ import annotations
import textwrap
+from aiohttp.web import Request
from collections.abc import Callable
from hamlish_jinja import HamlishExtension
from jinja2 import Environment, FileSystemLoader
@@ -13,13 +14,15 @@ from typing import TYPE_CHECKING, Any
from . import __version__
from .misc import get_resource
-from .views.base import View
if TYPE_CHECKING:
from .application import Application
class Template(Environment):
+ _render_markdown: Callable[[str], str]
+
+
def __init__(self, app: Application):
Environment.__init__(self,
autoescape = True,
@@ -40,12 +43,12 @@ class Template(Environment):
self.hamlish_mode = 'indented'
- def render(self, path: str, view: View | None = None, **context: Any) -> str:
+ def render(self, path: str, request: Request, **context: Any) -> str:
with self.app.database.session(False) as conn:
config = conn.get_config_all()
new_context = {
- 'view': view,
+ 'request': request,
'domain': self.app.config.domain,
'version': __version__,
'config': config,
@@ -56,7 +59,7 @@ class Template(Environment):
def render_markdown(self, text: str) -> str:
- return self._render_markdown(text) # type: ignore
+ return self._render_markdown(text)
class MarkdownExtension(Extension):
diff --git a/relay/views/activitypub.py b/relay/views/activitypub.py
index b19b7e1..4551c88 100644
--- a/relay/views/activitypub.py
+++ b/relay/views/activitypub.py
@@ -1,26 +1,23 @@
-from __future__ import annotations
-
import aputils
import traceback
-import typing
+
+from aiohttp.web import Request
+from blib import HttpError
from .base import View, register_route
from .. import logger as logging
+from ..database import schema
from ..misc import Message, Response
from ..processors import run_processor
-if typing.TYPE_CHECKING:
- from aiohttp.web import Request
- from bsql import Row
-
@register_route('/actor', '/inbox')
class ActorView(View):
signature: aputils.Signature
message: Message
actor: Message
- instancce: Row
+ instance: schema.Instance
signer: aputils.Signer
@@ -43,11 +40,10 @@ class ActorView(View):
async def post(self, request: Request) -> Response:
- if response := await self.get_post_data():
- return response
+ await self.get_post_data()
with self.database.session() as conn:
- self.instance = conn.get_inbox(self.actor.shared_inbox)
+ self.instance = conn.get_inbox(self.actor.shared_inbox) # type: ignore[assignment]
# reject if actor is banned
if conn.get_domain_ban(self.actor.domain):
@@ -69,13 +65,13 @@ class ActorView(View):
return Response.new(status = 202)
- async def get_post_data(self) -> Response | None:
+ async def get_post_data(self) -> None:
try:
self.signature = aputils.Signature.parse(self.request.headers['signature'])
except KeyError:
logging.verbose('Missing signature header')
- return Response.new_error(400, 'missing signature header', 'json')
+ raise HttpError(400, 'missing signature header')
try:
message: Message | None = await self.request.json(loads = Message.parse)
@@ -83,46 +79,47 @@ class ActorView(View):
except Exception:
traceback.print_exc()
logging.verbose('Failed to parse inbox message')
- return Response.new_error(400, 'failed to parse message', 'json')
+ raise HttpError(400, 'failed to parse message')
if message is None:
logging.verbose('empty message')
- return Response.new_error(400, 'missing message', 'json')
+ raise HttpError(400, 'missing message')
self.message = message
if 'actor' not in self.message:
logging.verbose('actor not in message')
- return Response.new_error(400, 'no actor in message', 'json')
+ raise HttpError(400, 'no actor in message')
- actor: Message | None = await self.client.get(self.signature.keyid, True, Message)
+ try:
+ self.actor = await self.client.get(self.signature.keyid, True, Message)
- if actor is None:
+ except HttpError:
# ld signatures aren't handled atm, so just ignore it
if self.message.type == 'Delete':
logging.verbose('Instance sent a delete which cannot be handled')
- return Response.new(status=202)
+ raise HttpError(202, '')
- logging.verbose(f'Failed to fetch actor: {self.signature.keyid}')
- return Response.new_error(400, 'failed to fetch actor', 'json')
+ logging.verbose('Failed to fetch actor: %s', self.signature.keyid)
+ raise HttpError(400, 'failed to fetch actor')
- self.actor = actor
+ except Exception:
+ traceback.print_exc()
+ raise HttpError(500, 'unexpected error when fetching actor')
try:
self.signer = self.actor.signer
except KeyError:
logging.verbose('Actor missing public key: %s', self.signature.keyid)
- return Response.new_error(400, 'actor missing public key', 'json')
+ raise HttpError(400, 'actor missing public key')
try:
await self.signer.validate_request_async(self.request)
except aputils.SignatureFailureError as e:
logging.verbose('signature validation failed for "%s": %s', self.actor.id, e)
- return Response.new_error(401, str(e), 'json')
-
- return None
+ raise HttpError(401, str(e))
@register_route('/outbox')
diff --git a/relay/views/api.py b/relay/views/api.py
index 70a9f0e..e7cb5fb 100644
--- a/relay/views/api.py
+++ b/relay/views/api.py
@@ -1,17 +1,20 @@
+import traceback
+
from aiohttp.web import Request, middleware
from argon2.exceptions import VerifyMismatchError
+from blib import HttpError, convert_to_boolean
from collections.abc import Awaitable, Callable, Sequence
-from typing import Any
from urllib.parse import urlparse
from .base import View, register_route
from .. import __version__
-from ..database import ConfigData
-from ..misc import Message, Response, boolean, get_app
+from ..database import ConfigData, schema
+from ..misc import Message, Response, boolean
-ALLOWED_HEADERS = {
+DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob'
+ALLOWED_HEADERS: set[str] = {
'accept',
'authorization',
'content-type'
@@ -19,6 +22,8 @@ ALLOWED_HEADERS = {
PUBLIC_API_PATHS: Sequence[tuple[str, str]] = (
('GET', '/api/v1/relay'),
+ ('POST', '/api/v1/app'),
+ ('POST', '/api/v1/login'),
('POST', '/api/v1/token')
)
@@ -34,64 +39,184 @@ def check_api_path(method: str, path: str) -> bool:
async def handle_api_path(
request: Request,
handler: Callable[[Request], Awaitable[Response]]) -> Response:
- try:
- if (token := request.cookies.get('user-token')):
- request['token'] = token
- else:
- request['token'] = request.headers['Authorization'].replace('Bearer', '').strip()
-
- with get_app().database.session() as conn:
- request['user'] = conn.get_user_by_token(request['token'])
-
- except (KeyError, ValueError):
- request['token'] = None
- request['user'] = None
+ if not request.path.startswith('/api') or request.path == '/api/doc':
+ return await handler(request)
if request.method != "OPTIONS" and check_api_path(request.method, request.path):
- if not request['token']:
- return Response.new_error(401, 'Missing token', 'json')
+ if request['token'] is None:
+ raise HttpError(401, 'Missing token')
- if not request['user']:
- return Response.new_error(401, 'Invalid token', 'json')
+ if request['user'] is None:
+ raise HttpError(401, 'Invalid token')
response = await handler(request)
-
- if request.path.startswith('/api'):
- response.headers['Access-Control-Allow-Origin'] = '*'
- response.headers['Access-Control-Allow-Headers'] = ', '.join(ALLOWED_HEADERS)
+ response.headers['Access-Control-Allow-Origin'] = '*'
+ response.headers['Access-Control-Allow-Headers'] = ', '.join(ALLOWED_HEADERS)
return response
-@register_route('/api/v1/token')
-class Login(View):
+@register_route('/oauth/authorize')
+@register_route('/api/oauth/authorize')
+class OauthAuthorize(View):
async def get(self, request: Request) -> Response:
- return Response.new({'message': 'Token valid'}, ctype = 'json')
+ data = await self.get_api_data(['response_type', 'client_id', 'redirect_uri'], [])
+
+ if data['response_type'] != 'code':
+ raise HttpError(400, 'Response type is not "code"')
+
+ with self.database.session(True) as conn:
+ with conn.select('apps', client_id = data['client_id']) as cur:
+ if (app := cur.one(schema.App)) is None:
+ raise HttpError(404, 'Could not find app')
+
+ if app.token is not None:
+ raise HttpError(400, 'Application has already been authorized')
+
+ if app.auth_code is not None:
+ context = {'application': app}
+ html = self.template.render(
+ 'page/authorize_show.haml', self.request, **context
+ )
+
+ return Response.new(html, ctype = 'html')
+
+ if data['redirect_uri'] != app.redirect_uri:
+ raise HttpError(400, 'redirect_uri does not match application')
+
+ context = {'application': app}
+ html = self.template.render('page/authorize_new.haml', self.request, **context)
+ return Response.new(html, ctype = 'html')
async def post(self, request: Request) -> Response:
- data = await self.get_api_data(['username', 'password'], [])
+ data = await self.get_api_data(
+ ['client_id', 'client_secret', 'redirect_uri', 'response'], []
+ )
- if isinstance(data, Response):
- return data
+ with self.database.session(True) as conn:
+ if (app := conn.get_app(data['client_id'], data['client_secret'])) is None:
+ return Response.new_error(404, 'Could not find app', 'json')
+
+ if convert_to_boolean(data['response']):
+ if app.token is not None:
+ raise HttpError(400, 'Application has already been authorized')
+
+ if app.auth_code is None:
+ app = conn.update_app(app, request['user'], True)
+
+ if app.redirect_uri == DEFAULT_REDIRECT:
+ context = {'application': app}
+ html = self.template.render(
+ 'page/authorize_show.haml', self.request, **context
+ )
+
+ return Response.new(html, ctype = 'html')
+
+ return Response.new_redir(f'{app.redirect_uri}?code={app.auth_code}')
+
+ if not conn.del_app(app.client_id, app.client_secret):
+ raise HttpError(404, 'App not found')
+
+ return Response.new_redir('/')
+
+
+@register_route('/oauth/token')
+@register_route('/api/oauth/token')
+class OauthToken(View):
+ async def post(self, request: Request) -> Response:
+ data = await self.get_api_data(
+ ['grant_type', 'code', 'client_id', 'client_secret', 'redirect_uri'], []
+ )
+
+ if data['grant_type'] != 'authorization_code':
+ raise HttpError(400, 'Invalid grant type')
+
+ with self.database.session(True) as conn:
+ if (app := conn.get_app(data['client_id'], data['client_secret'])) is None:
+ raise HttpError(404, 'Application not found')
+
+ if app.auth_code != data['code']:
+ raise HttpError(400, 'Invalid authentication code')
+
+ if app.redirect_uri != data['redirect_uri']:
+ raise HttpError(400, 'Invalid redirect uri')
+
+ app = conn.update_app(app, request['user'], False)
+
+ return Response.new(app.get_api_data(True), ctype = 'json')
+
+
+@register_route('/oauth/revoke')
+@register_route('/api/oauth/revoke')
+class OauthRevoke(View):
+ async def post(self, request: Request) -> Response:
+ data = await self.get_api_data(['client_id', 'client_secret', 'token'], [])
+
+ with self.database.session(True) as conn:
+ if (app := conn.get_app(**data)) is None:
+ raise HttpError(404, 'Could not find token')
+
+ if app.user != request['token'].username:
+ raise HttpError(403, 'Invalid token')
+
+ if not conn.del_app(**data):
+ raise HttpError(400, 'Token not removed')
+
+ return Response.new({'msg': 'Token deleted'}, ctype = 'json')
+
+
+@register_route('/api/v1/app')
+class App(View):
+ async def get(self, request: Request) -> Response:
+ return Response.new(request['token'].get_api_data(), ctype = 'json')
+
+
+ async def post(self, request: Request) -> Response:
+ data = await self.get_api_data(['name', 'redirect_uri'], ['website'])
+
+ with self.database.session(True) as conn:
+ app = conn.put_app(
+ name = data['name'],
+ redirect_uri = data['redirect_uri'],
+ website = data.get('website')
+ )
+
+ return Response.new(app.get_api_data(), ctype = 'json')
+
+
+ async def delete(self, request: Request) -> Response:
+ data = await self.get_api_data(['client_id', 'client_secret'], [])
+
+ with self.database.session(True) as conn:
+ if not conn.del_app(data['client_id'], data['client_secret'], request['token'].code):
+ raise HttpError(400, 'Token not removed')
+
+ return Response.new({'msg': 'Token deleted'}, ctype = 'json')
+
+
+@register_route('/api/v1/login')
+class Login(View):
+ async def post(self, request: Request) -> Response:
+ data = await self.get_api_data(['username', 'password'], [])
with self.database.session(True) as conn:
if not (user := conn.get_user(data['username'])):
- return Response.new_error(401, 'User not found', 'json')
+ raise HttpError(401, 'User not found')
try:
conn.hasher.verify(user['hash'], data['password'])
except VerifyMismatchError:
- return Response.new_error(401, 'Invalid password', 'json')
+ raise HttpError(401, 'Invalid password')
- token = conn.put_token(data['username'])
+ app = conn.put_app_login(user)
- resp = Response.new({'token': token['code']}, ctype = 'json')
+ resp = Response.new(app.get_api_data(True), ctype = 'json')
resp.set_cookie(
'user-token',
- token['code'],
+ app.token, # type: ignore[arg-type]
max_age = 60 * 60 * 24 * 365,
domain = self.config.domain,
path = '/',
@@ -103,19 +228,12 @@ class Login(View):
return resp
- async def delete(self, request: Request) -> Response:
- with self.database.session() as conn:
- conn.del_token(request['token'])
-
- return Response.new({'message': 'Token revoked'}, ctype = 'json')
-
-
@register_route('/api/v1/relay')
class RelayInfo(View):
async def get(self, request: Request) -> Response:
with self.database.session() as conn:
config = conn.get_config_all()
- inboxes = [row['domain'] for row in conn.get_inboxes()]
+ inboxes = [row.domain for row in conn.get_inboxes()]
data = {
'domain': self.config.domain,
@@ -152,17 +270,16 @@ class Config(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['key', 'value'], [])
-
- if isinstance(data, Response):
- return data
-
data['key'] = data['key'].replace('-', '_')
if data['key'] not in ConfigData.USER_KEYS():
- return Response.new_error(400, 'Invalid key', 'json')
+ raise HttpError(400, 'Invalid key')
with self.database.session() as conn:
- conn.put_config(data['key'], data['value'])
+ value = conn.put_config(data['key'], data['value'])
+
+ if data['key'] == 'log-level':
+ self.app.workers.set_log_level(value)
return Response.new({'message': 'Updated config'}, ctype = 'json')
@@ -170,14 +287,14 @@ class Config(View):
async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['key'], [])
- if isinstance(data, Response):
- return data
-
if data['key'] not in ConfigData.USER_KEYS():
- return Response.new_error(400, 'Invalid key', 'json')
+ raise HttpError(400, 'Invalid key')
with self.database.session() as conn:
- conn.put_config(data['key'], ConfigData.DEFAULT(data['key']))
+ value = conn.put_config(data['key'], ConfigData.DEFAULT(data['key']))
+
+ if data['key'] == 'log-level':
+ self.app.workers.set_log_level(value)
return Response.new({'message': 'Updated config'}, ctype = 'json')
@@ -186,40 +303,46 @@ class Config(View):
class Inbox(View):
async def get(self, request: Request) -> Response:
with self.database.session() as conn:
- data = conn.get_inboxes()
+ data = tuple(conn.get_inboxes())
return Response.new(data, ctype = 'json')
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['actor'], ['inbox', 'software', 'followid'])
-
- if isinstance(data, Response):
- return data
-
data['domain'] = urlparse(data["actor"]).netloc
with self.database.session() as conn:
- if conn.get_inbox(data['domain']):
- return Response.new_error(404, 'Instance already in database', 'json')
+ if conn.get_inbox(data['domain']) is not None:
+ raise HttpError(404, 'Instance already in database')
data['domain'] = data['domain'].encode('idna').decode()
if not data.get('inbox'):
- actor_data: Message | None = await self.client.get(data['actor'], True, Message)
+ try:
+ actor_data = await self.client.get(data['actor'], True, Message)
- if actor_data is None:
- return Response.new_error(500, 'Failed to fetch actor', 'json')
+ except Exception:
+ traceback.print_exc()
+ raise HttpError(500, 'Failed to fetch actor') from None
data['inbox'] = actor_data.shared_inbox
if not data.get('software'):
- nodeinfo = await self.client.fetch_nodeinfo(data['domain'])
-
- if nodeinfo is not None:
+ try:
+ nodeinfo = await self.client.fetch_nodeinfo(data['domain'])
data['software'] = nodeinfo.sw_name
- row = conn.put_inbox(**data) # type: ignore[arg-type]
+ except Exception:
+ pass
+
+ row = conn.put_inbox(
+ domain = data['domain'],
+ actor = data['actor'],
+ inbox = data.get('inbox'),
+ software = data.get('software'),
+ followid = data.get('followid')
+ )
return Response.new(row, ctype = 'json')
@@ -227,16 +350,17 @@ class Inbox(View):
async def patch(self, request: Request) -> Response:
with self.database.session() as conn:
data = await self.get_api_data(['domain'], ['actor', 'software', 'followid'])
-
- if isinstance(data, Response):
- return data
-
data['domain'] = data['domain'].encode('idna').decode()
- if not (instance := conn.get_inbox(data['domain'])):
- return Response.new_error(404, 'Instance with domain not found', 'json')
+ if (instance := conn.get_inbox(data['domain'])) is None:
+ raise HttpError(404, 'Instance with domain not found')
- instance = conn.put_inbox(instance['domain'], **data) # type: ignore[arg-type]
+ instance = conn.put_inbox(
+ instance.domain,
+ actor = data.get('actor'),
+ software = data.get('software'),
+ followid = data.get('followid')
+ )
return Response.new(instance, ctype = 'json')
@@ -244,14 +368,10 @@ class Inbox(View):
async def delete(self, request: Request) -> Response:
with self.database.session() as conn:
data = await self.get_api_data(['domain'], [])
-
- if isinstance(data, Response):
- return data
-
data['domain'] = data['domain'].encode('idna').decode()
if not conn.get_inbox(data['domain']):
- return Response.new_error(404, 'Instance with domain not found', 'json')
+ raise HttpError(404, 'Instance with domain not found')
conn.del_inbox(data['domain'])
@@ -262,43 +382,38 @@ class Inbox(View):
class RequestView(View):
async def get(self, request: Request) -> Response:
with self.database.session() as conn:
- instances = conn.get_requests()
+ instances = tuple(conn.get_requests())
return Response.new(instances, ctype = 'json')
async def post(self, request: Request) -> Response:
- data: dict[str, Any] | Response = await self.get_api_data(['domain', 'accept'], [])
-
- if isinstance(data, Response):
- return data
-
- data['accept'] = boolean(data['accept'])
+ data = await self.get_api_data(['domain', 'accept'], [])
data['domain'] = data['domain'].encode('idna').decode()
try:
with self.database.session(True) as conn:
- instance = conn.put_request_response(data['domain'], data['accept'])
+ instance = conn.put_request_response(data['domain'], boolean(data['accept']))
except KeyError:
- return Response.new_error(404, 'Request not found', 'json')
+ raise HttpError(404, 'Request not found') from None
message = Message.new_response(
host = self.config.domain,
- actor = instance['actor'],
- followid = instance['followid'],
- accept = data['accept']
+ actor = instance.actor,
+ followid = instance.followid,
+ accept = boolean(data['accept'])
)
- self.app.push_message(instance['inbox'], message, instance)
+ self.app.push_message(instance.inbox, message, instance)
- if data['accept'] and instance['software'] != 'mastodon':
+ if data['accept'] and instance.software != 'mastodon':
message = Message.new_follow(
host = self.config.domain,
- actor = instance['actor']
+ actor = instance.actor
)
- self.app.push_message(instance['inbox'], message, instance)
+ self.app.push_message(instance.inbox, message, instance)
resp_message = {'message': 'Request accepted' if data['accept'] else 'Request denied'}
return Response.new(resp_message, ctype = 'json')
@@ -308,24 +423,24 @@ class RequestView(View):
class DomainBan(View):
async def get(self, request: Request) -> Response:
with self.database.session() as conn:
- bans = tuple(conn.execute('SELECT * FROM domain_bans').all())
+ bans = tuple(conn.get_domain_bans())
return Response.new(bans, ctype = 'json')
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['domain'], ['note', 'reason'])
-
- if isinstance(data, Response):
- return data
-
data['domain'] = data['domain'].encode('idna').decode()
with self.database.session() as conn:
- if conn.get_domain_ban(data['domain']):
- return Response.new_error(400, 'Domain already banned', 'json')
+ if conn.get_domain_ban(data['domain']) is not None:
+ raise HttpError(400, 'Domain already banned')
- ban = conn.put_domain_ban(**data)
+ ban = conn.put_domain_ban(
+ domain = data['domain'],
+ reason = data.get('reason'),
+ note = data.get('note')
+ )
return Response.new(ban, ctype = 'json')
@@ -334,18 +449,19 @@ class DomainBan(View):
with self.database.session() as conn:
data = await self.get_api_data(['domain'], ['note', 'reason'])
- if isinstance(data, Response):
- return data
+ if not any([data.get('note'), data.get('reason')]):
+ raise HttpError(400, 'Must include note and/or reason parameters')
data['domain'] = data['domain'].encode('idna').decode()
- if not conn.get_domain_ban(data['domain']):
- return Response.new_error(404, 'Domain not banned', 'json')
+ if conn.get_domain_ban(data['domain']) is None:
+ raise HttpError(404, 'Domain not banned')
- if not any([data.get('note'), data.get('reason')]):
- return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
-
- ban = conn.update_domain_ban(**data)
+ ban = conn.update_domain_ban(
+ domain = data['domain'],
+ reason = data.get('reason'),
+ note = data.get('note')
+ )
return Response.new(ban, ctype = 'json')
@@ -353,14 +469,10 @@ class DomainBan(View):
async def delete(self, request: Request) -> Response:
with self.database.session() as conn:
data = await self.get_api_data(['domain'], [])
-
- if isinstance(data, Response):
- return data
-
data['domain'] = data['domain'].encode('idna').decode()
- if not conn.get_domain_ban(data['domain']):
- return Response.new_error(404, 'Domain not banned', 'json')
+ if conn.get_domain_ban(data['domain']) is None:
+ raise HttpError(404, 'Domain not banned')
conn.del_domain_ban(data['domain'])
@@ -371,7 +483,7 @@ class DomainBan(View):
class SoftwareBan(View):
async def get(self, request: Request) -> Response:
with self.database.session() as conn:
- bans = tuple(conn.execute('SELECT * FROM software_bans').all())
+ bans = tuple(conn.get_software_bans())
return Response.new(bans, ctype = 'json')
@@ -379,14 +491,15 @@ class SoftwareBan(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['name'], ['note', 'reason'])
- if isinstance(data, Response):
- return data
-
with self.database.session() as conn:
- if conn.get_software_ban(data['name']):
- return Response.new_error(400, 'Domain already banned', 'json')
+ if conn.get_software_ban(data['name']) is not None:
+ raise HttpError(400, 'Domain already banned')
- ban = conn.put_software_ban(**data)
+ ban = conn.put_software_ban(
+ name = data['name'],
+ reason = data.get('reason'),
+ note = data.get('note')
+ )
return Response.new(ban, ctype = 'json')
@@ -394,17 +507,18 @@ class SoftwareBan(View):
async def patch(self, request: Request) -> Response:
data = await self.get_api_data(['name'], ['note', 'reason'])
- if isinstance(data, Response):
- return data
+ if not any([data.get('note'), data.get('reason')]):
+ raise HttpError(400, 'Must include note and/or reason parameters')
with self.database.session() as conn:
- if not conn.get_software_ban(data['name']):
- return Response.new_error(404, 'Software not banned', 'json')
+ if conn.get_software_ban(data['name']) is None:
+ raise HttpError(404, 'Software not banned')
- if not any([data.get('note'), data.get('reason')]):
- return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
-
- ban = conn.update_software_ban(**data)
+ ban = conn.update_software_ban(
+ name = data['name'],
+ reason = data.get('reason'),
+ note = data.get('note')
+ )
return Response.new(ban, ctype = 'json')
@@ -412,12 +526,9 @@ class SoftwareBan(View):
async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['name'], [])
- if isinstance(data, Response):
- return data
-
with self.database.session() as conn:
- if not conn.get_software_ban(data['name']):
- return Response.new_error(404, 'Software not banned', 'json')
+ if conn.get_software_ban(data['name']) is None:
+ raise HttpError(404, 'Software not banned')
conn.del_software_ban(data['name'])
@@ -430,7 +541,7 @@ class User(View):
with self.database.session() as conn:
items = []
- for row in conn.execute('SELECT * FROM users'):
+ for row in conn.get_users():
del row['hash']
items.append(row)
@@ -440,41 +551,40 @@ class User(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['username', 'password'], ['handle'])
- if isinstance(data, Response):
- return data
-
with self.database.session() as conn:
- if conn.get_user(data['username']):
- return Response.new_error(404, 'User already exists', 'json')
+ if conn.get_user(data['username']) is not None:
+ raise HttpError(404, 'User already exists')
- user = conn.put_user(**data)
- del user['hash']
+ user = conn.put_user(
+ username = data['username'],
+ password = data['password'],
+ handle = data.get('handle')
+ )
+ del user['hash']
return Response.new(user, ctype = 'json')
async def patch(self, request: Request) -> Response:
data = await self.get_api_data(['username'], ['password', 'handle'])
- if isinstance(data, Response):
- return data
-
with self.database.session(True) as conn:
- user = conn.put_user(**data)
- del user['hash']
+ user = conn.put_user(
+ username = data['username'],
+ password = data['password'],
+ handle = data.get('handle')
+ )
+ del user['hash']
return Response.new(user, ctype = 'json')
async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['username'], [])
- if isinstance(data, Response):
- return data
-
with self.database.session(True) as conn:
- if not conn.get_user(data['username']):
- return Response.new_error(404, 'User does not exist', 'json')
+ if conn.get_user(data['username']) is None:
+ raise HttpError(404, 'User does not exist')
conn.del_user(data['username'])
@@ -485,7 +595,7 @@ class User(View):
class Whitelist(View):
async def get(self, request: Request) -> Response:
with self.database.session() as conn:
- items = tuple(conn.execute('SELECT * FROM whitelist').all())
+ items = tuple(conn.get_domains_whitelist())
return Response.new(items, ctype = 'json')
@@ -493,16 +603,13 @@ class Whitelist(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['domain'], [])
- if isinstance(data, Response):
- return data
-
- data['domain'] = data['domain'].encode('idna').decode()
+ domain = data['domain'].encode('idna').decode()
with self.database.session() as conn:
- if conn.get_domain_whitelist(data['domain']):
- return Response.new_error(400, 'Domain already added to whitelist', 'json')
+ if conn.get_domain_whitelist(domain) is not None:
+ raise HttpError(400, 'Domain already added to whitelist')
- item = conn.put_domain_whitelist(**data)
+ item = conn.put_domain_whitelist(domain)
return Response.new(item, ctype = 'json')
@@ -510,15 +617,12 @@ class Whitelist(View):
async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['domain'], [])
- if isinstance(data, Response):
- return data
-
- data['domain'] = data['domain'].encode('idna').decode()
+ domain = data['domain'].encode('idna').decode()
with self.database.session() as conn:
- if not conn.get_domain_whitelist(data['domain']):
- return Response.new_error(404, 'Domain not in whitelist', 'json')
+ if conn.get_domain_whitelist(domain) is None:
+ raise HttpError(404, 'Domain not in whitelist')
- conn.del_domain_whitelist(data['domain'])
+ conn.del_domain_whitelist(domain)
return Response.new({'message': 'Removed domain from whitelist'}, ctype = 'json')
diff --git a/relay/views/base.py b/relay/views/base.py
index 350016c..624ed9d 100644
--- a/relay/views/base.py
+++ b/relay/views/base.py
@@ -1,10 +1,9 @@
from __future__ import annotations
-from Crypto.Random import get_random_bytes
from aiohttp.abc import AbstractView
from aiohttp.hdrs import METH_ALL as METHODS
-from aiohttp.web import HTTPMethodNotAllowed, Request
-from base64 import b64encode
+from aiohttp.web import Request
+from blib import HttpError
from bsql import Database
from collections.abc import Awaitable, Callable, Generator, Sequence, Mapping
from functools import cached_property
@@ -18,18 +17,12 @@ from ..http_client import HttpClient
from ..misc import Response, get_app
if TYPE_CHECKING:
+ from typing import Self
from ..application import Application
from ..template import Template
-try:
- from typing import Self
-
-except ImportError:
- from typing_extensions import Self
HandlerCallback = Callable[[Request], Awaitable[Response]]
-
-
VIEWS: list[tuple[str, type[View]]] = []
@@ -49,10 +42,10 @@ def register_route(*paths: str) -> Callable[[type[View]], type[View]]:
class View(AbstractView):
def __await__(self) -> Generator[Any, None, Response]:
if self.request.method not in METHODS:
- raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
+ raise HttpError(405, f'"{self.request.method}" method not allowed')
if not (handler := self.handlers.get(self.request.method)):
- raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
+ raise HttpError(405, f'"{self.request.method}" method not allowed')
return self._run_handler(handler).__await__()
@@ -64,7 +57,6 @@ class View(AbstractView):
async def _run_handler(self, handler: HandlerCallback, **kwargs: Any) -> Response:
- self.request['hash'] = b64encode(get_random_bytes(16)).decode('ascii')
return await handler(self.request, **self.request.match_info, **kwargs)
@@ -123,17 +115,18 @@ class View(AbstractView):
async def get_api_data(self,
required: list[str],
- optional: list[str]) -> dict[str, str] | Response:
+ optional: list[str]) -> dict[str, str]:
- if self.request.content_type in {'x-www-form-urlencoded', 'multipart/form-data'}:
+ if self.request.content_type in {'application/x-www-form-urlencoded', 'multipart/form-data'}:
post_data = convert_data(await self.request.post())
+ # post_data = {key: value for key, value in parse_qsl(await self.request.text())}
elif self.request.content_type == 'application/json':
try:
post_data = convert_data(await self.request.json())
except JSONDecodeError:
- return Response.new_error(400, 'Invalid JSON data', 'json')
+ raise HttpError(400, 'Invalid JSON data')
else:
post_data = convert_data(self.request.query)
@@ -145,9 +138,9 @@ class View(AbstractView):
data[key] = post_data[key]
except KeyError as e:
- return Response.new_error(400, f'Missing {str(e)} pararmeter', 'json')
+ raise HttpError(400, f'Missing {str(e)} pararmeter') from None
for key in optional:
- data[key] = post_data.get(key, '')
+ data[key] = post_data.get(key) # type: ignore[assignment]
return data
diff --git a/relay/views/frontend.py b/relay/views/frontend.py
index 5dfb43a..b6dba7b 100644
--- a/relay/views/frontend.py
+++ b/relay/views/frontend.py
@@ -1,18 +1,13 @@
from aiohttp import web
from collections.abc import Awaitable, Callable
from typing import Any
+from urllib.parse import unquote
from .base import View, register_route
from ..database import THEMES
from ..logger import LogLevel
-from ..misc import Response, get_app
-
-
-UNAUTH_ROUTES = {
- '/',
- '/login'
-}
+from ..misc import TOKEN_PATHS, Response
@web.middleware
@@ -20,28 +15,25 @@ async def handle_frontend_path(
request: web.Request,
handler: Callable[[web.Request], Awaitable[Response]]) -> Response:
- app = get_app()
+ if request['user'] is not None and request.path == '/login':
+ return Response.new_redir('/')
- if request.path in UNAUTH_ROUTES or request.path.startswith('/admin'):
- request['token'] = request.cookies.get('user-token')
- request['user'] = None
+ if request.path.startswith(TOKEN_PATHS[:2]) and request['user'] is None:
+ if request.path == '/logout':
+ return Response.new_redir('/')
- if request['token']:
- with app.database.session(False) as conn:
- request['user'] = conn.get_user_by_token(request['token'])
+ response = Response.new_redir(f'/login?redir={request.path}')
- if request['user'] and request.path == '/login':
- return Response.new('', 302, {'Location': '/'})
-
- if not request['user'] and request.path.startswith('/admin'):
- response = Response.new('', 302, {'Location': f'/login?redir={request.path}'})
+ if request['token'] is not None:
response.del_cookie('user-token')
- return response
+
+ return response
response = await handler(request)
- if not request.path.startswith('/api') and not request['user'] and request['token']:
- response.del_cookie('user-token')
+ if not request.path.startswith('/api'):
+ if request['user'] is None and request['token'] is not None:
+ response.del_cookie('user-token')
return response
@@ -54,14 +46,15 @@ class HomeView(View):
'instances': tuple(conn.get_inboxes())
}
- data = self.template.render('page/home.haml', self, **context)
+ data = self.template.render('page/home.haml', self.request, **context)
return Response.new(data, ctype='html')
@register_route('/login')
class Login(View):
async def get(self, request: web.Request) -> Response:
- data = self.template.render('page/login.haml', self)
+ redir = unquote(request.query.get('redir', '/'))
+ data = self.template.render('page/login.haml', self.request, redir = redir)
return Response.new(data, ctype = 'html')
@@ -69,7 +62,7 @@ class Login(View):
class Logout(View):
async def get(self, request: web.Request) -> Response:
with self.database.session(True) as conn:
- conn.del_token(request['token'])
+ conn.del_app(request['token'].client_id, request['token'].client_secret)
resp = Response.new_redir('/')
resp.del_cookie('user-token', domain = self.config.domain, path = '/')
@@ -79,7 +72,7 @@ class Logout(View):
@register_route('/admin')
class Admin(View):
async def get(self, request: web.Request) -> Response:
- return Response.new('', 302, {'Location': '/admin/instances'})
+ return Response.new_redir(f'/login?redir={request.path}', 301)
@register_route('/admin/instances')
@@ -101,7 +94,7 @@ class AdminInstances(View):
if message:
context['message'] = message
- data = self.template.render('page/admin-instances.haml', self, **context)
+ data = self.template.render('page/admin-instances.haml', self.request, **context)
return Response.new(data, ctype = 'html')
@@ -123,7 +116,7 @@ class AdminWhitelist(View):
if message:
context['message'] = message
- data = self.template.render('page/admin-whitelist.haml', self, **context)
+ data = self.template.render('page/admin-whitelist.haml', self.request, **context)
return Response.new(data, ctype = 'html')
@@ -145,7 +138,7 @@ class AdminDomainBans(View):
if message:
context['message'] = message
- data = self.template.render('page/admin-domain_bans.haml', self, **context)
+ data = self.template.render('page/admin-domain_bans.haml', self.request, **context)
return Response.new(data, ctype = 'html')
@@ -167,7 +160,7 @@ class AdminSoftwareBans(View):
if message:
context['message'] = message
- data = self.template.render('page/admin-software_bans.haml', self, **context)
+ data = self.template.render('page/admin-software_bans.haml', self.request, **context)
return Response.new(data, ctype = 'html')
@@ -189,7 +182,7 @@ class AdminUsers(View):
if message:
context['message'] = message
- data = self.template.render('page/admin-users.haml', self, **context)
+ data = self.template.render('page/admin-users.haml', self.request, **context)
return Response.new(data, ctype = 'html')
@@ -199,10 +192,21 @@ class AdminConfig(View):
context: dict[str, Any] = {
'themes': tuple(THEMES.keys()),
'levels': tuple(level.name for level in LogLevel),
- 'message': message
+ 'message': message,
+ 'desc': {
+ "name": "Name of the relay to be displayed in the header of the pages and in " +
+ "the actor endpoint.", # noqa: E131
+ "note": "Description of the relay to be displayed on the front page and as the " +
+ "bio in the actor endpoint.",
+ "theme": "Color theme to use on the web pages.",
+ "log_level": "Minimum level of logging messages to print to the console.",
+ "whitelist_enabled": "Only allow instances in the whitelist to be able to follow.",
+ "approval_required": "Require instances not on the whitelist to be approved by " +
+ "and admin. The `whitelist-enabled` setting is ignored when this is enabled."
+ }
}
- data = self.template.render('page/admin-config.haml', self, **context)
+ data = self.template.render('page/admin-config.haml', self.request, **context)
return Response.new(data, ctype = 'html')
@@ -240,5 +244,5 @@ class ThemeCss(View):
except KeyError:
return Response.new('Invalid theme', 404)
- data = self.template.render('variables.css', self, **context)
+ data = self.template.render('variables.css', self.request, **context)
return Response.new(data, ctype = 'css')
diff --git a/relay/workers.py b/relay/workers.py
new file mode 100644
index 0000000..31cf4c3
--- /dev/null
+++ b/relay/workers.py
@@ -0,0 +1,139 @@
+from __future__ import annotations
+
+import asyncio
+import traceback
+
+from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
+from asyncio.exceptions import TimeoutError as AsyncTimeoutError
+from dataclasses import dataclass
+from multiprocessing import Event, Process, Queue, Value
+from multiprocessing.queues import Queue as QueueType
+from multiprocessing.sharedctypes import Synchronized
+from multiprocessing.synchronize import Event as EventType
+from pathlib import Path
+from queue import Empty
+from urllib.parse import urlparse
+
+from . import application, logger as logging
+from .database.schema import Instance
+from .http_client import HttpClient
+from .misc import IS_WINDOWS, Message, get_app
+
+
+@dataclass
+class PostItem:
+ inbox: str
+ message: Message
+ instance: Instance | None
+
+ @property
+ def domain(self) -> str:
+ return urlparse(self.inbox).netloc
+
+
+class PushWorker(Process):
+ client: HttpClient
+
+
+ def __init__(self, queue: QueueType[PostItem], log_level: Synchronized[int]) -> None:
+ Process.__init__(self)
+
+ self.queue: QueueType[PostItem] = queue
+ self.shutdown: EventType = Event()
+ self.path: Path = get_app().config.path
+ self.log_level: Synchronized[int] = log_level
+ self._log_level_changed: EventType = Event()
+
+
+ def stop(self) -> None:
+ self.shutdown.set()
+
+
+ def run(self) -> None:
+ asyncio.run(self.handle_queue())
+
+
+ async def handle_queue(self) -> None:
+ if IS_WINDOWS:
+ app = application.Application(self.path)
+ self.client = app.client
+
+ self.client.open()
+ app.database.connect()
+ app.cache.setup()
+
+ else:
+ self.client = HttpClient()
+ self.client.open()
+
+ logging.verbose("[%i] Starting worker", self.pid)
+
+ while not self.shutdown.is_set():
+ try:
+ if self._log_level_changed.is_set():
+ logging.set_level(logging.LogLevel.parse(self.log_level.value))
+ self._log_level_changed.clear()
+
+ item = self.queue.get(block=True, timeout=0.1)
+ asyncio.create_task(self.handle_post(item))
+
+ except Empty:
+ await asyncio.sleep(0)
+
+ except Exception:
+ traceback.print_exc()
+
+ if IS_WINDOWS:
+ app.database.disconnect()
+ app.cache.close()
+
+ await self.client.close()
+
+
+ async def handle_post(self, item: PostItem) -> None:
+ try:
+ await self.client.post(item.inbox, item.message, item.instance)
+
+ except AsyncTimeoutError:
+ logging.error('Timeout when pushing to %s', item.domain)
+
+ except ClientConnectionError as e:
+ logging.error('Failed to connect to %s for message push: %s', item.domain, str(e))
+
+ except ClientSSLError as e:
+ logging.error('SSL error when pushing to %s: %s', item.domain, str(e))
+
+
+class PushWorkers(list[PushWorker]):
+ def __init__(self, count: int) -> None:
+ self.queue: QueueType[PostItem] = Queue()
+ self._log_level: Synchronized[int] = Value("i", logging.get_level())
+ self._count: int = count
+
+
+ def push_message(self, inbox: str, message: Message, instance: Instance) -> None:
+ self.queue.put(PostItem(inbox, message, instance))
+
+
+ def set_log_level(self, value: logging.LogLevel) -> None:
+ self._log_level.value = value
+
+ for worker in self:
+ worker._log_level_changed.set()
+
+
+ def start(self) -> None:
+ if len(self) > 0:
+ return
+
+ for _ in range(self._count):
+ worker = PushWorker(self.queue, self._log_level)
+ worker.start()
+ self.append(worker)
+
+
+ def stop(self) -> None:
+ for worker in self:
+ worker.stop()
+
+ self.clear()