replace single quotes with double quotes

This commit is contained in:
Izalia Mae 2024-11-28 07:09:53 -05:00
parent 29ebba7999
commit 338fd26688
23 changed files with 1113 additions and 1109 deletions

103
dev.py
View file

@ -37,32 +37,32 @@ from watchdog.events import FileSystemEvent, PatternMatchingEventHandler
REPO = Path(__file__).parent
IGNORE_EXT = {
'.py',
'.pyc'
".py",
".pyc"
}
@click.group('cli')
@click.group("cli")
def cli() -> None:
'Useful commands for development'
"Useful commands for development"
@cli.command('install')
@click.option('--no-dev', '-d', is_flag = True, help = 'Do not install development dependencies')
@cli.command("install")
@click.option("--no-dev", "-d", is_flag = True, help = "Do not install development dependencies")
def cli_install(no_dev: bool) -> None:
with open('pyproject.toml', 'r', encoding = 'utf-8') as fd:
with open("pyproject.toml", "r", encoding = "utf-8") as fd:
data = tomllib.loads(fd.read())
deps = data['project']['dependencies']
deps.extend(data['project']['optional-dependencies']['dev'])
deps = data["project"]["dependencies"]
deps.extend(data["project"]["optional-dependencies"]["dev"])
subprocess.run([sys.executable, '-m', 'pip', 'install', '-U', *deps], check = False)
subprocess.run([sys.executable, "-m", "pip", "install", "-U", *deps], check = False)
@cli.command('lint')
@click.argument('path', required = False, type = Path, default = REPO.joinpath('relay'))
@click.option('--watch', '-w', is_flag = True,
help = 'Automatically, re-run the linters on source change')
@cli.command("lint")
@click.argument("path", required = False, type = Path, default = REPO.joinpath("relay"))
@click.option("--watch", "-w", is_flag = True,
help = "Automatically, re-run the linters on source change")
def cli_lint(path: Path, watch: bool) -> None:
path = path.expanduser().resolve()
@ -70,74 +70,73 @@ def cli_lint(path: Path, watch: bool) -> None:
handle_run_watcher([sys.executable, "dev.py", "lint", str(path)], wait = True)
return
flake8 = [sys.executable, '-m', 'flake8', "dev.py", str(path)]
mypy = [sys.executable, '-m', 'mypy', '--python-version', '3.12', 'dev.py', str(path)]
flake8 = [sys.executable, "-m", "flake8", "dev.py", str(path)]
mypy = [sys.executable, "-m", "mypy", "--python-version", "3.12", "dev.py", str(path)]
click.echo('----- flake8 -----')
click.echo("----- flake8 -----")
subprocess.run(flake8)
click.echo('\n\n----- mypy -----')
click.echo("\n\n----- mypy -----")
subprocess.run(mypy)
@cli.command('clean')
@cli.command("clean")
def cli_clean() -> None:
dirs = {
'dist',
'build',
'dist-pypi'
"dist",
"build",
"dist-pypi"
}
for directory in dirs:
shutil.rmtree(directory, ignore_errors = True)
for path in REPO.glob('*.egg-info'):
for path in REPO.glob("*.egg-info"):
shutil.rmtree(path)
for path in REPO.glob('*.spec'):
for path in REPO.glob("*.spec"):
path.unlink()
@cli.command('build')
@cli.command("build")
def cli_build() -> None:
from relay import __version__
with TemporaryDirectory() as tmp:
arch = 'amd64' if sys.maxsize >= 2**32 else 'i386'
arch = "amd64" if sys.maxsize >= 2**32 else "i386"
cmd = [
sys.executable, '-m', 'PyInstaller',
'--collect-data', 'relay',
'--collect-data', 'aiohttp_swagger',
'--hidden-import', 'pg8000',
'--hidden-import', 'sqlite3',
'--name', f'activityrelay-{__version__}-{platform.system().lower()}-{arch}',
'--workpath', tmp,
'--onefile', 'relay/__main__.py',
sys.executable, "-m", "PyInstaller",
"--collect-data", "relay",
"--hidden-import", "pg8000",
"--hidden-import", "sqlite3",
"--name", f"activityrelay-{__version__}-{platform.system().lower()}-{arch}",
"--workpath", tmp,
"--onefile", "relay/__main__.py",
]
if platform.system() == 'Windows':
cmd.append('--console')
if platform.system() == "Windows":
cmd.append("--console")
# putting the spec path on a different drive than the source dir breaks
if str(REPO)[0] == tmp[0]:
cmd.extend(['--specpath', tmp])
cmd.extend(["--specpath", tmp])
else:
cmd.append('--strip')
cmd.extend(['--specpath', tmp])
cmd.append("--strip")
cmd.extend(["--specpath", tmp])
subprocess.run(cmd, check = False)
@cli.command('run')
@click.option('--dev', '-d', is_flag = True)
@cli.command("run")
@click.option("--dev", "-d", is_flag = True)
def cli_run(dev: bool) -> None:
print('Starting process watcher')
print("Starting process watcher")
cmd = [sys.executable, '-m', 'relay', 'run']
cmd = [sys.executable, "-m", "relay", "run"]
if dev:
cmd.append('-d')
cmd.append("-d")
handle_run_watcher(cmd, watch_path = REPO.joinpath("relay"))
@ -167,7 +166,7 @@ def handle_run_watcher(
class WatchHandler(PatternMatchingEventHandler):
patterns = ['*.py']
patterns = ["*.py"]
def __init__(self, *commands: Sequence[str], wait: bool = False) -> None:
@ -184,7 +183,7 @@ class WatchHandler(PatternMatchingEventHandler):
if proc.poll() is not None:
continue
print(f'Terminating process {proc.pid}')
print(f"Terminating process {proc.pid}")
proc.terminate()
sec = 0.0
@ -193,11 +192,11 @@ class WatchHandler(PatternMatchingEventHandler):
sec += 0.1
if sec >= 5:
print('Failed to terminate. Killing process...')
print("Failed to terminate. Killing process...")
proc.kill()
break
print('Process terminated')
print("Process terminated")
def run_procs(self, restart: bool = False) -> None:
@ -213,21 +212,21 @@ class WatchHandler(PatternMatchingEventHandler):
self.procs = []
for cmd in self.commands:
print('Running command:', ' '.join(cmd))
print("Running command:", " ".join(cmd))
subprocess.run(cmd)
else:
self.procs = list(subprocess.Popen(cmd) for cmd in self.commands)
pids = (str(proc.pid) for proc in self.procs)
print('Started processes with PIDs:', ', '.join(pids))
print("Started processes with PIDs:", ", ".join(pids))
def on_any_event(self, event: FileSystemEvent) -> None:
if event.event_type not in ['modified', 'created', 'deleted']:
if event.event_type not in ["modified", "created", "deleted"]:
return
self.run_procs(restart = True)
if __name__ == '__main__':
if __name__ == "__main__":
cli()

View file

@ -1 +1 @@
__version__ = '0.3.3'
__version__ = "0.3.3"

View file

@ -1,8 +1,5 @@
import multiprocessing
from relay.manage import main
if __name__ == '__main__':
multiprocessing.freeze_support()
if __name__ == "__main__":
main()

View file

@ -41,10 +41,10 @@ def get_csp(request: web.Request) -> str:
"img-src 'self'",
"object-src 'none'",
"frame-ancestors 'none'",
f"manifest-src 'self' https://{request.app['config'].domain}"
f"manifest-src 'self' https://{request.app["config"].domain}"
]
return '; '.join(data) + ';'
return "; ".join(data) + ";"
class Application(web.Application):
@ -61,20 +61,20 @@ class Application(web.Application):
Application.DEFAULT = self
self['running'] = False
self['signer'] = None
self['start_time'] = None
self['cleanup_thread'] = None
self['dev'] = dev
self["running"] = False
self["signer"] = None
self["start_time"] = None
self["cleanup_thread"] = None
self["dev"] = dev
self['config'] = Config(cfgpath, load = True)
self['database'] = get_database(self.config)
self['client'] = HttpClient()
self['cache'] = get_cache(self)
self['cache'].setup()
self['template'] = Template(self)
self['push_queue'] = multiprocessing.Queue()
self['workers'] = PushWorkers(self.config.workers)
self["config"] = Config(cfgpath, load = True)
self["database"] = get_database(self.config)
self["client"] = HttpClient()
self["cache"] = get_cache(self)
self["cache"].setup()
self["template"] = Template(self)
self["push_queue"] = multiprocessing.Queue()
self["workers"] = PushWorkers(self.config.workers)
self.cache.setup()
self.on_cleanup.append(handle_cleanup) # type: ignore
@ -82,69 +82,69 @@ class Application(web.Application):
@property
def cache(self) -> Cache:
return cast(Cache, self['cache'])
return cast(Cache, self["cache"])
@property
def client(self) -> HttpClient:
return cast(HttpClient, self['client'])
return cast(HttpClient, self["client"])
@property
def config(self) -> Config:
return cast(Config, self['config'])
return cast(Config, self["config"])
@property
def database(self) -> Database[Connection]:
return cast(Database[Connection], self['database'])
return cast(Database[Connection], self["database"])
@property
def signer(self) -> Signer:
return cast(Signer, self['signer'])
return cast(Signer, self["signer"])
@signer.setter
def signer(self, value: Signer | str) -> None:
if isinstance(value, Signer):
self['signer'] = value
self["signer"] = value
return
self['signer'] = Signer(value, self.config.keyid)
self["signer"] = Signer(value, self.config.keyid)
@property
def template(self) -> Template:
return cast(Template, self['template'])
return cast(Template, self["template"])
@property
def uptime(self) -> timedelta:
if not self['start_time']:
if not self["start_time"]:
return timedelta(seconds=0)
uptime = datetime.now() - self['start_time']
uptime = datetime.now() - self["start_time"]
return timedelta(seconds=uptime.seconds)
@property
def workers(self) -> PushWorkers:
return cast(PushWorkers, self['workers'])
return cast(PushWorkers, self["workers"])
def push_message(self, inbox: str, message: Message, instance: Instance) -> None:
self['workers'].push_message(inbox, message, instance)
self["workers"].push_message(inbox, message, instance)
def register_static_routes(self) -> None:
if self['dev']:
static = StaticResource('/static', File.from_resource('relay', 'frontend/static'))
if self["dev"]:
static = StaticResource("/static", File.from_resource("relay", "frontend/static"))
else:
static = CachedStaticResource(
'/static', Path(File.from_resource('relay', 'frontend/static'))
"/static", Path(File.from_resource("relay", "frontend/static"))
)
self.router.register_resource(static)
@ -158,18 +158,18 @@ class Application(web.Application):
host = self.config.listen
port = self.config.port
if port_check(port, '127.0.0.1' if host == '0.0.0.0' else host):
logging.error(f'A server is already running on {host}:{port}')
if port_check(port, "127.0.0.1" if host == "0.0.0.0" else host):
logging.error(f"A server is already running on {host}:{port}")
return
self.register_static_routes()
logging.info(f'Starting webserver at {domain} ({host}:{port})')
logging.info(f"Starting webserver at {domain} ({host}:{port})")
asyncio.run(self.handle_run())
def set_signal_handler(self, startup: bool) -> None:
for sig in ('SIGHUP', 'SIGINT', 'SIGQUIT', 'SIGTERM'):
for sig in ("SIGHUP", "SIGINT", "SIGQUIT", "SIGTERM"):
try:
signal.signal(getattr(signal, sig), self.stop if startup else signal.SIG_DFL)
@ -179,22 +179,25 @@ class Application(web.Application):
def stop(self, *_: Any) -> None:
self['running'] = False
self["running"] = False
async def handle_run(self) -> None:
self['running'] = True
self["running"] = True
self.set_signal_handler(True)
self['client'].open()
self['database'].connect()
self['cache'].setup()
self['cleanup_thread'] = CacheCleanupThread(self)
self['cleanup_thread'].start()
self['workers'].start()
self["client"].open()
self["database"].connect()
self["cache"].setup()
self["cleanup_thread"] = CacheCleanupThread(self)
self["cleanup_thread"].start()
self["workers"].start()
runner = web.AppRunner(
self, access_log_format = "%{X-Forwarded-For}i \"%r\" %s %b \"%{User-Agent}i\""
)
runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"')
await runner.setup()
site = web.TCPSite(
@ -205,22 +208,22 @@ class Application(web.Application):
)
await site.start()
self['starttime'] = datetime.now()
self["starttime"] = datetime.now()
while self['running']:
while self["running"]:
await asyncio.sleep(0.25)
await site.stop()
self['workers'].stop()
self["workers"].stop()
self.set_signal_handler(False)
self['starttime'] = None
self['running'] = False
self['cleanup_thread'].stop()
self['database'].disconnect()
self['cache'].close()
self["starttime"] = None
self["running"] = False
self["cleanup_thread"].stop()
self["database"].disconnect()
self["cache"].close()
class CachedStaticResource(StaticResource):
@ -229,19 +232,19 @@ class CachedStaticResource(StaticResource):
self.cache: dict[str, bytes] = {}
for filename in path.rglob('*'):
for filename in path.rglob("*"):
if filename.is_dir():
continue
rel_path = str(filename.relative_to(path))
with filename.open('rb') as fd:
logging.debug('Loading static resource "%s"', rel_path)
with filename.open("rb") as fd:
logging.debug("Loading static resource \"%s\"", rel_path)
self.cache[rel_path] = fd.read()
async def _handle(self, request: web.Request) -> web.StreamResponse:
rel_url = request.match_info['filename']
rel_url = request.match_info["filename"]
if Path(rel_url).anchor:
raise web.HTTPForbidden()
@ -281,8 +284,8 @@ class CacheCleanupThread(Thread):
def format_error(request: web.Request, error: HttpError) -> Response:
if request.path.startswith(JSON_PATHS) or 'json' in request.headers.get('accept', ''):
return Response.new({'error': error.message}, error.status, ctype = 'json')
if request.path.startswith(JSON_PATHS) or "json" in request.headers.get("accept", ""):
return Response.new({"error": error.message}, error.status, ctype = "json")
else:
context = {"e": error}
@ -294,27 +297,27 @@ async def handle_response_headers(
request: web.Request,
handler: Callable[[web.Request], Awaitable[Response]]) -> Response:
request['hash'] = b64encode(get_random_bytes(16)).decode('ascii')
request['token'] = None
request['user'] = None
request["hash"] = b64encode(get_random_bytes(16)).decode("ascii")
request["token"] = None
request["user"] = None
app: Application = request.app # type: ignore[assignment]
if request.path in {"/", "/docs"} or request.path.startswith(TOKEN_PATHS):
with app.database.session() as conn:
tokens = (
request.headers.get('Authorization', '').replace('Bearer', '').strip(),
request.cookies.get('user-token')
request.headers.get("Authorization", "").replace("Bearer", "").strip(),
request.cookies.get("user-token")
)
for token in tokens:
if not token:
continue
request['token'] = conn.get_app_by_token(token)
request["token"] = conn.get_app_by_token(token)
if request['token'] is not None:
request['user'] = conn.get_user(request['token'].user)
if request["token"] is not None:
request["user"] = conn.get_user(request["token"].user)
break
@ -338,21 +341,21 @@ async def handle_response_headers(
raise
except Exception:
resp = format_error(request, HttpError(500, 'Internal server error'))
resp = format_error(request, HttpError(500, "Internal server error"))
traceback.print_exc()
resp.headers['Server'] = 'ActivityRelay'
resp.headers["Server"] = "ActivityRelay"
# Still have to figure out how csp headers work
if resp.content_type == 'text/html' and not request.path.startswith("/api"):
resp.headers['Content-Security-Policy'] = get_csp(request)
if resp.content_type == "text/html" and not request.path.startswith("/api"):
resp.headers["Content-Security-Policy"] = get_csp(request)
if not request.app['dev'] and request.path.endswith(('.css', '.js', '.woff2')):
if not request.app["dev"] and request.path.endswith((".css", ".js", ".woff2")):
# cache for 2 weeks
resp.headers['Cache-Control'] = 'public,max-age=1209600,immutable'
resp.headers["Cache-Control"] = "public,max-age=1209600,immutable"
else:
resp.headers['Cache-Control'] = 'no-store'
resp.headers["Cache-Control"] = "no-store"
return resp
@ -362,25 +365,25 @@ async def handle_frontend_path(
request: web.Request,
handler: Callable[[web.Request], Awaitable[Response]]) -> Response:
if request['user'] is not None and request.path == '/login':
return Response.new_redir('/')
if request["user"] is not None and request.path == "/login":
return Response.new_redir("/")
if request.path.startswith(TOKEN_PATHS[:2]) and request['user'] is None:
if request.path == '/logout':
return Response.new_redir('/')
if request.path.startswith(TOKEN_PATHS[:2]) and request["user"] is None:
if request.path == "/logout":
return Response.new_redir("/")
response = Response.new_redir(f'/login?redir={request.path}')
response = Response.new_redir(f"/login?redir={request.path}")
if request['token'] is not None:
response.del_cookie('user-token')
if request["token"] is not None:
response.del_cookie("user-token")
return response
response = await handler(request)
if not request.path.startswith('/api'):
if request['user'] is None and request['token'] is not None:
response.del_cookie('user-token')
if not request.path.startswith("/api"):
if request["user"] is None and request["token"] is not None:
response.del_cookie("user-token")
return response

View file

@ -24,11 +24,11 @@ DeserializerCallback = Callable[[str], Any]
BACKENDS: dict[str, type[Cache]] = {}
CONVERTERS: dict[str, tuple[SerializerCallback, DeserializerCallback]] = {
'str': (str, str),
'int': (str, int),
'bool': (str, convert_to_boolean),
'json': (json.dumps, json.loads),
'message': (lambda x: x.to_json(), Message.parse)
"str": (str, str),
"int": (str, int),
"bool": (str, convert_to_boolean),
"json": (json.dumps, json.loads),
"message": (lambda x: x.to_json(), Message.parse)
}
@ -49,14 +49,14 @@ def register_cache(backend: type[Cache]) -> type[Cache]:
return backend
def serialize_value(value: Any, value_type: str = 'str') -> str:
def serialize_value(value: Any, value_type: str = "str") -> str:
if isinstance(value, str):
return value
return CONVERTERS[value_type][0](value)
def deserialize_value(value: str, value_type: str = 'str') -> Any:
def deserialize_value(value: str, value_type: str = "str") -> Any:
return CONVERTERS[value_type][1](value)
@ -93,7 +93,7 @@ class Item:
class Cache(ABC):
name: str = 'null'
name: str
def __init__(self, app: Application):
@ -116,7 +116,7 @@ class Cache(ABC):
@abstractmethod
def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> Item:
def set(self, namespace: str, key: str, value: Any, value_type: str = "key") -> Item:
...
@ -160,7 +160,7 @@ class Cache(ABC):
@register_cache
class SqlCache(Cache):
name: str = 'database'
name: str = "database"
def __init__(self, app: Application):
@ -173,16 +173,16 @@ class SqlCache(Cache):
raise RuntimeError("Database has not been setup")
params = {
'namespace': namespace,
'key': key
"namespace": namespace,
"key": key
}
with self._db.session(False) as conn:
with conn.run('get-cache-item', params) as cur:
with conn.run("get-cache-item", params) as cur:
if not (row := cur.one(Row)):
raise KeyError(f'{namespace}:{key}')
raise KeyError(f"{namespace}:{key}")
row.pop('id', None)
row.pop("id", None)
return Item.from_data(*tuple(row.values()))
@ -191,8 +191,8 @@ class SqlCache(Cache):
raise RuntimeError("Database has not been setup")
with self._db.session(False) as conn:
for row in conn.run('get-cache-keys', {'namespace': namespace}):
yield row['key']
for row in conn.run("get-cache-keys", {"namespace": namespace}):
yield row["key"]
def get_namespaces(self) -> Iterator[str]:
@ -200,28 +200,28 @@ class SqlCache(Cache):
raise RuntimeError("Database has not been setup")
with self._db.session(False) as conn:
for row in conn.run('get-cache-namespaces', None):
yield row['namespace']
for row in conn.run("get-cache-namespaces", None):
yield row["namespace"]
def set(self, namespace: str, key: str, value: Any, value_type: str = 'str') -> Item:
def set(self, namespace: str, key: str, value: Any, value_type: str = "str") -> Item:
if self._db is None:
raise RuntimeError("Database has not been setup")
params = {
'namespace': namespace,
'key': key,
'value': serialize_value(value, value_type),
'type': value_type,
'date': Date.new_utc()
"namespace": namespace,
"key": key,
"value": serialize_value(value, value_type),
"type": value_type,
"date": Date.new_utc()
}
with self._db.session(True) as conn:
with conn.run('set-cache-item', params) as cur:
with conn.run("set-cache-item", params) as cur:
if (row := cur.one(Row)) is None:
raise RuntimeError("Cache item not set")
row.pop('id', None)
row.pop("id", None)
return Item.from_data(*tuple(row.values()))
@ -230,12 +230,12 @@ class SqlCache(Cache):
raise RuntimeError("Database has not been setup")
params = {
'namespace': namespace,
'key': key
"namespace": namespace,
"key": key
}
with self._db.session(True) as conn:
with conn.run('del-cache-item', params):
with conn.run("del-cache-item", params):
pass
@ -267,7 +267,7 @@ class SqlCache(Cache):
self._db.connect()
with self._db.session(True) as conn:
with conn.run(f'create-cache-table-{self._db.backend_type.value}', None):
with conn.run(f"create-cache-table-{self._db.backend_type.value}", None):
pass
@ -281,7 +281,7 @@ class SqlCache(Cache):
@register_cache
class RedisCache(Cache):
name: str = 'redis'
name: str = "redis"
def __init__(self, app: Application):
@ -295,7 +295,7 @@ class RedisCache(Cache):
def get_key_name(self, namespace: str, key: str) -> str:
return f'{self.prefix}:{namespace}:{key}'
return f"{self.prefix}:{namespace}:{key}"
def get(self, namespace: str, key: str) -> Item:
@ -305,9 +305,9 @@ class RedisCache(Cache):
key_name = self.get_key_name(namespace, key)
if not (raw_value := self._rd.get(key_name)):
raise KeyError(f'{namespace}:{key}')
raise KeyError(f"{namespace}:{key}")
value_type, updated, value = raw_value.split(':', 2) # type: ignore[union-attr]
value_type, updated, value = raw_value.split(":", 2) # type: ignore[union-attr]
return Item.from_data(
namespace,
@ -322,8 +322,8 @@ class RedisCache(Cache):
if self._rd is None:
raise ConnectionError("Not connected")
for key in self._rd.scan_iter(self.get_key_name(namespace, '*')):
*_, key_name = key.split(':', 2)
for key in self._rd.scan_iter(self.get_key_name(namespace, "*")):
*_, key_name = key.split(":", 2)
yield key_name
@ -333,15 +333,15 @@ class RedisCache(Cache):
namespaces = []
for key in self._rd.scan_iter(f'{self.prefix}:*'):
_, namespace, _ = key.split(':', 2)
for key in self._rd.scan_iter(f"{self.prefix}:*"):
_, namespace, _ = key.split(":", 2)
if namespace not in namespaces:
namespaces.append(namespace)
yield namespace
def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> Item:
def set(self, namespace: str, key: str, value: Any, value_type: str = "key") -> Item:
if self._rd is None:
raise ConnectionError("Not connected")
@ -350,7 +350,7 @@ class RedisCache(Cache):
self._rd.set(
self.get_key_name(namespace, key),
f'{value_type}:{date}:{value}'
f"{value_type}:{date}:{value}"
)
return self.get(namespace, key)
@ -369,8 +369,8 @@ class RedisCache(Cache):
limit = Date.new_utc() - timedelta(days = days)
for full_key in self._rd.scan_iter(f'{self.prefix}:*'):
_, namespace, key = full_key.split(':', 2)
for full_key in self._rd.scan_iter(f"{self.prefix}:*"):
_, namespace, key = full_key.split(":", 2)
item = self.get(namespace, key)
if item.updated < limit:
@ -389,11 +389,11 @@ class RedisCache(Cache):
return
options: RedisConnectType = {
'client_name': f'ActivityRelay_{self.app.config.domain}',
'decode_responses': True,
'username': self.app.config.rd_user,
'password': self.app.config.rd_pass,
'db': self.app.config.rd_database
"client_name": f"ActivityRelay_{self.app.config.domain}",
"decode_responses": True,
"username": self.app.config.rd_user,
"password": self.app.config.rd_pass,
"db": self.app.config.rd_database
}
if os.path.exists(self.app.config.rd_host):

View file

@ -14,21 +14,21 @@ class RelayConfig(dict[str, Any]):
dict.__init__(self, {})
if self.is_docker:
path = '/data/config.yaml'
path = "/data/config.yaml"
self._path = Path(path).expanduser().resolve()
self.reset()
def __setitem__(self, key: str, value: Any) -> None:
if key in {'blocked_instances', 'blocked_software', 'whitelist'}:
if key in {"blocked_instances", "blocked_software", "whitelist"}:
assert isinstance(value, (list, set, tuple))
elif key in {'port', 'workers', 'json_cache', 'timeout'}:
elif key in {"port", "workers", "json_cache", "timeout"}:
if not isinstance(value, int):
value = int(value)
elif key == 'whitelist_enabled':
elif key == "whitelist_enabled":
if not isinstance(value, bool):
value = convert_to_boolean(value)
@ -37,45 +37,45 @@ class RelayConfig(dict[str, Any]):
@property
def db(self) -> Path:
return Path(self['db']).expanduser().resolve()
return Path(self["db"]).expanduser().resolve()
@property
def actor(self) -> str:
return f'https://{self["host"]}/actor'
return f"https://{self['host']}/actor"
@property
def inbox(self) -> str:
return f'https://{self["host"]}/inbox'
return f"https://{self['host']}/inbox"
@property
def keyid(self) -> str:
return f'{self.actor}#main-key'
return f"{self.actor}#main-key"
@cached_property
def is_docker(self) -> bool:
return bool(os.environ.get('DOCKER_RUNNING'))
return bool(os.environ.get("DOCKER_RUNNING"))
def reset(self) -> None:
self.clear()
self.update({
'db': str(self._path.parent.joinpath(f'{self._path.stem}.jsonld')),
'listen': '0.0.0.0',
'port': 8080,
'note': 'Make a note about your instance here.',
'push_limit': 512,
'json_cache': 1024,
'timeout': 10,
'workers': 0,
'host': 'relay.example.com',
'whitelist_enabled': False,
'blocked_software': [],
'blocked_instances': [],
'whitelist': []
"db": str(self._path.parent.joinpath(f"{self._path.stem}.jsonld")),
"listen": "0.0.0.0",
"port": 8080,
"note": "Make a note about your instance here.",
"push_limit": 512,
"json_cache": 1024,
"timeout": 10,
"workers": 0,
"host": "relay.example.com",
"whitelist_enabled": False,
"blocked_software": [],
"blocked_instances": [],
"whitelist": []
})
@ -85,13 +85,13 @@ class RelayConfig(dict[str, Any]):
options = {}
try:
options['Loader'] = yaml.FullLoader
options["Loader"] = yaml.FullLoader
except AttributeError:
pass
try:
with self._path.open('r', encoding = 'UTF-8') as fd:
with self._path.open("r", encoding = "UTF-8") as fd:
config = yaml.load(fd, **options)
except FileNotFoundError:
@ -101,7 +101,7 @@ class RelayConfig(dict[str, Any]):
return
for key, value in config.items():
if key == 'ap':
if key == "ap":
for k, v in value.items():
if k not in self:
continue
@ -119,10 +119,10 @@ class RelayConfig(dict[str, Any]):
class RelayDatabase(dict[str, Any]):
def __init__(self, config: RelayConfig):
dict.__init__(self, {
'relay-list': {},
'private-key': None,
'follow-requests': {},
'version': 1
"relay-list": {},
"private-key": None,
"follow-requests": {},
"version": 1
})
self.config = config
@ -131,12 +131,12 @@ class RelayDatabase(dict[str, Any]):
@property
def hostnames(self) -> tuple[str]:
return tuple(self['relay-list'].keys())
return tuple(self["relay-list"].keys())
@property
def inboxes(self) -> tuple[dict[str, str]]:
return tuple(data['inbox'] for data in self['relay-list'].values())
return tuple(data["inbox"] for data in self["relay-list"].values())
def load(self) -> None:
@ -144,29 +144,29 @@ class RelayDatabase(dict[str, Any]):
with self.config.db.open() as fd:
data = json.load(fd)
self['version'] = data.get('version', None)
self['private-key'] = data.get('private-key')
self["version"] = data.get("version", None)
self["private-key"] = data.get("private-key")
if self['version'] is None:
self['version'] = 1
if self["version"] is None:
self["version"] = 1
if 'actorKeys' in data:
self['private-key'] = data['actorKeys']['privateKey']
if "actorKeys" in data:
self["private-key"] = data["actorKeys"]["privateKey"]
for item in data.get('relay-list', []):
for item in data.get("relay-list", []):
domain = urlparse(item).hostname
self['relay-list'][domain] = {
'domain': domain,
'inbox': item,
'followid': None
self["relay-list"][domain] = {
"domain": domain,
"inbox": item,
"followid": None
}
else:
self['relay-list'] = data.get('relay-list', {})
self["relay-list"] = data.get("relay-list", {})
for domain, instance in self['relay-list'].items():
if not instance.get('domain'):
instance['domain'] = domain
for domain, instance in self["relay-list"].items():
if not instance.get("domain"):
instance["domain"] = domain
except FileNotFoundError:
pass

View file

@ -16,7 +16,7 @@ if TYPE_CHECKING:
from typing import Self
if platform.system() == 'Windows':
if platform.system() == "Windows":
import multiprocessing
CORE_COUNT = multiprocessing.cpu_count()
@ -25,9 +25,9 @@ else:
DOCKER_VALUES = {
'listen': '0.0.0.0',
'port': 8080,
'sq_path': '/data/relay.sqlite3'
"listen": "0.0.0.0",
"port": 8080,
"sq_path": "/data/relay.sqlite3"
}
@ -37,26 +37,26 @@ class NOVALUE:
@dataclass(init = False)
class Config:
listen: str = '0.0.0.0'
listen: str = "0.0.0.0"
port: int = 8080
domain: str = 'relay.example.com'
domain: str = "relay.example.com"
workers: int = CORE_COUNT
db_type: str = 'sqlite'
ca_type: str = 'database'
sq_path: str = 'relay.sqlite3'
db_type: str = "sqlite"
ca_type: str = "database"
sq_path: str = "relay.sqlite3"
pg_host: str = '/var/run/postgresql'
pg_host: str = "/var/run/postgresql"
pg_port: int = 5432
pg_user: str = getpass.getuser()
pg_pass: str | None = None
pg_name: str = 'activityrelay'
pg_name: str = "activityrelay"
rd_host: str = 'localhost'
rd_host: str = "localhost"
rd_port: int = 6470
rd_user: str | None = None
rd_pass: str | None = None
rd_database: int = 0
rd_prefix: str = 'activityrelay'
rd_prefix: str = "activityrelay"
def __init__(self, path: Path | None = None, load: bool = False):
@ -116,17 +116,17 @@ class Config:
@property
def actor(self) -> str:
return f'https://{self.domain}/actor'
return f"https://{self.domain}/actor"
@property
def inbox(self) -> str:
return f'https://{self.domain}/inbox'
return f"https://{self.domain}/inbox"
@property
def keyid(self) -> str:
return f'{self.actor}#main-key'
return f"{self.actor}#main-key"
def load(self) -> None:
@ -134,43 +134,43 @@ class Config:
options = {}
try:
options['Loader'] = yaml.FullLoader
options["Loader"] = yaml.FullLoader
except AttributeError:
pass
with self.path.open('r', encoding = 'UTF-8') as fd:
with self.path.open("r", encoding = "UTF-8") as fd:
config = yaml.load(fd, **options)
if not config:
raise ValueError('Config is empty')
raise ValueError("Config is empty")
pgcfg = config.get('postgres', {})
rdcfg = config.get('redis', {})
pgcfg = config.get("postgres", {})
rdcfg = config.get("redis", {})
for key in type(self).KEYS():
if IS_DOCKER and key in {'listen', 'port', 'sq_path'}:
if IS_DOCKER and key in {"listen", "port", "sq_path"}:
self.set(key, DOCKER_VALUES[key])
continue
if key.startswith('pg'):
if key.startswith("pg"):
self.set(key, pgcfg.get(key[3:], NOVALUE))
continue
elif key.startswith('rd'):
elif key.startswith("rd"):
self.set(key, rdcfg.get(key[3:], NOVALUE))
continue
cfgkey = key
if key == 'db_type':
cfgkey = 'database_type'
if key == "db_type":
cfgkey = "database_type"
elif key == 'ca_type':
cfgkey = 'cache_type'
elif key == "ca_type":
cfgkey = "cache_type"
elif key == 'sq_path':
cfgkey = 'sqlite_path'
elif key == "sq_path":
cfgkey = "sqlite_path"
self.set(key, config.get(cfgkey, NOVALUE))
@ -186,32 +186,32 @@ class Config:
data: dict[str, Any] = {}
for key, value in asdict(self).items():
if key.startswith('pg_'):
if 'postgres' not in data:
data['postgres'] = {}
if key.startswith("pg_"):
if "postgres" not in data:
data["postgres"] = {}
data['postgres'][key[3:]] = value
data["postgres"][key[3:]] = value
continue
if key.startswith('rd_'):
if 'redis' not in data:
data['redis'] = {}
if key.startswith("rd_"):
if "redis" not in data:
data["redis"] = {}
data['redis'][key[3:]] = value
data["redis"][key[3:]] = value
continue
if key == 'db_type':
key = 'database_type'
if key == "db_type":
key = "database_type"
elif key == 'ca_type':
key = 'cache_type'
elif key == "ca_type":
key = "cache_type"
elif key == 'sq_path':
key = 'sqlite_path'
elif key == "sq_path":
key = "sqlite_path"
data[key] = value
with self.path.open('w', encoding = 'utf-8') as fd:
with self.path.open("w", encoding = "utf-8") as fd:
yaml.dump(data, fd, sort_keys = False)

View file

@ -16,17 +16,17 @@ sqlite3.register_adapter(Date, Date.timestamp)
def get_database(config: Config, migrate: bool = True) -> Database[Connection]:
options = {
'connection_class': Connection,
'pool_size': 5,
'tables': TABLES
"connection_class": Connection,
"pool_size": 5,
"tables": TABLES
}
db: Database[Connection]
if config.db_type == 'sqlite':
if config.db_type == "sqlite":
db = Database.sqlite(config.sqlite_path, **options)
elif config.db_type == 'postgres':
elif config.db_type == "postgres":
db = Database.postgresql(
config.pg_name,
config.pg_host,
@ -36,30 +36,30 @@ def get_database(config: Config, migrate: bool = True) -> Database[Connection]:
**options
)
db.load_prepared_statements(File.from_resource('relay', 'data/statements.sql'))
db.load_prepared_statements(File.from_resource("relay", "data/statements.sql"))
db.connect()
if not migrate:
return db
with db.session(True) as conn:
if 'config' not in conn.get_tables():
if "config" not in conn.get_tables():
logging.info("Creating database tables")
migrate_0(conn)
return db
if (schema_ver := conn.get_config('schema-version')) < ConfigData.DEFAULT('schema-version'):
if (schema_ver := conn.get_config("schema-version")) < ConfigData.DEFAULT("schema-version"):
logging.info("Migrating database from version '%i'", schema_ver)
for ver, func in VERSIONS.items():
if schema_ver < ver:
func(conn)
conn.put_config('schema-version', ver)
conn.put_config("schema-version", ver)
logging.info("Updated database to %i", ver)
if (privkey := conn.get_config('private-key')):
if (privkey := conn.get_config("private-key")):
conn.app.signer = privkey
logging.set_level(conn.get_config('log-level'))
logging.set_level(conn.get_config("log-level"))
return db

View file

@ -15,76 +15,76 @@ if TYPE_CHECKING:
THEMES = {
'default': {
'text': '#DDD',
'background': '#222',
'primary': '#D85',
'primary-hover': '#DA8',
'section-background': '#333',
'table-background': '#444',
'border': '#444',
'message-text': '#DDD',
'message-background': '#335',
'message-border': '#446',
'error-text': '#DDD',
'error-background': '#533',
'error-border': '#644'
"default": {
"text": "#DDD",
"background": "#222",
"primary": "#D85",
"primary-hover": "#DA8",
"section-background": "#333",
"table-background": "#444",
"border": "#444",
"message-text": "#DDD",
"message-background": "#335",
"message-border": "#446",
"error-text": "#DDD",
"error-background": "#533",
"error-border": "#644"
},
'pink': {
'text': '#DDD',
'background': '#222',
'primary': '#D69',
'primary-hover': '#D36',
'section-background': '#333',
'table-background': '#444',
'border': '#444',
'message-text': '#DDD',
'message-background': '#335',
'message-border': '#446',
'error-text': '#DDD',
'error-background': '#533',
'error-border': '#644'
"pink": {
"text": "#DDD",
"background": "#222",
"primary": "#D69",
"primary-hover": "#D36",
"section-background": "#333",
"table-background": "#444",
"border": "#444",
"message-text": "#DDD",
"message-background": "#335",
"message-border": "#446",
"error-text": "#DDD",
"error-background": "#533",
"error-border": "#644"
},
'blue': {
'text': '#DDD',
'background': '#222',
'primary': '#69D',
'primary-hover': '#36D',
'section-background': '#333',
'table-background': '#444',
'border': '#444',
'message-text': '#DDD',
'message-background': '#335',
'message-border': '#446',
'error-text': '#DDD',
'error-background': '#533',
'error-border': '#644'
"blue": {
"text": "#DDD",
"background": "#222",
"primary": "#69D",
"primary-hover": "#36D",
"section-background": "#333",
"table-background": "#444",
"border": "#444",
"message-text": "#DDD",
"message-background": "#335",
"message-border": "#446",
"error-text": "#DDD",
"error-background": "#533",
"error-border": "#644"
}
}
# serializer | deserializer
CONFIG_CONVERT: dict[str, tuple[Callable[[Any], str], Callable[[str], Any]]] = {
'str': (str, str),
'int': (str, int),
'bool': (str, convert_to_boolean),
'logging.LogLevel': (lambda x: x.name, logging.LogLevel.parse)
"str": (str, str),
"int": (str, int),
"bool": (str, convert_to_boolean),
"logging.LogLevel": (lambda x: x.name, logging.LogLevel.parse)
}
@dataclass()
class ConfigData:
schema_version: int = 20240625
private_key: str = ''
private_key: str = ""
approval_required: bool = False
log_level: logging.LogLevel = logging.LogLevel.INFO
name: str = 'ActivityRelay'
note: str = ''
theme: str = 'default'
name: str = "ActivityRelay"
note: str = ""
theme: str = "default"
whitelist_enabled: bool = False
def __getitem__(self, key: str) -> Any:
if (value := getattr(self, key.replace('-', '_'), None)) is None:
if (value := getattr(self, key.replace("-", "_"), None)) is None:
raise KeyError(key)
return value
@ -101,7 +101,7 @@ class ConfigData:
@staticmethod
def SYSTEM_KEYS() -> Sequence[str]:
return ('schema-version', 'schema_version', 'private-key', 'private_key')
return ("schema-version", "schema_version", "private-key", "private_key")
@classmethod
@ -111,12 +111,12 @@ class ConfigData:
@classmethod
def DEFAULT(cls: type[Self], key: str) -> str | int | bool:
return cls.FIELD(key.replace('-', '_')).default # type: ignore[return-value]
return cls.FIELD(key.replace("-", "_")).default # type: ignore[return-value]
@classmethod
def FIELD(cls: type[Self], key: str) -> Field[str | int | bool]:
parsed_key = key.replace('-', '_')
parsed_key = key.replace("-", "_")
for field in fields(cls):
if field.name == parsed_key:
@ -131,9 +131,9 @@ class ConfigData:
set_schema_version = False
for row in rows:
data.set(row['key'], row['value'])
data.set(row["key"], row["value"])
if row['key'] == 'schema-version':
if row["key"] == "schema-version":
set_schema_version = True
if not set_schema_version:
@ -161,4 +161,4 @@ class ConfigData:
def to_dict(self) -> dict[str, Any]:
return {key.replace('_', '-'): value for key, value in asdict(self).items()}
return {key.replace("_", "-"): value for key, value in asdict(self).items()}

View file

@ -23,16 +23,16 @@ if TYPE_CHECKING:
from ..application import Application
RELAY_SOFTWARE = [
'activityrelay', # https://git.pleroma.social/pleroma/relay
'activity-relay', # https://github.com/yukimochi/Activity-Relay
'aoderelay', # https://git.asonix.dog/asonix/relay
'feditools-relay' # https://git.ptzo.gdn/feditools/relay
"activityrelay", # https://git.pleroma.social/pleroma/relay
"activity-relay", # https://github.com/yukimochi/Activity-Relay
"aoderelay", # https://git.asonix.dog/asonix/relay
"feditools-relay" # https://git.ptzo.gdn/feditools/relay
]
class Connection(SqlConnection):
hasher = PasswordHasher(
encoding = 'utf-8'
encoding = "utf-8"
)
@property
@ -63,49 +63,49 @@ class Connection(SqlConnection):
def fix_timestamps(self) -> None:
for app in self.select('apps').all(schema.App):
data = {'created': app.created.timestamp(), 'accessed': app.accessed.timestamp()}
self.update('apps', data, client_id = app.client_id)
for app in self.select("apps").all(schema.App):
data = {"created": app.created.timestamp(), "accessed": app.accessed.timestamp()}
self.update("apps", data, client_id = app.client_id)
for item in self.select('cache'):
data = {'updated': Date.parse(item['updated']).timestamp()}
self.update('cache', data, id = item['id'])
for item in self.select("cache"):
data = {"updated": Date.parse(item["updated"]).timestamp()}
self.update("cache", data, id = item["id"])
for dban in self.select('domain_bans').all(schema.DomainBan):
data = {'created': dban.created.timestamp()}
self.update('domain_bans', data, domain = dban.domain)
for dban in self.select("domain_bans").all(schema.DomainBan):
data = {"created": dban.created.timestamp()}
self.update("domain_bans", data, domain = dban.domain)
for instance in self.select('inboxes').all(schema.Instance):
data = {'created': instance.created.timestamp()}
self.update('inboxes', data, domain = instance.domain)
for instance in self.select("inboxes").all(schema.Instance):
data = {"created": instance.created.timestamp()}
self.update("inboxes", data, domain = instance.domain)
for sban in self.select('software_bans').all(schema.SoftwareBan):
data = {'created': sban.created.timestamp()}
self.update('software_bans', data, name = sban.name)
for sban in self.select("software_bans").all(schema.SoftwareBan):
data = {"created": sban.created.timestamp()}
self.update("software_bans", data, name = sban.name)
for user in self.select('users').all(schema.User):
data = {'created': user.created.timestamp()}
self.update('users', data, username = user.username)
for user in self.select("users").all(schema.User):
data = {"created": user.created.timestamp()}
self.update("users", data, username = user.username)
for wlist in self.select('whitelist').all(schema.Whitelist):
data = {'created': wlist.created.timestamp()}
self.update('whitelist', data, domain = wlist.domain)
for wlist in self.select("whitelist").all(schema.Whitelist):
data = {"created": wlist.created.timestamp()}
self.update("whitelist", data, domain = wlist.domain)
def get_config(self, key: str) -> Any:
key = key.replace('_', '-')
key = key.replace("_", "-")
with self.run('get-config', {'key': key}) as cur:
with self.run("get-config", {"key": key}) as cur:
if (row := cur.one(Row)) is None:
return ConfigData.DEFAULT(key)
data = ConfigData()
data.set(row['key'], row['value'])
data.set(row["key"], row["value"])
return data.get(key)
def get_config_all(self) -> ConfigData:
rows = tuple(self.run('get-config-all', None).all(schema.Row))
rows = tuple(self.run("get-config-all", None).all(schema.Row))
return ConfigData.from_rows(rows)
@ -119,7 +119,7 @@ class Connection(SqlConnection):
case "log_level":
value = logging.LogLevel.parse(value)
logging.set_level(value)
self.app['workers'].set_log_level(value)
self.app["workers"].set_log_level(value)
case "approval_required":
value = convert_to_boolean(value)
@ -129,25 +129,25 @@ class Connection(SqlConnection):
case "theme":
if value not in THEMES:
raise ValueError(f'"{value}" is not a valid theme')
raise ValueError(f"\"{value}\" is not a valid theme")
data = ConfigData()
data.set(key, value)
params = {
'key': key,
'value': data.get(key, serialize = True),
'type': 'LogLevel' if field.type == 'logging.LogLevel' else field.type # type: ignore
"key": key,
"value": data.get(key, serialize = True),
"type": "LogLevel" if field.type == "logging.LogLevel" else field.type # type: ignore
}
with self.run('put-config', params):
with self.run("put-config", params):
pass
return data.get(key)
def get_inbox(self, value: str) -> schema.Instance | None:
with self.run('get-inbox', {'value': value}) as cur:
with self.run("get-inbox", {"value": value}) as cur:
return cur.one(schema.Instance)
@ -165,21 +165,21 @@ class Connection(SqlConnection):
accepted: bool = True) -> schema.Instance:
params: dict[str, Any] = {
'inbox': inbox,
'actor': actor,
'followid': followid,
'software': software,
'accepted': accepted
"inbox": inbox,
"actor": actor,
"followid": followid,
"software": software,
"accepted": accepted
}
if self.get_inbox(domain) is None:
if not inbox:
raise ValueError("Missing inbox")
params['domain'] = domain
params['created'] = datetime.now(tz = timezone.utc)
params["domain"] = domain
params["created"] = datetime.now(tz = timezone.utc)
with self.run('put-inbox', params) as cur:
with self.run("put-inbox", params) as cur:
if (row := cur.one(schema.Instance)) is None:
raise RuntimeError(f"Failed to insert instance: {domain}")
@ -189,7 +189,7 @@ class Connection(SqlConnection):
if value is None:
del params[key]
with self.update('inboxes', params, domain = domain) as cur:
with self.update("inboxes", params, domain = domain) as cur:
if (row := cur.one(schema.Instance)) is None:
raise RuntimeError(f"Failed to update instance: {domain}")
@ -197,20 +197,20 @@ class Connection(SqlConnection):
def del_inbox(self, value: str) -> bool:
with self.run('del-inbox', {'value': value}) as cur:
with self.run("del-inbox", {"value": value}) as cur:
if cur.row_count > 1:
raise ValueError('More than one row was modified')
raise ValueError("More than one row was modified")
return cur.row_count == 1
def get_request(self, domain: str) -> schema.Instance | None:
with self.run('get-request', {'domain': domain}) as cur:
with self.run("get-request", {"domain": domain}) as cur:
return cur.one(schema.Instance)
def get_requests(self) -> Iterator[schema.Instance]:
return self.execute('SELECT * FROM inboxes WHERE accepted = false').all(schema.Instance)
return self.execute("SELECT * FROM inboxes WHERE accepted = false").all(schema.Instance)
def put_request_response(self, domain: str, accepted: bool) -> schema.Instance:
@ -219,16 +219,16 @@ class Connection(SqlConnection):
if not accepted:
if not self.del_inbox(domain):
raise RuntimeError(f'Failed to delete request: {domain}')
raise RuntimeError(f"Failed to delete request: {domain}")
return instance
params = {
'domain': domain,
'accepted': accepted
"domain": domain,
"accepted": accepted
}
with self.run('put-inbox-accept', params) as cur:
with self.run("put-inbox-accept", params) as cur:
if (row := cur.one(schema.Instance)) is None:
raise RuntimeError(f"Failed to insert response for domain: {domain}")
@ -236,12 +236,12 @@ class Connection(SqlConnection):
def get_user(self, value: str) -> schema.User | None:
with self.run('get-user', {'value': value}) as cur:
with self.run("get-user", {"value": value}) as cur:
return cur.one(schema.User)
def get_user_by_token(self, token: str) -> schema.User | None:
with self.run('get-user-by-token', {'token': token}) as cur:
with self.run("get-user-by-token", {"token": token}) as cur:
return cur.one(schema.User)
@ -254,10 +254,10 @@ class Connection(SqlConnection):
data: dict[str, str | datetime | None] = {}
if password:
data['hash'] = self.hasher.hash(password)
data["hash"] = self.hasher.hash(password)
if handle:
data['handle'] = handle
data["handle"] = handle
stmt = Update("users", data)
stmt.set_where("username", username)
@ -269,16 +269,16 @@ class Connection(SqlConnection):
return row
if password is None:
raise ValueError('Password cannot be empty')
raise ValueError("Password cannot be empty")
data = {
'username': username,
'hash': self.hasher.hash(password),
'handle': handle,
'created': datetime.now(tz = timezone.utc)
"username": username,
"hash": self.hasher.hash(password),
"handle": handle,
"created": datetime.now(tz = timezone.utc)
}
with self.run('put-user', data) as cur:
with self.run("put-user", data) as cur:
if (row := cur.one(schema.User)) is None:
raise RuntimeError(f"Failed to insert user: {username}")
@ -289,10 +289,10 @@ class Connection(SqlConnection):
if (user := self.get_user(username)) is None:
raise KeyError(username)
with self.run('del-token-user', {'username': user.username}):
with self.run("del-token-user", {"username": user.username}):
pass
with self.run('del-user', {'username': user.username}):
with self.run("del-user", {"username": user.username}):
pass
@ -302,61 +302,61 @@ class Connection(SqlConnection):
token: str | None = None) -> schema.App | None:
params = {
'id': client_id,
'secret': client_secret
"id": client_id,
"secret": client_secret
}
if token is not None:
command = 'get-app-with-token'
params['token'] = token
command = "get-app-with-token"
params["token"] = token
else:
command = 'get-app'
command = "get-app"
with self.run(command, params) as cur:
return cur.one(schema.App)
def get_app_by_token(self, token: str) -> schema.App | None:
with self.run('get-app-by-token', {'token': token}) as cur:
with self.run("get-app-by-token", {"token": token}) as cur:
return cur.one(schema.App)
def put_app(self, name: str, redirect_uri: str, website: str | None = None) -> schema.App:
params = {
'name': name,
'redirect_uri': redirect_uri,
'website': website,
'client_id': secrets.token_hex(20),
'client_secret': secrets.token_hex(20),
'created': Date.new_utc(),
'accessed': Date.new_utc()
"name": name,
"redirect_uri": redirect_uri,
"website": website,
"client_id": secrets.token_hex(20),
"client_secret": secrets.token_hex(20),
"created": Date.new_utc(),
"accessed": Date.new_utc()
}
with self.insert('apps', params) as cur:
with self.insert("apps", params) as cur:
if (row := cur.one(schema.App)) is None:
raise RuntimeError(f'Failed to insert app: {name}')
raise RuntimeError(f"Failed to insert app: {name}")
return row
def put_app_login(self, user: schema.User) -> schema.App:
params = {
'name': 'Web',
'redirect_uri': 'urn:ietf:wg:oauth:2.0:oob',
'website': None,
'user': user.username,
'client_id': secrets.token_hex(20),
'client_secret': secrets.token_hex(20),
'auth_code': None,
'token': secrets.token_hex(20),
'created': Date.new_utc(),
'accessed': Date.new_utc()
"name": "Web",
"redirect_uri": "urn:ietf:wg:oauth:2.0:oob",
"website": None,
"user": user.username,
"client_id": secrets.token_hex(20),
"client_secret": secrets.token_hex(20),
"auth_code": None,
"token": secrets.token_hex(20),
"created": Date.new_utc(),
"accessed": Date.new_utc()
}
with self.insert('apps', params) as cur:
with self.insert("apps", params) as cur:
if (row := cur.one(schema.App)) is None:
raise RuntimeError(f'Failed to create app for "{user.username}"')
raise RuntimeError(f"Failed to create app for \"{user.username}\"")
return row
@ -365,52 +365,52 @@ class Connection(SqlConnection):
data: dict[str, str | None] = {}
if user is not None:
data['user'] = user.username
data["user"] = user.username
if set_auth:
data['auth_code'] = secrets.token_hex(20)
data["auth_code"] = secrets.token_hex(20)
else:
data['token'] = secrets.token_hex(20)
data['auth_code'] = None
data["token"] = secrets.token_hex(20)
data["auth_code"] = None
params = {
'client_id': app.client_id,
'client_secret': app.client_secret
"client_id": app.client_id,
"client_secret": app.client_secret
}
with self.update('apps', data, **params) as cur: # type: ignore[arg-type]
with self.update("apps", data, **params) as cur: # type: ignore[arg-type]
if (row := cur.one(schema.App)) is None:
raise RuntimeError('Failed to update row')
raise RuntimeError("Failed to update row")
return row
def del_app(self, client_id: str, client_secret: str, token: str | None = None) -> bool:
params = {
'id': client_id,
'secret': client_secret
"id": client_id,
"secret": client_secret
}
if token is not None:
command = 'del-app-with-token'
params['token'] = token
command = "del-app-with-token"
params["token"] = token
else:
command = 'del-app'
command = "del-app"
with self.run(command, params) as cur:
if cur.row_count > 1:
raise RuntimeError('More than 1 row was deleted')
raise RuntimeError("More than 1 row was deleted")
return cur.row_count == 0
def get_domain_ban(self, domain: str) -> schema.DomainBan | None:
if domain.startswith('http'):
if domain.startswith("http"):
domain = urlparse(domain).netloc
with self.run('get-domain-ban', {'domain': domain}) as cur:
with self.run("get-domain-ban", {"domain": domain}) as cur:
return cur.one(schema.DomainBan)
@ -424,13 +424,13 @@ class Connection(SqlConnection):
note: str | None = None) -> schema.DomainBan:
params = {
'domain': domain,
'reason': reason,
'note': note,
'created': datetime.now(tz = timezone.utc)
"domain": domain,
"reason": reason,
"note": note,
"created": datetime.now(tz = timezone.utc)
}
with self.run('put-domain-ban', params) as cur:
with self.run("put-domain-ban", params) as cur:
if (row := cur.one(schema.DomainBan)) is None:
raise RuntimeError(f"Failed to insert domain ban: {domain}")
@ -443,22 +443,22 @@ class Connection(SqlConnection):
note: str | None = None) -> schema.DomainBan:
if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified')
raise ValueError("\"reason\" and/or \"note\" must be specified")
params = {}
if reason is not None:
params['reason'] = reason
params["reason"] = reason
if note is not None:
params['note'] = note
params["note"] = note
statement = Update('domain_bans', params)
statement = Update("domain_bans", params)
statement.set_where("domain", domain)
with self.query(statement) as cur:
if cur.row_count > 1:
raise ValueError('More than one row was modified')
raise ValueError("More than one row was modified")
if (row := cur.one(schema.DomainBan)) is None:
raise RuntimeError(f"Failed to update domain ban: {domain}")
@ -467,20 +467,20 @@ class Connection(SqlConnection):
def del_domain_ban(self, domain: str) -> bool:
with self.run('del-domain-ban', {'domain': domain}) as cur:
with self.run("del-domain-ban", {"domain": domain}) as cur:
if cur.row_count > 1:
raise ValueError('More than one row was modified')
raise ValueError("More than one row was modified")
return cur.row_count == 1
def get_software_ban(self, name: str) -> schema.SoftwareBan | None:
with self.run('get-software-ban', {'name': name}) as cur:
with self.run("get-software-ban", {"name": name}) as cur:
return cur.one(schema.SoftwareBan)
def get_software_bans(self) -> Iterator[schema.SoftwareBan,]:
return self.execute('SELECT * FROM software_bans').all(schema.SoftwareBan)
return self.execute("SELECT * FROM software_bans").all(schema.SoftwareBan)
def put_software_ban(self,
@ -489,15 +489,15 @@ class Connection(SqlConnection):
note: str | None = None) -> schema.SoftwareBan:
params = {
'name': name,
'reason': reason,
'note': note,
'created': datetime.now(tz = timezone.utc)
"name": name,
"reason": reason,
"note": note,
"created": datetime.now(tz = timezone.utc)
}
with self.run('put-software-ban', params) as cur:
with self.run("put-software-ban", params) as cur:
if (row := cur.one(schema.SoftwareBan)) is None:
raise RuntimeError(f'Failed to insert software ban: {name}')
raise RuntimeError(f"Failed to insert software ban: {name}")
return row
@ -508,39 +508,39 @@ class Connection(SqlConnection):
note: str | None = None) -> schema.SoftwareBan:
if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified')
raise ValueError("\"reason\" and/or \"note\" must be specified")
params = {}
if reason is not None:
params['reason'] = reason
params["reason"] = reason
if note is not None:
params['note'] = note
params["note"] = note
statement = Update('software_bans', params)
statement = Update("software_bans", params)
statement.set_where("name", name)
with self.query(statement) as cur:
if cur.row_count > 1:
raise ValueError('More than one row was modified')
raise ValueError("More than one row was modified")
if (row := cur.one(schema.SoftwareBan)) is None:
raise RuntimeError(f'Failed to update software ban: {name}')
raise RuntimeError(f"Failed to update software ban: {name}")
return row
def del_software_ban(self, name: str) -> bool:
with self.run('del-software-ban', {'name': name}) as cur:
with self.run("del-software-ban", {"name": name}) as cur:
if cur.row_count > 1:
raise ValueError('More than one row was modified')
raise ValueError("More than one row was modified")
return cur.row_count == 1
def get_domain_whitelist(self, domain: str) -> schema.Whitelist | None:
with self.run('get-domain-whitelist', {'domain': domain}) as cur:
with self.run("get-domain-whitelist", {"domain": domain}) as cur:
return cur.one()
@ -550,20 +550,20 @@ class Connection(SqlConnection):
def put_domain_whitelist(self, domain: str) -> schema.Whitelist:
params = {
'domain': domain,
'created': datetime.now(tz = timezone.utc)
"domain": domain,
"created": datetime.now(tz = timezone.utc)
}
with self.run('put-domain-whitelist', params) as cur:
with self.run("put-domain-whitelist", params) as cur:
if (row := cur.one(schema.Whitelist)) is None:
raise RuntimeError(f'Failed to insert whitelisted domain: {domain}')
raise RuntimeError(f"Failed to insert whitelisted domain: {domain}")
return row
def del_domain_whitelist(self, domain: str) -> bool:
with self.run('del-domain-whitelist', {'domain': domain}) as cur:
with self.run("del-domain-whitelist", {"domain": domain}) as cur:
if cur.row_count > 1:
raise ValueError('More than one row was modified')
raise ValueError("More than one row was modified")
return cur.row_count == 1

View file

@ -32,113 +32,113 @@ def deserialize_timestamp(value: Any) -> Date:
@TABLES.add_row
class Config(Row):
key: Column[str] = Column('key', 'text', primary_key = True, unique = True, nullable = False)
value: Column[str] = Column('value', 'text')
type: Column[str] = Column('type', 'text', default = 'str')
key: Column[str] = Column("key", "text", primary_key = True, unique = True, nullable = False)
value: Column[str] = Column("value", "text")
type: Column[str] = Column("type", "text", default = "str")
@TABLES.add_row
class Instance(Row):
table_name: str = 'inboxes'
table_name: str = "inboxes"
domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = False)
actor: Column[str] = Column('actor', 'text', unique = True)
inbox: Column[str] = Column('inbox', 'text', unique = True, nullable = False)
followid: Column[str] = Column('followid', 'text')
software: Column[str] = Column('software', 'text')
accepted: Column[Date] = Column('accepted', 'boolean')
"domain", "text", primary_key = True, unique = True, nullable = False)
actor: Column[str] = Column("actor", "text", unique = True)
inbox: Column[str] = Column("inbox", "text", unique = True, nullable = False)
followid: Column[str] = Column("followid", "text")
software: Column[str] = Column("software", "text")
accepted: Column[Date] = Column("accepted", "boolean")
created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
"created", "timestamp", nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row
class Whitelist(Row):
domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = True)
"domain", "text", primary_key = True, unique = True, nullable = True)
created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
"created", "timestamp", nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row
class DomainBan(Row):
table_name: str = 'domain_bans'
table_name: str = "domain_bans"
domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text')
"domain", "text", primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column("reason", "text")
note: Column[str] = Column("note", "text")
created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
"created", "timestamp", nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row
class SoftwareBan(Row):
table_name: str = 'software_bans'
table_name: str = "software_bans"
name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text')
name: Column[str] = Column("name", "text", primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column("reason", "text")
note: Column[str] = Column("note", "text")
created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
"created", "timestamp", nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row
class User(Row):
table_name: str = 'users'
table_name: str = "users"
username: Column[str] = Column(
'username', 'text', primary_key = True, unique = True, nullable = False)
hash: Column[str] = Column('hash', 'text', nullable = False)
handle: Column[str] = Column('handle', 'text')
"username", "text", primary_key = True, unique = True, nullable = False)
hash: Column[str] = Column("hash", "text", nullable = False)
handle: Column[str] = Column("handle", "text")
created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
"created", "timestamp", nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row
class App(Row):
table_name: str = 'apps'
table_name: str = "apps"
client_id: Column[str] = Column(
'client_id', 'text', primary_key = True, unique = True, nullable = False)
client_secret: Column[str] = Column('client_secret', 'text', nullable = False)
name: Column[str] = Column('name', 'text')
website: Column[str] = Column('website', 'text')
redirect_uri: Column[str] = Column('redirect_uri', 'text', nullable = False)
token: Column[str | None] = Column('token', 'text')
auth_code: Column[str | None] = Column('auth_code', 'text')
user: Column[str | None] = Column('user', 'text')
"client_id", "text", primary_key = True, unique = True, nullable = False)
client_secret: Column[str] = Column("client_secret", "text", nullable = False)
name: Column[str] = Column("name", "text")
website: Column[str] = Column("website", "text")
redirect_uri: Column[str] = Column("redirect_uri", "text", nullable = False)
token: Column[str | None] = Column("token", "text")
auth_code: Column[str | None] = Column("auth_code", "text")
user: Column[str | None] = Column("user", "text")
created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
"created", "timestamp", nullable = False, deserializer = deserialize_timestamp)
accessed: Column[Date] = Column(
'accessed', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
"accessed", "timestamp", nullable = False, deserializer = deserialize_timestamp)
def get_api_data(self, include_token: bool = False) -> dict[str, Any]:
data = deepcopy(self)
data.pop('user')
data.pop('auth_code')
data.pop("user")
data.pop("auth_code")
if not include_token:
data.pop('token')
data.pop("token")
return data
def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]:
ver = int(func.__name__.replace('migrate_', ''))
ver = int(func.__name__.replace("migrate_", ""))
VERSIONS[ver] = func
return func
def migrate_0(conn: Connection) -> None:
conn.create_tables()
conn.put_config('schema-version', ConfigData.DEFAULT('schema-version'))
conn.put_config("schema-version", ConfigData.DEFAULT("schema-version"))
@migration
@ -148,11 +148,11 @@ def migrate_20240206(conn: Connection) -> None:
@migration
def migrate_20240310(conn: Connection) -> None:
conn.execute('ALTER TABLE "inboxes" ADD COLUMN "accepted" BOOLEAN').close()
conn.execute('UPDATE "inboxes" SET "accepted" = true').close()
conn.execute("ALTER TABLE \"inboxes\" ADD COLUMN \"accepted\" BOOLEAN").close()
conn.execute("UPDATE \"inboxes\" SET \"accepted\" = true").close()
@migration
def migrate_20240625(conn: Connection) -> None:
conn.create_tables()
conn.execute('DROP TABLE "tokens"').close()
conn.execute("DROP TABLE \"tokens\"").close()

