Compare commits

..

No commits in common. "b22b5bbefaa1b6cf13deaeb65396b135dc3fb192" and "5217516c8a5c6b00711a4784eea4caf26801de4d" have entirely different histories.

11 changed files with 99 additions and 981 deletions

47
dev.py
View file

@ -1,38 +1,25 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import click
import platform import platform
import shutil import shutil
import subprocess import subprocess
import sys import sys
import time import time
import tomllib
from datetime import datetime, timedelta from datetime import datetime, timedelta
from importlib.util import find_spec
from pathlib import Path from pathlib import Path
from relay import __version__, logger as logging
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Any, Sequence from typing import Any, Sequence
try: try:
import tomllib from watchdog.observers import Observer
from watchdog.events import FileSystemEvent, PatternMatchingEventHandler
except ImportError: except ImportError:
if find_spec("toml") is None: class PatternMatchingEventHandler: # type: ignore
subprocess.run([sys.executable, "-m", "pip", "install", "toml"]) pass
import toml as tomllib # type: ignore[no-redef]
if None in [find_spec("click"), find_spec("watchdog")]:
CMD = [sys.executable, "-m", "pip", "install", "click >= 8.1.0", "watchdog >= 4.0.0"]
PROC = subprocess.run(CMD, check = False)
if PROC.returncode != 0:
sys.exit()
print("Successfully installed dependencies")
import click
from watchdog.observers import Observer
from watchdog.events import FileSystemEvent, PatternMatchingEventHandler
REPO = Path(__file__).parent REPO = Path(__file__).parent
@ -50,10 +37,12 @@ def cli() -> None:
@cli.command('install') @cli.command('install')
@click.option('--no-dev', '-d', is_flag = True, help = 'Do not install development dependencies') @click.option('--no-dev', '-d', is_flag = True, help = 'Do not install development dependencies')
def cli_install(no_dev: bool) -> None: def cli_install(no_dev: bool) -> None:
with open('pyproject.toml', 'r', encoding = 'utf-8') as fd: with open('pyproject.toml', 'rb') as fd:
data = tomllib.loads(fd.read()) data = tomllib.load(fd)
deps = data['project']['dependencies'] deps = data['project']['dependencies']
if not no_dev:
deps.extend(data['project']['optional-dependencies']['dev']) deps.extend(data['project']['optional-dependencies']['dev'])
subprocess.run([sys.executable, '-m', 'pip', 'install', '-U', *deps], check = False) subprocess.run([sys.executable, '-m', 'pip', 'install', '-U', *deps], check = False)
@ -71,7 +60,7 @@ def cli_lint(path: Path, watch: bool) -> None:
return return
flake8 = [sys.executable, '-m', 'flake8', "dev.py", str(path)] flake8 = [sys.executable, '-m', 'flake8', "dev.py", str(path)]
mypy = [sys.executable, '-m', 'mypy', '--python-version', '3.12', 'dev.py', str(path)] mypy = [sys.executable, '-m', 'mypy', "dev.py", str(path)]
click.echo('----- flake8 -----') click.echo('----- flake8 -----')
subprocess.run(flake8) subprocess.run(flake8)
@ -100,8 +89,6 @@ def cli_clean() -> None:
@cli.command('build') @cli.command('build')
def cli_build() -> None: def cli_build() -> None:
from relay import __version__
with TemporaryDirectory() as tmp: with TemporaryDirectory() as tmp:
arch = 'amd64' if sys.maxsize >= 2**32 else 'i386' arch = 'amd64' if sys.maxsize >= 2**32 else 'i386'
cmd = [ cmd = [
@ -184,7 +171,7 @@ class WatchHandler(PatternMatchingEventHandler):
if proc.poll() is not None: if proc.poll() is not None:
continue continue
print(f'Terminating process {proc.pid}') logging.info(f'Terminating process {proc.pid}')
proc.terminate() proc.terminate()
sec = 0.0 sec = 0.0
@ -193,11 +180,11 @@ class WatchHandler(PatternMatchingEventHandler):
sec += 0.1 sec += 0.1
if sec >= 5: if sec >= 5:
print('Failed to terminate. Killing process...') logging.error('Failed to terminate. Killing process...')
proc.kill() proc.kill()
break break
print('Process terminated') logging.info('Process terminated')
def run_procs(self, restart: bool = False) -> None: def run_procs(self, restart: bool = False) -> None:
@ -213,13 +200,13 @@ class WatchHandler(PatternMatchingEventHandler):
self.procs = [] self.procs = []
for cmd in self.commands: for cmd in self.commands:
print('Running command:', ' '.join(cmd)) logging.info('Running command: %s', ' '.join(cmd))
subprocess.run(cmd) subprocess.run(cmd)
else: else:
self.procs = list(subprocess.Popen(cmd) for cmd in self.commands) self.procs = list(subprocess.Popen(cmd) for cmd in self.commands)
pids = (str(proc.pid) for proc in self.procs) pids = (str(proc.pid) for proc in self.procs)
print('Started processes with PIDs:', ', '.join(pids)) logging.info('Started processes with PIDs: %s', ', '.join(pids))
def on_any_event(self, event: FileSystemEvent) -> None: def on_any_event(self, event: FileSystemEvent) -> None:

View file

@ -9,27 +9,30 @@ license = {text = "AGPLv3"}
classifiers = [ classifiers = [
"Environment :: Console", "Environment :: Console",
"License :: OSI Approved :: GNU Affero General Public License v3", "License :: OSI Approved :: GNU Affero General Public License v3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12" "Programming Language :: Python :: 3.12",
] ]
dependencies = [ dependencies = [
"activitypub-utils >= 0.3.1.post1, < 0.4.0", "activitypub-utils >= 0.3.1, < 0.4.0",
"aiohttp >= 3.9.5", "aiohttp >= 3.9.5",
"aiohttp-swagger[performance] == 1.0.16", "aiohttp-swagger[performance] == 1.0.16",
"argon2-cffi == 23.1.0", "argon2-cffi == 23.1.0",
"barkshark-lib >= 0.1.5rc1, < 0.2.0", "barkshark-lib >= 0.1.4, < 0.2.0",
"barkshark-sql >= 0.2.0rc2, < 0.3.0", "barkshark-sql >= 0.2.0-rc1, < 0.3.0",
"click == 8.1.2", "click == 8.1.2",
"hiredis == 2.3.2", "hiredis == 2.3.2",
"idna == 3.4", "idna == 3.4",
"jinja2-haml == 0.3.5", "jinja2-haml == 0.3.5",
"markdown == 3.6", "markdown == 3.6",
"platformdirs == 4.2.2", "platformdirs == 4.2.2",
"pyyaml == 6.0.1", "pyyaml == 6.0",
"redis == 5.0.7" "redis == 5.0.5",
"importlib-resources == 6.4.0; python_version < '3.9'"
] ]
requires-python = ">=3.10" requires-python = ">=3.8"
dynamic = ["version"] dynamic = ["version"]
[project.readme] [project.readme]
@ -46,10 +49,11 @@ activityrelay = "relay.manage:main"
[project.optional-dependencies] [project.optional-dependencies]
dev = [ dev = [
"flake8 == 7.1.0", "flake8 == 7.0.0",
"mypy == 1.10.1", "mypy == 1.10.0",
"pyinstaller == 6.8.0", "pyinstaller == 6.8.0",
"watchdog == 4.0.1" "watchdog == 4.0.1",
"typing-extensions == 4.12.2; python_version < '3.11.0'"
] ]
[tool.setuptools] [tool.setuptools]

View file

@ -4,13 +4,12 @@ import json
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from blib import Date
from bsql import Database, Row from bsql import Database, Row
from collections.abc import Callable, Iterator from collections.abc import Callable, Iterator
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from datetime import timedelta from datetime import datetime, timedelta, timezone
from redis import Redis from redis import Redis
from typing import TYPE_CHECKING, Any, TypedDict from typing import TYPE_CHECKING, Any
from .database import Connection, get_database from .database import Connection, get_database
from .misc import Message, boolean from .misc import Message, boolean
@ -32,14 +31,6 @@ CONVERTERS: dict[str, tuple[SerializerCallback, DeserializerCallback]] = {
} }
class RedisConnectType(TypedDict):
client_name: str
decode_responses: bool
username: str | None
password: str | None
db: int
def get_cache(app: Application) -> Cache: def get_cache(app: Application) -> Cache:
return BACKENDS[app.config.ca_type](app) return BACKENDS[app.config.ca_type](app)
@ -66,11 +57,12 @@ class Item:
key: str key: str
value: Any value: Any
value_type: str value_type: str
updated: Date updated: datetime
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.updated = Date.parse(self.updated) if isinstance(self.updated, str): # type: ignore[unreachable]
self.updated = datetime.fromisoformat(self.updated) # type: ignore[unreachable]
@classmethod @classmethod
@ -78,11 +70,14 @@ class Item:
data = cls(*args) data = cls(*args)
data.value = deserialize_value(data.value, data.value_type) data.value = deserialize_value(data.value, data.value_type)
if not isinstance(data.updated, datetime):
data.updated = datetime.fromtimestamp(data.updated, tz = timezone.utc) # type: ignore
return data return data
def older_than(self, hours: int) -> bool: def older_than(self, hours: int) -> bool:
delta = Date.new_utc() - self.updated delta = datetime.now(tz = timezone.utc) - self.updated
return (delta.total_seconds()) > hours * 3600 return (delta.total_seconds()) > hours * 3600
@ -211,7 +206,7 @@ class SqlCache(Cache):
'key': key, 'key': key,
'value': serialize_value(value, value_type), 'value': serialize_value(value, value_type),
'type': value_type, 'type': value_type,
'date': Date.new_utc() 'date': datetime.now(tz = timezone.utc)
} }
with self._db.session(True) as conn: with self._db.session(True) as conn:
@ -241,7 +236,7 @@ class SqlCache(Cache):
if self._db is None: if self._db is None:
raise RuntimeError("Database has not been setup") raise RuntimeError("Database has not been setup")
limit = Date.new_utc() - timedelta(days = days) limit = datetime.now(tz = timezone.utc) - timedelta(days = days)
params = {"limit": limit.timestamp()} params = {"limit": limit.timestamp()}
with self._db.session(True) as conn: with self._db.session(True) as conn:
@ -285,7 +280,7 @@ class RedisCache(Cache):
def __init__(self, app: Application): def __init__(self, app: Application):
Cache.__init__(self, app) Cache.__init__(self, app)
self._rd: Redis | None = None self._rd: Redis = None # type: ignore
@property @property
@ -298,38 +293,28 @@ class RedisCache(Cache):
def get(self, namespace: str, key: str) -> Item: def get(self, namespace: str, key: str) -> Item:
if self._rd is None:
raise ConnectionError("Not connected")
key_name = self.get_key_name(namespace, key) key_name = self.get_key_name(namespace, key)
if not (raw_value := self._rd.get(key_name)): if not (raw_value := self._rd.get(key_name)):
raise KeyError(f'{namespace}:{key}') raise KeyError(f'{namespace}:{key}')
value_type, updated, value = raw_value.split(':', 2) # type: ignore[union-attr] value_type, updated, value = raw_value.split(':', 2) # type: ignore
return Item.from_data( return Item.from_data(
namespace, namespace,
key, key,
value, value,
value_type, value_type,
Date.parse(float(updated)) datetime.fromtimestamp(float(updated), tz = timezone.utc)
) )
def get_keys(self, namespace: str) -> Iterator[str]: def get_keys(self, namespace: str) -> Iterator[str]:
if self._rd is None:
raise ConnectionError("Not connected")
for key in self._rd.scan_iter(self.get_key_name(namespace, '*')): for key in self._rd.scan_iter(self.get_key_name(namespace, '*')):
*_, key_name = key.split(':', 2) *_, key_name = key.split(':', 2)
yield key_name yield key_name
def get_namespaces(self) -> Iterator[str]: def get_namespaces(self) -> Iterator[str]:
if self._rd is None:
raise ConnectionError("Not connected")
namespaces = [] namespaces = []
for key in self._rd.scan_iter(f'{self.prefix}:*'): for key in self._rd.scan_iter(f'{self.prefix}:*'):
@ -341,10 +326,7 @@ class RedisCache(Cache):
def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> Item: def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> Item:
if self._rd is None: date = datetime.now(tz = timezone.utc).timestamp()
raise ConnectionError("Not connected")
date = Date.new_utc().timestamp()
value = serialize_value(value, value_type) value = serialize_value(value, value_type)
self._rd.set( self._rd.set(
@ -356,17 +338,11 @@ class RedisCache(Cache):
def delete(self, namespace: str, key: str) -> None: def delete(self, namespace: str, key: str) -> None:
if self._rd is None:
raise ConnectionError("Not connected")
self._rd.delete(self.get_key_name(namespace, key)) self._rd.delete(self.get_key_name(namespace, key))
def delete_old(self, days: int = 14) -> None: def delete_old(self, days: int = 14) -> None:
if self._rd is None: limit = datetime.now(tz = timezone.utc) - timedelta(days = days)
raise ConnectionError("Not connected")
limit = Date.new_utc() - timedelta(days = days)
for full_key in self._rd.scan_iter(f'{self.prefix}:*'): for full_key in self._rd.scan_iter(f'{self.prefix}:*'):
_, namespace, key = full_key.split(':', 2) _, namespace, key = full_key.split(':', 2)
@ -377,17 +353,14 @@ class RedisCache(Cache):
def clear(self) -> None: def clear(self) -> None:
if self._rd is None:
raise ConnectionError("Not connected")
self._rd.delete(f"{self.prefix}:*") self._rd.delete(f"{self.prefix}:*")
def setup(self) -> None: def setup(self) -> None:
if self._rd is not None: if self._rd:
return return
options: RedisConnectType = { options = {
'client_name': f'ActivityRelay_{self.app.config.domain}', 'client_name': f'ActivityRelay_{self.app.config.domain}',
'decode_responses': True, 'decode_responses': True,
'username': self.app.config.rd_user, 'username': self.app.config.rd_user,
@ -396,22 +369,18 @@ class RedisCache(Cache):
} }
if os.path.exists(self.app.config.rd_host): if os.path.exists(self.app.config.rd_host):
self._rd = Redis( options['unix_socket_path'] = self.app.config.rd_host
unix_socket_path = self.app.config.rd_host,
**options
)
return
self._rd = Redis( else:
host = self.app.config.rd_host, options['host'] = self.app.config.rd_host
port = self.app.config.rd_port, options['port'] = self.app.config.rd_port
**options
) self._rd = Redis(**options) # type: ignore
def close(self) -> None: def close(self) -> None:
if not self._rd: if not self._rd:
return return
self._rd.close() # type: ignore[no-untyped-call] self._rd.close() # type: ignore
self._rd = None self._rd = None # type: ignore

View file

@ -13,8 +13,12 @@ from typing import TYPE_CHECKING, Any
from .misc import IS_DOCKER from .misc import IS_DOCKER
if TYPE_CHECKING: if TYPE_CHECKING:
try:
from typing import Self from typing import Self
except ImportError:
from typing_extensions import Self
if platform.system() == 'Windows': if platform.system() == 'Windows':
import multiprocessing import multiprocessing
@ -80,7 +84,7 @@ class Config:
def DEFAULT(cls: type[Self], key: str) -> str | int | None: def DEFAULT(cls: type[Self], key: str) -> str | int | None:
for field in fields(cls): for field in fields(cls):
if field.name == key: if field.name == key:
return field.default # type: ignore[return-value] return field.default # type: ignore
raise KeyError(key) raise KeyError(key)

View file

@ -1,862 +0,0 @@
// toast notifications
const notifications = document.querySelector("#notifications")
function remove_toast(toast) {
toast.classList.add("hide");
if (toast.timeoutId) {
clearTimeout(toast.timeoutId);
}
setTimeout(() => toast.remove(), 300);
}
function toast(text, type="error", timeout=5) {
const toast = document.createElement("li");
toast.className = `section ${type}`
toast.innerHTML = `<span class=".text">${text}</span><a href="#">&#10006;</span>`
toast.querySelector("a").addEventListener("click", async (event) => {
event.preventDefault();
await remove_toast(toast);
});
notifications.appendChild(toast);
toast.timeoutId = setTimeout(() => remove_toast(toast), timeout * 1000);
}
// menu
const body = document.getElementById("container")
const menu = document.getElementById("menu");
const menu_open = document.querySelector("#menu-open i");
const menu_close = document.getElementById("menu-close");
function toggle_menu() {
let new_value = menu.attributes.visible.nodeValue === "true" ? "false" : "true";
menu.attributes.visible.nodeValue = new_value;
}
menu_open.addEventListener("click", toggle_menu);
menu_close.addEventListener("click", toggle_menu);
body.addEventListener("click", (event) => {
if (event.target === menu_open) {
return;
}
menu.attributes.visible.nodeValue = "false";
});
for (const elem of document.querySelectorAll("#menu-open div")) {
elem.addEventListener("click", toggle_menu);
}
// misc
function get_date_string(date) {
var year = date.getUTCFullYear().toString();
var month = (date.getUTCMonth() + 1).toString().padStart(2, "0");
var day = date.getUTCDate().toString().padStart(2, "0");
return `${year}-${month}-${day}`;
}
function append_table_row(table, row_name, row) {
var table_row = table.insertRow(-1);
table_row.id = row_name;
index = 0;
for (var prop in row) {
if (Object.prototype.hasOwnProperty.call(row, prop)) {
var cell = table_row.insertCell(index);
cell.className = prop;
cell.innerHTML = row[prop];
index += 1;
}
}
return table_row;
}
async function request(method, path, body = null) {
var headers = {
"Accept": "application/json"
}
if (body !== null) {
headers["Content-Type"] = "application/json"
body = JSON.stringify(body)
}
const response = await fetch("/api/" + path, {
method: method,
mode: "cors",
cache: "no-store",
redirect: "follow",
body: body,
headers: headers
});
const message = await response.json();
if (Object.hasOwn(message, "error")) {
throw new Error(message.error);
}
if (Array.isArray(message)) {
message.forEach((msg) => {
if (Object.hasOwn(msg, "created")) {
msg.created = new Date(msg.created);
}
});
} else {
if (Object.hasOwn(message, "created")) {
message.created = new Date(message.created);
}
}
return message;
}
// page functions
function page_config() {
const elems = [
document.querySelector("#name"),
document.querySelector("#note"),
document.querySelector("#theme"),
document.querySelector("#log-level"),
document.querySelector("#whitelist-enabled"),
document.querySelector("#approval-required")
]
async function handle_config_change(event) {
params = {
key: event.target.id,
value: event.target.type === "checkbox" ? event.target.checked : event.target.value
}
try {
await request("POST", "v1/config", params);
} catch (error) {
toast(error);
return;
}
if (params.key === "name") {
document.querySelector("#header .title").innerHTML = params.value;
document.querySelector("title").innerHTML = params.value;
}
if (params.key === "theme") {
document.querySelector("link.theme").href = `/theme/${params.value}.css`;
}
toast("Updated config", "message");
}
document.querySelector("#name").addEventListener("keydown", async (event) => {
if (event.which === 13) {
await handle_config_change(event);
}
});
document.querySelector("#note").addEventListener("keydown", async (event) => {
if (event.which === 13 && event.ctrlKey) {
await handle_config_change(event);
}
});
for (const elem of elems) {
elem.addEventListener("change", handle_config_change);
}
}
function page_domain_ban() {
function create_ban_object(domain, reason, note) {
var text = '<details>\n';
text += `<summary>${domain}</summary>\n`;
text += '<div class="grid-2col">\n';
text += `<label for="${domain}-reason" class="reason">Reason</label>\n`;
text += `<textarea id="${domain}-reason" class="reason">${reason}</textarea>\n`;
text += `<label for="${domain}-note" class="note">Note</label>\n`;
text += `<textarea id="${domain}-note" class="note">${note}</textarea>\n`;
text += `<input class="update-ban" type="button" value="Update">`;
text += '</details>';
return text;
}
function add_row_listeners(row) {
row.querySelector(".update-ban").addEventListener("click", async (event) => {
await update_ban(row.id);
});
row.querySelector(".remove a").addEventListener("click", async (event) => {
event.preventDefault();
await unban(row.id);
});
}
async function ban() {
var table = document.querySelector("table");
var elems = {
domain: document.getElementById("new-domain"),
reason: document.getElementById("new-reason"),
note: document.getElementById("new-note")
}
var values = {
domain: elems.domain.value.trim(),
reason: elems.reason.value.trim(),
note: elems.note.value.trim()
}
if (values.domain === "") {
toast("Domain is required");
return;
}
try {
var ban = await request("POST", "v1/domain_ban", values);
} catch (err) {
toast(err);
return
}
var row = append_table_row(document.querySelector("table"), ban.domain, {
domain: create_ban_object(ban.domain, ban.reason, ban.note),
date: get_date_string(ban.created),
remove: `<a href="#" title="Unban domain">&#10006;</a>`
});
add_row_listeners(row);
elems.domain.value = null;
elems.reason.value = null;
elems.note.value = null;
document.querySelector("details.section").open = false;
toast("Banned domain", "message");
}
async function update_ban(domain) {
var row = document.getElementById(domain);
var elems = {
"reason": row.querySelector("textarea.reason"),
"note": row.querySelector("textarea.note")
}
var values = {
"domain": domain,
"reason": elems.reason.value,
"note": elems.note.value
}
try {
await request("PATCH", "v1/domain_ban", values)
} catch (error) {
toast(error);
return;
}
row.querySelector("details").open = false;
toast("Updated baned domain", "message");
}
async function unban(domain) {
try {
await request("DELETE", "v1/domain_ban", {"domain": domain});
} catch (error) {
toast(error);
return;
}
document.getElementById(domain).remove();
toast("Unbanned domain", "message");
}
document.querySelector("#new-ban").addEventListener("click", async (event) => {
await ban();
});
for (var elem of document.querySelectorAll("#add-item input")) {
elem.addEventListener("keydown", async (event) => {
if (event.which === 13) {
await ban();
}
});
}
for (var row of document.querySelector("fieldset.section table").rows) {
if (!row.querySelector(".update-ban")) {
continue;
}
add_row_listeners(row);
}
}
function page_instance() {
function add_instance_listeners(row) {
row.querySelector(".remove a").addEventListener("click", async (event) => {
event.preventDefault();
await del_instance(row.id);
});
}
function add_request_listeners(row) {
row.querySelector(".approve a").addEventListener("click", async (event) => {
event.preventDefault();
await req_response(row.id, true);
});
row.querySelector(".deny a").addEventListener("click", async (event) => {
event.preventDefault();
await req_response(row.id, false);
});
}
async function add_instance() {
var elems = {
actor: document.getElementById("new-actor"),
inbox: document.getElementById("new-inbox"),
followid: document.getElementById("new-followid"),
software: document.getElementById("new-software")
}
var values = {
actor: elems.actor.value.trim(),
inbox: elems.inbox.value.trim(),
followid: elems.followid.value.trim(),
software: elems.software.value.trim()
}
if (values.actor === "") {
toast("Actor is required");
return;
}
try {
var instance = await request("POST", "v1/instance", values);
} catch (err) {
toast(err);
return
}
row = append_table_row(document.getElementById("instances"), instance.domain, {
domain: `<a href="https://${instance.domain}/" target="_new">${instance.domain}</a>`,
software: instance.software,
date: get_date_string(instance.created),
remove: `<a href="#" title="Remove Instance">&#10006;</a>`
});
add_instance_listeners(row);
elems.actor.value = null;
elems.inbox.value = null;
elems.followid.value = null;
elems.software.value = null;
document.querySelector("details.section").open = false;
toast("Added instance", "message");
}
async function del_instance(domain) {
try {
await request("DELETE", "v1/instance", {"domain": domain});
} catch (error) {
toast(error);
return;
}
document.getElementById(domain).remove();
}
async function req_response(domain, accept) {
params = {
"domain": domain,
"accept": accept
}
try {
await request("POST", "v1/request", params);
} catch (error) {
toast(error);
return;
}
document.getElementById(domain).remove();
if (document.getElementById("requests").rows.length < 2) {
document.querySelector("fieldset.requests").remove()
}
if (!accept) {
toast("Denied instance request", "message");
return;
}
instances = await request("GET", `v1/instance`, null);
instances.forEach((instance) => {
if (instance.domain === domain) {
row = append_table_row(document.getElementById("instances"), instance.domain, {
domain: `<a href="https://${instance.domain}/" target="_new">${instance.domain}</a>`,
software: instance.software,
date: get_date_string(instance.created),
remove: `<a href="#" title="Remove Instance">&#10006;</a>`
});
add_instance_listeners(row);
}
});
toast("Accepted instance request", "message");
}
document.querySelector("#add-instance").addEventListener("click", async (event) => {
await add_instance();
})
for (var elem of document.querySelectorAll("#add-item input")) {
elem.addEventListener("keydown", async (event) => {
if (event.which === 13) {
await add_instance();
}
});
}
for (var row of document.querySelector("#instances").rows) {
if (!row.querySelector(".remove a")) {
continue;
}
add_instance_listeners(row);
}
if (document.querySelector("#requests")) {
for (var row of document.querySelector("#requests").rows) {
if (!row.querySelector(".approve a")) {
continue;
}
add_request_listeners(row);
}
}
}
function page_login() {
const fields = {
username: document.querySelector("#username"),
password: document.querySelector("#password")
}
async function login(event) {
const values = {
username: fields.username.value.trim(),
password: fields.password.value.trim()
}
if (values.username === "" | values.password === "") {
toast("Username and/or password field is blank");
return;
}
try {
await request("POST", "v1/token", values);
} catch (error) {
toast(error);
return;
}
document.location = "/";
}
document.querySelector("#username").addEventListener("keydown", async (event) => {
if (event.which === 13) {
fields.password.focus();
fields.password.select();
}
});
document.querySelector("#password").addEventListener("keydown", async (event) => {
if (event.which === 13) {
await login(event);
}
});
document.querySelector(".submit").addEventListener("click", login);
}
function page_software_ban() {
function create_ban_object(name, reason, note) {
var text = '<details>\n';
text += `<summary>${name}</summary>\n`;
text += '<div class="grid-2col">\n';
text += `<label for="${name}-reason" class="reason">Reason</label>\n`;
text += `<textarea id="${name}-reason" class="reason">${reason}</textarea>\n`;
text += `<label for="${name}-note" class="note">Note</label>\n`;
text += `<textarea id="${name}-note" class="note">${note}</textarea>\n`;
text += `<input class="update-ban" type="button" value="Update">`;
text += '</details>';
return text;
}
function add_row_listeners(row) {
row.querySelector(".update-ban").addEventListener("click", async (event) => {
await update_ban(row.id);
});
row.querySelector(".remove a").addEventListener("click", async (event) => {
event.preventDefault();
await unban(row.id);
});
}
async function ban() {
var elems = {
name: document.getElementById("new-name"),
reason: document.getElementById("new-reason"),
note: document.getElementById("new-note")
}
var values = {
name: elems.name.value.trim(),
reason: elems.reason.value,
note: elems.note.value
}
if (values.name === "") {
toast("Domain is required");
return;
}
try {
var ban = await request("POST", "v1/software_ban", values);
} catch (err) {
toast(err);
return
}
var row = append_table_row(document.getElementById("bans"), ban.name, {
name: create_ban_object(ban.name, ban.reason, ban.note),
date: get_date_string(ban.created),
remove: `<a href="#" title="Unban software">&#10006;</a>`
});
add_row_listeners(row);
elems.name.value = null;
elems.reason.value = null;
elems.note.value = null;
document.querySelector("details.section").open = false;
toast("Banned software", "message");
}
async function update_ban(name) {
var row = document.getElementById(name);
var elems = {
"reason": row.querySelector("textarea.reason"),
"note": row.querySelector("textarea.note")
}
var values = {
"name": name,
"reason": elems.reason.value,
"note": elems.note.value
}
try {
await request("PATCH", "v1/software_ban", values)
} catch (error) {
toast(error);
return;
}
row.querySelector("details").open = false;
toast("Updated software ban", "message");
}
async function unban(name) {
try {
await request("DELETE", "v1/software_ban", {"name": name});
} catch (error) {
toast(error);
return;
}
document.getElementById(name).remove();
toast("Unbanned software", "message");
}
document.querySelector("#new-ban").addEventListener("click", async (event) => {
await ban();
});
for (var elem of document.querySelectorAll("#add-item input")) {
elem.addEventListener("keydown", async (event) => {
if (event.which === 13) {
await ban();
}
});
}
for (var elem of document.querySelectorAll("#add-item textarea")) {
elem.addEventListener("keydown", async (event) => {
if (event.which === 13 && event.ctrlKey) {
await ban();
}
});
}
for (var row of document.querySelector("#bans").rows) {
if (!row.querySelector(".update-ban")) {
continue;
}
add_row_listeners(row);
}
}
function page_user() {
function add_row_listeners(row) {
row.querySelector(".remove a").addEventListener("click", async (event) => {
event.preventDefault();
await del_user(row.id);
});
}
async function add_user() {
var elems = {
username: document.getElementById("new-username"),
password: document.getElementById("new-password"),
password2: document.getElementById("new-password2"),
handle: document.getElementById("new-handle")
}
var values = {
username: elems.username.value.trim(),
password: elems.password.value.trim(),
password2: elems.password2.value.trim(),
handle: elems.handle.value.trim()
}
if (values.username === "" | values.password === "" | values.password2 === "") {
toast("Username, password, and password2 are required");
return;
}
if (values.password !== values.password2) {
toast("Passwords do not match");
return;
}
try {
var user = await request("POST", "v1/user", values);
} catch (err) {
toast(err);
return
}
var row = append_table_row(document.querySelector("fieldset.section table"), user.username, {
domain: user.username,
handle: user.handle ? self.handle : "n/a",
date: get_date_string(user.created),
remove: `<a href="#" title="Delete User">&#10006;</a>`
});
add_row_listeners(row);
elems.username.value = null;
elems.password.value = null;
elems.password2.value = null;
elems.handle.value = null;
document.querySelector("details.section").open = false;
toast("Created user", "message");
}
async function del_user(username) {
try {
await request("DELETE", "v1/user", {"username": username});
} catch (error) {
toast(error);
return;
}
document.getElementById(username).remove();
toast("Deleted user", "message");
}
document.querySelector("#new-user").addEventListener("click", async (event) => {
await add_user();
});
for (var elem of document.querySelectorAll("#add-item input")) {
elem.addEventListener("keydown", async (event) => {
if (event.which === 13) {
await add_user();
}
});
}
for (var row of document.querySelector("#users").rows) {
if (!row.querySelector(".remove a")) {
continue;
}
add_row_listeners(row);
}
}
function page_whitelist() {
function add_row_listeners(row) {
row.querySelector(".remove a").addEventListener("click", async (event) => {
event.preventDefault();
await del_whitelist(row.id);
});
}
async function add_whitelist() {
var domain_elem = document.getElementById("new-domain");
var domain = domain_elem.value.trim();
if (domain === "") {
toast("Domain is required");
return;
}
try {
var item = await request("POST", "v1/whitelist", {"domain": domain});
} catch (err) {
toast(err);
return;
}
var row = append_table_row(document.getElementById("whitelist"), item.domain, {
domain: item.domain,
date: get_date_string(item.created),
remove: `<a href="#" title="Remove whitelisted domain">&#10006;</a>`
});
add_row_listeners(row);
domain_elem.value = null;
document.querySelector("details.section").open = false;
toast("Added domain", "message");
}
async function del_whitelist(domain) {
try {
await request("DELETE", "v1/whitelist", {"domain": domain});
} catch (error) {
toast(error);
return;
}
document.getElementById(domain).remove();
toast("Removed domain", "message");
}
document.querySelector("#new-item").addEventListener("click", async (event) => {
await add_whitelist();
});
document.querySelector("#add-item").addEventListener("keydown", async (event) => {
if (event.which === 13) {
await add_whitelist();
}
});
for (var row of document.querySelector("fieldset.section table").rows) {
if (!row.querySelector(".remove a")) {
continue;
}
add_row_listeners(row);
}
}
if (location.pathname.startsWith("/admin/config")) {
page_config();
} else if (location.pathname.startsWith("/admin/domain_bans")) {
page_domain_ban();
} else if (location.pathname.startsWith("/admin/instances")) {
page_instance();
} else if (location.pathname.startsWith("/admin/login")) {
page_login();
} else if (location.pathname.startsWith("/admin/software_bans")) {
page_software_ban();
} else if (location.pathname.startsWith("/admin/users")) {
page_user();
} else if (location.pathname.startsWith("/admin/whitelist")) {
page_whitelist();
}

View file

@ -8,8 +8,12 @@ from pathlib import Path
from typing import TYPE_CHECKING, Any, Protocol from typing import TYPE_CHECKING, Any, Protocol
if TYPE_CHECKING: if TYPE_CHECKING:
try:
from typing import Self from typing import Self
except ImportError:
from typing_extensions import Self
class LoggingMethod(Protocol): class LoggingMethod(Protocol):
def __call__(self, msg: Any, *args: Any, **kwargs: Any) -> None: ... def __call__(self, msg: Any, *args: Any, **kwargs: Any) -> None: ...

View file

@ -9,15 +9,25 @@ import socket
from aiohttp.web import Response as AiohttpResponse from aiohttp.web import Response as AiohttpResponse
from collections.abc import Sequence from collections.abc import Sequence
from datetime import datetime from datetime import datetime
from importlib.resources import files as pkgfiles
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, TypedDict, TypeVar from typing import TYPE_CHECKING, Any, TypedDict, TypeVar
from uuid import uuid4 from uuid import uuid4
try:
from importlib.resources import files as pkgfiles
except ImportError:
from importlib_resources import files as pkgfiles # type: ignore
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Self
from .application import Application from .application import Application
try:
from typing import Self
except ImportError:
from typing_extensions import Self
T = TypeVar('T') T = TypeVar('T')
ResponseType = TypedDict('ResponseType', { ResponseType = TypedDict('ResponseType', {

View file

@ -20,9 +20,6 @@ if TYPE_CHECKING:
class Template(Environment): class Template(Environment):
_render_markdown: Callable[[str], str]
def __init__(self, app: Application): def __init__(self, app: Application):
Environment.__init__(self, Environment.__init__(self,
autoescape = True, autoescape = True,
@ -59,7 +56,7 @@ class Template(Environment):
def render_markdown(self, text: str) -> str: def render_markdown(self, text: str) -> str:
return self._render_markdown(text) return self._render_markdown(text) # type: ignore
class MarkdownExtension(Extension): class MarkdownExtension(Extension):

View file

@ -13,8 +13,7 @@ from ..database import ConfigData
from ..misc import Message, Response, boolean, get_app from ..misc import Message, Response, boolean, get_app
DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob' ALLOWED_HEADERS = {
ALLOWED_HEADERS: set[str] = {
'accept', 'accept',
'authorization', 'authorization',
'content-type' 'content-type'

View file

@ -18,12 +18,18 @@ from ..http_client import HttpClient
from ..misc import Response, get_app from ..misc import Response, get_app
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Self
from ..application import Application from ..application import Application
from ..template import Template from ..template import Template
try:
from typing import Self
except ImportError:
from typing_extensions import Self
HandlerCallback = Callable[[Request], Awaitable[Response]] HandlerCallback = Callable[[Request], Awaitable[Response]]
VIEWS: list[tuple[str, type[View]]] = [] VIEWS: list[tuple[str, type[View]]] = []

View file

@ -1,17 +1,14 @@
from __future__ import annotations
import asyncio import asyncio
import traceback import traceback
import typing
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
from asyncio.exceptions import TimeoutError as AsyncTimeoutError from asyncio.exceptions import TimeoutError as AsyncTimeoutError
from dataclasses import dataclass from dataclasses import dataclass
from multiprocessing import Event, Process, Queue, Value from multiprocessing import Event, Process, Queue, Value
from multiprocessing.queues import Queue as QueueType
from multiprocessing.sharedctypes import Synchronized
from multiprocessing.synchronize import Event as EventType from multiprocessing.synchronize import Event as EventType
from pathlib import Path from pathlib import Path
from queue import Empty from queue import Empty, Queue as QueueType
from urllib.parse import urlparse from urllib.parse import urlparse
from . import application, logger as logging from . import application, logger as logging
@ -19,6 +16,9 @@ from .database.schema import Instance
from .http_client import HttpClient from .http_client import HttpClient
from .misc import IS_WINDOWS, Message, get_app from .misc import IS_WINDOWS, Message, get_app
if typing.TYPE_CHECKING:
from .multiprocessing.synchronize import Syncronized
@dataclass @dataclass
class QueueItem: class QueueItem:
@ -40,13 +40,13 @@ class PushWorker(Process):
client: HttpClient client: HttpClient
def __init__(self, queue: QueueType[QueueItem], log_level: Synchronized[int]) -> None: def __init__(self, queue: QueueType[QueueItem], log_level: "Syncronized[str]") -> None:
Process.__init__(self) Process.__init__(self)
self.queue: QueueType[QueueItem] = queue self.queue: QueueType[QueueItem] = queue
self.shutdown: EventType = Event() self.shutdown: EventType = Event()
self.path: Path = get_app().config.path self.path: Path = get_app().config.path
self.log_level: Synchronized[int] = log_level self.log_level: "Syncronized[str]" = log_level
self._log_level_changed: EventType = Event() self._log_level_changed: EventType = Event()
@ -113,8 +113,8 @@ class PushWorker(Process):
class PushWorkers(list[PushWorker]): class PushWorkers(list[PushWorker]):
def __init__(self, count: int) -> None: def __init__(self, count: int) -> None:
self.queue: QueueType[QueueItem] = Queue() self.queue: QueueType[QueueItem] = Queue() # type: ignore[assignment]
self._log_level: Synchronized[int] = Value("i", logging.get_level()) self._log_level: "Syncronized[str]" = Value("i", logging.get_level())
self._count: int = count self._count: int = count