\n';
+ text += `
\n`;
+ text += `
\n`;
+ text += `
\n`;
+ text += `
\n`;
+ text += `
`;
+ text += '';
+
+ return text;
+}
+
+
+function add_row_listeners(row) {
+ row.querySelector(".update-ban").addEventListener("click", async (event) => {
+ await update_ban(row.id);
+ });
+
+ row.querySelector(".remove a").addEventListener("click", async (event) => {
+ event.preventDefault();
+ await unban(row.id);
+ });
+}
+
+
+async function ban() {
+ var elems = {
+ name: document.getElementById("new-name"),
+ reason: document.getElementById("new-reason"),
+ note: document.getElementById("new-note")
+ }
+
+ var values = {
+ name: elems.name.value.trim(),
+ reason: elems.reason.value,
+ note: elems.note.value
+ }
+
+ if (values.name === "") {
+ toast("Domain is required");
+ return;
+ }
+
+ try {
+ var ban = await request("POST", "v1/software_ban", values);
+
+ } catch (err) {
+ toast(err);
+ return
+ }
+
+ var row = append_table_row(document.getElementById("bans"), ban.name, {
+ name: create_ban_object(ban.name, ban.reason, ban.note),
+ date: get_date_string(ban.created),
+ remove: `
✖`
+ });
+
+ add_row_listeners(row);
+
+ elems.name.value = null;
+ elems.reason.value = null;
+ elems.note.value = null;
+
+ document.querySelector("details.section").open = false;
+ toast("Banned software", "message");
+}
+
+
+async function update_ban(name) {
+ var row = document.getElementById(name);
+
+ var elems = {
+ "reason": row.querySelector("textarea.reason"),
+ "note": row.querySelector("textarea.note")
+ }
+
+ var values = {
+ "name": name,
+ "reason": elems.reason.value,
+ "note": elems.note.value
+ }
+
+ try {
+ await request("PATCH", "v1/software_ban", values)
+
+ } catch (error) {
+ toast(error);
+ return;
+ }
+
+ row.querySelector("details").open = false;
+ toast("Updated software ban", "message");
+}
+
+
+async function unban(name) {
+ try {
+ await request("DELETE", "v1/software_ban", {"name": name});
+
+ } catch (error) {
+ toast(error);
+ return;
+ }
+
+ document.getElementById(name).remove();
+ toast("Unbanned software", "message");
+}
+
+
+document.querySelector("#new-ban").addEventListener("click", async (event) => {
+ await ban();
+});
+
+for (var row of document.querySelector("#bans").rows) {
+ if (!row.querySelector(".update-ban")) {
+ continue;
+ }
+
+ add_row_listeners(row);
+}
diff --git a/relay/frontend/style.css b/relay/frontend/static/style.css
similarity index 74%
rename from relay/frontend/style.css
rename to relay/frontend/static/style.css
index f2a6fe1..635aa55 100644
--- a/relay/frontend/style.css
+++ b/relay/frontend/static/style.css
@@ -23,11 +23,29 @@ details summary {
cursor: pointer;
}
+fieldset {
+ margin-left: 0px;
+ margin-right: 0px;
+}
+
+fieldset > *:nth-child(2) {
+ margin-top: 0px !important;
+}
+
form input[type="submit"] {
display: block;
margin: 0 auto;
}
+legend {
+ background-color: var(--table-background);
+ padding: 5px;
+ border: 1px solid var(--border);
+ border-radius: 5px;
+ font-size: 10pt;
+ font-weight: bold;
+}
+
p {
line-height: 1em;
margin: 0px;
@@ -91,6 +109,17 @@ textarea {
margin: 0px auto;
}
+#content .title {
+ font-size: 24px;
+ text-align: center;
+ font-weight: bold;
+ margin-bottom: 10px;
+}
+
+#content .title:not(:first-child) {
+ margin-top: 10px;
+}
+
#header {
display: grid;
grid-template-columns: 50px auto 50px;
@@ -175,6 +204,37 @@ textarea {
text-align: center;
}
+#notifications {
+ position: fixed;
+ top: 40px;
+ left: 50%;
+ transform: translateX(-50%);
+}
+
+#notifications li {
+ position: relative;
+ overflow: hidden;
+ list-style: none;
+ border-radius: 5px;
+ padding: 5px;;
+ margin-bottom: var(--spacing);
+ animation: show_toast 0.3s ease forwards;
+ display: grid;
+ grid-template-columns: auto max-content;
+ grid-gap: 5px;
+ align-items: center;
+}
+
+#notifications a {
+ font-size: 1.5em;
+ line-height: 1em;
+ text-decoration: none;
+}
+
+#notifications li.hide {
+ animation: hide_toast 0.3s ease forwards;
+}
+
#footer {
display: grid;
grid-template-columns: auto auto;
@@ -193,15 +253,6 @@ textarea {
align-items: center;
}
-#data-table td:first-child {
- width: 100%;
-}
-
-#data-table .date {
- width: max-content;
- text-align: right;
-}
-
.button {
background-color: var(--primary);
border: 1px solid var(--primary);
@@ -220,6 +271,15 @@ textarea {
grid-template-columns: max-content auto;
}
+.data-table td:first-child {
+ width: 100%;
+}
+
+.data-table .date {
+ width: max-content;
+ text-align: right;
+}
+
.error, .message {
text-align: center;
}
@@ -267,6 +327,44 @@ textarea {
}
+@keyframes show_toast {
+ 0% {
+ transform: translateX(100%);
+ }
+
+ 40% {
+ transform: translateX(-5%);
+ }
+
+ 80% {
+ transform: translateX(0%);
+ }
+
+ 100% {
+ transform: translateX(-10px);
+ }
+}
+
+
+@keyframes hide_toast {
+ 0% {
+ transform: translateX(-10px);
+ }
+
+ 40% {
+ transform: translateX(0%);
+ }
+
+ 80% {
+ transform: translateX(-5%);
+ }
+
+ 100% {
+ transform: translateX(calc(100% + 20px));
+ }
+}
+
+
@media (max-width: 1026px) {
body {
margin: 0px;
diff --git a/relay/frontend/static/user.js b/relay/frontend/static/user.js
new file mode 100644
index 0000000..9c74359
--- /dev/null
+++ b/relay/frontend/static/user.js
@@ -0,0 +1,85 @@
+function add_row_listeners(row) {
+ row.querySelector(".remove a").addEventListener("click", async (event) => {
+ event.preventDefault();
+ await del_user(row.id);
+ });
+}
+
+
+async function add_user() {
+ var elems = {
+ username: document.getElementById("new-username"),
+ password: document.getElementById("new-password"),
+ password2: document.getElementById("new-password2"),
+ handle: document.getElementById("new-handle")
+ }
+
+ var values = {
+ username: elems.username.value.trim(),
+ password: elems.password.value.trim(),
+ password2: elems.password2.value.trim(),
+ handle: elems.handle.value.trim()
+ }
+
+ if (values.username === "" | values.password === "" | values.password2 === "") {
+ toast("Username, password, and password2 are required");
+ return;
+ }
+
+ if (values.password !== values.password2) {
+ toast("Passwords do not match");
+ return;
+ }
+
+ try {
+ var user = await request("POST", "v1/user", values);
+
+ } catch (err) {
+ toast(err);
+ return
+ }
+
+ var row = append_table_row(document.querySelector("fieldset.section table"), user.username, {
+ domain: user.username,
+ handle: user.handle ? self.handle : "n/a",
+ date: get_date_string(user.created),
+ remove: `
✖`
+ });
+
+ add_row_listeners(row);
+
+ elems.username.value = null;
+ elems.password.value = null;
+ elems.password2.value = null;
+ elems.handle.value = null;
+
+ document.querySelector("details.section").open = false;
+ toast("Created user", "message");
+}
+
+
+async function del_user(username) {
+ try {
+ await request("DELETE", "v1/user", {"username": username});
+
+ } catch (error) {
+ toast(error);
+ return;
+ }
+
+ document.getElementById(username).remove();
+ toast("Deleted user", "message");
+}
+
+
+document.querySelector("#new-user").addEventListener("click", async (event) => {
+ await add_user();
+});
+
+for (var row of document.querySelector("#users").rows) {
+ if (!row.querySelector(".remove a")) {
+ continue;
+ }
+
+ add_row_listeners(row);
+}
diff --git a/relay/frontend/static/whitelist.js b/relay/frontend/static/whitelist.js
new file mode 100644
index 0000000..70d4db1
--- /dev/null
+++ b/relay/frontend/static/whitelist.js
@@ -0,0 +1,64 @@
+function add_row_listeners(row) {
+ row.querySelector(".remove a").addEventListener("click", async (event) => {
+ event.preventDefault();
+ await del_whitelist(row.id);
+ });
+}
+
+
+async function add_whitelist() {
+ var domain_elem = document.getElementById("new-domain");
+ var domain = domain_elem.value.trim();
+
+ if (domain === "") {
+ toast("Domain is required");
+ return;
+ }
+
+ try {
+ var item = await request("POST", "v1/whitelist", {"domain": domain});
+
+ } catch (err) {
+ toast(err);
+ return;
+ }
+
+ var row = append_table_row(document.getElementById("whitelist"), item.domain, {
+ domain: item.domain,
+ date: get_date_string(item.created),
+ remove: `
✖`
+ });
+
+ add_row_listeners(row);
+
+ domain_elem.value = null;
+ document.querySelector("details.section").open = false;
+ toast("Added domain", "message");
+}
+
+
+async function del_whitelist(domain) {
+ try {
+ await request("DELETE", "v1/whitelist", {"domain": domain});
+
+ } catch (error) {
+ toast(error);
+ return;
+ }
+
+ document.getElementById(domain).remove();
+ toast("Removed domain", "message");
+}
+
+
+document.querySelector("#new-item").addEventListener("click", async (event) => {
+ await add_whitelist();
+});
+
+for (var row of document.querySelector("fieldset.section table").rows) {
+ if (!row.querySelector(".remove a")) {
+ continue;
+ }
+
+ add_row_listeners(row);
+}
diff --git a/relay/http_client.py b/relay/http_client.py
index 7e7bbd9..04533c5 100644
--- a/relay/http_client.py
+++ b/relay/http_client.py
@@ -7,7 +7,7 @@ import typing
from aiohttp import ClientSession, ClientTimeout, TCPConnector
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
-from aputils.objects import Nodeinfo, WellKnownNodeinfo
+from aputils import AlgorithmType, JsonBase, Nodeinfo, ObjectType, WellKnownNodeinfo
from json.decoder import JSONDecodeError
from urllib.parse import urlparse
@@ -17,12 +17,13 @@ from .misc import MIMETYPES, Message, get_app
if typing.TYPE_CHECKING:
from aputils import Signer
- from tinysql import Row
+ from bsql import Row
from typing import Any
from .application import Application
from .cache import Cache
+T = typing.TypeVar('T', bound = JsonBase)
HEADERS = {
'Accept': f'{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9',
'User-Agent': f'ActivityRelay/{__version__}'
@@ -33,12 +34,12 @@ class HttpClient:
def __init__(self, limit: int = 100, timeout: int = 10):
self.limit = limit
self.timeout = timeout
- self._conn = None
- self._session = None
+ self._conn: TCPConnector | None = None
+ self._session: ClientSession | None = None
async def __aenter__(self) -> HttpClient:
- await self.open()
+ self.open()
return self
@@ -61,7 +62,7 @@ class HttpClient:
return self.app.signer
- async def open(self) -> None:
+ def open(self) -> None:
if self._session:
return
@@ -79,23 +80,19 @@ class HttpClient:
async def close(self) -> None:
- if not self._session:
- return
+ if self._session:
+ await self._session.close()
- await self._session.close()
- await self._conn.close()
+ if self._conn:
+ await self._conn.close()
self._conn = None
self._session = None
- async def get(self, # pylint: disable=too-many-branches
- url: str,
- sign_headers: bool = False,
- loads: callable = json.loads,
- force: bool = False) -> dict | None:
-
- await self.open()
+ async def _get(self, url: str, sign_headers: bool, force: bool) -> dict[str, Any] | None:
+ if not self._session:
+ raise RuntimeError('Client not open')
try:
url, _ = url.split('#', 1)
@@ -105,10 +102,8 @@ class HttpClient:
if not force:
try:
- item = self.cache.get('request', url)
-
- if not item.older_than(48):
- return loads(item.value)
+ if not (item := self.cache.get('request', url)).older_than(48):
+ return json.loads(item.value)
except KeyError:
logging.verbose('No cached data for url: %s', url)
@@ -116,38 +111,39 @@ class HttpClient:
headers = {}
if sign_headers:
- self.signer.sign_headers('GET', url, algorithm = 'original')
+ headers = self.signer.sign_headers('GET', url, algorithm = AlgorithmType.HS2019)
try:
logging.debug('Fetching resource: %s', url)
- async with self._session.get(url, headers=headers) as resp:
- ## Not expecting a response with 202s, so just return
+ async with self._session.get(url, headers = headers) as resp:
+ # Not expecting a response with 202s, so just return
if resp.status == 202:
return None
- data = await resp.read()
+ data = await resp.text()
if resp.status != 200:
logging.verbose('Received error when requesting %s: %i', url, resp.status)
- logging.debug(await resp.read())
+ logging.debug(data)
return None
- message = loads(data)
- self.cache.set('request', url, data.decode('utf-8'), 'str')
- logging.debug('%s >> resp %s', url, json.dumps(message, indent = 4))
+ self.cache.set('request', url, data, 'str')
+ logging.debug('%s >> resp %s', url, json.dumps(json.loads(data), indent = 4))
- return message
+ return json.loads(data)
except JSONDecodeError:
logging.verbose('Failed to parse JSON')
return None
- except ClientSSLError:
+ except ClientSSLError as e:
logging.verbose('SSL error when connecting to %s', urlparse(url).netloc)
+ logging.warning(str(e))
- except (AsyncTimeoutError, ClientConnectionError):
+ except (AsyncTimeoutError, ClientConnectionError) as e:
logging.verbose('Failed to connect to %s', urlparse(url).netloc)
+ logging.warning(str(e))
except Exception:
traceback.print_exc()
@@ -155,39 +151,74 @@ class HttpClient:
return None
- async def post(self, url: str, message: Message, instance: Row | None = None) -> None:
- await self.open()
+ async def get(self,
+ url: str,
+ sign_headers: bool,
+ cls: type[T],
+ force: bool = False) -> T | None:
- ## Using the old algo by default is probably a better idea right now
- # pylint: disable=consider-ternary-expression
+ if not issubclass(cls, JsonBase):
+ raise TypeError('cls must be a sub-class of "aputils.JsonBase"')
+
+ if (data := (await self._get(url, sign_headers, force))) is None:
+ return None
+
+ return cls.parse(data)
+
+
+ async def post(self, url: str, data: Message | bytes, instance: Row | None = None) -> None:
+ if not self._session:
+ raise RuntimeError('Client not open')
+
+ # akkoma and pleroma do not support HS2019 and other software still needs to be tested
if instance and instance['software'] in {'mastodon'}:
- algorithm = 'hs2019'
+ algorithm = AlgorithmType.HS2019
else:
- algorithm = 'original'
- # pylint: enable=consider-ternary-expression
+ algorithm = AlgorithmType.RSASHA256
- headers = {'Content-Type': 'application/activity+json'}
- headers.update(get_app().signer.sign_headers('POST', url, message, algorithm=algorithm))
+ body: bytes
+ message: Message
+
+ if isinstance(data, bytes):
+ body = data
+ message = Message.parse(data)
+
+ else:
+ body = data.to_json().encode("utf-8")
+ message = data
+
+ mtype = message.type.value if isinstance(message.type, ObjectType) else message.type
+ headers = self.signer.sign_headers(
+ 'POST',
+ url,
+ body,
+ headers = {'Content-Type': 'application/activity+json'},
+ algorithm = algorithm
+ )
try:
- logging.verbose('Sending "%s" to %s', message.type, url)
+ logging.verbose('Sending "%s" to %s', mtype, url)
- async with self._session.post(url, headers=headers, data=message.to_json()) as resp:
+ async with self._session.post(url, headers = headers, data = body) as resp:
# Not expecting a response, so just return
if resp.status in {200, 202}:
- logging.verbose('Successfully sent "%s" to %s', message.type, url)
+ logging.verbose('Successfully sent "%s" to %s', mtype, url)
return
logging.verbose('Received error when pushing to %s: %i', url, resp.status)
logging.debug(await resp.read())
+ logging.debug("message: %s", body.decode("utf-8"))
+ logging.debug("headers: %s", json.dumps(headers, indent = 4))
return
- except ClientSSLError:
+ except ClientSSLError as e:
logging.warning('SSL error when pushing to %s', urlparse(url).netloc)
+ logging.warning(str(e))
- except (AsyncTimeoutError, ClientConnectionError):
+ except (AsyncTimeoutError, ClientConnectionError) as e:
logging.warning('Failed to connect to %s for message push', urlparse(url).netloc)
+ logging.warning(str(e))
# prevent workers from being brought down
except Exception:
@@ -198,10 +229,11 @@ class HttpClient:
nodeinfo_url = None
wk_nodeinfo = await self.get(
f'https://{domain}/.well-known/nodeinfo',
- loads = WellKnownNodeinfo.parse
+ False,
+ WellKnownNodeinfo
)
- if not wk_nodeinfo:
+ if wk_nodeinfo is None:
logging.verbose('Failed to fetch well-known nodeinfo url for %s', domain)
return None
@@ -212,14 +244,14 @@ class HttpClient:
except KeyError:
pass
- if not nodeinfo_url:
+ if nodeinfo_url is None:
logging.verbose('Failed to fetch nodeinfo url for %s', domain)
return None
- return await self.get(nodeinfo_url, loads = Nodeinfo.parse) or None
+ return await self.get(nodeinfo_url, False, Nodeinfo)
-async def get(*args: Any, **kwargs: Any) -> Message | dict | None:
+async def get(*args: Any, **kwargs: Any) -> Any:
async with HttpClient() as client:
return await client.get(*args, **kwargs)
diff --git a/relay/logger.py b/relay/logger.py
index 8aff62d..916fa71 100644
--- a/relay/logger.py
+++ b/relay/logger.py
@@ -11,6 +11,12 @@ if typing.TYPE_CHECKING:
from collections.abc import Callable
from typing import Any
+ try:
+ from typing import Self
+
+ except ImportError:
+ from typing_extensions import Self
+
class LogLevel(IntEnum):
DEBUG = logging.DEBUG
@@ -26,7 +32,13 @@ class LogLevel(IntEnum):
@classmethod
- def parse(cls: type[IntEnum], data: object) -> IntEnum:
+ def parse(cls: type[Self], data: Any) -> Self:
+ try:
+ data = int(data)
+
+ except ValueError:
+ pass
+
if isinstance(data, cls):
return data
@@ -57,10 +69,10 @@ def set_level(level: LogLevel | str) -> None:
def verbose(message: str, *args: Any, **kwargs: Any) -> None:
- if not logging.root.isEnabledFor(LogLevel['VERBOSE']):
+ if not logging.root.isEnabledFor(LogLevel.VERBOSE):
return
- logging.log(LogLevel['VERBOSE'], message, *args, **kwargs)
+ logging.log(LogLevel.VERBOSE, message, *args, **kwargs)
debug: Callable = logging.debug
@@ -70,23 +82,27 @@ error: Callable = logging.error
critical: Callable = logging.critical
-env_log_level = os.environ.get('LOG_LEVEL', 'INFO').upper()
-
try:
- env_log_file = Path(os.environ['LOG_FILE']).expanduser().resolve()
+ env_log_file: Path | None = Path(os.environ['LOG_FILE']).expanduser().resolve()
except KeyError:
env_log_file = None
-handlers = [logging.StreamHandler()]
+handlers: list[Any] = [logging.StreamHandler()]
if env_log_file:
handlers.append(logging.FileHandler(env_log_file))
-logging.addLevelName(LogLevel['VERBOSE'], 'VERBOSE')
+if os.environ.get('IS_SYSTEMD'):
+ logging_format = '%(levelname)s: %(message)s'
+
+else:
+ logging_format = '[%(asctime)s] %(levelname)s: %(message)s'
+
+logging.addLevelName(LogLevel.VERBOSE, 'VERBOSE')
logging.basicConfig(
level = LogLevel.INFO,
- format = '[%(asctime)s] %(levelname)s: %(message)s',
+ format = logging_format,
datefmt = '%Y-%m-%d %H:%M:%S',
handlers = handlers
)
diff --git a/relay/manage.py b/relay/manage.py
index 796ec0b..d768284 100644
--- a/relay/manage.py
+++ b/relay/manage.py
@@ -21,19 +21,10 @@ from .database import RELAY_SOFTWARE, get_database
from .misc import ACTOR_FORMATS, SOFTWARE, IS_DOCKER, Message
if typing.TYPE_CHECKING:
- from tinysql import Row
+ from bsql import Row
from typing import Any
-# pylint: disable=unsubscriptable-object,unsupported-assignment-operation
-
-
-CONFIG_IGNORE = (
- 'schema-version',
- 'private-key'
-)
-
-
def check_alphanumeric(text: str) -> str:
if not text.isalnum():
raise click.BadParameter('String not alphanumeric')
@@ -50,7 +41,7 @@ def cli(ctx: click.Context, config: str | None) -> None:
if not ctx.invoked_subcommand:
if ctx.obj.config.domain.endswith('example.com'):
- cli_setup.callback()
+ cli_setup.callback() # type: ignore
else:
click.echo(
@@ -58,7 +49,7 @@ def cli(ctx: click.Context, config: str | None) -> None:
'future.'
)
- cli_run.callback()
+ cli_run.callback() # type: ignore
@cli.command('setup')
@@ -184,7 +175,7 @@ def cli_setup(ctx: click.Context) -> None:
conn.put_config(key, value)
if not IS_DOCKER and click.confirm('Relay all setup! Would you like to run it now?'):
- cli_run.callback()
+ cli_run.callback() # type: ignore
@cli.command('run')
@@ -257,7 +248,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
conn.put_config('note', config['note'])
conn.put_config('whitelist-enabled', config['whitelist_enabled'])
- with click.progressbar(
+ with click.progressbar( # type: ignore
database['relay-list'].values(),
label = 'Inboxes'.ljust(15),
width = 0
@@ -281,7 +272,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
software = inbox['software']
)
- with click.progressbar(
+ with click.progressbar( # type: ignore
config['blocked_software'],
label = 'Banned software'.ljust(15),
width = 0
@@ -293,7 +284,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
reason = 'relay' if software in RELAY_SOFTWARE else None
)
- with click.progressbar(
+ with click.progressbar( # type: ignore
config['blocked_instances'],
label = 'Banned domains'.ljust(15),
width = 0
@@ -302,7 +293,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
for domain in banned_software:
conn.put_domain_ban(domain)
- with click.progressbar(
+ with click.progressbar( # type: ignore
config['whitelist'],
label = 'Whitelist'.ljust(15),
width = 0
@@ -339,10 +330,17 @@ def cli_config_list(ctx: click.Context) -> None:
click.echo('Relay Config:')
with ctx.obj.database.session() as conn:
- for key, value in conn.get_config_all().items():
- if key not in CONFIG_IGNORE:
- key = f'{key}:'.ljust(20)
- click.echo(f'- {key} {value}')
+ config = conn.get_config_all()
+
+ for key, value in config.to_dict().items():
+ if key in type(config).SYSTEM_KEYS():
+ continue
+
+ if key == 'log-level':
+ value = value.name
+
+ key_str = f'{key}:'.ljust(20)
+ click.echo(f'- {key_str} {repr(value)}')
@cli_config.command('set')
@@ -477,7 +475,7 @@ def cli_inbox_list(ctx: click.Context) -> None:
click.echo('Connected to the following instances or relays:')
with ctx.obj.database.session() as conn:
- for inbox in conn.execute('SELECT * FROM inboxes'):
+ for inbox in conn.get_inboxes():
click.echo(f'- {inbox["inbox"]}')
@@ -520,7 +518,7 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
'Unfollow an actor (Relay must be running)'
- inbox_data: Row = None
+ inbox_data: Row | None = None
with ctx.obj.database.session() as conn:
if conn.get_domain_ban(actor):
@@ -540,6 +538,11 @@ def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
actor = f'https://{actor}/actor'
actor_data = asyncio.run(http.get(actor, sign_headers = True))
+
+ if not actor_data:
+ click.echo("Failed to fetch actor")
+ return
+
inbox = actor_data.shared_inbox
message = Message.new_unfollow(
host = ctx.obj.config.domain,
@@ -618,6 +621,80 @@ def cli_inbox_remove(ctx: click.Context, inbox: str) -> None:
click.echo(f'Removed inbox from the database: {inbox}')
+@cli.group('request')
+def cli_request() -> None:
+ 'Manage follow requests'
+
+
+@cli_request.command('list')
+@click.pass_context
+def cli_request_list(ctx: click.Context) -> None:
+ 'List all current follow requests'
+
+ click.echo('Follow requests:')
+
+ with ctx.obj.database.session() as conn:
+ for instance in conn.get_requests():
+ date = instance['created'].strftime('%Y-%m-%d')
+ click.echo(f'- [{date}] {instance["domain"]}')
+
+
+@cli_request.command('accept')
+@click.argument('domain')
+@click.pass_context
+def cli_request_accept(ctx: click.Context, domain: str) -> None:
+ 'Accept a follow request'
+
+ try:
+ with ctx.obj.database.session() as conn:
+ instance = conn.put_request_response(domain, True)
+
+ except KeyError:
+ click.echo('Request not found')
+ return
+
+ message = Message.new_response(
+ host = ctx.obj.config.domain,
+ actor = instance['actor'],
+ followid = instance['followid'],
+ accept = True
+ )
+
+ asyncio.run(http.post(instance['inbox'], message, instance))
+
+ if instance['software'] != 'mastodon':
+ message = Message.new_follow(
+ host = ctx.obj.config.domain,
+ actor = instance['actor']
+ )
+
+ asyncio.run(http.post(instance['inbox'], message, instance))
+
+
+@cli_request.command('deny')
+@click.argument('domain')
+@click.pass_context
+def cli_request_deny(ctx: click.Context, domain: str) -> None:
+ 'Accept a follow request'
+
+ try:
+ with ctx.obj.database.session() as conn:
+ instance = conn.put_request_response(domain, False)
+
+ except KeyError:
+ click.echo('Request not found')
+ return
+
+ response = Message.new_response(
+ host = ctx.obj.config.domain,
+ actor = instance['actor'],
+ followid = instance['followid'],
+ accept = False
+ )
+
+ asyncio.run(http.post(instance['inbox'], response, instance))
+
+
@cli.group('instance')
def cli_instance() -> None:
'Manage instance bans'
@@ -893,7 +970,6 @@ def cli_whitelist_import(ctx: click.Context) -> None:
def main() -> None:
- # pylint: disable=no-value-for-parameter
cli(prog_name='relay')
diff --git a/relay/misc.py b/relay/misc.py
index 33e7a06..82b1fd2 100644
--- a/relay/misc.py
+++ b/relay/misc.py
@@ -8,27 +8,44 @@ import typing
from aiohttp.web import Response as AiohttpResponse
from datetime import datetime
+from pathlib import Path
from uuid import uuid4
try:
from importlib.resources import files as pkgfiles
except ImportError:
- from importlib_resources import files as pkgfiles
+ from importlib_resources import files as pkgfiles # type: ignore
if typing.TYPE_CHECKING:
- from pathlib import Path
from typing import Any
from .application import Application
+ try:
+ from typing import Self
+
+ except ImportError:
+ from typing_extensions import Self
+
+
+T = typing.TypeVar('T')
+ResponseType = typing.TypedDict('ResponseType', {
+ 'status': int,
+ 'headers': dict[str, typing.Any] | None,
+ 'content_type': str,
+ 'body': bytes | None,
+ 'text': str | None
+})
IS_DOCKER = bool(os.environ.get('DOCKER_RUNNING'))
+
MIMETYPES = {
'activity': 'application/activity+json',
'css': 'text/css',
'html': 'text/html',
'json': 'application/json',
- 'text': 'text/plain'
+ 'text': 'text/plain',
+ 'webmanifest': 'application/manifest+json'
}
NODEINFO_NS = {
@@ -92,7 +109,7 @@ def check_open_port(host: str, port: int) -> bool:
def get_app() -> Application:
- from .application import Application # pylint: disable=import-outside-toplevel
+ from .application import Application
if not Application.DEFAULT:
raise ValueError('No default application set')
@@ -101,7 +118,7 @@ def get_app() -> Application:
def get_resource(path: str) -> Path:
- return pkgfiles('relay').joinpath(path)
+ return Path(str(pkgfiles('relay'))).joinpath(path)
class JsonEncoder(json.JSONEncoder):
@@ -114,18 +131,18 @@ class JsonEncoder(json.JSONEncoder):
class Message(aputils.Message):
@classmethod
- def new_actor(cls: type[Message], # pylint: disable=arguments-differ
+ def new_actor(cls: type[Self], # type: ignore
host: str,
pubkey: str,
- description: str | None = None) -> Message:
+ description: str | None = None,
+ approves: bool = False) -> Self:
- return cls({
- '@context': 'https://www.w3.org/ns/activitystreams',
+ return cls.new(aputils.ObjectType.APPLICATION, {
'id': f'https://{host}/actor',
- 'type': 'Application',
'preferredUsername': 'relay',
'name': 'ActivityRelay',
'summary': description or 'ActivityRelay bot',
+ 'manuallyApprovesFollowers': approves,
'followers': f'https://{host}/followers',
'following': f'https://{host}/following',
'inbox': f'https://{host}/inbox',
@@ -142,11 +159,9 @@ class Message(aputils.Message):
@classmethod
- def new_announce(cls: type[Message], host: str, obj: str) -> Message:
- return cls({
- '@context': 'https://www.w3.org/ns/activitystreams',
+ def new_announce(cls: type[Self], host: str, obj: str | dict[str, Any]) -> Self:
+ return cls.new(aputils.ObjectType.ANNOUNCE, {
'id': f'https://{host}/activities/{uuid4()}',
- 'type': 'Announce',
'to': [f'https://{host}/followers'],
'actor': f'https://{host}/actor',
'object': obj
@@ -154,23 +169,19 @@ class Message(aputils.Message):
@classmethod
- def new_follow(cls: type[Message], host: str, actor: str) -> Message:
- return cls({
- '@context': 'https://www.w3.org/ns/activitystreams',
- 'type': 'Follow',
+ def new_follow(cls: type[Self], host: str, actor: str) -> Self:
+ return cls.new(aputils.ObjectType.FOLLOW, {
+ 'id': f'https://{host}/activities/{uuid4()}',
'to': [actor],
'object': actor,
- 'id': f'https://{host}/activities/{uuid4()}',
'actor': f'https://{host}/actor'
})
@classmethod
- def new_unfollow(cls: type[Message], host: str, actor: str, follow: str) -> Message:
- return cls({
- '@context': 'https://www.w3.org/ns/activitystreams',
+ def new_unfollow(cls: type[Self], host: str, actor: str, follow: dict[str, str]) -> Self:
+ return cls.new(aputils.ObjectType.UNDO, {
'id': f'https://{host}/activities/{uuid4()}',
- 'type': 'Undo',
'to': [actor],
'actor': f'https://{host}/actor',
'object': follow
@@ -178,16 +189,9 @@ class Message(aputils.Message):
@classmethod
- def new_response(cls: type[Message],
- host: str,
- actor: str,
- followid: str,
- accept: bool) -> Message:
-
- return cls({
- '@context': 'https://www.w3.org/ns/activitystreams',
+ def new_response(cls: type[Self], host: str, actor: str, followid: str, accept: bool) -> Self:
+ return cls.new(aputils.ObjectType.ACCEPT if accept else aputils.ObjectType.REJECT, {
'id': f'https://{host}/activities/{uuid4()}',
- 'type': 'Accept' if accept else 'Reject',
'to': [actor],
'actor': f'https://{host}/actor',
'object': {
@@ -206,16 +210,18 @@ class Response(AiohttpResponse):
@classmethod
- def new(cls: type[Response],
- body: str | bytes | dict = '',
+ def new(cls: type[Self],
+ body: str | bytes | dict | tuple | list | set = '',
status: int = 200,
headers: dict[str, str] | None = None,
- ctype: str = 'text') -> Response:
+ ctype: str = 'text') -> Self:
- kwargs = {
+ kwargs: ResponseType = {
'status': status,
'headers': headers,
- 'content_type': MIMETYPES[ctype]
+ 'content_type': MIMETYPES[ctype],
+ 'body': None,
+ 'text': None
}
if isinstance(body, bytes):
@@ -231,10 +237,10 @@ class Response(AiohttpResponse):
@classmethod
- def new_error(cls: type[Response],
+ def new_error(cls: type[Self],
status: int,
body: str | bytes | dict,
- ctype: str = 'text') -> Response:
+ ctype: str = 'text') -> Self:
if ctype == 'json':
body = {'error': body}
@@ -243,14 +249,14 @@ class Response(AiohttpResponse):
@classmethod
- def new_redir(cls: type[Response], path: str) -> Response:
+ def new_redir(cls: type[Self], path: str) -> Self:
body = f'Redirect to
{path}'
return cls.new(body, 302, {'Location': path})
@property
def location(self) -> str:
- return self.headers.get('Location')
+ return self.headers.get('Location', '')
@location.setter
diff --git a/relay/processors.py b/relay/processors.py
index 824a975..910ecf3 100644
--- a/relay/processors.py
+++ b/relay/processors.py
@@ -7,10 +7,10 @@ from .database import Connection
from .misc import Message
if typing.TYPE_CHECKING:
- from .views import ActorView
+ from .views.activitypub import ActorView
-def person_check(actor: str, software: str) -> bool:
+def person_check(actor: Message, software: str | None) -> bool:
# pleroma and akkoma may use Person for the actor type for some reason
# akkoma changed this in 3.6.0
if software in {'akkoma', 'pleroma'} and actor.id == f'https://{actor.domain}/relay':
@@ -35,8 +35,8 @@ async def handle_relay(view: ActorView, conn: Connection) -> None:
message = Message.new_announce(view.config.domain, view.message.object_id)
logging.debug('>> relay: %s', message)
- for inbox in conn.distill_inboxes(view.message):
- view.app.push_message(inbox, message, view.instance)
+ for instance in conn.distill_inboxes(view.message):
+ view.app.push_message(instance["inbox"], message, instance)
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
@@ -53,8 +53,8 @@ async def handle_forward(view: ActorView, conn: Connection) -> None:
message = Message.new_announce(view.config.domain, view.message)
logging.debug('>> forward: %s', message)
- for inbox in conn.distill_inboxes(view.message):
- view.app.push_message(inbox, message, view.instance)
+ for instance in conn.distill_inboxes(view.message):
+ view.app.push_message(instance["inbox"], await view.request.read(), instance)
view.cache.set('handle-relay', view.message.id, message.id, 'str')
@@ -62,9 +62,12 @@ async def handle_forward(view: ActorView, conn: Connection) -> None:
async def handle_follow(view: ActorView, conn: Connection) -> None:
nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain)
software = nodeinfo.sw_name if nodeinfo else None
+ config = conn.get_config_all()
# reject if software used by actor is banned
- if conn.get_software_ban(software):
+ if software and conn.get_software_ban(software):
+ logging.verbose('Rejected banned actor: %s', view.actor.id)
+
view.app.push_message(
view.actor.shared_inbox,
Message.new_response(
@@ -72,7 +75,8 @@ async def handle_follow(view: ActorView, conn: Connection) -> None:
actor = view.actor.id,
followid = view.message.id,
accept = False
- )
+ ),
+ view.instance
)
logging.verbose(
@@ -83,8 +87,10 @@ async def handle_follow(view: ActorView, conn: Connection) -> None:
return
- ## reject if the actor is not an instance actor
+ # reject if the actor is not an instance actor
if person_check(view.actor, software):
+ logging.verbose('Non-application actor tried to follow: %s', view.actor.id)
+
view.app.push_message(
view.actor.shared_inbox,
Message.new_response(
@@ -92,23 +98,54 @@ async def handle_follow(view: ActorView, conn: Connection) -> None:
actor = view.actor.id,
followid = view.message.id,
accept = False
- )
+ ),
+ view.instance
)
- logging.verbose('Non-application actor tried to follow: %s', view.actor.id)
return
- with conn.transaction():
- if conn.get_inbox(view.actor.shared_inbox):
- view.instance = conn.update_inbox(view.actor.shared_inbox, followid = view.message.id)
+ if not conn.get_domain_whitelist(view.actor.domain):
+ # add request if approval-required is enabled
+ if config.approval_required:
+ logging.verbose('New follow request fromm actor: %s', view.actor.id)
- else:
- view.instance = conn.put_inbox(
- view.actor.domain,
+ with conn.transaction():
+ view.instance = conn.put_inbox(
+ domain = view.actor.domain,
+ inbox = view.actor.shared_inbox,
+ actor = view.actor.id,
+ followid = view.message.id,
+ software = software,
+ accepted = False
+ )
+
+ return
+
+ # reject if the actor isn't whitelisted while the whiltelist is enabled
+ if config.whitelist_enabled:
+ logging.verbose('Rejected actor for not being in the whitelist: %s', view.actor.id)
+
+ view.app.push_message(
view.actor.shared_inbox,
- view.actor.id,
- view.message.id,
- software
+ Message.new_response(
+ host = view.config.domain,
+ actor = view.actor.id,
+ followid = view.message.id,
+ accept = False
+ ),
+ view.instance
+ )
+
+ return
+
+ with conn.transaction():
+ view.instance = conn.put_inbox(
+ domain = view.actor.domain,
+ inbox = view.actor.shared_inbox,
+ actor = view.actor.id,
+ followid = view.message.id,
+ software = software,
+ accepted = True
)
view.app.push_message(
@@ -136,7 +173,7 @@ async def handle_follow(view: ActorView, conn: Connection) -> None:
async def handle_undo(view: ActorView, conn: Connection) -> None:
- ## If the object is not a Follow, forward it
+ # If the object is not a Follow, forward it
if view.message.object['type'] != 'Follow':
await handle_forward(view, conn)
return
@@ -150,7 +187,7 @@ async def handle_undo(view: ActorView, conn: Connection) -> None:
logging.verbose(
'Failed to delete "%s" with follow ID "%s"',
view.actor.id,
- view.message.object['id']
+ view.message.object_id
)
view.app.push_message(
@@ -189,15 +226,15 @@ async def run_processor(view: ActorView) -> None:
if not view.instance['software']:
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])):
with conn.transaction():
- view.instance = conn.update_inbox(
- view.instance['inbox'],
+ view.instance = conn.put_inbox(
+ domain = view.instance['domain'],
software = nodeinfo.sw_name
)
if not view.instance['actor']:
with conn.transaction():
- view.instance = conn.update_inbox(
- view.instance['inbox'],
+ view.instance = conn.put_inbox(
+ domain = view.instance['domain'],
actor = view.actor.id
)
diff --git a/relay/template.py b/relay/template.py
index 64738e0..1335fab 100644
--- a/relay/template.py
+++ b/relay/template.py
@@ -1,15 +1,22 @@
from __future__ import annotations
+import textwrap
import typing
-from hamlish_jinja.extension import HamlishExtension
+from collections.abc import Callable
+from hamlish_jinja import HamlishExtension
from jinja2 import Environment, FileSystemLoader
+from jinja2.ext import Extension
+from jinja2.nodes import CallBlock
+from markdown import Markdown
+
from . import __version__
-from .database.config import THEMES
from .misc import get_resource
if typing.TYPE_CHECKING:
+ from jinja2.nodes import Node
+ from jinja2.parser import Parser
from typing import Any
from .application import Application
from .views.base import View
@@ -22,7 +29,8 @@ class Template(Environment):
trim_blocks = True,
lstrip_blocks = True,
extensions = [
- HamlishExtension
+ HamlishExtension,
+ MarkdownExtension
],
loader = FileSystemLoader([
get_resource('frontend'),
@@ -36,16 +44,52 @@ class Template(Environment):
def render(self, path: str, view: View | None = None, **context: Any) -> str:
- with self.app.database.session(False) as s:
- config = s.get_config_all()
+ with self.app.database.session(False) as conn:
+ config = conn.get_config_all()
new_context = {
'view': view,
'domain': self.app.config.domain,
'version': __version__,
'config': config,
- 'theme_name': config['theme'] or 'Default',
**(context or {})
}
return self.get_template(path).render(new_context)
+
+
+ def render_markdown(self, text: str) -> str:
+ return self._render_markdown(text) # type: ignore
+
+
+class MarkdownExtension(Extension):
+ tags = {'markdown'}
+ extensions = (
+ 'attr_list',
+ 'smarty',
+ 'tables'
+ )
+
+
+ def __init__(self, environment: Environment):
+ Extension.__init__(self, environment)
+ self._markdown = Markdown(extensions = MarkdownExtension.extensions)
+ environment.extend(
+ _render_markdown = self._render_markdown
+ )
+
+
+ def parse(self, parser: Parser) -> Node | list[Node]:
+ lineno = next(parser.stream).lineno
+ body = parser.parse_statements(
+ ('name:endmarkdown',),
+ drop_needle = True
+ )
+
+ output = CallBlock(self.call_method('_render_markdown'), [], [], body)
+ return output.set_lineno(lineno)
+
+
+ def _render_markdown(self, caller: Callable[[], str] | str) -> str:
+ text = caller if isinstance(caller, str) else caller()
+ return self._markdown.convert(textwrap.dedent(text.strip('\n')))
diff --git a/relay/views/__init__.py b/relay/views/__init__.py
index 6366592..25a7a62 100644
--- a/relay/views/__init__.py
+++ b/relay/views/__init__.py
@@ -1,4 +1,4 @@
from __future__ import annotations
from . import activitypub, api, frontend, misc
-from .base import VIEWS
+from .base import VIEWS, View
diff --git a/relay/views/activitypub.py b/relay/views/activitypub.py
index 31266f6..f2eff48 100644
--- a/relay/views/activitypub.py
+++ b/relay/views/activitypub.py
@@ -12,27 +12,31 @@ from ..processors import run_processor
if typing.TYPE_CHECKING:
from aiohttp.web import Request
- from tinysql import Row
+ from bsql import Row
-# pylint: disable=unused-argument
-
@register_route('/actor', '/inbox')
class ActorView(View):
+ signature: aputils.Signature
+ message: Message
+ actor: Message
+ instancce: Row
+ signer: aputils.Signer
+
+
def __init__(self, request: Request):
View.__init__(self, request)
- self.signature: aputils.Signature = None
- self.message: Message = None
- self.actor: Message = None
- self.instance: Row = None
- self.signer: aputils.Signer = None
-
async def get(self, request: Request) -> Response:
+ with self.database.session(False) as conn:
+ config = conn.get_config_all()
+
data = Message.new_actor(
host = self.config.domain,
- pubkey = self.app.signer.pubkey
+ pubkey = self.app.signer.pubkey,
+ description = self.app.template.render_markdown(config.note),
+ approves = config.approval_required
)
return Response.new(data, ctype='activity')
@@ -44,19 +48,13 @@ class ActorView(View):
with self.database.session() as conn:
self.instance = conn.get_inbox(self.actor.shared_inbox)
- config = conn.get_config_all()
- ## reject if the actor isn't whitelisted while the whiltelist is enabled
- if config['whitelist-enabled'] and not conn.get_domain_whitelist(self.actor.domain):
- logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id)
- return Response.new_error(403, 'access denied', 'json')
-
- ## reject if actor is banned
+ # reject if actor is banned
if conn.get_domain_ban(self.actor.domain):
logging.verbose('Ignored request from banned actor: %s', self.actor.id)
return Response.new_error(403, 'access denied', 'json')
- ## reject if activity type isn't 'Follow' and the actor isn't following
+ # reject if activity type isn't 'Follow' and the actor isn't following
if self.message.type != 'Follow' and not self.instance:
logging.verbose(
'Rejected actor for trying to post while not following: %s',
@@ -73,35 +71,33 @@ class ActorView(View):
async def get_post_data(self) -> Response | None:
try:
- self.signature = aputils.Signature.new_from_signature(self.request.headers['signature'])
+ self.signature = aputils.Signature.parse(self.request.headers['signature'])
except KeyError:
logging.verbose('Missing signature header')
return Response.new_error(400, 'missing signature header', 'json')
try:
- self.message = await self.request.json(loads = Message.parse)
+ message: Message | None = await self.request.json(loads = Message.parse)
except Exception:
traceback.print_exc()
logging.verbose('Failed to parse inbox message')
return Response.new_error(400, 'failed to parse message', 'json')
- if self.message is None:
+ if message is None:
logging.verbose('empty message')
return Response.new_error(400, 'missing message', 'json')
+ self.message = message
+
if 'actor' not in self.message:
logging.verbose('actor not in message')
return Response.new_error(400, 'no actor in message', 'json')
- self.actor = await self.client.get(
- self.signature.keyid,
- sign_headers = True,
- loads = Message.parse
- )
+ actor: Message | None = await self.client.get(self.signature.keyid, True, Message)
- if not self.actor:
+ if actor is None:
# ld signatures aren't handled atm, so just ignore it
if self.message.type == 'Delete':
logging.verbose('Instance sent a delete which cannot be handled')
@@ -110,6 +106,8 @@ class ActorView(View):
logging.verbose(f'Failed to fetch actor: {self.signature.keyid}')
return Response.new_error(400, 'failed to fetch actor', 'json')
+ self.actor = actor
+
try:
self.signer = self.actor.signer
@@ -118,42 +116,13 @@ class ActorView(View):
return Response.new_error(400, 'actor missing public key', 'json')
try:
- self.validate_signature(await self.request.read())
+ await self.signer.validate_request_async(self.request)
except aputils.SignatureFailureError as e:
logging.verbose('signature validation failed for "%s": %s', self.actor.id, e)
return Response.new_error(401, str(e), 'json')
-
- def validate_signature(self, body: bytes) -> None:
- headers = {key.lower(): value for key, value in self.request.headers.items()}
- headers["(request-target)"] = " ".join([self.request.method.lower(), self.request.path])
-
- if (digest := aputils.Digest.new_from_digest(headers.get("digest"))):
- if not body:
- raise aputils.SignatureFailureError("Missing body for digest verification")
-
- if not digest.validate(body):
- raise aputils.SignatureFailureError("Body digest does not match")
-
- if self.signature.algorithm_type == "hs2019":
- if "(created)" not in self.signature.headers:
- raise aputils.SignatureFailureError("'(created)' header not used")
-
- current_timestamp = aputils.HttpDate.new_utc().timestamp()
-
- if self.signature.created > current_timestamp:
- raise aputils.SignatureFailureError("Creation date after current date")
-
- if current_timestamp > self.signature.expires:
- raise aputils.SignatureFailureError("Expiration date before current date")
-
- headers["(created)"] = self.signature.created
- headers["(expires)"] = self.signature.expires
-
- # pylint: disable=protected-access
- if not self.signer._validate_signature(headers, self.signature):
- raise aputils.SignatureFailureError("Signature does not match")
+ return None
@register_route('/.well-known/webfinger')
diff --git a/relay/views/api.py b/relay/views/api.py
index 5a32cac..04b9af8 100644
--- a/relay/views/api.py
+++ b/relay/views/api.py
@@ -9,23 +9,22 @@ from urllib.parse import urlparse
from .base import View, register_route
from .. import __version__
-from .. import logger as logging
-from ..database.config import CONFIG_DEFAULTS
-from ..misc import Message, Response
+from ..database import ConfigData
+from ..misc import Message, Response, boolean, get_app
if typing.TYPE_CHECKING:
from aiohttp.web import Request
- from collections.abc import Coroutine
+ from collections.abc import Callable, Sequence
+ from typing import Any
-CONFIG_IGNORE = (
- 'schema-version',
- 'private-key'
-)
+ALLOWED_HEADERS = {
+ 'accept',
+ 'authorization',
+ 'content-type'
+}
-CONFIG_VALID = {key for key in CONFIG_DEFAULTS if key not in CONFIG_IGNORE}
-
-PUBLIC_API_PATHS: tuple[tuple[str, str]] = (
+PUBLIC_API_PATHS: Sequence[tuple[str, str]] = (
('GET', '/api/v1/relay'),
('GET', '/api/v1/instance'),
('POST', '/api/v1/token')
@@ -40,28 +39,36 @@ def check_api_path(method: str, path: str) -> bool:
@web.middleware
-async def handle_api_path(request: web.Request, handler: Coroutine) -> web.Response:
+async def handle_api_path(request: Request, handler: Callable) -> Response:
try:
- request['token'] = request.headers['Authorization'].replace('Bearer', '').strip()
+ if (token := request.cookies.get('user-token')):
+ request['token'] = token
- with request.app.database.session() as conn:
+ else:
+ request['token'] = request.headers['Authorization'].replace('Bearer', '').strip()
+
+ with get_app().database.session() as conn:
request['user'] = conn.get_user_by_token(request['token'])
except (KeyError, ValueError):
request['token'] = None
request['user'] = None
- if check_api_path(request.method, request.path):
+ if request.method != "OPTIONS" and check_api_path(request.method, request.path):
if not request['token']:
return Response.new_error(401, 'Missing token', 'json')
if not request['user']:
return Response.new_error(401, 'Invalid token', 'json')
- return await handler(request)
+ response = await handler(request)
+ if request.path.startswith('/api'):
+ response.headers['Access-Control-Allow-Origin'] = '*'
+ response.headers['Access-Control-Allow-Headers'] = ', '.join(ALLOWED_HEADERS)
+
+ return response
-# pylint: disable=no-self-use,unused-argument
@register_route('/api/v1/token')
class Login(View):
@@ -87,7 +94,19 @@ class Login(View):
token = conn.put_token(data['username'])
- return Response.new({'token': token['code']}, ctype = 'json')
+ resp = Response.new({'token': token['code']}, ctype = 'json')
+ resp.set_cookie(
+ 'user-token',
+ token['code'],
+ max_age = 60 * 60 * 24 * 365,
+ domain = self.config.domain,
+ path = '/',
+ secure = True,
+ httponly = False,
+ samesite = 'lax'
+ )
+
+ return resp
async def delete(self, request: Request) -> Response:
@@ -102,14 +121,14 @@ class RelayInfo(View):
async def get(self, request: Request) -> Response:
with self.database.session() as conn:
config = conn.get_config_all()
- inboxes = [row['domain'] for row in conn.execute('SELECT * FROM inboxes')]
+ inboxes = [row['domain'] for row in conn.get_inboxes()]
data = {
'domain': self.config.domain,
- 'name': config['name'],
- 'description': config['note'],
+ 'name': config.name,
+ 'description': config.note,
'version': __version__,
- 'whitelist_enabled': config['whitelist-enabled'],
+ 'whitelist_enabled': config.whitelist_enabled,
'email': None,
'admin': None,
'icon': None,
@@ -122,12 +141,17 @@ class RelayInfo(View):
@register_route('/api/v1/config')
class Config(View):
async def get(self, request: Request) -> Response:
- with self.database.session() as conn:
- data = conn.get_config_all()
- data['log-level'] = data['log-level'].name
+ data = {}
- for key in CONFIG_IGNORE:
- del data[key]
+ with self.database.session() as conn:
+ for key, value in conn.get_config_all().to_dict().items():
+ if key in ConfigData.SYSTEM_KEYS():
+ continue
+
+ if key == 'log-level':
+ value = value.name
+
+ data[key] = value
return Response.new(data, ctype = 'json')
@@ -138,7 +162,9 @@ class Config(View):
if isinstance(data, Response):
return data
- if data['key'] not in CONFIG_VALID:
+ data['key'] = data['key'].replace('-', '_')
+
+ if data['key'] not in ConfigData.USER_KEYS():
return Response.new_error(400, 'Invalid key', 'json')
with self.database.session() as conn:
@@ -153,11 +179,11 @@ class Config(View):
if isinstance(data, Response):
return data
- if data['key'] not in CONFIG_VALID:
+ if data['key'] not in ConfigData.USER_KEYS():
return Response.new_error(400, 'Invalid key', 'json')
with self.database.session() as conn:
- conn.put_config(data['key'], CONFIG_DEFAULTS[data['key']][1])
+ conn.put_config(data['key'], ConfigData.DEFAULT(data['key']))
return Response.new({'message': 'Updated config'}, ctype = 'json')
@@ -166,7 +192,7 @@ class Config(View):
class Inbox(View):
async def get(self, request: Request) -> Response:
with self.database.session() as conn:
- data = tuple(conn.execute('SELECT * FROM inboxes').all())
+ data = conn.get_inboxes()
return Response.new(data, ctype = 'json')
@@ -184,19 +210,19 @@ class Inbox(View):
return Response.new_error(404, 'Instance already in database', 'json')
if not data.get('inbox'):
- try:
- actor_data = await self.client.get(
- data['actor'],
- sign_headers = True,
- loads = Message.parse
- )
+ actor_data: Message | None = await self.client.get(data['actor'], True, Message)
- data['inbox'] = actor_data.shared_inbox
-
- except Exception as e:
- logging.error('Failed to fetch actor: %s', str(e))
+ if actor_data is None:
return Response.new_error(500, 'Failed to fetch actor', 'json')
+ data['inbox'] = actor_data.shared_inbox
+
+ if not data.get('software'):
+ nodeinfo = await self.client.fetch_nodeinfo(data['domain'])
+
+ if nodeinfo is not None:
+ data['software'] = nodeinfo.sw_name
+
row = conn.put_inbox(**data)
return Response.new(row, ctype = 'json')
@@ -212,12 +238,12 @@ class Inbox(View):
if not (instance := conn.get_inbox(data['domain'])):
return Response.new_error(404, 'Instance with domain not found', 'json')
- instance = conn.update_inbox(instance['inbox'], **data)
+ instance = conn.put_inbox(instance['domain'], **data)
return Response.new(instance, ctype = 'json')
- async def delete(self, request: Request, domain: str) -> Response:
+ async def delete(self, request: Request) -> Response:
with self.database.session() as conn:
data = await self.get_api_data(['domain'], [])
@@ -232,6 +258,47 @@ class Inbox(View):
return Response.new({'message': 'Deleted instance'}, ctype = 'json')
+@register_route('/api/v1/request')
+class RequestView(View):
+ async def get(self, request: Request) -> Response:
+ with self.database.session() as conn:
+ instances = conn.get_requests()
+
+ return Response.new(instances, ctype = 'json')
+
+
+ async def post(self, request: Request) -> Response:
+ data: dict[str, Any] | Response = await self.get_api_data(['domain', 'accept'], [])
+ data['accept'] = boolean(data['accept'])
+
+ try:
+ with self.database.session(True) as conn:
+ instance = conn.put_request_response(data['domain'], data['accept'])
+
+ except KeyError:
+ return Response.new_error(404, 'Request not found', 'json')
+
+ message = Message.new_response(
+ host = self.config.domain,
+ actor = instance['actor'],
+ followid = instance['followid'],
+ accept = data['accept']
+ )
+
+ self.app.push_message(instance['inbox'], message, instance)
+
+ if data['accept'] and instance['software'] != 'mastodon':
+ message = Message.new_follow(
+ host = self.config.domain,
+ actor = instance['actor']
+ )
+
+ self.app.push_message(instance['inbox'], message, instance)
+
+ resp_message = {'message': 'Request accepted' if data['accept'] else 'Request denied'}
+ return Response.new(resp_message, ctype = 'json')
+
+
@register_route('/api/v1/domain_ban')
class DomainBan(View):
async def get(self, request: Request) -> Response:
@@ -269,7 +336,7 @@ class DomainBan(View):
if not any([data.get('note'), data.get('reason')]):
return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
- ban = conn.update_domain_ban(data['domain'], **data)
+ ban = conn.update_domain_ban(**data)
return Response.new(ban, ctype = 'json')
@@ -326,7 +393,7 @@ class SoftwareBan(View):
if not any([data.get('note'), data.get('reason')]):
return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
- ban = conn.update_software_ban(data['name'], **data)
+ ban = conn.update_software_ban(**data)
return Response.new(ban, ctype = 'json')
@@ -346,6 +413,63 @@ class SoftwareBan(View):
return Response.new({'message': 'Unbanned software'}, ctype = 'json')
+@register_route('/api/v1/user')
+class User(View):
+ async def get(self, request: Request) -> Response:
+ with self.database.session() as conn:
+ items = []
+
+ for row in conn.execute('SELECT * FROM users'):
+ del row['hash']
+ items.append(row)
+
+ return Response.new(items, ctype = 'json')
+
+
+ async def post(self, request: Request) -> Response:
+ data = await self.get_api_data(['username', 'password'], ['handle'])
+
+ if isinstance(data, Response):
+ return data
+
+ with self.database.session() as conn:
+ if conn.get_user(data['username']):
+ return Response.new_error(404, 'User already exists', 'json')
+
+ user = conn.put_user(**data)
+ del user['hash']
+
+ return Response.new(user, ctype = 'json')
+
+
+ async def patch(self, request: Request) -> Response:
+ data = await self.get_api_data(['username'], ['password', 'handle'])
+
+ if isinstance(data, Response):
+ return data
+
+ with self.database.session(True) as conn:
+ user = conn.put_user(**data)
+ del user['hash']
+
+ return Response.new(user, ctype = 'json')
+
+
+ async def delete(self, request: Request) -> Response:
+ data = await self.get_api_data(['username'], [])
+
+ if isinstance(data, Response):
+ return data
+
+ with self.database.session(True) as conn:
+ if not conn.get_user(data['username']):
+ return Response.new_error(404, 'User does not exist', 'json')
+
+ conn.del_user(data['username'])
+
+ return Response.new({'message': 'Deleted user'}, ctype = 'json')
+
+
@register_route('/api/v1/whitelist')
class Whitelist(View):
async def get(self, request: Request) -> Response:
diff --git a/relay/views/base.py b/relay/views/base.py
index f568525..93b3e3b 100644
--- a/relay/views/base.py
+++ b/relay/views/base.py
@@ -2,40 +2,52 @@ from __future__ import annotations
import typing
+from Crypto.Random import get_random_bytes
from aiohttp.abc import AbstractView
from aiohttp.hdrs import METH_ALL as METHODS
from aiohttp.web import HTTPMethodNotAllowed
+from base64 import b64encode
from functools import cached_property
from json.decoder import JSONDecodeError
-from ..misc import Response
+from ..misc import Response, get_app
if typing.TYPE_CHECKING:
from aiohttp.web import Request
- from collections.abc import Callable, Coroutine, Generator
+ from collections.abc import Callable, Generator, Sequence, Mapping
from bsql import Database
- from typing import Any, Self
+ from typing import Any
from ..application import Application
from ..cache import Cache
from ..config import Config
from ..http_client import HttpClient
from ..template import Template
+ try:
+ from typing import Self
-VIEWS = []
+ except ImportError:
+ from typing_extensions import Self
+
+
+VIEWS: list[tuple[str, type[View]]] = []
+
+
+def convert_data(data: Mapping[str, Any]) -> dict[str, str]:
+ return {key: str(value) for key, value in data.items()}
def register_route(*paths: str) -> Callable:
- def wrapper(view: View) -> View:
+ def wrapper(view: type[View]) -> type[View]:
for path in paths:
- VIEWS.append([path, view])
+ VIEWS.append((path, view))
return view
return wrapper
class View(AbstractView):
- def __await__(self) -> Generator[Response]:
+ def __await__(self) -> Generator[Any, None, Response]:
if self.request.method not in METHODS:
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
@@ -46,22 +58,27 @@ class View(AbstractView):
@classmethod
- async def run(cls: type[Self], method: str, request: Request, **kwargs: Any) -> Self:
+ async def run(cls: type[Self], method: str, request: Request, **kwargs: Any) -> Response:
view = cls(request)
return await view.handlers[method](request, **kwargs)
- async def _run_handler(self, handler: Coroutine, **kwargs: Any) -> Response:
+ async def _run_handler(self, handler: Callable[..., Any], **kwargs: Any) -> Response:
+ self.request['hash'] = b64encode(get_random_bytes(16)).decode('ascii')
return await handler(self.request, **self.request.match_info, **kwargs)
+ async def options(self, request: Request) -> Response:
+ return Response.new()
+
+
@cached_property
- def allowed_methods(self) -> tuple[str]:
+ def allowed_methods(self) -> Sequence[str]:
return tuple(self.handlers.keys())
@cached_property
- def handlers(self) -> dict[str, Coroutine]:
+ def handlers(self) -> dict[str, Callable[..., Any]]:
data = {}
for method in METHODS:
@@ -74,10 +91,9 @@ class View(AbstractView):
return data
- # app components
@property
def app(self) -> Application:
- return self.request.app
+ return get_app()
@property
@@ -110,17 +126,17 @@ class View(AbstractView):
optional: list[str]) -> dict[str, str] | Response:
if self.request.content_type in {'x-www-form-urlencoded', 'multipart/form-data'}:
- post_data = await self.request.post()
+ post_data = convert_data(await self.request.post())
elif self.request.content_type == 'application/json':
try:
- post_data = await self.request.json()
+ post_data = convert_data(await self.request.json())
except JSONDecodeError:
return Response.new_error(400, 'Invalid JSON data', 'json')
else:
- post_data = self.request.query
+ post_data = convert_data(self.request.query)
data = {}
@@ -132,6 +148,6 @@ class View(AbstractView):
return Response.new_error(400, f'Missing {str(e)} pararmeter', 'json')
for key in optional:
- data[key] = post_data.get(key)
+ data[key] = post_data.get(key, '')
return data
diff --git a/relay/views/frontend.py b/relay/views/frontend.py
index bd63417..2b5bec0 100644
--- a/relay/views/frontend.py
+++ b/relay/views/frontend.py
@@ -3,60 +3,59 @@ from __future__ import annotations
import typing
from aiohttp import web
-from argon2.exceptions import VerifyMismatchError
-from urllib.parse import urlparse
from .base import View, register_route
-from ..database import CONFIG_DEFAULTS, THEMES
+from ..database import THEMES
from ..logger import LogLevel
-from ..misc import ACTOR_FORMATS, Message, Response
+from ..misc import Response, get_app
if typing.TYPE_CHECKING:
from aiohttp.web import Request
- from collections.abc import Coroutine
+ from collections.abc import Callable
+ from typing import Any
-# pylint: disable=no-self-use
-
UNAUTH_ROUTES = {
'/',
'/login'
}
-CONFIG_IGNORE = (
- 'schema-version',
- 'private-key'
-)
-
@web.middleware
-async def handle_frontend_path(request: web.Request, handler: Coroutine) -> Response:
+async def handle_frontend_path(request: web.Request, handler: Callable) -> Response:
+ app = get_app()
+
if request.path in UNAUTH_ROUTES or request.path.startswith('/admin'):
request['token'] = request.cookies.get('user-token')
request['user'] = None
if request['token']:
- with request.app.database.session(False) as conn:
+ with app.database.session(False) as conn:
request['user'] = conn.get_user_by_token(request['token'])
if request['user'] and request.path == '/login':
return Response.new('', 302, {'Location': '/'})
if not request['user'] and request.path.startswith('/admin'):
- return Response.new('', 302, {'Location': f'/login?redir={request.path}'})
+ response = Response.new('', 302, {'Location': f'/login?redir={request.path}'})
+ response.del_cookie('user-token')
+ return response
- return await handler(request)
+ response = await handler(request)
+ if not request.path.startswith('/api') and not request['user'] and request['token']:
+ response.del_cookie('user-token')
+
+ return response
-# pylint: disable=unused-argument
@register_route('/')
class HomeView(View):
async def get(self, request: Request) -> Response:
with self.database.session() as conn:
- context = {
- 'instances': tuple(conn.execute('SELECT * FROM inboxes').all())
+ context: dict[str, Any] = {
+ 'instances': tuple(conn.get_inboxes())
}
data = self.template.render('page/home.haml', self, **context)
@@ -70,47 +69,6 @@ class Login(View):
return Response.new(data, ctype = 'html')
- async def post(self, request: Request) -> Response:
- form = await request.post()
- params = {}
-
- with self.database.session(True) as conn:
- if not (user := conn.get_user(form['username'])):
- params = {
- 'username': form['username'],
- 'error': 'User not found'
- }
-
- else:
- try:
- conn.hasher.verify(user['hash'], form['password'])
-
- except VerifyMismatchError:
- params = {
- 'username': form['username'],
- 'error': 'Invalid password'
- }
-
- if params:
- data = self.template.render('page/login.haml', self, **params)
- return Response.new(data, ctype = 'html')
-
- token = conn.put_token(user['username'])
- resp = Response.new_redir(request.query.getone('redir', '/'))
- resp.set_cookie(
- 'user-token',
- token['code'],
- max_age = 60 * 60 * 24 * 365,
- domain = self.config.domain,
- path = '/',
- secure = True,
- httponly = True,
- samesite = 'Strict'
- )
-
- return resp
-
-
@register_route('/logout')
class Logout(View):
async def get(self, request: Request) -> Response:
@@ -136,8 +94,9 @@ class AdminInstances(View):
message: str | None = None) -> Response:
with self.database.session() as conn:
- context = {
- 'instances': tuple(conn.execute('SELECT * FROM inboxes').all())
+ context: dict[str, Any] = {
+ 'instances': tuple(conn.get_inboxes()),
+ 'requests': tuple(conn.get_requests())
}
if error:
@@ -150,44 +109,6 @@ class AdminInstances(View):
return Response.new(data, ctype = 'html')
- async def post(self, request: Request) -> Response:
- data = await request.post()
-
- if not data.get('actor') and not data.get('domain'):
- return await self.get(request, error = 'Missing actor and/or domain')
-
- if not data.get('domain'):
- data['domain'] = urlparse(data['actor']).netloc
-
- if not data.get('software'):
- nodeinfo = await self.client.fetch_nodeinfo(data['domain'])
- data['software'] = nodeinfo.sw_name
-
- if not data.get('actor') and data['software'] in ACTOR_FORMATS:
- data['actor'] = ACTOR_FORMATS[data['software']].format(domain = data['domain'])
-
- if not data.get('inbox') and data['actor']:
- actor = await self.client.get(data['actor'], sign_headers = True, loads = Message.parse)
- data['inbox'] = actor.shared_inbox
-
- with self.database.session(True) as conn:
- conn.put_inbox(**data)
-
- return await self.get(request, message = "Added new inbox")
-
-
-@register_route('/admin/instances/delete/{domain}')
-class AdminInstancesDelete(View):
- async def get(self, request: Request, domain: str) -> Response:
- with self.database.session() as conn:
- if not conn.get_inbox(domain):
- return await AdminInstances(request).get(request, message = 'Instance not found')
-
- conn.del_inbox(domain)
-
- return await AdminInstances(request).get(request, message = 'Removed instance')
-
-
@register_route('/admin/whitelist')
class AdminWhitelist(View):
async def get(self,
@@ -196,8 +117,8 @@ class AdminWhitelist(View):
message: str | None = None) -> Response:
with self.database.session() as conn:
- context = {
- 'whitelist': tuple(conn.execute('SELECT * FROM whitelist').all())
+ context: dict[str, Any] = {
+ 'whitelist': tuple(conn.execute('SELECT * FROM whitelist ORDER BY domain ASC'))
}
if error:
@@ -210,34 +131,6 @@ class AdminWhitelist(View):
return Response.new(data, ctype = 'html')
- async def post(self, request: Request) -> Response:
- data = await request.post()
-
- if not data['domain']:
- return await self.get(request, error = 'Missing domain')
-
- with self.database.session(True) as conn:
- if conn.get_domain_whitelist(data['domain']):
- return await self.get(request, message = "Domain already in whitelist")
-
- conn.put_domain_whitelist(data['domain'])
-
- return await self.get(request, message = "Added/updated domain ban")
-
-
-@register_route('/admin/whitelist/delete/{domain}')
-class AdminWhitlistDelete(View):
- async def get(self, request: Request, domain: str) -> Response:
- with self.database.session() as conn:
- if not conn.get_domain_whitelist(domain):
- msg = 'Whitelisted domain not found'
- return await AdminWhitelist.run("GET", request, message = msg)
-
- conn.del_domain_whitelist(domain)
-
- return await AdminWhitelist.run("GET", request, message = 'Removed domain from whitelist')
-
-
@register_route('/admin/domain_bans')
class AdminDomainBans(View):
async def get(self,
@@ -246,8 +139,8 @@ class AdminDomainBans(View):
message: str | None = None) -> Response:
with self.database.session() as conn:
- context = {
- 'bans': tuple(conn.execute('SELECT * FROM domain_bans ORDER BY domain ASC').all())
+ context: dict[str, Any] = {
+ 'bans': tuple(conn.execute('SELECT * FROM domain_bans ORDER BY domain ASC'))
}
if error:
@@ -260,42 +153,6 @@ class AdminDomainBans(View):
return Response.new(data, ctype = 'html')
- async def post(self, request: Request) -> Response:
- data = await request.post()
-
- if not data['domain']:
- return await self.get(request, error = 'Missing domain')
-
- with self.database.session(True) as conn:
- if conn.get_domain_ban(data['domain']):
- conn.update_domain_ban(
- data['domain'],
- data.get('reason'),
- data.get('note')
- )
-
- else:
- conn.put_domain_ban(
- data['domain'],
- data.get('reason'),
- data.get('note')
- )
-
- return await self.get(request, message = "Added/updated domain ban")
-
-
-@register_route('/admin/domain_bans/delete/{domain}')
-class AdminDomainBansDelete(View):
- async def get(self, request: Request, domain: str) -> Response:
- with self.database.session() as conn:
- if not conn.get_domain_ban(domain):
- return await AdminDomainBans.run("GET", request, message = 'Domain ban not found')
-
- conn.del_domain_ban(domain)
-
- return await AdminDomainBans.run("GET", request, message = 'Unbanned domain')
-
-
@register_route('/admin/software_bans')
class AdminSoftwareBans(View):
async def get(self,
@@ -304,8 +161,8 @@ class AdminSoftwareBans(View):
message: str | None = None) -> Response:
with self.database.session() as conn:
- context = {
- 'bans': tuple(conn.execute('SELECT * FROM software_bans ORDER BY name ASC').all())
+ context: dict[str, Any] = {
+ 'bans': tuple(conn.execute('SELECT * FROM software_bans ORDER BY name ASC'))
}
if error:
@@ -318,42 +175,6 @@ class AdminSoftwareBans(View):
return Response.new(data, ctype = 'html')
- async def post(self, request: Request) -> Response:
- data = await request.post()
-
- if not data['name']:
- return await self.get(request, error = 'Missing name')
-
- with self.database.session(True) as conn:
- if conn.get_software_ban(data['name']):
- conn.update_software_ban(
- data['name'],
- data.get('reason'),
- data.get('note')
- )
-
- else:
- conn.put_software_ban(
- data['name'],
- data.get('reason'),
- data.get('note')
- )
-
- return await self.get(request, message = "Added/updated software ban")
-
-
-@register_route('/admin/software_bans/delete/{name}')
-class AdminSoftwareBansDelete(View):
- async def get(self, request: Request, name: str) -> Response:
- with self.database.session() as conn:
- if not conn.get_software_ban(name):
- return await AdminSoftwareBans.run("GET", request, message = 'Software ban not found')
-
- conn.del_software_ban(name)
-
- return await AdminSoftwareBans.run("GET", request, message = 'Unbanned software')
-
-
@register_route('/admin/users')
class AdminUsers(View):
async def get(self,
@@ -362,8 +183,8 @@ class AdminUsers(View):
message: str | None = None) -> Response:
with self.database.session() as conn:
- context = {
- 'users': tuple(conn.execute('SELECT * FROM users').all())
+ context: dict[str, Any] = {
+ 'users': tuple(conn.execute('SELECT * FROM users ORDER BY username ASC'))
}
if error:
@@ -376,82 +197,47 @@ class AdminUsers(View):
return Response.new(data, ctype = 'html')
- async def post(self, request: Request) -> Response:
- data = await request.post()
- required_fields = {'username', 'password', 'password2'}
-
- if not all(data.get(field) for field in required_fields):
- return await self.get(request, error = 'Missing username and/or password')
-
- if data['password'] != data['password2']:
- return await self.get(request, error = 'Passwords do not match')
-
- with self.database.session(True) as conn:
- if conn.get_user(data['username']):
- return await self.get(request, message = "User already exists")
-
- conn.put_user(data['username'], data['password'], data['handle'])
-
- return await self.get(request, message = "Added user")
-
-
-@register_route('/admin/users/delete/{name}')
-class AdminUsersDelete(View):
- async def get(self, request: Request, name: str) -> Response:
- with self.database.session() as conn:
- if not conn.get_user(name):
- return await AdminUsers.run("GET", request, message = 'User not found')
-
- conn.del_user(name)
-
- return await AdminUsers.run("GET", request, message = 'User deleted')
-
-
@register_route('/admin/config')
class AdminConfig(View):
async def get(self, request: Request, message: str | None = None) -> Response:
- context = {
+ context: dict[str, Any] = {
'themes': tuple(THEMES.keys()),
- 'LogLevel': LogLevel,
+ 'levels': tuple(level.name for level in LogLevel),
'message': message
}
+
data = self.template.render('page/admin-config.haml', self, **context)
return Response.new(data, ctype = 'html')
- async def post(self, request: Request) -> Response:
- form = dict(await request.post())
-
- with self.database.session(True) as conn:
- for key in CONFIG_DEFAULTS:
- value = form.get(key)
-
- if key == 'whitelist-enabled':
- value = bool(value)
-
- elif key.lower() in CONFIG_IGNORE:
- continue
-
- if value is None:
- continue
-
- conn.put_config(key, value)
-
- return await self.get(request, message = 'Updated config')
-
-
-@register_route('/style.css')
-class StyleCss(View):
+@register_route('/manifest.json')
+class ManifestJson(View):
async def get(self, request: Request) -> Response:
- data = self.template.render('style.css', self)
- return Response.new(data, ctype = 'css')
+ with self.database.session(False) as conn:
+ config = conn.get_config_all()
+ theme = THEMES[config.theme]
+
+ data = {
+ 'background_color': theme['background'],
+ 'categories': ['activitypub'],
+ 'description': 'Message relay for the ActivityPub network',
+ 'display': 'standalone',
+ 'name': config['name'],
+ 'orientation': 'portrait',
+ 'scope': f"https://{self.config.domain}/",
+ 'short_name': 'ActivityRelay',
+ 'start_url': f"https://{self.config.domain}/",
+ 'theme_color': theme['primary']
+ }
+
+ return Response.new(data, ctype = 'webmanifest')
@register_route('/theme/{theme}.css')
class ThemeCss(View):
async def get(self, request: Request, theme: str) -> Response:
try:
- context = {
+ context: dict[str, Any] = {
'theme': THEMES[theme]
}
diff --git a/relay/views/misc.py b/relay/views/misc.py
index 65025e3..f10a877 100644
--- a/relay/views/misc.py
+++ b/relay/views/misc.py
@@ -27,28 +27,26 @@ if Path(__file__).parent.parent.joinpath('.git').exists():
pass
-# pylint: disable=unused-argument
-
@register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}')
class NodeinfoView(View):
- # pylint: disable=no-self-use
async def get(self, request: Request, niversion: str) -> Response:
with self.database.session() as conn:
- inboxes = conn.execute('SELECT * FROM inboxes').all()
+ inboxes = conn.get_inboxes()
- data = {
- 'name': 'activityrelay',
- 'version': VERSION,
- 'protocols': ['activitypub'],
- 'open_regs': not conn.get_config('whitelist-enabled'),
- 'users': 1,
- 'metadata': {'peers': [inbox['domain'] for inbox in inboxes]}
- }
+ nodeinfo = aputils.Nodeinfo.new(
+ name = 'activityrelay',
+ version = VERSION,
+ protocols = ['activitypub'],
+ open_regs = not conn.get_config('whitelist-enabled'),
+ users = 1,
+ repo = 'https://git.pleroma.social/pleroma/relay' if niversion == '2.1' else None,
+ metadata = {
+ 'approval_required': conn.get_config('approval-required'),
+ 'peers': [inbox['domain'] for inbox in inboxes]
+ }
+ )
- if niversion == '2.1':
- data['repo'] = 'https://git.pleroma.social/pleroma/relay'
-
- return Response.new(aputils.Nodeinfo.new(**data), ctype = 'json')
+ return Response.new(nodeinfo, ctype = 'json')
@register_route('/.well-known/nodeinfo')
diff --git a/requirements.txt b/requirements.txt
index 4c43b87..5649873 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,13 +1,14 @@
-aiohttp>=3.9.1
-aiohttp-swagger[performance]==1.0.16
-aputils@https://git.barkshark.xyz/barkshark/aputils/archive/0.1.7.tar.gz
-argon2-cffi==23.1.0
-barkshark-sql@https://git.barkshark.xyz/barkshark/bsql/archive/0.1.2.tar.gz
-click>=8.1.2
-hamlish-jinja@https://git.barkshark.xyz/barkshark/hamlish-jinja/archive/0.3.5.tar.gz
-hiredis==2.3.2
-platformdirs==4.2.0
-pyyaml>=6.0
-redis==5.0.1
+activitypub-utils == 0.2.1
+aiohttp >= 3.9.1
+aiohttp-swagger[performance] == 1.0.16
+argon2-cffi == 23.1.0
+barkshark-sql @ https://git.barkshark.xyz/barkshark/bsql/archive/0.1.2.tar.gz
+click >= 8.1.2
+hamlish-jinja @ https://git.barkshark.xyz/barkshark/hamlish-jinja/archive/0.3.5.tar.gz
+hiredis == 2.3.2
+markdown == 3.5.2
+platformdirs == 4.2.0
+pyyaml >= 6.0
+redis == 5.0.1
-importlib_resources==6.1.1;python_version<'3.9'
+importlib_resources == 6.1.1; python_version < '3.9'
diff --git a/setup.cfg b/setup.cfg
index 41c2a30..b7d4fdc 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -44,6 +44,8 @@ console_scripts =
[flake8]
-select = F401
+extend-ignore = E128,E251,E261,E303,W191
+max-line-length = 100
+indent-size = 4
per-file-ignores =
__init__.py: F401