View file

@ -17,6 +17,13 @@ if TYPE_CHECKING:
from .application import Application
T = TypeVar("T", bound = JsonBase[Any])
HEADERS = {
"Accept": f"{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9",
"User-Agent": f"ActivityRelay/{__version__}"
}
SUPPORTS_HS2019 = {
'friendica',
'gotosocial',
@ -32,12 +39,6 @@ SUPPORTS_HS2019 = {
'sharkey'
}
T = TypeVar('T', bound = JsonBase[Any])
HEADERS = {
'Accept': f'{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9',
'User-Agent': f'ActivityRelay/{__version__}'
}
class HttpClient:
def __init__(self, limit: int = 100, timeout: int = 10):
@ -106,25 +107,25 @@ class HttpClient:
old_algo: bool) -> str | None:
if not self._session:
raise RuntimeError('Client not open')
raise RuntimeError("Client not open")
url = url.split("#", 1)[0]
if not force:
try:
if not (item := self.cache.get('request', url)).older_than(48):
if not (item := self.cache.get("request", url)).older_than(48):
return item.value # type: ignore [no-any-return]
except KeyError:
logging.verbose('No cached data for url: %s', url)
logging.verbose("No cached data for url: %s", url)
headers = {}
if sign_headers:
algo = AlgorithmType.RSASHA256 if old_algo else AlgorithmType.HS2019
headers = self.signer.sign_headers('GET', url, algorithm = algo)
headers = self.signer.sign_headers("GET", url, algorithm = algo)
logging.debug('Fetching resource: %s', url)
logging.debug("Fetching resource: %s", url)
async with self._session.get(url, headers = headers) as resp:
# Not expecting a response with 202s, so just return
@ -142,7 +143,7 @@ class HttpClient:
raise HttpError(resp.status, error)
self.cache.set('request', url, data, 'str')
self.cache.set("request", url, data, "str")
return data
@ -172,13 +173,13 @@ class HttpClient:
old_algo: bool = True) -> T | str | None:
if cls is not None and not issubclass(cls, JsonBase):
raise TypeError('cls must be a sub-class of "blib.JsonBase"')
raise TypeError("cls must be a sub-class of \"blib.JsonBase\"")
data = await self._get(url, sign_headers, force, old_algo)
if cls is not None:
if data is None:
# this shouldn't actually get raised, but keeping just in case
# this shouldn"t actually get raised, but keeping just in case
raise EmptyBodyError(f"GET {url}")
return cls.parse(data)
@ -188,7 +189,7 @@ class HttpClient:
async def post(self, url: str, data: Message | bytes, instance: Instance | None = None) -> None:
if not self._session:
raise RuntimeError('Client not open')
raise RuntimeError("Client not open")
# akkoma and pleroma do not support HS2019 and other software still needs to be tested
if instance is not None and instance.software in SUPPORTS_HS2019:
@ -210,14 +211,14 @@ class HttpClient:
mtype = message.type.value if isinstance(message.type, ObjectType) else message.type
headers = self.signer.sign_headers(
'POST',
"POST",
url,
body,
headers = {'Content-Type': 'application/activity+json'},
headers = {"Content-Type": "application/activity+json"},
algorithm = algorithm
)
logging.verbose('Sending "%s" to %s', mtype, url)
logging.verbose("Sending \"%s\" to %s", mtype, url)
async with self._session.post(url, headers = headers, data = body) as resp:
if resp.status not in (200, 202):
@ -231,10 +232,10 @@ class HttpClient:
async def fetch_nodeinfo(self, domain: str, force: bool = False) -> Nodeinfo:
nodeinfo_url = None
wk_nodeinfo = await self.get(
f'https://{domain}/.well-known/nodeinfo', False, WellKnownNodeinfo, force
f"https://{domain}/.well-known/nodeinfo", False, WellKnownNodeinfo, force
)
for version in ('20', '21'):
for version in ("20", "21"):
try:
nodeinfo_url = wk_nodeinfo.get_url(version)
@ -242,7 +243,7 @@ class HttpClient:
pass
if nodeinfo_url is None:
raise ValueError(f'Failed to fetch nodeinfo url for {domain}')
raise ValueError(f"Failed to fetch nodeinfo url for {domain}")
return await self.get(nodeinfo_url, False, Nodeinfo, force)

View file

@ -54,7 +54,7 @@ class LogLevel(IntEnum):
except ValueError:
pass
raise AttributeError(f'Invalid enum property for {cls.__name__}: {data}')
raise AttributeError(f"Invalid enum property for {cls.__name__}: {data}")
def get_level() -> LogLevel:
@ -80,7 +80,7 @@ critical: LoggingMethod = logging.critical
try:
env_log_file: Path | None = Path(os.environ['LOG_FILE']).expanduser().resolve()
env_log_file: Path | None = Path(os.environ["LOG_FILE"]).expanduser().resolve()
except KeyError:
env_log_file = None
@ -90,16 +90,16 @@ handlers: list[Any] = [logging.StreamHandler()]
if env_log_file:
handlers.append(logging.FileHandler(env_log_file))
if os.environ.get('IS_SYSTEMD'):
logging_format = '%(levelname)s: %(message)s'
if os.environ.get("IS_SYSTEMD"):
logging_format = "%(levelname)s: %(message)s"
else:
logging_format = '[%(asctime)s] %(levelname)s: %(message)s'
logging_format = "[%(asctime)s] %(levelname)s: %(message)s"
logging.addLevelName(LogLevel.VERBOSE, 'VERBOSE')
logging.addLevelName(LogLevel.VERBOSE, "VERBOSE")
logging.basicConfig(
level = LogLevel.INFO,
format = logging_format,
datefmt = '%Y-%m-%d %H:%M:%S',
datefmt = "%Y-%m-%d %H:%M:%S",
handlers = handlers
)

File diff suppressed because it is too large Load diff

View file

@ -16,63 +16,63 @@ if TYPE_CHECKING:
from .application import Application
T = TypeVar('T')
ResponseType = TypedDict('ResponseType', {
'status': int,
'headers': dict[str, Any] | None,
'content_type': str,
'body': bytes | None,
'text': str | None
T = TypeVar("T")
ResponseType = TypedDict("ResponseType", {
"status": int,
"headers": dict[str, Any] | None,
"content_type": str,
"body": bytes | None,
"text": str | None
})
IS_DOCKER = bool(os.environ.get('DOCKER_RUNNING'))
IS_WINDOWS = platform.system() == 'Windows'
IS_DOCKER = bool(os.environ.get("DOCKER_RUNNING"))
IS_WINDOWS = platform.system() == "Windows"
MIMETYPES = {
'activity': 'application/activity+json',
'css': 'text/css',
'html': 'text/html',
'json': 'application/json',
'text': 'text/plain',
'webmanifest': 'application/manifest+json'
"activity": "application/activity+json",
"css": "text/css",
"html": "text/html",
"json": "application/json",
"text": "text/plain",
"webmanifest": "application/manifest+json"
}
ACTOR_FORMATS = {
'mastodon': 'https://{domain}/actor',
'akkoma': 'https://{domain}/relay',
'pleroma': 'https://{domain}/relay'
"mastodon": "https://{domain}/actor",
"akkoma": "https://{domain}/relay",
"pleroma": "https://{domain}/relay"
}
SOFTWARE = (
'mastodon',
'akkoma',
'pleroma',
'misskey',
'friendica',
'hubzilla',
'firefish',
'gotosocial'
"mastodon",
"akkoma",
"pleroma",
"misskey",
"friendica",
"hubzilla",
"firefish",
"gotosocial"
)
JSON_PATHS: tuple[str, ...] = (
'/api/v1',
'/actor',
'/inbox',
'/outbox',
'/following',
'/followers',
'/.well-known',
'/nodeinfo',
'/oauth/token',
'/oauth/revoke'
"/api/v1",
"/actor",
"/inbox",
"/outbox",
"/following",
"/followers",
"/.well-known",
"/nodeinfo",
"/oauth/token",
"/oauth/revoke"
)
TOKEN_PATHS: tuple[str, ...] = (
'/logout',
'/admin',
'/api',
'/oauth/authorize',
'/oauth/revoke'
"/logout",
"/admin",
"/api",
"/oauth/authorize",
"/oauth/revoke"
)
@ -80,7 +80,7 @@ def get_app() -> Application:
from .application import Application
if not Application.DEFAULT:
raise ValueError('No default application set')
raise ValueError("No default application set")
return Application.DEFAULT
@ -136,23 +136,23 @@ class Message(aputils.Message):
approves: bool = False) -> Self:
return cls.new(aputils.ObjectType.APPLICATION, {
'id': f'https://{host}/actor',
'preferredUsername': 'relay',
'name': 'ActivityRelay',
'summary': description or 'ActivityRelay bot',
'manuallyApprovesFollowers': approves,
'followers': f'https://{host}/followers',
'following': f'https://{host}/following',
'inbox': f'https://{host}/inbox',
'outbox': f'https://{host}/outbox',
'url': f'https://{host}/',
'endpoints': {
'sharedInbox': f'https://{host}/inbox'
"id": f"https://{host}/actor",
"preferredUsername": "relay",
"name": "ActivityRelay",
"summary": description or "ActivityRelay bot",
"manuallyApprovesFollowers": approves,
"followers": f"https://{host}/followers",
"following": f"https://{host}/following",
"inbox": f"https://{host}/inbox",
"outbox": f"https://{host}/outbox",
"url": f"https://{host}/",
"endpoints": {
"sharedInbox": f"https://{host}/inbox"
},
'publicKey': {
'id': f'https://{host}/actor#main-key',
'owner': f'https://{host}/actor',
'publicKeyPem': pubkey
"publicKey": {
"id": f"https://{host}/actor#main-key",
"owner": f"https://{host}/actor",
"publicKeyPem": pubkey
}
})
@ -160,44 +160,44 @@ class Message(aputils.Message):
@classmethod
def new_announce(cls: type[Self], host: str, obj: str | dict[str, Any]) -> Self:
return cls.new(aputils.ObjectType.ANNOUNCE, {
'id': f'https://{host}/activities/{uuid4()}',
'to': [f'https://{host}/followers'],
'actor': f'https://{host}/actor',
'object': obj
"id": f"https://{host}/activities/{uuid4()}",
"to": [f"https://{host}/followers"],
"actor": f"https://{host}/actor",
"object": obj
})
@classmethod
def new_follow(cls: type[Self], host: str, actor: str) -> Self:
return cls.new(aputils.ObjectType.FOLLOW, {
'id': f'https://{host}/activities/{uuid4()}',
'to': [actor],
'object': actor,
'actor': f'https://{host}/actor'
"id": f"https://{host}/activities/{uuid4()}",
"to": [actor],
"object": actor,
"actor": f"https://{host}/actor"
})
@classmethod
def new_unfollow(cls: type[Self], host: str, actor: str, follow: dict[str, str]) -> Self:
return cls.new(aputils.ObjectType.UNDO, {
'id': f'https://{host}/activities/{uuid4()}',
'to': [actor],
'actor': f'https://{host}/actor',
'object': follow
"id": f"https://{host}/activities/{uuid4()}",
"to": [actor],
"actor": f"https://{host}/actor",
"object": follow
})
@classmethod
def new_response(cls: type[Self], host: str, actor: str, followid: str, accept: bool) -> Self:
return cls.new(aputils.ObjectType.ACCEPT if accept else aputils.ObjectType.REJECT, {
'id': f'https://{host}/activities/{uuid4()}',
'to': [actor],
'actor': f'https://{host}/actor',
'object': {
'id': followid,
'type': 'Follow',
'object': f'https://{host}/actor',
'actor': actor
"id": f"https://{host}/activities/{uuid4()}",
"to": [actor],
"actor": f"https://{host}/actor",
"object": {
"id": followid,
"type": "Follow",
"object": f"https://{host}/actor",
"actor": actor
}
})
@ -210,35 +210,35 @@ class Response(AiohttpResponse):
@classmethod
def new(cls: type[Self],
body: str | bytes | dict[str, Any] | Sequence[Any] = '',
body: str | bytes | dict[str, Any] | Sequence[Any] = "",
status: int = 200,
headers: dict[str, str] | None = None,
ctype: str = 'text') -> Self:
ctype: str = "text") -> Self:
kwargs: ResponseType = {
'status': status,
'headers': headers,
'content_type': MIMETYPES[ctype],
'body': None,
'text': None
"status": status,
"headers": headers,
"content_type": MIMETYPES[ctype],
"body": None,
"text": None
}
if isinstance(body, str):
kwargs['text'] = body
kwargs["text"] = body
elif isinstance(body, bytes):
kwargs['body'] = body
kwargs["body"] = body
elif isinstance(body, (dict, Sequence)):
kwargs['text'] = json.dumps(body, cls = JsonEncoder)
kwargs["text"] = json.dumps(body, cls = JsonEncoder)
return cls(**kwargs)
@classmethod
def new_redir(cls: type[Self], path: str, status: int = 307) -> Self:
body = f'Redirect to <a href="{path}">{path}</a>'
return cls.new(body, status, {'Location': path}, ctype = 'html')
body = f"Redirect to <a href=\"{path}\">{path}</a>"
return cls.new(body, status, {"Location": path}, ctype = "html")
@classmethod
@ -256,9 +256,9 @@ class Response(AiohttpResponse):
@property
def location(self) -> str:
return self.headers.get('Location', '')
return self.headers.get("Location", "")
@location.setter
def location(self, value: str) -> None:
self.headers['Location'] = value
self.headers["Location"] = value

View file

@ -12,11 +12,11 @@ if typing.TYPE_CHECKING:
def actor_type_check(actor: Message, software: str | None) -> bool:
if actor.type == 'Application':
if actor.type == "Application":
return True
# akkoma (< 3.6.0) and pleroma use Person for the actor type
if software in {'akkoma', 'pleroma'} and actor.id == f'https://{actor.domain}/relay':
if software in {"akkoma", "pleroma"} and actor.id == f"https://{actor.domain}/relay":
return True
return False
@ -24,38 +24,38 @@ def actor_type_check(actor: Message, software: str | None) -> bool:
async def handle_relay(app: Application, data: InboxData, conn: Connection) -> None:
try:
app.cache.get('handle-relay', data.message.object_id)
logging.verbose('already relayed %s', data.message.object_id)
app.cache.get("handle-relay", data.message.object_id)
logging.verbose("already relayed %s", data.message.object_id)
return
except KeyError:
pass
message = Message.new_announce(app.config.domain, data.message.object_id)
logging.debug('>> relay: %s', message)
logging.debug(">> relay: %s", message)
for instance in conn.distill_inboxes(data.message):
app.push_message(instance.inbox, message, instance)
app.cache.set('handle-relay', data.message.object_id, message.id, 'str')
app.cache.set("handle-relay", data.message.object_id, message.id, "str")
async def handle_forward(app: Application, data: InboxData, conn: Connection) -> None:
try:
app.cache.get('handle-relay', data.message.id)
logging.verbose('already forwarded %s', data.message.id)
app.cache.get("handle-relay", data.message.id)
logging.verbose("already forwarded %s", data.message.id)
return
except KeyError:
pass
message = Message.new_announce(app.config.domain, data.message)
logging.debug('>> forward: %s', message)
logging.debug(">> forward: %s", message)
for instance in conn.distill_inboxes(data.message):
app.push_message(instance.inbox, data.message, instance)
app.cache.set('handle-relay', data.message.id, message.id, 'str')
app.cache.set("handle-relay", data.message.id, message.id, "str")
async def handle_follow(app: Application, data: InboxData, conn: Connection) -> None:
@ -65,7 +65,7 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) ->
# reject if software used by actor is banned
if software and conn.get_software_ban(software):
logging.verbose('Rejected banned actor: %s', data.actor.id)
logging.verbose("Rejected banned actor: %s", data.actor.id)
app.push_message(
data.actor.shared_inbox,
@ -79,7 +79,7 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) ->
)
logging.verbose(
'Rejected follow from actor for using specific software: actor=%s, software=%s',
"Rejected follow from actor for using specific software: actor=%s, software=%s",
data.actor.id,
software
)
@ -88,7 +88,7 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) ->
# reject if the actor is not an instance actor
if actor_type_check(data.actor, software):
logging.verbose('Non-application actor tried to follow: %s', data.actor.id)
logging.verbose("Non-application actor tried to follow: %s", data.actor.id)
app.push_message(
data.actor.shared_inbox,
@ -106,7 +106,7 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) ->
if not conn.get_domain_whitelist(data.actor.domain):
# add request if approval-required is enabled
if config.approval_required:
logging.verbose('New follow request fromm actor: %s', data.actor.id)
logging.verbose("New follow request fromm actor: %s", data.actor.id)
with conn.transaction():
data.instance = conn.put_inbox(
@ -120,9 +120,9 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) ->
return
# reject if the actor isn't whitelisted while the whiltelist is enabled
# reject if the actor isn"t whitelisted while the whiltelist is enabled
if config.whitelist_enabled:
logging.verbose('Rejected actor for not being in the whitelist: %s', data.actor.id)
logging.verbose("Rejected actor for not being in the whitelist: %s", data.actor.id)
app.push_message(
data.actor.shared_inbox,
@ -160,7 +160,7 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) ->
# Are Akkoma and Pleroma the only two that expect a follow back?
# Ignoring only Mastodon for now
if software != 'mastodon':
if software != "mastodon":
app.push_message(
data.actor.shared_inbox,
Message.new_follow(
@ -172,8 +172,8 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) ->
async def handle_undo(app: Application, data: InboxData, conn: Connection) -> None:
if data.message.object['type'] != 'Follow':
# forwarding deletes does not work, so don't bother
if data.message.object["type"] != "Follow":
# forwarding deletes does not work, so don"t bother
# await handle_forward(app, data, conn)
return
@ -187,7 +187,7 @@ async def handle_undo(app: Application, data: InboxData, conn: Connection) -> No
with conn.transaction():
if not conn.del_inbox(data.actor.id):
logging.verbose(
'Failed to delete "%s" with follow ID "%s"',
"Failed to delete \"%s\" with follow ID \"%s\"",
data.actor.id,
data.message.object_id
)
@ -204,19 +204,19 @@ async def handle_undo(app: Application, data: InboxData, conn: Connection) -> No
processors = {
'Announce': handle_relay,
'Create': handle_relay,
'Delete': handle_forward,
'Follow': handle_follow,
'Undo': handle_undo,
'Update': handle_forward,
"Announce": handle_relay,
"Create": handle_relay,
"Delete": handle_forward,
"Follow": handle_follow,
"Undo": handle_undo,
"Update": handle_forward,
}
async def run_processor(data: InboxData) -> None:
if data.message.type not in processors:
logging.verbose(
'Message type "%s" from actor cannot be handled: %s',
"Message type \"%s\" from actor cannot be handled: %s",
data.message.type,
data.actor.id
)
@ -242,5 +242,5 @@ async def run_processor(data: InboxData) -> None:
actor = data.actor.id
)
logging.verbose('New "%s" from actor: %s', data.message.type, data.actor.id)
logging.verbose("New \"%s\" from actor: %s", data.message.type, data.actor.id)
await processors[data.message.type](app, data, conn)

View file

@ -34,8 +34,8 @@ class Template(Environment):
MarkdownExtension
],
loader = FileSystemLoader([
File.from_resource('relay', 'frontend'),
app.config.path.parent.joinpath('template')
File.from_resource("relay", "frontend"),
app.config.path.parent.joinpath("template")
])
)
@ -47,10 +47,10 @@ class Template(Environment):
config = conn.get_config_all()
new_context = {
'request': request,
'domain': self.app.config.domain,
'version': __version__,
'config': config,
"request": request,
"domain": self.app.config.domain,
"version": __version__,
"config": config,
**(context or {})
}
@ -58,11 +58,11 @@ class Template(Environment):
class MarkdownExtension(Extension):
tags = {'markdown'}
tags = {"markdown"}
extensions = (
'attr_list',
'smarty',
'tables'
"attr_list",
"smarty",
"tables"
)
@ -77,14 +77,14 @@ class MarkdownExtension(Extension):
def parse(self, parser: Parser) -> Node | list[Node]:
lineno = next(parser.stream).lineno
body = parser.parse_statements(
('name:endmarkdown',),
("name:endmarkdown",),
drop_needle = True
)
output = CallBlock(self.call_method('_render_markdown'), [], [], body)
output = CallBlock(self.call_method("_render_markdown"), [], [], body)
return output.set_lineno(lineno)
def _render_markdown(self, caller: Callable[[], str] | str) -> str:
text = caller if isinstance(caller, str) else caller()
return self._markdown.convert(textwrap.dedent(text.strip('\n')))
return self._markdown.convert(textwrap.dedent(text.strip("\n")))

View file

@ -44,61 +44,61 @@ class InboxData:
signer: Signer | None = None
try:
signature = Signature.parse(request.headers['signature'])
signature = Signature.parse(request.headers["signature"])
except KeyError:
logging.verbose('Missing signature header')
raise HttpError(400, 'missing signature header')
logging.verbose("Missing signature header")
raise HttpError(400, "missing signature header")
try:
message = await request.json(loads = Message.parse)
except Exception:
traceback.print_exc()
logging.verbose('Failed to parse message from actor: %s', signature.keyid)
raise HttpError(400, 'failed to parse message')
logging.verbose("Failed to parse message from actor: %s", signature.keyid)
raise HttpError(400, "failed to parse message")
if message is None:
logging.verbose('empty message')
raise HttpError(400, 'missing message')
logging.verbose("empty message")
raise HttpError(400, "missing message")
if 'actor' not in message:
logging.verbose('actor not in message')
raise HttpError(400, 'no actor in message')
if "actor" not in message:
logging.verbose("actor not in message")
raise HttpError(400, "no actor in message")
try:
actor = await app.client.get(signature.keyid, True, Message)
except HttpError as e:
# ld signatures aren't handled atm, so just ignore it
if message.type == 'Delete':
logging.verbose('Instance sent a delete which cannot be handled')
raise HttpError(202, '')
# ld signatures aren"t handled atm, so just ignore it
if message.type == "Delete":
logging.verbose("Instance sent a delete which cannot be handled")
raise HttpError(202, "")
logging.verbose('Failed to fetch actor: %s', signature.keyid)
logging.debug('HTTP Status %i: %s', e.status, e.message)
raise HttpError(400, 'failed to fetch actor')
logging.verbose("Failed to fetch actor: %s", signature.keyid)
logging.debug("HTTP Status %i: %s", e.status, e.message)
raise HttpError(400, "failed to fetch actor")
except ClientConnectorError as e:
logging.warning('Error when trying to fetch actor: %s, %s', signature.keyid, str(e))
raise HttpError(400, 'failed to fetch actor')
logging.warning("Error when trying to fetch actor: %s, %s", signature.keyid, str(e))
raise HttpError(400, "failed to fetch actor")
except Exception:
traceback.print_exc()
raise HttpError(500, 'unexpected error when fetching actor')
raise HttpError(500, "unexpected error when fetching actor")
try:
signer = actor.signer
except KeyError:
logging.verbose('Actor missing public key: %s', signature.keyid)
raise HttpError(400, 'actor missing public key')
logging.verbose("Actor missing public key: %s", signature.keyid)
raise HttpError(400, "actor missing public key")
try:
await signer.validate_request_async(request)
except SignatureFailureError as e:
logging.verbose('signature validation failed for "%s": %s', actor.id, e)
logging.verbose("signature validation failed for \"%s\": %s", actor.id, e)
raise HttpError(401, str(e))
return cls(signature, message, actor, signer, None)
@ -128,30 +128,30 @@ async def handle_inbox(app: Application, request: Request) -> Response:
# reject if actor is banned
if conn.get_domain_ban(data.actor.domain):
logging.verbose('Ignored request from banned actor: %s', data.actor.id)
raise HttpError(403, 'access denied')
logging.verbose("Ignored request from banned actor: %s", data.actor.id)
raise HttpError(403, "access denied")
# reject if activity type isn't 'Follow' and the actor isn't following
if data.message.type != 'Follow' and not data.instance:
# reject if activity type isn"t "Follow" and the actor isn"t following
if data.message.type != "Follow" and not data.instance:
logging.verbose(
'Rejected actor for trying to post while not following: %s',
"Rejected actor for trying to post while not following: %s",
data.actor.id
)
raise HttpError(401, 'access denied')
raise HttpError(401, "access denied")
logging.debug('>> payload %s', data.message.to_json(4))
logging.debug(">> payload %s", data.message.to_json(4))
await run_processor(data)
return Response.new(status = 202)
@register_route(HttpMethod.GET, '/outbox')
@register_route(HttpMethod.GET, "/outbox")
async def handle_outbox(app: Application, request: Request) -> Response:
msg = aputils.Message.new(
aputils.ObjectType.ORDERED_COLLECTION,
{
"id": f'https://{app.config.domain}/outbox',
"id": f"https://{app.config.domain}/outbox",
"totalItems": 0,
"orderedItems": []
}
@ -160,15 +160,15 @@ async def handle_outbox(app: Application, request: Request) -> Response:
return Response.new(msg, ctype = "activity")
@register_route(HttpMethod.GET, '/following', '/followers')
@register_route(HttpMethod.GET, "/following", "/followers")
async def handle_follow(app: Application, request: Request) -> Response:
with app.database.session(False) as s:
inboxes = [row['actor'] for row in s.get_inboxes()]
inboxes = [row["actor"] for row in s.get_inboxes()]
msg = aputils.Message.new(
aputils.ObjectType.COLLECTION,
{
"id": f'https://{app.config.domain}{request.path}',
"id": f"https://{app.config.domain}{request.path}",
"totalItems": len(inboxes),
"items": inboxes
}
@ -177,21 +177,21 @@ async def handle_follow(app: Application, request: Request) -> Response:
return Response.new(msg, ctype = "activity")
@register_route(HttpMethod.GET, '/.well-known/webfinger')
@register_route(HttpMethod.GET, "/.well-known/webfinger")
async def get(app: Application, request: Request) -> Response:
try:
subject = request.query['resource']
subject = request.query["resource"]
except KeyError:
raise HttpError(400, 'missing "resource" query key')
raise HttpError(400, "missing \"resource\" query key")
if subject != f'acct:relay@{app.config.domain}':
raise HttpError(404, 'user not found')
if subject != f"acct:relay@{app.config.domain}":
raise HttpError(404, "user not found")
data = aputils.Webfinger.new(
handle = 'relay',
handle = "relay",
domain = app.config.domain,
actor = app.config.actor
)
return Response.new(data, ctype = 'json')
return Response.new(data, ctype = "json")

View file

@ -343,7 +343,7 @@ async def handle_instance_add(
with app.database.session(False) as s:
if s.get_inbox(domain) is not None:
raise HttpError(404, 'Instance already in database')
raise HttpError(404, "Instance already in database")
if inbox is None:
try:
@ -396,7 +396,7 @@ async def handle_instance_update(
with app.database.session(False) as s:
if (instance := s.get_inbox(domain)) is None:
raise HttpError(404, 'Instance with domain not found')
raise HttpError(404, "Instance with domain not found")
row = s.put_inbox(
instance.domain,

View file

@ -31,11 +31,11 @@ if TYPE_CHECKING:
METHODS: dict[str, Method] = {}
ROUTES: list[tuple[str, str, HandlerCallback]] = []
DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob'
DEFAULT_REDIRECT: str = "urn:ietf:wg:oauth:2.0:oob"
ALLOWED_HEADERS: set[str] = {
'accept',
'authorization',
'content-type'
"accept",
"authorization",
"content-type"
}
@ -100,14 +100,14 @@ class Method:
return_type = get_origin(return_type)
if not issubclass(return_type, (Response, ApiObject, list)):
raise ValueError(f"Invalid return type '{return_type.__name__}' for {func.__name__}")
raise ValueError(f"Invalid return type \"{return_type.__name__}\" for {func.__name__}")
args = {key: value for key, value in inspect.signature(func).parameters.items()}
docstring, paramdocs = parse_docstring(func.__doc__ or "")
params = []
if func.__doc__ is None:
logging.warning(f"Missing docstring for '{func.__name__}'")
logging.warning(f"Missing docstring for \"{func.__name__}\"")
for key, value in args.items():
types: list[type[Any]] = []
@ -134,7 +134,7 @@ class Method:
))
if not paramdocs.get(key):
logging.warning(f"Missing docs for '{key}' parameter in '{func.__name__}'")
logging.warning(f"Missing docs for \"{key}\" parameter in \"{func.__name__}\"")
rtype = annotations.get("return") or type(None)
return cls(func.__name__, category, docstring, method, path, rtype, tuple(params))
@ -222,7 +222,7 @@ class Route:
if request.method != "OPTIONS" and self.require_token:
if (auth := request.headers.getone("Authorization", None)) is None:
raise HttpError(401, 'Missing token')
raise HttpError(401, "Missing token")
try:
authtype, code = auth.split(" ", 1)
@ -245,15 +245,15 @@ class Route:
request["application"] = application
if request.content_type in {'application/x-www-form-urlencoded', 'multipart/form-data'}:
if request.content_type in {"application/x-www-form-urlencoded", "multipart/form-data"}:
post_data = {key: value for key, value in (await request.post()).items()}
elif request.content_type == 'application/json':
elif request.content_type == "application/json":
try:
post_data = await request.json()
except JSONDecodeError:
raise HttpError(400, 'Invalid JSON data')
raise HttpError(400, "Invalid JSON data")
else:
post_data = {key: str(value) for key, value in request.query.items()}
@ -262,7 +262,7 @@ class Route:
response = await self.handler(get_app(), request, **post_data)
except HttpError as error:
return Response.new({'error': error.message}, error.status, ctype = "json")
return Response.new({"error": error.message}, error.status, ctype = "json")
headers = {
"Access-Control-Allow-Origin": "*",

View file

@ -19,7 +19,7 @@ if TYPE_CHECKING:
async def handle_home(app: Application, request: Request) -> Response:
with app.database.session() as conn:
context: dict[str, Any] = {
'instances': tuple(conn.get_inboxes())
"instances": tuple(conn.get_inboxes())
}
return Response.new_template(200, "page/home.haml", request, context)
@ -34,28 +34,28 @@ async def handle_api_doc(app: Application, request: Request) -> Response:
return Response.new_template(200, "page/docs.haml", request, context)
@register_route(HttpMethod.GET, '/login')
@register_route(HttpMethod.GET, "/login")
async def handle_login(app: Application, request: Request) -> Response:
context = {"redir": unquote(request.query.get("redir", "/"))}
return Response.new_template(200, "page/login.haml", request, context)
@register_route(HttpMethod.GET, '/logout')
@register_route(HttpMethod.GET, "/logout")
async def handle_logout(app: Application, request: Request) -> Response:
with app.database.session(True) as conn:
conn.del_app(request['token'].client_id, request['token'].client_secret)
conn.del_app(request["token"].client_id, request["token"].client_secret)
resp = Response.new_redir('/')
resp.del_cookie('user-token', domain = app.config.domain, path = '/')
resp = Response.new_redir("/")
resp.del_cookie("user-token", domain = app.config.domain, path = "/")
return resp
@register_route(HttpMethod.GET, '/admin')
@register_route(HttpMethod.GET, "/admin")
async def handle_admin(app: Application, request: Request) -> Response:
return Response.new_redir(f'/login?redir={request.path}', 301)
return Response.new_redir(f"/login?redir={request.path}", 301)
@register_route(HttpMethod.GET, '/admin/instances')
@register_route(HttpMethod.GET, "/admin/instances")
async def handle_admin_instances(
app: Application,
request: Request,
@ -64,20 +64,20 @@ async def handle_admin_instances(
with app.database.session() as conn:
context: dict[str, Any] = {
'instances': tuple(conn.get_inboxes()),
'requests': tuple(conn.get_requests())
"instances": tuple(conn.get_inboxes()),
"requests": tuple(conn.get_requests())
}
if error:
context['error'] = error
context["error"] = error
if message:
context['message'] = message
context["message"] = message
return Response.new_template(200, "page/admin-instances.haml", request, context)
@register_route(HttpMethod.GET, '/admin/whitelist')
@register_route(HttpMethod.GET, "/admin/whitelist")
async def handle_admin_whitelist(
app: Application,
request: Request,
@ -86,19 +86,19 @@ async def handle_admin_whitelist(
with app.database.session() as conn:
context: dict[str, Any] = {
'whitelist': tuple(conn.execute('SELECT * FROM whitelist ORDER BY domain ASC'))
"whitelist": tuple(conn.execute("SELECT * FROM whitelist ORDER BY domain ASC"))
}
if error:
context['error'] = error
context["error"] = error
if message:
context['message'] = message
context["message"] = message
return Response.new_template(200, "page/admin-whitelist.haml", request, context)
@register_route(HttpMethod.GET, '/admin/domain_bans')
@register_route(HttpMethod.GET, "/admin/domain_bans")
async def handle_admin_instance_bans(
app: Application,
request: Request,
@ -107,19 +107,19 @@ async def handle_admin_instance_bans(
with app.database.session() as conn:
context: dict[str, Any] = {
'bans': tuple(conn.execute('SELECT * FROM domain_bans ORDER BY domain ASC'))
"bans": tuple(conn.execute("SELECT * FROM domain_bans ORDER BY domain ASC"))
}
if error:
context['error'] = error
context["error"] = error
if message:
context['message'] = message
context["message"] = message
return Response.new_template(200, "page/admin-domain_bans.haml", request, context)
@register_route(HttpMethod.GET, '/admin/software_bans')
@register_route(HttpMethod.GET, "/admin/software_bans")
async def handle_admin_software_bans(
app: Application,
request: Request,
@ -128,19 +128,19 @@ async def handle_admin_software_bans(
with app.database.session() as conn:
context: dict[str, Any] = {
'bans': tuple(conn.execute('SELECT * FROM software_bans ORDER BY name ASC'))
"bans": tuple(conn.execute("SELECT * FROM software_bans ORDER BY name ASC"))
}
if error:
context['error'] = error
context["error"] = error
if message:
context['message'] = message
context["message"] = message
return Response.new_template(200, "page/admin-software_bans.haml", request, context)
@register_route(HttpMethod.GET, '/admin/users')
@register_route(HttpMethod.GET, "/admin/users")
async def handle_admin_users(
app: Application,
request: Request,
@ -149,29 +149,29 @@ async def handle_admin_users(
with app.database.session() as conn:
context: dict[str, Any] = {
'users': tuple(conn.execute('SELECT * FROM users ORDER BY username ASC'))
"users": tuple(conn.execute("SELECT * FROM users ORDER BY username ASC"))
}
if error:
context['error'] = error
context["error"] = error
if message:
context['message'] = message
context["message"] = message
return Response.new_template(200, "page/admin-users.haml", request, context)
@register_route(HttpMethod.GET, '/admin/config')
@register_route(HttpMethod.GET, "/admin/config")
async def handle_admin_config(
app: Application,
request: Request,
message: str | None = None) -> Response:
context: dict[str, Any] = {
'themes': tuple(THEMES.keys()),
'levels': tuple(level.name for level in LogLevel),
'message': message,
'desc': {
"themes": tuple(THEMES.keys()),
"levels": tuple(level.name for level in LogLevel),
"message": message,
"desc": {
"name": "Name of the relay to be displayed in the header of the pages and in " +
"the actor endpoint.", # noqa: E131
"note": "Description of the relay to be displayed on the front page and as the " +
@ -187,36 +187,36 @@ async def handle_admin_config(
return Response.new_template(200, "page/admin-config.haml", request, context)
@register_route(HttpMethod.GET, '/manifest.json')
@register_route(HttpMethod.GET, "/manifest.json")
async def handle_manifest(app: Application, request: Request) -> Response:
with app.database.session(False) as conn:
config = conn.get_config_all()
theme = THEMES[config.theme]
data = {
'background_color': theme['background'],
'categories': ['activitypub'],
'description': 'Message relay for the ActivityPub network',
'display': 'standalone',
'name': config['name'],
'orientation': 'portrait',
'scope': f"https://{app.config.domain}/",
'short_name': 'ActivityRelay',
'start_url': f"https://{app.config.domain}/",
'theme_color': theme['primary']
"background_color": theme["background"],
"categories": ["activitypub"],
"description": "Message relay for the ActivityPub network",
"display": "standalone",
"name": config["name"],
"orientation": "portrait",
"scope": f"https://{app.config.domain}/",
"short_name": "ActivityRelay",
"start_url": f"https://{app.config.domain}/",
"theme_color": theme["primary"]
}
return Response.new(data, ctype = 'webmanifest')
return Response.new(data, ctype = "webmanifest")
@register_route(HttpMethod.GET, '/theme/{theme}.css') # type: ignore[arg-type]
@register_route(HttpMethod.GET, "/theme/{theme}.css") # type: ignore[arg-type]
async def handle_theme(app: Application, request: Request, theme: str) -> Response:
try:
context: dict[str, Any] = {
'theme': THEMES[theme]
"theme": THEMES[theme]
}
except KeyError:
return Response.new('Invalid theme', 404)
return Response.new("Invalid theme", 404)
return Response.new_template(200, "variables.css", request, context, ctype = "css")

View file

@ -21,16 +21,18 @@ VERSION = __version__
if File(__file__).join("../../../.git").resolve().exists:
try:
commit_label = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode('ascii')
VERSION = f'{__version__} {commit_label}'
commit_label = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode("ascii")
VERSION = f"{__version__} {commit_label}"
del commit_label
except Exception:
pass
NODEINFO_PATHS = [
'/nodeinfo/{niversion:\\d.\\d}.json',
'/nodeinfo/{niversion:\\d.\\d}'
"/nodeinfo/{niversion:\\d.\\d}.json",
"/nodeinfo/{niversion:\\d.\\d}"
]
@ -40,23 +42,23 @@ async def handle_nodeinfo(app: Application, request: Request, niversion: str) ->
inboxes = conn.get_inboxes()
nodeinfo = aputils.Nodeinfo.new(
name = 'activityrelay',
name = "activityrelay",
version = VERSION,
protocols = ['activitypub'],
open_regs = not conn.get_config('whitelist-enabled'),
protocols = ["activitypub"],
open_regs = not conn.get_config("whitelist-enabled"),
users = 1,
repo = 'https://git.pleroma.social/pleroma/relay' if niversion == '2.1' else None,
repo = "https://codeberg.org/barkshark/activityrelay" if niversion == "2.1" else None,
metadata = {
'approval_required': conn.get_config('approval-required'),
'peers': [inbox['domain'] for inbox in inboxes]
"approval_required": conn.get_config("approval-required"),
"peers": [inbox["domain"] for inbox in inboxes]
}
)
return Response.new(nodeinfo, ctype = 'json')
return Response.new(nodeinfo, ctype = "json")
@register_route(HttpMethod.GET, '/.well-known/nodeinfo')
@register_route(HttpMethod.GET, "/.well-known/nodeinfo")
async def handle_wk_nodeinfo(app: Application, request: Request) -> Response:
data = aputils.WellKnownNodeinfo.new_template(app.config.domain)
return Response.new(data, ctype = 'json')
return Response.new(data, ctype = "json")

View file

@ -96,16 +96,16 @@ class PushWorker(Process):
await self.client.post(item.inbox, item.message, item.instance)
except HttpError as e:
logging.error('HTTP Error when pushing to %s: %i %s', item.inbox, e.status, e.message)
logging.error("HTTP Error when pushing to %s: %i %s", item.inbox, e.status, e.message)
except AsyncTimeoutError:
logging.error('Timeout when pushing to %s', item.domain)
logging.error("Timeout when pushing to %s", item.domain)
except ClientConnectionError as e:
logging.error('Failed to connect to %s for message push: %s', item.domain, str(e))
logging.error("Failed to connect to %s for message push: %s", item.domain, str(e))
except ClientSSLError as e:
logging.error('SSL error when pushing to %s: %s', item.domain, str(e))
logging.error("SSL error when pushing to %s: %s", item.domain, str(e))
class PushWorkers(list[PushWorker]):