diff --git a/dev.py b/dev.py index 114073f..5ef6f53 100755 --- a/dev.py +++ b/dev.py @@ -37,32 +37,32 @@ from watchdog.events import FileSystemEvent, PatternMatchingEventHandler REPO = Path(__file__).parent IGNORE_EXT = { - '.py', - '.pyc' + ".py", + ".pyc" } -@click.group('cli') +@click.group("cli") def cli() -> None: - 'Useful commands for development' + "Useful commands for development" -@cli.command('install') -@click.option('--no-dev', '-d', is_flag = True, help = 'Do not install development dependencies') +@cli.command("install") +@click.option("--no-dev", "-d", is_flag = True, help = "Do not install development dependencies") def cli_install(no_dev: bool) -> None: - with open('pyproject.toml', 'r', encoding = 'utf-8') as fd: + with open("pyproject.toml", "r", encoding = "utf-8") as fd: data = tomllib.loads(fd.read()) - deps = data['project']['dependencies'] - deps.extend(data['project']['optional-dependencies']['dev']) + deps = data["project"]["dependencies"] + deps.extend(data["project"]["optional-dependencies"]["dev"]) - subprocess.run([sys.executable, '-m', 'pip', 'install', '-U', *deps], check = False) + subprocess.run([sys.executable, "-m", "pip", "install", "-U", *deps], check = False) -@cli.command('lint') -@click.argument('path', required = False, type = Path, default = REPO.joinpath('relay')) -@click.option('--watch', '-w', is_flag = True, - help = 'Automatically, re-run the linters on source change') +@cli.command("lint") +@click.argument("path", required = False, type = Path, default = REPO.joinpath("relay")) +@click.option("--watch", "-w", is_flag = True, + help = "Automatically, re-run the linters on source change") def cli_lint(path: Path, watch: bool) -> None: path = path.expanduser().resolve() @@ -70,74 +70,73 @@ def cli_lint(path: Path, watch: bool) -> None: handle_run_watcher([sys.executable, "dev.py", "lint", str(path)], wait = True) return - flake8 = [sys.executable, '-m', 'flake8', "dev.py", str(path)] - mypy = [sys.executable, '-m', 'mypy', '--python-version', '3.12', 'dev.py', str(path)] + flake8 = [sys.executable, "-m", "flake8", "dev.py", str(path)] + mypy = [sys.executable, "-m", "mypy", "--python-version", "3.12", "dev.py", str(path)] - click.echo('----- flake8 -----') + click.echo("----- flake8 -----") subprocess.run(flake8) - click.echo('\n\n----- mypy -----') + click.echo("\n\n----- mypy -----") subprocess.run(mypy) -@cli.command('clean') +@cli.command("clean") def cli_clean() -> None: dirs = { - 'dist', - 'build', - 'dist-pypi' + "dist", + "build", + "dist-pypi" } for directory in dirs: shutil.rmtree(directory, ignore_errors = True) - for path in REPO.glob('*.egg-info'): + for path in REPO.glob("*.egg-info"): shutil.rmtree(path) - for path in REPO.glob('*.spec'): + for path in REPO.glob("*.spec"): path.unlink() -@cli.command('build') +@cli.command("build") def cli_build() -> None: from relay import __version__ with TemporaryDirectory() as tmp: - arch = 'amd64' if sys.maxsize >= 2**32 else 'i386' + arch = "amd64" if sys.maxsize >= 2**32 else "i386" cmd = [ - sys.executable, '-m', 'PyInstaller', - '--collect-data', 'relay', - '--collect-data', 'aiohttp_swagger', - '--hidden-import', 'pg8000', - '--hidden-import', 'sqlite3', - '--name', f'activityrelay-{__version__}-{platform.system().lower()}-{arch}', - '--workpath', tmp, - '--onefile', 'relay/__main__.py', + sys.executable, "-m", "PyInstaller", + "--collect-data", "relay", + "--hidden-import", "pg8000", + "--hidden-import", "sqlite3", + "--name", f"activityrelay-{__version__}-{platform.system().lower()}-{arch}", + "--workpath", tmp, + "--onefile", "relay/__main__.py", ] - if platform.system() == 'Windows': - cmd.append('--console') + if platform.system() == "Windows": + cmd.append("--console") # putting the spec path on a different drive than the source dir breaks if str(REPO)[0] == tmp[0]: - cmd.extend(['--specpath', tmp]) + cmd.extend(["--specpath", tmp]) else: - cmd.append('--strip') - cmd.extend(['--specpath', tmp]) + cmd.append("--strip") + cmd.extend(["--specpath", tmp]) subprocess.run(cmd, check = False) -@cli.command('run') -@click.option('--dev', '-d', is_flag = True) +@cli.command("run") +@click.option("--dev", "-d", is_flag = True) def cli_run(dev: bool) -> None: - print('Starting process watcher') + print("Starting process watcher") - cmd = [sys.executable, '-m', 'relay', 'run'] + cmd = [sys.executable, "-m", "relay", "run"] if dev: - cmd.append('-d') + cmd.append("-d") handle_run_watcher(cmd, watch_path = REPO.joinpath("relay")) @@ -167,7 +166,7 @@ def handle_run_watcher( class WatchHandler(PatternMatchingEventHandler): - patterns = ['*.py'] + patterns = ["*.py"] def __init__(self, *commands: Sequence[str], wait: bool = False) -> None: @@ -184,7 +183,7 @@ class WatchHandler(PatternMatchingEventHandler): if proc.poll() is not None: continue - print(f'Terminating process {proc.pid}') + print(f"Terminating process {proc.pid}") proc.terminate() sec = 0.0 @@ -193,11 +192,11 @@ class WatchHandler(PatternMatchingEventHandler): sec += 0.1 if sec >= 5: - print('Failed to terminate. Killing process...') + print("Failed to terminate. Killing process...") proc.kill() break - print('Process terminated') + print("Process terminated") def run_procs(self, restart: bool = False) -> None: @@ -213,21 +212,21 @@ class WatchHandler(PatternMatchingEventHandler): self.procs = [] for cmd in self.commands: - print('Running command:', ' '.join(cmd)) + print("Running command:", " ".join(cmd)) subprocess.run(cmd) else: self.procs = list(subprocess.Popen(cmd) for cmd in self.commands) pids = (str(proc.pid) for proc in self.procs) - print('Started processes with PIDs:', ', '.join(pids)) + print("Started processes with PIDs:", ", ".join(pids)) def on_any_event(self, event: FileSystemEvent) -> None: - if event.event_type not in ['modified', 'created', 'deleted']: + if event.event_type not in ["modified", "created", "deleted"]: return self.run_procs(restart = True) -if __name__ == '__main__': +if __name__ == "__main__": cli() diff --git a/relay/__init__.py b/relay/__init__.py index 80eb7f9..e19434e 100644 --- a/relay/__init__.py +++ b/relay/__init__.py @@ -1 +1 @@ -__version__ = '0.3.3' +__version__ = "0.3.3" diff --git a/relay/__main__.py b/relay/__main__.py index d3d7c18..d4146b5 100644 --- a/relay/__main__.py +++ b/relay/__main__.py @@ -1,8 +1,5 @@ -import multiprocessing - from relay.manage import main -if __name__ == '__main__': - multiprocessing.freeze_support() +if __name__ == "__main__": main() diff --git a/relay/application.py b/relay/application.py index 6180792..7082c74 100644 --- a/relay/application.py +++ b/relay/application.py @@ -41,10 +41,10 @@ def get_csp(request: web.Request) -> str: "img-src 'self'", "object-src 'none'", "frame-ancestors 'none'", - f"manifest-src 'self' https://{request.app['config'].domain}" + f"manifest-src 'self' https://{request.app["config"].domain}" ] - return '; '.join(data) + ';' + return "; ".join(data) + ";" class Application(web.Application): @@ -61,20 +61,20 @@ class Application(web.Application): Application.DEFAULT = self - self['running'] = False - self['signer'] = None - self['start_time'] = None - self['cleanup_thread'] = None - self['dev'] = dev + self["running"] = False + self["signer"] = None + self["start_time"] = None + self["cleanup_thread"] = None + self["dev"] = dev - self['config'] = Config(cfgpath, load = True) - self['database'] = get_database(self.config) - self['client'] = HttpClient() - self['cache'] = get_cache(self) - self['cache'].setup() - self['template'] = Template(self) - self['push_queue'] = multiprocessing.Queue() - self['workers'] = PushWorkers(self.config.workers) + self["config"] = Config(cfgpath, load = True) + self["database"] = get_database(self.config) + self["client"] = HttpClient() + self["cache"] = get_cache(self) + self["cache"].setup() + self["template"] = Template(self) + self["push_queue"] = multiprocessing.Queue() + self["workers"] = PushWorkers(self.config.workers) self.cache.setup() self.on_cleanup.append(handle_cleanup) # type: ignore @@ -82,69 +82,69 @@ class Application(web.Application): @property def cache(self) -> Cache: - return cast(Cache, self['cache']) + return cast(Cache, self["cache"]) @property def client(self) -> HttpClient: - return cast(HttpClient, self['client']) + return cast(HttpClient, self["client"]) @property def config(self) -> Config: - return cast(Config, self['config']) + return cast(Config, self["config"]) @property def database(self) -> Database[Connection]: - return cast(Database[Connection], self['database']) + return cast(Database[Connection], self["database"]) @property def signer(self) -> Signer: - return cast(Signer, self['signer']) + return cast(Signer, self["signer"]) @signer.setter def signer(self, value: Signer | str) -> None: if isinstance(value, Signer): - self['signer'] = value + self["signer"] = value return - self['signer'] = Signer(value, self.config.keyid) + self["signer"] = Signer(value, self.config.keyid) @property def template(self) -> Template: - return cast(Template, self['template']) + return cast(Template, self["template"]) @property def uptime(self) -> timedelta: - if not self['start_time']: + if not self["start_time"]: return timedelta(seconds=0) - uptime = datetime.now() - self['start_time'] + uptime = datetime.now() - self["start_time"] return timedelta(seconds=uptime.seconds) @property def workers(self) -> PushWorkers: - return cast(PushWorkers, self['workers']) + return cast(PushWorkers, self["workers"]) def push_message(self, inbox: str, message: Message, instance: Instance) -> None: - self['workers'].push_message(inbox, message, instance) + self["workers"].push_message(inbox, message, instance) def register_static_routes(self) -> None: - if self['dev']: - static = StaticResource('/static', File.from_resource('relay', 'frontend/static')) + if self["dev"]: + static = StaticResource("/static", File.from_resource("relay", "frontend/static")) else: static = CachedStaticResource( - '/static', Path(File.from_resource('relay', 'frontend/static')) + "/static", Path(File.from_resource("relay", "frontend/static")) ) self.router.register_resource(static) @@ -158,18 +158,18 @@ class Application(web.Application): host = self.config.listen port = self.config.port - if port_check(port, '127.0.0.1' if host == '0.0.0.0' else host): - logging.error(f'A server is already running on {host}:{port}') + if port_check(port, "127.0.0.1" if host == "0.0.0.0" else host): + logging.error(f"A server is already running on {host}:{port}") return self.register_static_routes() - logging.info(f'Starting webserver at {domain} ({host}:{port})') + logging.info(f"Starting webserver at {domain} ({host}:{port})") asyncio.run(self.handle_run()) def set_signal_handler(self, startup: bool) -> None: - for sig in ('SIGHUP', 'SIGINT', 'SIGQUIT', 'SIGTERM'): + for sig in ("SIGHUP", "SIGINT", "SIGQUIT", "SIGTERM"): try: signal.signal(getattr(signal, sig), self.stop if startup else signal.SIG_DFL) @@ -179,22 +179,25 @@ class Application(web.Application): def stop(self, *_: Any) -> None: - self['running'] = False + self["running"] = False async def handle_run(self) -> None: - self['running'] = True + self["running"] = True self.set_signal_handler(True) - self['client'].open() - self['database'].connect() - self['cache'].setup() - self['cleanup_thread'] = CacheCleanupThread(self) - self['cleanup_thread'].start() - self['workers'].start() + self["client"].open() + self["database"].connect() + self["cache"].setup() + self["cleanup_thread"] = CacheCleanupThread(self) + self["cleanup_thread"].start() + self["workers"].start() + + runner = web.AppRunner( + self, access_log_format = "%{X-Forwarded-For}i \"%r\" %s %b \"%{User-Agent}i\"" + ) - runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"') await runner.setup() site = web.TCPSite( @@ -205,22 +208,22 @@ class Application(web.Application): ) await site.start() - self['starttime'] = datetime.now() + self["starttime"] = datetime.now() - while self['running']: + while self["running"]: await asyncio.sleep(0.25) await site.stop() - self['workers'].stop() + self["workers"].stop() self.set_signal_handler(False) - self['starttime'] = None - self['running'] = False - self['cleanup_thread'].stop() - self['database'].disconnect() - self['cache'].close() + self["starttime"] = None + self["running"] = False + self["cleanup_thread"].stop() + self["database"].disconnect() + self["cache"].close() class CachedStaticResource(StaticResource): @@ -229,19 +232,19 @@ class CachedStaticResource(StaticResource): self.cache: dict[str, bytes] = {} - for filename in path.rglob('*'): + for filename in path.rglob("*"): if filename.is_dir(): continue rel_path = str(filename.relative_to(path)) - with filename.open('rb') as fd: - logging.debug('Loading static resource "%s"', rel_path) + with filename.open("rb") as fd: + logging.debug("Loading static resource \"%s\"", rel_path) self.cache[rel_path] = fd.read() async def _handle(self, request: web.Request) -> web.StreamResponse: - rel_url = request.match_info['filename'] + rel_url = request.match_info["filename"] if Path(rel_url).anchor: raise web.HTTPForbidden() @@ -281,8 +284,8 @@ class CacheCleanupThread(Thread): def format_error(request: web.Request, error: HttpError) -> Response: - if request.path.startswith(JSON_PATHS) or 'json' in request.headers.get('accept', ''): - return Response.new({'error': error.message}, error.status, ctype = 'json') + if request.path.startswith(JSON_PATHS) or "json" in request.headers.get("accept", ""): + return Response.new({"error": error.message}, error.status, ctype = "json") else: context = {"e": error} @@ -294,27 +297,27 @@ async def handle_response_headers( request: web.Request, handler: Callable[[web.Request], Awaitable[Response]]) -> Response: - request['hash'] = b64encode(get_random_bytes(16)).decode('ascii') - request['token'] = None - request['user'] = None + request["hash"] = b64encode(get_random_bytes(16)).decode("ascii") + request["token"] = None + request["user"] = None app: Application = request.app # type: ignore[assignment] if request.path in {"/", "/docs"} or request.path.startswith(TOKEN_PATHS): with app.database.session() as conn: tokens = ( - request.headers.get('Authorization', '').replace('Bearer', '').strip(), - request.cookies.get('user-token') + request.headers.get("Authorization", "").replace("Bearer", "").strip(), + request.cookies.get("user-token") ) for token in tokens: if not token: continue - request['token'] = conn.get_app_by_token(token) + request["token"] = conn.get_app_by_token(token) - if request['token'] is not None: - request['user'] = conn.get_user(request['token'].user) + if request["token"] is not None: + request["user"] = conn.get_user(request["token"].user) break @@ -338,21 +341,21 @@ async def handle_response_headers( raise except Exception: - resp = format_error(request, HttpError(500, 'Internal server error')) + resp = format_error(request, HttpError(500, "Internal server error")) traceback.print_exc() - resp.headers['Server'] = 'ActivityRelay' + resp.headers["Server"] = "ActivityRelay" # Still have to figure out how csp headers work - if resp.content_type == 'text/html' and not request.path.startswith("/api"): - resp.headers['Content-Security-Policy'] = get_csp(request) + if resp.content_type == "text/html" and not request.path.startswith("/api"): + resp.headers["Content-Security-Policy"] = get_csp(request) - if not request.app['dev'] and request.path.endswith(('.css', '.js', '.woff2')): + if not request.app["dev"] and request.path.endswith((".css", ".js", ".woff2")): # cache for 2 weeks - resp.headers['Cache-Control'] = 'public,max-age=1209600,immutable' + resp.headers["Cache-Control"] = "public,max-age=1209600,immutable" else: - resp.headers['Cache-Control'] = 'no-store' + resp.headers["Cache-Control"] = "no-store" return resp @@ -362,25 +365,25 @@ async def handle_frontend_path( request: web.Request, handler: Callable[[web.Request], Awaitable[Response]]) -> Response: - if request['user'] is not None and request.path == '/login': - return Response.new_redir('/') + if request["user"] is not None and request.path == "/login": + return Response.new_redir("/") - if request.path.startswith(TOKEN_PATHS[:2]) and request['user'] is None: - if request.path == '/logout': - return Response.new_redir('/') + if request.path.startswith(TOKEN_PATHS[:2]) and request["user"] is None: + if request.path == "/logout": + return Response.new_redir("/") - response = Response.new_redir(f'/login?redir={request.path}') + response = Response.new_redir(f"/login?redir={request.path}") - if request['token'] is not None: - response.del_cookie('user-token') + if request["token"] is not None: + response.del_cookie("user-token") return response response = await handler(request) - if not request.path.startswith('/api'): - if request['user'] is None and request['token'] is not None: - response.del_cookie('user-token') + if not request.path.startswith("/api"): + if request["user"] is None and request["token"] is not None: + response.del_cookie("user-token") return response diff --git a/relay/cache.py b/relay/cache.py index 4a6b7dd..5b3627a 100644 --- a/relay/cache.py +++ b/relay/cache.py @@ -24,11 +24,11 @@ DeserializerCallback = Callable[[str], Any] BACKENDS: dict[str, type[Cache]] = {} CONVERTERS: dict[str, tuple[SerializerCallback, DeserializerCallback]] = { - 'str': (str, str), - 'int': (str, int), - 'bool': (str, convert_to_boolean), - 'json': (json.dumps, json.loads), - 'message': (lambda x: x.to_json(), Message.parse) + "str": (str, str), + "int": (str, int), + "bool": (str, convert_to_boolean), + "json": (json.dumps, json.loads), + "message": (lambda x: x.to_json(), Message.parse) } @@ -49,14 +49,14 @@ def register_cache(backend: type[Cache]) -> type[Cache]: return backend -def serialize_value(value: Any, value_type: str = 'str') -> str: +def serialize_value(value: Any, value_type: str = "str") -> str: if isinstance(value, str): return value return CONVERTERS[value_type][0](value) -def deserialize_value(value: str, value_type: str = 'str') -> Any: +def deserialize_value(value: str, value_type: str = "str") -> Any: return CONVERTERS[value_type][1](value) @@ -93,7 +93,7 @@ class Item: class Cache(ABC): - name: str = 'null' + name: str def __init__(self, app: Application): @@ -116,7 +116,7 @@ class Cache(ABC): @abstractmethod - def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> Item: + def set(self, namespace: str, key: str, value: Any, value_type: str = "key") -> Item: ... @@ -160,7 +160,7 @@ class Cache(ABC): @register_cache class SqlCache(Cache): - name: str = 'database' + name: str = "database" def __init__(self, app: Application): @@ -173,16 +173,16 @@ class SqlCache(Cache): raise RuntimeError("Database has not been setup") params = { - 'namespace': namespace, - 'key': key + "namespace": namespace, + "key": key } with self._db.session(False) as conn: - with conn.run('get-cache-item', params) as cur: + with conn.run("get-cache-item", params) as cur: if not (row := cur.one(Row)): - raise KeyError(f'{namespace}:{key}') + raise KeyError(f"{namespace}:{key}") - row.pop('id', None) + row.pop("id", None) return Item.from_data(*tuple(row.values())) @@ -191,8 +191,8 @@ class SqlCache(Cache): raise RuntimeError("Database has not been setup") with self._db.session(False) as conn: - for row in conn.run('get-cache-keys', {'namespace': namespace}): - yield row['key'] + for row in conn.run("get-cache-keys", {"namespace": namespace}): + yield row["key"] def get_namespaces(self) -> Iterator[str]: @@ -200,28 +200,28 @@ class SqlCache(Cache): raise RuntimeError("Database has not been setup") with self._db.session(False) as conn: - for row in conn.run('get-cache-namespaces', None): - yield row['namespace'] + for row in conn.run("get-cache-namespaces", None): + yield row["namespace"] - def set(self, namespace: str, key: str, value: Any, value_type: str = 'str') -> Item: + def set(self, namespace: str, key: str, value: Any, value_type: str = "str") -> Item: if self._db is None: raise RuntimeError("Database has not been setup") params = { - 'namespace': namespace, - 'key': key, - 'value': serialize_value(value, value_type), - 'type': value_type, - 'date': Date.new_utc() + "namespace": namespace, + "key": key, + "value": serialize_value(value, value_type), + "type": value_type, + "date": Date.new_utc() } with self._db.session(True) as conn: - with conn.run('set-cache-item', params) as cur: + with conn.run("set-cache-item", params) as cur: if (row := cur.one(Row)) is None: raise RuntimeError("Cache item not set") - row.pop('id', None) + row.pop("id", None) return Item.from_data(*tuple(row.values())) @@ -230,12 +230,12 @@ class SqlCache(Cache): raise RuntimeError("Database has not been setup") params = { - 'namespace': namespace, - 'key': key + "namespace": namespace, + "key": key } with self._db.session(True) as conn: - with conn.run('del-cache-item', params): + with conn.run("del-cache-item", params): pass @@ -267,7 +267,7 @@ class SqlCache(Cache): self._db.connect() with self._db.session(True) as conn: - with conn.run(f'create-cache-table-{self._db.backend_type.value}', None): + with conn.run(f"create-cache-table-{self._db.backend_type.value}", None): pass @@ -281,7 +281,7 @@ class SqlCache(Cache): @register_cache class RedisCache(Cache): - name: str = 'redis' + name: str = "redis" def __init__(self, app: Application): @@ -295,7 +295,7 @@ class RedisCache(Cache): def get_key_name(self, namespace: str, key: str) -> str: - return f'{self.prefix}:{namespace}:{key}' + return f"{self.prefix}:{namespace}:{key}" def get(self, namespace: str, key: str) -> Item: @@ -305,9 +305,9 @@ class RedisCache(Cache): key_name = self.get_key_name(namespace, key) if not (raw_value := self._rd.get(key_name)): - raise KeyError(f'{namespace}:{key}') + raise KeyError(f"{namespace}:{key}") - value_type, updated, value = raw_value.split(':', 2) # type: ignore[union-attr] + value_type, updated, value = raw_value.split(":", 2) # type: ignore[union-attr] return Item.from_data( namespace, @@ -322,8 +322,8 @@ class RedisCache(Cache): if self._rd is None: raise ConnectionError("Not connected") - for key in self._rd.scan_iter(self.get_key_name(namespace, '*')): - *_, key_name = key.split(':', 2) + for key in self._rd.scan_iter(self.get_key_name(namespace, "*")): + *_, key_name = key.split(":", 2) yield key_name @@ -333,15 +333,15 @@ class RedisCache(Cache): namespaces = [] - for key in self._rd.scan_iter(f'{self.prefix}:*'): - _, namespace, _ = key.split(':', 2) + for key in self._rd.scan_iter(f"{self.prefix}:*"): + _, namespace, _ = key.split(":", 2) if namespace not in namespaces: namespaces.append(namespace) yield namespace - def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> Item: + def set(self, namespace: str, key: str, value: Any, value_type: str = "key") -> Item: if self._rd is None: raise ConnectionError("Not connected") @@ -350,7 +350,7 @@ class RedisCache(Cache): self._rd.set( self.get_key_name(namespace, key), - f'{value_type}:{date}:{value}' + f"{value_type}:{date}:{value}" ) return self.get(namespace, key) @@ -369,8 +369,8 @@ class RedisCache(Cache): limit = Date.new_utc() - timedelta(days = days) - for full_key in self._rd.scan_iter(f'{self.prefix}:*'): - _, namespace, key = full_key.split(':', 2) + for full_key in self._rd.scan_iter(f"{self.prefix}:*"): + _, namespace, key = full_key.split(":", 2) item = self.get(namespace, key) if item.updated < limit: @@ -389,11 +389,11 @@ class RedisCache(Cache): return options: RedisConnectType = { - 'client_name': f'ActivityRelay_{self.app.config.domain}', - 'decode_responses': True, - 'username': self.app.config.rd_user, - 'password': self.app.config.rd_pass, - 'db': self.app.config.rd_database + "client_name": f"ActivityRelay_{self.app.config.domain}", + "decode_responses": True, + "username": self.app.config.rd_user, + "password": self.app.config.rd_pass, + "db": self.app.config.rd_database } if os.path.exists(self.app.config.rd_host): diff --git a/relay/compat.py b/relay/compat.py index 1f81296..ef0ebd8 100644 --- a/relay/compat.py +++ b/relay/compat.py @@ -14,21 +14,21 @@ class RelayConfig(dict[str, Any]): dict.__init__(self, {}) if self.is_docker: - path = '/data/config.yaml' + path = "/data/config.yaml" self._path = Path(path).expanduser().resolve() self.reset() def __setitem__(self, key: str, value: Any) -> None: - if key in {'blocked_instances', 'blocked_software', 'whitelist'}: + if key in {"blocked_instances", "blocked_software", "whitelist"}: assert isinstance(value, (list, set, tuple)) - elif key in {'port', 'workers', 'json_cache', 'timeout'}: + elif key in {"port", "workers", "json_cache", "timeout"}: if not isinstance(value, int): value = int(value) - elif key == 'whitelist_enabled': + elif key == "whitelist_enabled": if not isinstance(value, bool): value = convert_to_boolean(value) @@ -37,45 +37,45 @@ class RelayConfig(dict[str, Any]): @property def db(self) -> Path: - return Path(self['db']).expanduser().resolve() + return Path(self["db"]).expanduser().resolve() @property def actor(self) -> str: - return f'https://{self["host"]}/actor' + return f"https://{self['host']}/actor" @property def inbox(self) -> str: - return f'https://{self["host"]}/inbox' + return f"https://{self['host']}/inbox" @property def keyid(self) -> str: - return f'{self.actor}#main-key' + return f"{self.actor}#main-key" @cached_property def is_docker(self) -> bool: - return bool(os.environ.get('DOCKER_RUNNING')) + return bool(os.environ.get("DOCKER_RUNNING")) def reset(self) -> None: self.clear() self.update({ - 'db': str(self._path.parent.joinpath(f'{self._path.stem}.jsonld')), - 'listen': '0.0.0.0', - 'port': 8080, - 'note': 'Make a note about your instance here.', - 'push_limit': 512, - 'json_cache': 1024, - 'timeout': 10, - 'workers': 0, - 'host': 'relay.example.com', - 'whitelist_enabled': False, - 'blocked_software': [], - 'blocked_instances': [], - 'whitelist': [] + "db": str(self._path.parent.joinpath(f"{self._path.stem}.jsonld")), + "listen": "0.0.0.0", + "port": 8080, + "note": "Make a note about your instance here.", + "push_limit": 512, + "json_cache": 1024, + "timeout": 10, + "workers": 0, + "host": "relay.example.com", + "whitelist_enabled": False, + "blocked_software": [], + "blocked_instances": [], + "whitelist": [] }) @@ -85,13 +85,13 @@ class RelayConfig(dict[str, Any]): options = {} try: - options['Loader'] = yaml.FullLoader + options["Loader"] = yaml.FullLoader except AttributeError: pass try: - with self._path.open('r', encoding = 'UTF-8') as fd: + with self._path.open("r", encoding = "UTF-8") as fd: config = yaml.load(fd, **options) except FileNotFoundError: @@ -101,7 +101,7 @@ class RelayConfig(dict[str, Any]): return for key, value in config.items(): - if key == 'ap': + if key == "ap": for k, v in value.items(): if k not in self: continue @@ -119,10 +119,10 @@ class RelayConfig(dict[str, Any]): class RelayDatabase(dict[str, Any]): def __init__(self, config: RelayConfig): dict.__init__(self, { - 'relay-list': {}, - 'private-key': None, - 'follow-requests': {}, - 'version': 1 + "relay-list": {}, + "private-key": None, + "follow-requests": {}, + "version": 1 }) self.config = config @@ -131,12 +131,12 @@ class RelayDatabase(dict[str, Any]): @property def hostnames(self) -> tuple[str]: - return tuple(self['relay-list'].keys()) + return tuple(self["relay-list"].keys()) @property def inboxes(self) -> tuple[dict[str, str]]: - return tuple(data['inbox'] for data in self['relay-list'].values()) + return tuple(data["inbox"] for data in self["relay-list"].values()) def load(self) -> None: @@ -144,29 +144,29 @@ class RelayDatabase(dict[str, Any]): with self.config.db.open() as fd: data = json.load(fd) - self['version'] = data.get('version', None) - self['private-key'] = data.get('private-key') + self["version"] = data.get("version", None) + self["private-key"] = data.get("private-key") - if self['version'] is None: - self['version'] = 1 + if self["version"] is None: + self["version"] = 1 - if 'actorKeys' in data: - self['private-key'] = data['actorKeys']['privateKey'] + if "actorKeys" in data: + self["private-key"] = data["actorKeys"]["privateKey"] - for item in data.get('relay-list', []): + for item in data.get("relay-list", []): domain = urlparse(item).hostname - self['relay-list'][domain] = { - 'domain': domain, - 'inbox': item, - 'followid': None + self["relay-list"][domain] = { + "domain": domain, + "inbox": item, + "followid": None } else: - self['relay-list'] = data.get('relay-list', {}) + self["relay-list"] = data.get("relay-list", {}) - for domain, instance in self['relay-list'].items(): - if not instance.get('domain'): - instance['domain'] = domain + for domain, instance in self["relay-list"].items(): + if not instance.get("domain"): + instance["domain"] = domain except FileNotFoundError: pass diff --git a/relay/config.py b/relay/config.py index e40cf70..6f31e9a 100644 --- a/relay/config.py +++ b/relay/config.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from typing import Self -if platform.system() == 'Windows': +if platform.system() == "Windows": import multiprocessing CORE_COUNT = multiprocessing.cpu_count() @@ -25,9 +25,9 @@ else: DOCKER_VALUES = { - 'listen': '0.0.0.0', - 'port': 8080, - 'sq_path': '/data/relay.sqlite3' + "listen": "0.0.0.0", + "port": 8080, + "sq_path": "/data/relay.sqlite3" } @@ -37,26 +37,26 @@ class NOVALUE: @dataclass(init = False) class Config: - listen: str = '0.0.0.0' + listen: str = "0.0.0.0" port: int = 8080 - domain: str = 'relay.example.com' + domain: str = "relay.example.com" workers: int = CORE_COUNT - db_type: str = 'sqlite' - ca_type: str = 'database' - sq_path: str = 'relay.sqlite3' + db_type: str = "sqlite" + ca_type: str = "database" + sq_path: str = "relay.sqlite3" - pg_host: str = '/var/run/postgresql' + pg_host: str = "/var/run/postgresql" pg_port: int = 5432 pg_user: str = getpass.getuser() pg_pass: str | None = None - pg_name: str = 'activityrelay' + pg_name: str = "activityrelay" - rd_host: str = 'localhost' + rd_host: str = "localhost" rd_port: int = 6470 rd_user: str | None = None rd_pass: str | None = None rd_database: int = 0 - rd_prefix: str = 'activityrelay' + rd_prefix: str = "activityrelay" def __init__(self, path: Path | None = None, load: bool = False): @@ -116,17 +116,17 @@ class Config: @property def actor(self) -> str: - return f'https://{self.domain}/actor' + return f"https://{self.domain}/actor" @property def inbox(self) -> str: - return f'https://{self.domain}/inbox' + return f"https://{self.domain}/inbox" @property def keyid(self) -> str: - return f'{self.actor}#main-key' + return f"{self.actor}#main-key" def load(self) -> None: @@ -134,43 +134,43 @@ class Config: options = {} try: - options['Loader'] = yaml.FullLoader + options["Loader"] = yaml.FullLoader except AttributeError: pass - with self.path.open('r', encoding = 'UTF-8') as fd: + with self.path.open("r", encoding = "UTF-8") as fd: config = yaml.load(fd, **options) if not config: - raise ValueError('Config is empty') + raise ValueError("Config is empty") - pgcfg = config.get('postgres', {}) - rdcfg = config.get('redis', {}) + pgcfg = config.get("postgres", {}) + rdcfg = config.get("redis", {}) for key in type(self).KEYS(): - if IS_DOCKER and key in {'listen', 'port', 'sq_path'}: + if IS_DOCKER and key in {"listen", "port", "sq_path"}: self.set(key, DOCKER_VALUES[key]) continue - if key.startswith('pg'): + if key.startswith("pg"): self.set(key, pgcfg.get(key[3:], NOVALUE)) continue - elif key.startswith('rd'): + elif key.startswith("rd"): self.set(key, rdcfg.get(key[3:], NOVALUE)) continue cfgkey = key - if key == 'db_type': - cfgkey = 'database_type' + if key == "db_type": + cfgkey = "database_type" - elif key == 'ca_type': - cfgkey = 'cache_type' + elif key == "ca_type": + cfgkey = "cache_type" - elif key == 'sq_path': - cfgkey = 'sqlite_path' + elif key == "sq_path": + cfgkey = "sqlite_path" self.set(key, config.get(cfgkey, NOVALUE)) @@ -186,32 +186,32 @@ class Config: data: dict[str, Any] = {} for key, value in asdict(self).items(): - if key.startswith('pg_'): - if 'postgres' not in data: - data['postgres'] = {} + if key.startswith("pg_"): + if "postgres" not in data: + data["postgres"] = {} - data['postgres'][key[3:]] = value + data["postgres"][key[3:]] = value continue - if key.startswith('rd_'): - if 'redis' not in data: - data['redis'] = {} + if key.startswith("rd_"): + if "redis" not in data: + data["redis"] = {} - data['redis'][key[3:]] = value + data["redis"][key[3:]] = value continue - if key == 'db_type': - key = 'database_type' + if key == "db_type": + key = "database_type" - elif key == 'ca_type': - key = 'cache_type' + elif key == "ca_type": + key = "cache_type" - elif key == 'sq_path': - key = 'sqlite_path' + elif key == "sq_path": + key = "sqlite_path" data[key] = value - with self.path.open('w', encoding = 'utf-8') as fd: + with self.path.open("w", encoding = "utf-8") as fd: yaml.dump(data, fd, sort_keys = False) diff --git a/relay/database/__init__.py b/relay/database/__init__.py index db5120b..17f5a40 100644 --- a/relay/database/__init__.py +++ b/relay/database/__init__.py @@ -16,17 +16,17 @@ sqlite3.register_adapter(Date, Date.timestamp) def get_database(config: Config, migrate: bool = True) -> Database[Connection]: options = { - 'connection_class': Connection, - 'pool_size': 5, - 'tables': TABLES + "connection_class": Connection, + "pool_size": 5, + "tables": TABLES } db: Database[Connection] - if config.db_type == 'sqlite': + if config.db_type == "sqlite": db = Database.sqlite(config.sqlite_path, **options) - elif config.db_type == 'postgres': + elif config.db_type == "postgres": db = Database.postgresql( config.pg_name, config.pg_host, @@ -36,30 +36,30 @@ def get_database(config: Config, migrate: bool = True) -> Database[Connection]: **options ) - db.load_prepared_statements(File.from_resource('relay', 'data/statements.sql')) + db.load_prepared_statements(File.from_resource("relay", "data/statements.sql")) db.connect() if not migrate: return db with db.session(True) as conn: - if 'config' not in conn.get_tables(): + if "config" not in conn.get_tables(): logging.info("Creating database tables") migrate_0(conn) return db - if (schema_ver := conn.get_config('schema-version')) < ConfigData.DEFAULT('schema-version'): + if (schema_ver := conn.get_config("schema-version")) < ConfigData.DEFAULT("schema-version"): logging.info("Migrating database from version '%i'", schema_ver) for ver, func in VERSIONS.items(): if schema_ver < ver: func(conn) - conn.put_config('schema-version', ver) + conn.put_config("schema-version", ver) logging.info("Updated database to %i", ver) - if (privkey := conn.get_config('private-key')): + if (privkey := conn.get_config("private-key")): conn.app.signer = privkey - logging.set_level(conn.get_config('log-level')) + logging.set_level(conn.get_config("log-level")) return db diff --git a/relay/database/config.py b/relay/database/config.py index acb6f1e..3884813 100644 --- a/relay/database/config.py +++ b/relay/database/config.py @@ -15,76 +15,76 @@ if TYPE_CHECKING: THEMES = { - 'default': { - 'text': '#DDD', - 'background': '#222', - 'primary': '#D85', - 'primary-hover': '#DA8', - 'section-background': '#333', - 'table-background': '#444', - 'border': '#444', - 'message-text': '#DDD', - 'message-background': '#335', - 'message-border': '#446', - 'error-text': '#DDD', - 'error-background': '#533', - 'error-border': '#644' + "default": { + "text": "#DDD", + "background": "#222", + "primary": "#D85", + "primary-hover": "#DA8", + "section-background": "#333", + "table-background": "#444", + "border": "#444", + "message-text": "#DDD", + "message-background": "#335", + "message-border": "#446", + "error-text": "#DDD", + "error-background": "#533", + "error-border": "#644" }, - 'pink': { - 'text': '#DDD', - 'background': '#222', - 'primary': '#D69', - 'primary-hover': '#D36', - 'section-background': '#333', - 'table-background': '#444', - 'border': '#444', - 'message-text': '#DDD', - 'message-background': '#335', - 'message-border': '#446', - 'error-text': '#DDD', - 'error-background': '#533', - 'error-border': '#644' + "pink": { + "text": "#DDD", + "background": "#222", + "primary": "#D69", + "primary-hover": "#D36", + "section-background": "#333", + "table-background": "#444", + "border": "#444", + "message-text": "#DDD", + "message-background": "#335", + "message-border": "#446", + "error-text": "#DDD", + "error-background": "#533", + "error-border": "#644" }, - 'blue': { - 'text': '#DDD', - 'background': '#222', - 'primary': '#69D', - 'primary-hover': '#36D', - 'section-background': '#333', - 'table-background': '#444', - 'border': '#444', - 'message-text': '#DDD', - 'message-background': '#335', - 'message-border': '#446', - 'error-text': '#DDD', - 'error-background': '#533', - 'error-border': '#644' + "blue": { + "text": "#DDD", + "background": "#222", + "primary": "#69D", + "primary-hover": "#36D", + "section-background": "#333", + "table-background": "#444", + "border": "#444", + "message-text": "#DDD", + "message-background": "#335", + "message-border": "#446", + "error-text": "#DDD", + "error-background": "#533", + "error-border": "#644" } } # serializer | deserializer CONFIG_CONVERT: dict[str, tuple[Callable[[Any], str], Callable[[str], Any]]] = { - 'str': (str, str), - 'int': (str, int), - 'bool': (str, convert_to_boolean), - 'logging.LogLevel': (lambda x: x.name, logging.LogLevel.parse) + "str": (str, str), + "int": (str, int), + "bool": (str, convert_to_boolean), + "logging.LogLevel": (lambda x: x.name, logging.LogLevel.parse) } @dataclass() class ConfigData: schema_version: int = 20240625 - private_key: str = '' + private_key: str = "" approval_required: bool = False log_level: logging.LogLevel = logging.LogLevel.INFO - name: str = 'ActivityRelay' - note: str = '' - theme: str = 'default' + name: str = "ActivityRelay" + note: str = "" + theme: str = "default" whitelist_enabled: bool = False def __getitem__(self, key: str) -> Any: - if (value := getattr(self, key.replace('-', '_'), None)) is None: + if (value := getattr(self, key.replace("-", "_"), None)) is None: raise KeyError(key) return value @@ -101,7 +101,7 @@ class ConfigData: @staticmethod def SYSTEM_KEYS() -> Sequence[str]: - return ('schema-version', 'schema_version', 'private-key', 'private_key') + return ("schema-version", "schema_version", "private-key", "private_key") @classmethod @@ -111,12 +111,12 @@ class ConfigData: @classmethod def DEFAULT(cls: type[Self], key: str) -> str | int | bool: - return cls.FIELD(key.replace('-', '_')).default # type: ignore[return-value] + return cls.FIELD(key.replace("-", "_")).default # type: ignore[return-value] @classmethod def FIELD(cls: type[Self], key: str) -> Field[str | int | bool]: - parsed_key = key.replace('-', '_') + parsed_key = key.replace("-", "_") for field in fields(cls): if field.name == parsed_key: @@ -131,9 +131,9 @@ class ConfigData: set_schema_version = False for row in rows: - data.set(row['key'], row['value']) + data.set(row["key"], row["value"]) - if row['key'] == 'schema-version': + if row["key"] == "schema-version": set_schema_version = True if not set_schema_version: @@ -161,4 +161,4 @@ class ConfigData: def to_dict(self) -> dict[str, Any]: - return {key.replace('_', '-'): value for key, value in asdict(self).items()} + return {key.replace("_", "-"): value for key, value in asdict(self).items()} diff --git a/relay/database/connection.py b/relay/database/connection.py index 8e94627..eb66973 100644 --- a/relay/database/connection.py +++ b/relay/database/connection.py @@ -23,16 +23,16 @@ if TYPE_CHECKING: from ..application import Application RELAY_SOFTWARE = [ - 'activityrelay', # https://git.pleroma.social/pleroma/relay - 'activity-relay', # https://github.com/yukimochi/Activity-Relay - 'aoderelay', # https://git.asonix.dog/asonix/relay - 'feditools-relay' # https://git.ptzo.gdn/feditools/relay + "activityrelay", # https://git.pleroma.social/pleroma/relay + "activity-relay", # https://github.com/yukimochi/Activity-Relay + "aoderelay", # https://git.asonix.dog/asonix/relay + "feditools-relay" # https://git.ptzo.gdn/feditools/relay ] class Connection(SqlConnection): hasher = PasswordHasher( - encoding = 'utf-8' + encoding = "utf-8" ) @property @@ -63,49 +63,49 @@ class Connection(SqlConnection): def fix_timestamps(self) -> None: - for app in self.select('apps').all(schema.App): - data = {'created': app.created.timestamp(), 'accessed': app.accessed.timestamp()} - self.update('apps', data, client_id = app.client_id) + for app in self.select("apps").all(schema.App): + data = {"created": app.created.timestamp(), "accessed": app.accessed.timestamp()} + self.update("apps", data, client_id = app.client_id) - for item in self.select('cache'): - data = {'updated': Date.parse(item['updated']).timestamp()} - self.update('cache', data, id = item['id']) + for item in self.select("cache"): + data = {"updated": Date.parse(item["updated"]).timestamp()} + self.update("cache", data, id = item["id"]) - for dban in self.select('domain_bans').all(schema.DomainBan): - data = {'created': dban.created.timestamp()} - self.update('domain_bans', data, domain = dban.domain) + for dban in self.select("domain_bans").all(schema.DomainBan): + data = {"created": dban.created.timestamp()} + self.update("domain_bans", data, domain = dban.domain) - for instance in self.select('inboxes').all(schema.Instance): - data = {'created': instance.created.timestamp()} - self.update('inboxes', data, domain = instance.domain) + for instance in self.select("inboxes").all(schema.Instance): + data = {"created": instance.created.timestamp()} + self.update("inboxes", data, domain = instance.domain) - for sban in self.select('software_bans').all(schema.SoftwareBan): - data = {'created': sban.created.timestamp()} - self.update('software_bans', data, name = sban.name) + for sban in self.select("software_bans").all(schema.SoftwareBan): + data = {"created": sban.created.timestamp()} + self.update("software_bans", data, name = sban.name) - for user in self.select('users').all(schema.User): - data = {'created': user.created.timestamp()} - self.update('users', data, username = user.username) + for user in self.select("users").all(schema.User): + data = {"created": user.created.timestamp()} + self.update("users", data, username = user.username) - for wlist in self.select('whitelist').all(schema.Whitelist): - data = {'created': wlist.created.timestamp()} - self.update('whitelist', data, domain = wlist.domain) + for wlist in self.select("whitelist").all(schema.Whitelist): + data = {"created": wlist.created.timestamp()} + self.update("whitelist", data, domain = wlist.domain) def get_config(self, key: str) -> Any: - key = key.replace('_', '-') + key = key.replace("_", "-") - with self.run('get-config', {'key': key}) as cur: + with self.run("get-config", {"key": key}) as cur: if (row := cur.one(Row)) is None: return ConfigData.DEFAULT(key) data = ConfigData() - data.set(row['key'], row['value']) + data.set(row["key"], row["value"]) return data.get(key) def get_config_all(self) -> ConfigData: - rows = tuple(self.run('get-config-all', None).all(schema.Row)) + rows = tuple(self.run("get-config-all", None).all(schema.Row)) return ConfigData.from_rows(rows) @@ -119,7 +119,7 @@ class Connection(SqlConnection): case "log_level": value = logging.LogLevel.parse(value) logging.set_level(value) - self.app['workers'].set_log_level(value) + self.app["workers"].set_log_level(value) case "approval_required": value = convert_to_boolean(value) @@ -129,25 +129,25 @@ class Connection(SqlConnection): case "theme": if value not in THEMES: - raise ValueError(f'"{value}" is not a valid theme') + raise ValueError(f"\"{value}\" is not a valid theme") data = ConfigData() data.set(key, value) params = { - 'key': key, - 'value': data.get(key, serialize = True), - 'type': 'LogLevel' if field.type == 'logging.LogLevel' else field.type # type: ignore + "key": key, + "value": data.get(key, serialize = True), + "type": "LogLevel" if field.type == "logging.LogLevel" else field.type # type: ignore } - with self.run('put-config', params): + with self.run("put-config", params): pass return data.get(key) def get_inbox(self, value: str) -> schema.Instance | None: - with self.run('get-inbox', {'value': value}) as cur: + with self.run("get-inbox", {"value": value}) as cur: return cur.one(schema.Instance) @@ -165,21 +165,21 @@ class Connection(SqlConnection): accepted: bool = True) -> schema.Instance: params: dict[str, Any] = { - 'inbox': inbox, - 'actor': actor, - 'followid': followid, - 'software': software, - 'accepted': accepted + "inbox": inbox, + "actor": actor, + "followid": followid, + "software": software, + "accepted": accepted } if self.get_inbox(domain) is None: if not inbox: raise ValueError("Missing inbox") - params['domain'] = domain - params['created'] = datetime.now(tz = timezone.utc) + params["domain"] = domain + params["created"] = datetime.now(tz = timezone.utc) - with self.run('put-inbox', params) as cur: + with self.run("put-inbox", params) as cur: if (row := cur.one(schema.Instance)) is None: raise RuntimeError(f"Failed to insert instance: {domain}") @@ -189,7 +189,7 @@ class Connection(SqlConnection): if value is None: del params[key] - with self.update('inboxes', params, domain = domain) as cur: + with self.update("inboxes", params, domain = domain) as cur: if (row := cur.one(schema.Instance)) is None: raise RuntimeError(f"Failed to update instance: {domain}") @@ -197,20 +197,20 @@ class Connection(SqlConnection): def del_inbox(self, value: str) -> bool: - with self.run('del-inbox', {'value': value}) as cur: + with self.run("del-inbox", {"value": value}) as cur: if cur.row_count > 1: - raise ValueError('More than one row was modified') + raise ValueError("More than one row was modified") return cur.row_count == 1 def get_request(self, domain: str) -> schema.Instance | None: - with self.run('get-request', {'domain': domain}) as cur: + with self.run("get-request", {"domain": domain}) as cur: return cur.one(schema.Instance) def get_requests(self) -> Iterator[schema.Instance]: - return self.execute('SELECT * FROM inboxes WHERE accepted = false').all(schema.Instance) + return self.execute("SELECT * FROM inboxes WHERE accepted = false").all(schema.Instance) def put_request_response(self, domain: str, accepted: bool) -> schema.Instance: @@ -219,16 +219,16 @@ class Connection(SqlConnection): if not accepted: if not self.del_inbox(domain): - raise RuntimeError(f'Failed to delete request: {domain}') + raise RuntimeError(f"Failed to delete request: {domain}") return instance params = { - 'domain': domain, - 'accepted': accepted + "domain": domain, + "accepted": accepted } - with self.run('put-inbox-accept', params) as cur: + with self.run("put-inbox-accept", params) as cur: if (row := cur.one(schema.Instance)) is None: raise RuntimeError(f"Failed to insert response for domain: {domain}") @@ -236,12 +236,12 @@ class Connection(SqlConnection): def get_user(self, value: str) -> schema.User | None: - with self.run('get-user', {'value': value}) as cur: + with self.run("get-user", {"value": value}) as cur: return cur.one(schema.User) def get_user_by_token(self, token: str) -> schema.User | None: - with self.run('get-user-by-token', {'token': token}) as cur: + with self.run("get-user-by-token", {"token": token}) as cur: return cur.one(schema.User) @@ -254,10 +254,10 @@ class Connection(SqlConnection): data: dict[str, str | datetime | None] = {} if password: - data['hash'] = self.hasher.hash(password) + data["hash"] = self.hasher.hash(password) if handle: - data['handle'] = handle + data["handle"] = handle stmt = Update("users", data) stmt.set_where("username", username) @@ -269,16 +269,16 @@ class Connection(SqlConnection): return row if password is None: - raise ValueError('Password cannot be empty') + raise ValueError("Password cannot be empty") data = { - 'username': username, - 'hash': self.hasher.hash(password), - 'handle': handle, - 'created': datetime.now(tz = timezone.utc) + "username": username, + "hash": self.hasher.hash(password), + "handle": handle, + "created": datetime.now(tz = timezone.utc) } - with self.run('put-user', data) as cur: + with self.run("put-user", data) as cur: if (row := cur.one(schema.User)) is None: raise RuntimeError(f"Failed to insert user: {username}") @@ -289,10 +289,10 @@ class Connection(SqlConnection): if (user := self.get_user(username)) is None: raise KeyError(username) - with self.run('del-token-user', {'username': user.username}): + with self.run("del-token-user", {"username": user.username}): pass - with self.run('del-user', {'username': user.username}): + with self.run("del-user", {"username": user.username}): pass @@ -302,61 +302,61 @@ class Connection(SqlConnection): token: str | None = None) -> schema.App | None: params = { - 'id': client_id, - 'secret': client_secret + "id": client_id, + "secret": client_secret } if token is not None: - command = 'get-app-with-token' - params['token'] = token + command = "get-app-with-token" + params["token"] = token else: - command = 'get-app' + command = "get-app" with self.run(command, params) as cur: return cur.one(schema.App) def get_app_by_token(self, token: str) -> schema.App | None: - with self.run('get-app-by-token', {'token': token}) as cur: + with self.run("get-app-by-token", {"token": token}) as cur: return cur.one(schema.App) def put_app(self, name: str, redirect_uri: str, website: str | None = None) -> schema.App: params = { - 'name': name, - 'redirect_uri': redirect_uri, - 'website': website, - 'client_id': secrets.token_hex(20), - 'client_secret': secrets.token_hex(20), - 'created': Date.new_utc(), - 'accessed': Date.new_utc() + "name": name, + "redirect_uri": redirect_uri, + "website": website, + "client_id": secrets.token_hex(20), + "client_secret": secrets.token_hex(20), + "created": Date.new_utc(), + "accessed": Date.new_utc() } - with self.insert('apps', params) as cur: + with self.insert("apps", params) as cur: if (row := cur.one(schema.App)) is None: - raise RuntimeError(f'Failed to insert app: {name}') + raise RuntimeError(f"Failed to insert app: {name}") return row def put_app_login(self, user: schema.User) -> schema.App: params = { - 'name': 'Web', - 'redirect_uri': 'urn:ietf:wg:oauth:2.0:oob', - 'website': None, - 'user': user.username, - 'client_id': secrets.token_hex(20), - 'client_secret': secrets.token_hex(20), - 'auth_code': None, - 'token': secrets.token_hex(20), - 'created': Date.new_utc(), - 'accessed': Date.new_utc() + "name": "Web", + "redirect_uri": "urn:ietf:wg:oauth:2.0:oob", + "website": None, + "user": user.username, + "client_id": secrets.token_hex(20), + "client_secret": secrets.token_hex(20), + "auth_code": None, + "token": secrets.token_hex(20), + "created": Date.new_utc(), + "accessed": Date.new_utc() } - with self.insert('apps', params) as cur: + with self.insert("apps", params) as cur: if (row := cur.one(schema.App)) is None: - raise RuntimeError(f'Failed to create app for "{user.username}"') + raise RuntimeError(f"Failed to create app for \"{user.username}\"") return row @@ -365,52 +365,52 @@ class Connection(SqlConnection): data: dict[str, str | None] = {} if user is not None: - data['user'] = user.username + data["user"] = user.username if set_auth: - data['auth_code'] = secrets.token_hex(20) + data["auth_code"] = secrets.token_hex(20) else: - data['token'] = secrets.token_hex(20) - data['auth_code'] = None + data["token"] = secrets.token_hex(20) + data["auth_code"] = None params = { - 'client_id': app.client_id, - 'client_secret': app.client_secret + "client_id": app.client_id, + "client_secret": app.client_secret } - with self.update('apps', data, **params) as cur: # type: ignore[arg-type] + with self.update("apps", data, **params) as cur: # type: ignore[arg-type] if (row := cur.one(schema.App)) is None: - raise RuntimeError('Failed to update row') + raise RuntimeError("Failed to update row") return row def del_app(self, client_id: str, client_secret: str, token: str | None = None) -> bool: params = { - 'id': client_id, - 'secret': client_secret + "id": client_id, + "secret": client_secret } if token is not None: - command = 'del-app-with-token' - params['token'] = token + command = "del-app-with-token" + params["token"] = token else: - command = 'del-app' + command = "del-app" with self.run(command, params) as cur: if cur.row_count > 1: - raise RuntimeError('More than 1 row was deleted') + raise RuntimeError("More than 1 row was deleted") return cur.row_count == 0 def get_domain_ban(self, domain: str) -> schema.DomainBan | None: - if domain.startswith('http'): + if domain.startswith("http"): domain = urlparse(domain).netloc - with self.run('get-domain-ban', {'domain': domain}) as cur: + with self.run("get-domain-ban", {"domain": domain}) as cur: return cur.one(schema.DomainBan) @@ -424,13 +424,13 @@ class Connection(SqlConnection): note: str | None = None) -> schema.DomainBan: params = { - 'domain': domain, - 'reason': reason, - 'note': note, - 'created': datetime.now(tz = timezone.utc) + "domain": domain, + "reason": reason, + "note": note, + "created": datetime.now(tz = timezone.utc) } - with self.run('put-domain-ban', params) as cur: + with self.run("put-domain-ban", params) as cur: if (row := cur.one(schema.DomainBan)) is None: raise RuntimeError(f"Failed to insert domain ban: {domain}") @@ -443,22 +443,22 @@ class Connection(SqlConnection): note: str | None = None) -> schema.DomainBan: if not (reason or note): - raise ValueError('"reason" and/or "note" must be specified') + raise ValueError("\"reason\" and/or \"note\" must be specified") params = {} if reason is not None: - params['reason'] = reason + params["reason"] = reason if note is not None: - params['note'] = note + params["note"] = note - statement = Update('domain_bans', params) + statement = Update("domain_bans", params) statement.set_where("domain", domain) with self.query(statement) as cur: if cur.row_count > 1: - raise ValueError('More than one row was modified') + raise ValueError("More than one row was modified") if (row := cur.one(schema.DomainBan)) is None: raise RuntimeError(f"Failed to update domain ban: {domain}") @@ -467,20 +467,20 @@ class Connection(SqlConnection): def del_domain_ban(self, domain: str) -> bool: - with self.run('del-domain-ban', {'domain': domain}) as cur: + with self.run("del-domain-ban", {"domain": domain}) as cur: if cur.row_count > 1: - raise ValueError('More than one row was modified') + raise ValueError("More than one row was modified") return cur.row_count == 1 def get_software_ban(self, name: str) -> schema.SoftwareBan | None: - with self.run('get-software-ban', {'name': name}) as cur: + with self.run("get-software-ban", {"name": name}) as cur: return cur.one(schema.SoftwareBan) def get_software_bans(self) -> Iterator[schema.SoftwareBan,]: - return self.execute('SELECT * FROM software_bans').all(schema.SoftwareBan) + return self.execute("SELECT * FROM software_bans").all(schema.SoftwareBan) def put_software_ban(self, @@ -489,15 +489,15 @@ class Connection(SqlConnection): note: str | None = None) -> schema.SoftwareBan: params = { - 'name': name, - 'reason': reason, - 'note': note, - 'created': datetime.now(tz = timezone.utc) + "name": name, + "reason": reason, + "note": note, + "created": datetime.now(tz = timezone.utc) } - with self.run('put-software-ban', params) as cur: + with self.run("put-software-ban", params) as cur: if (row := cur.one(schema.SoftwareBan)) is None: - raise RuntimeError(f'Failed to insert software ban: {name}') + raise RuntimeError(f"Failed to insert software ban: {name}") return row @@ -508,39 +508,39 @@ class Connection(SqlConnection): note: str | None = None) -> schema.SoftwareBan: if not (reason or note): - raise ValueError('"reason" and/or "note" must be specified') + raise ValueError("\"reason\" and/or \"note\" must be specified") params = {} if reason is not None: - params['reason'] = reason + params["reason"] = reason if note is not None: - params['note'] = note + params["note"] = note - statement = Update('software_bans', params) + statement = Update("software_bans", params) statement.set_where("name", name) with self.query(statement) as cur: if cur.row_count > 1: - raise ValueError('More than one row was modified') + raise ValueError("More than one row was modified") if (row := cur.one(schema.SoftwareBan)) is None: - raise RuntimeError(f'Failed to update software ban: {name}') + raise RuntimeError(f"Failed to update software ban: {name}") return row def del_software_ban(self, name: str) -> bool: - with self.run('del-software-ban', {'name': name}) as cur: + with self.run("del-software-ban", {"name": name}) as cur: if cur.row_count > 1: - raise ValueError('More than one row was modified') + raise ValueError("More than one row was modified") return cur.row_count == 1 def get_domain_whitelist(self, domain: str) -> schema.Whitelist | None: - with self.run('get-domain-whitelist', {'domain': domain}) as cur: + with self.run("get-domain-whitelist", {"domain": domain}) as cur: return cur.one() @@ -550,20 +550,20 @@ class Connection(SqlConnection): def put_domain_whitelist(self, domain: str) -> schema.Whitelist: params = { - 'domain': domain, - 'created': datetime.now(tz = timezone.utc) + "domain": domain, + "created": datetime.now(tz = timezone.utc) } - with self.run('put-domain-whitelist', params) as cur: + with self.run("put-domain-whitelist", params) as cur: if (row := cur.one(schema.Whitelist)) is None: - raise RuntimeError(f'Failed to insert whitelisted domain: {domain}') + raise RuntimeError(f"Failed to insert whitelisted domain: {domain}") return row def del_domain_whitelist(self, domain: str) -> bool: - with self.run('del-domain-whitelist', {'domain': domain}) as cur: + with self.run("del-domain-whitelist", {"domain": domain}) as cur: if cur.row_count > 1: - raise ValueError('More than one row was modified') + raise ValueError("More than one row was modified") return cur.row_count == 1 diff --git a/relay/database/schema.py b/relay/database/schema.py index 55ca608..4b06aad 100644 --- a/relay/database/schema.py +++ b/relay/database/schema.py @@ -32,113 +32,113 @@ def deserialize_timestamp(value: Any) -> Date: @TABLES.add_row class Config(Row): - key: Column[str] = Column('key', 'text', primary_key = True, unique = True, nullable = False) - value: Column[str] = Column('value', 'text') - type: Column[str] = Column('type', 'text', default = 'str') + key: Column[str] = Column("key", "text", primary_key = True, unique = True, nullable = False) + value: Column[str] = Column("value", "text") + type: Column[str] = Column("type", "text", default = "str") @TABLES.add_row class Instance(Row): - table_name: str = 'inboxes' + table_name: str = "inboxes" domain: Column[str] = Column( - 'domain', 'text', primary_key = True, unique = True, nullable = False) - actor: Column[str] = Column('actor', 'text', unique = True) - inbox: Column[str] = Column('inbox', 'text', unique = True, nullable = False) - followid: Column[str] = Column('followid', 'text') - software: Column[str] = Column('software', 'text') - accepted: Column[Date] = Column('accepted', 'boolean') + "domain", "text", primary_key = True, unique = True, nullable = False) + actor: Column[str] = Column("actor", "text", unique = True) + inbox: Column[str] = Column("inbox", "text", unique = True, nullable = False) + followid: Column[str] = Column("followid", "text") + software: Column[str] = Column("software", "text") + accepted: Column[Date] = Column("accepted", "boolean") created: Column[Date] = Column( - 'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp) + "created", "timestamp", nullable = False, deserializer = deserialize_timestamp) @TABLES.add_row class Whitelist(Row): domain: Column[str] = Column( - 'domain', 'text', primary_key = True, unique = True, nullable = True) + "domain", "text", primary_key = True, unique = True, nullable = True) created: Column[Date] = Column( - 'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp) + "created", "timestamp", nullable = False, deserializer = deserialize_timestamp) @TABLES.add_row class DomainBan(Row): - table_name: str = 'domain_bans' + table_name: str = "domain_bans" domain: Column[str] = Column( - 'domain', 'text', primary_key = True, unique = True, nullable = True) - reason: Column[str] = Column('reason', 'text') - note: Column[str] = Column('note', 'text') + "domain", "text", primary_key = True, unique = True, nullable = True) + reason: Column[str] = Column("reason", "text") + note: Column[str] = Column("note", "text") created: Column[Date] = Column( - 'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp) + "created", "timestamp", nullable = False, deserializer = deserialize_timestamp) @TABLES.add_row class SoftwareBan(Row): - table_name: str = 'software_bans' + table_name: str = "software_bans" - name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True) - reason: Column[str] = Column('reason', 'text') - note: Column[str] = Column('note', 'text') + name: Column[str] = Column("name", "text", primary_key = True, unique = True, nullable = True) + reason: Column[str] = Column("reason", "text") + note: Column[str] = Column("note", "text") created: Column[Date] = Column( - 'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp) + "created", "timestamp", nullable = False, deserializer = deserialize_timestamp) @TABLES.add_row class User(Row): - table_name: str = 'users' + table_name: str = "users" username: Column[str] = Column( - 'username', 'text', primary_key = True, unique = True, nullable = False) - hash: Column[str] = Column('hash', 'text', nullable = False) - handle: Column[str] = Column('handle', 'text') + "username", "text", primary_key = True, unique = True, nullable = False) + hash: Column[str] = Column("hash", "text", nullable = False) + handle: Column[str] = Column("handle", "text") created: Column[Date] = Column( - 'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp) + "created", "timestamp", nullable = False, deserializer = deserialize_timestamp) @TABLES.add_row class App(Row): - table_name: str = 'apps' + table_name: str = "apps" client_id: Column[str] = Column( - 'client_id', 'text', primary_key = True, unique = True, nullable = False) - client_secret: Column[str] = Column('client_secret', 'text', nullable = False) - name: Column[str] = Column('name', 'text') - website: Column[str] = Column('website', 'text') - redirect_uri: Column[str] = Column('redirect_uri', 'text', nullable = False) - token: Column[str | None] = Column('token', 'text') - auth_code: Column[str | None] = Column('auth_code', 'text') - user: Column[str | None] = Column('user', 'text') + "client_id", "text", primary_key = True, unique = True, nullable = False) + client_secret: Column[str] = Column("client_secret", "text", nullable = False) + name: Column[str] = Column("name", "text") + website: Column[str] = Column("website", "text") + redirect_uri: Column[str] = Column("redirect_uri", "text", nullable = False) + token: Column[str | None] = Column("token", "text") + auth_code: Column[str | None] = Column("auth_code", "text") + user: Column[str | None] = Column("user", "text") created: Column[Date] = Column( - 'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp) + "created", "timestamp", nullable = False, deserializer = deserialize_timestamp) accessed: Column[Date] = Column( - 'accessed', 'timestamp', nullable = False, deserializer = deserialize_timestamp) + "accessed", "timestamp", nullable = False, deserializer = deserialize_timestamp) def get_api_data(self, include_token: bool = False) -> dict[str, Any]: data = deepcopy(self) - data.pop('user') - data.pop('auth_code') + data.pop("user") + data.pop("auth_code") if not include_token: - data.pop('token') + data.pop("token") return data def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]: - ver = int(func.__name__.replace('migrate_', '')) + ver = int(func.__name__.replace("migrate_", "")) VERSIONS[ver] = func return func def migrate_0(conn: Connection) -> None: conn.create_tables() - conn.put_config('schema-version', ConfigData.DEFAULT('schema-version')) + conn.put_config("schema-version", ConfigData.DEFAULT("schema-version")) @migration @@ -148,11 +148,11 @@ def migrate_20240206(conn: Connection) -> None: @migration def migrate_20240310(conn: Connection) -> None: - conn.execute('ALTER TABLE "inboxes" ADD COLUMN "accepted" BOOLEAN').close() - conn.execute('UPDATE "inboxes" SET "accepted" = true').close() + conn.execute("ALTER TABLE \"inboxes\" ADD COLUMN \"accepted\" BOOLEAN").close() + conn.execute("UPDATE \"inboxes\" SET \"accepted\" = true").close() @migration def migrate_20240625(conn: Connection) -> None: conn.create_tables() - conn.execute('DROP TABLE "tokens"').close() + conn.execute("DROP TABLE \"tokens\"").close() diff --git a/relay/http_client.py b/relay/http_client.py index 8548dcc..3c707bc 100644 --- a/relay/http_client.py +++ b/relay/http_client.py @@ -17,6 +17,13 @@ if TYPE_CHECKING: from .application import Application +T = TypeVar("T", bound = JsonBase[Any]) + +HEADERS = { + "Accept": f"{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9", + "User-Agent": f"ActivityRelay/{__version__}" +} + SUPPORTS_HS2019 = { 'friendica', 'gotosocial', @@ -32,12 +39,6 @@ SUPPORTS_HS2019 = { 'sharkey' } -T = TypeVar('T', bound = JsonBase[Any]) -HEADERS = { - 'Accept': f'{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9', - 'User-Agent': f'ActivityRelay/{__version__}' -} - class HttpClient: def __init__(self, limit: int = 100, timeout: int = 10): @@ -106,25 +107,25 @@ class HttpClient: old_algo: bool) -> str | None: if not self._session: - raise RuntimeError('Client not open') + raise RuntimeError("Client not open") url = url.split("#", 1)[0] if not force: try: - if not (item := self.cache.get('request', url)).older_than(48): + if not (item := self.cache.get("request", url)).older_than(48): return item.value # type: ignore [no-any-return] except KeyError: - logging.verbose('No cached data for url: %s', url) + logging.verbose("No cached data for url: %s", url) headers = {} if sign_headers: algo = AlgorithmType.RSASHA256 if old_algo else AlgorithmType.HS2019 - headers = self.signer.sign_headers('GET', url, algorithm = algo) + headers = self.signer.sign_headers("GET", url, algorithm = algo) - logging.debug('Fetching resource: %s', url) + logging.debug("Fetching resource: %s", url) async with self._session.get(url, headers = headers) as resp: # Not expecting a response with 202s, so just return @@ -142,7 +143,7 @@ class HttpClient: raise HttpError(resp.status, error) - self.cache.set('request', url, data, 'str') + self.cache.set("request", url, data, "str") return data @@ -172,13 +173,13 @@ class HttpClient: old_algo: bool = True) -> T | str | None: if cls is not None and not issubclass(cls, JsonBase): - raise TypeError('cls must be a sub-class of "blib.JsonBase"') + raise TypeError("cls must be a sub-class of \"blib.JsonBase\"") data = await self._get(url, sign_headers, force, old_algo) if cls is not None: if data is None: - # this shouldn't actually get raised, but keeping just in case + # this shouldn"t actually get raised, but keeping just in case raise EmptyBodyError(f"GET {url}") return cls.parse(data) @@ -188,7 +189,7 @@ class HttpClient: async def post(self, url: str, data: Message | bytes, instance: Instance | None = None) -> None: if not self._session: - raise RuntimeError('Client not open') + raise RuntimeError("Client not open") # akkoma and pleroma do not support HS2019 and other software still needs to be tested if instance is not None and instance.software in SUPPORTS_HS2019: @@ -210,14 +211,14 @@ class HttpClient: mtype = message.type.value if isinstance(message.type, ObjectType) else message.type headers = self.signer.sign_headers( - 'POST', + "POST", url, body, - headers = {'Content-Type': 'application/activity+json'}, + headers = {"Content-Type": "application/activity+json"}, algorithm = algorithm ) - logging.verbose('Sending "%s" to %s', mtype, url) + logging.verbose("Sending \"%s\" to %s", mtype, url) async with self._session.post(url, headers = headers, data = body) as resp: if resp.status not in (200, 202): @@ -231,10 +232,10 @@ class HttpClient: async def fetch_nodeinfo(self, domain: str, force: bool = False) -> Nodeinfo: nodeinfo_url = None wk_nodeinfo = await self.get( - f'https://{domain}/.well-known/nodeinfo', False, WellKnownNodeinfo, force + f"https://{domain}/.well-known/nodeinfo", False, WellKnownNodeinfo, force ) - for version in ('20', '21'): + for version in ("20", "21"): try: nodeinfo_url = wk_nodeinfo.get_url(version) @@ -242,7 +243,7 @@ class HttpClient: pass if nodeinfo_url is None: - raise ValueError(f'Failed to fetch nodeinfo url for {domain}') + raise ValueError(f"Failed to fetch nodeinfo url for {domain}") return await self.get(nodeinfo_url, False, Nodeinfo, force) diff --git a/relay/logger.py b/relay/logger.py index 7caac9f..bc55fa6 100644 --- a/relay/logger.py +++ b/relay/logger.py @@ -54,7 +54,7 @@ class LogLevel(IntEnum): except ValueError: pass - raise AttributeError(f'Invalid enum property for {cls.__name__}: {data}') + raise AttributeError(f"Invalid enum property for {cls.__name__}: {data}") def get_level() -> LogLevel: @@ -80,7 +80,7 @@ critical: LoggingMethod = logging.critical try: - env_log_file: Path | None = Path(os.environ['LOG_FILE']).expanduser().resolve() + env_log_file: Path | None = Path(os.environ["LOG_FILE"]).expanduser().resolve() except KeyError: env_log_file = None @@ -90,16 +90,16 @@ handlers: list[Any] = [logging.StreamHandler()] if env_log_file: handlers.append(logging.FileHandler(env_log_file)) -if os.environ.get('IS_SYSTEMD'): - logging_format = '%(levelname)s: %(message)s' +if os.environ.get("IS_SYSTEMD"): + logging_format = "%(levelname)s: %(message)s" else: - logging_format = '[%(asctime)s] %(levelname)s: %(message)s' + logging_format = "[%(asctime)s] %(levelname)s: %(message)s" -logging.addLevelName(LogLevel.VERBOSE, 'VERBOSE') +logging.addLevelName(LogLevel.VERBOSE, "VERBOSE") logging.basicConfig( level = LogLevel.INFO, format = logging_format, - datefmt = '%Y-%m-%d %H:%M:%S', + datefmt = "%Y-%m-%d %H:%M:%S", handlers = handlers ) diff --git a/relay/manage.py b/relay/manage.py index 7322d44..0b17712 100644 --- a/relay/manage.py +++ b/relay/manage.py @@ -4,6 +4,7 @@ import aputils import asyncio import click import json +import multiprocessing import os from pathlib import Path @@ -24,26 +25,26 @@ from .views import ROUTES def check_alphanumeric(text: str) -> str: if not text.isalnum(): - raise click.BadParameter('String not alphanumeric') + raise click.BadParameter("String not alphanumeric") return text -@click.group('cli', context_settings = {'show_default': True}) -@click.option('--config', '-c', type = Path, help = 'path to the relay\'s config') -@click.version_option(version = __version__, prog_name = 'ActivityRelay') +@click.group("cli", context_settings = {"show_default": True}) +@click.option("--config", "-c", type = Path, help = "path to the relay config") +@click.version_option(version = __version__, prog_name = "ActivityRelay") @click.pass_context def cli(ctx: click.Context, config: Path | None) -> None: if IS_DOCKER: config = Path("/data/relay.yaml") - # The database was named "relay.jsonld" even though it's an sqlite file. Fix it. - db = Path('/data/relay.sqlite3') - wrongdb = Path('/data/relay.jsonld') + # The database was named "relay.jsonld" even though it"s an sqlite file. Fix it. + db = Path("/data/relay.sqlite3") + wrongdb = Path("/data/relay.jsonld") if wrongdb.exists() and not db.exists(): try: - with wrongdb.open('rb') as fd: + with wrongdb.open("rb") as fd: json.load(fd) except json.JSONDecodeError: @@ -52,97 +53,97 @@ def cli(ctx: click.Context, config: Path | None) -> None: ctx.obj = Application(config) -@cli.command('setup') -@click.option('--skip-questions', '-s', is_flag = True, help = 'Just setup the database') +@cli.command("setup") +@click.option("--skip-questions", "-s", is_flag = True, help = "Just setup the database") @click.pass_context def cli_setup(ctx: click.Context, skip_questions: bool) -> None: - 'Generate a new config and create the database' + "Generate a new config and create the database" if ctx.obj.signer is not None: - if not click.prompt('The database is already setup. Are you sure you want to continue?'): + if not click.prompt("The database is already setup. Are you sure you want to continue?"): return - if skip_questions and ctx.obj.config.domain.endswith('example.com'): - click.echo('You cannot skip the questions if the relay is not configured yet') + if skip_questions and ctx.obj.config.domain.endswith("example.com"): + click.echo("You cannot skip the questions if the relay is not configured yet") return if not skip_questions: while True: ctx.obj.config.domain = click.prompt( - 'What domain will the relay be hosted on?', + "What domain will the relay be hosted on?", default = ctx.obj.config.domain ) - if not ctx.obj.config.domain.endswith('example.com'): + if not ctx.obj.config.domain.endswith("example.com"): break - click.echo('The domain must not end with "example.com"') + click.echo("The domain must not end with \"example.com\"") if not IS_DOCKER: ctx.obj.config.listen = click.prompt( - 'Which address should the relay listen on?', + "Which address should the relay listen on?", default = ctx.obj.config.listen ) ctx.obj.config.port = click.prompt( - 'What TCP port should the relay listen on?', + "What TCP port should the relay listen on?", default = ctx.obj.config.port, type = int ) ctx.obj.config.db_type = click.prompt( - 'Which database backend will be used?', + "Which database backend will be used?", default = ctx.obj.config.db_type, - type = click.Choice(['postgres', 'sqlite'], case_sensitive = False) + type = click.Choice(["postgres", "sqlite"], case_sensitive = False) ) - if ctx.obj.config.db_type == 'sqlite' and not IS_DOCKER: + if ctx.obj.config.db_type == "sqlite" and not IS_DOCKER: ctx.obj.config.sq_path = click.prompt( - 'Where should the database be stored?', + "Where should the database be stored?", default = ctx.obj.config.sq_path ) - elif ctx.obj.config.db_type == 'postgres': + elif ctx.obj.config.db_type == "postgres": config_postgresql(ctx.obj.config) ctx.obj.config.ca_type = click.prompt( - 'Which caching backend?', + "Which caching backend?", default = ctx.obj.config.ca_type, - type = click.Choice(['database', 'redis'], case_sensitive = False) + type = click.Choice(["database", "redis"], case_sensitive = False) ) - if ctx.obj.config.ca_type == 'redis': + if ctx.obj.config.ca_type == "redis": ctx.obj.config.rd_host = click.prompt( - 'What IP address, hostname, or unix socket does the server listen on?', + "What IP address, hostname, or unix socket does the server listen on?", default = ctx.obj.config.rd_host ) ctx.obj.config.rd_port = click.prompt( - 'What port does the server listen on?', + "What port does the server listen on?", default = ctx.obj.config.rd_port, type = int ) ctx.obj.config.rd_user = click.prompt( - 'Which user will authenticate with the server', + "Which user will authenticate with the server", default = ctx.obj.config.rd_user ) ctx.obj.config.rd_pass = click.prompt( - 'User password', + "User password", hide_input = True, show_default = False, default = ctx.obj.config.rd_pass or "" ) or None ctx.obj.config.rd_database = click.prompt( - 'Which database number to use?', + "Which database number to use?", default = ctx.obj.config.rd_database, type = int ) ctx.obj.config.rd_prefix = click.prompt( - 'What text should each cache key be prefixed with?', + "What text should each cache key be prefixed with?", default = ctx.obj.config.rd_database, type = check_alphanumeric ) @@ -150,7 +151,7 @@ def cli_setup(ctx: click.Context, skip_questions: bool) -> None: ctx.obj.config.save() config = { - 'private-key': aputils.Signer.new('n/a').export() + "private-key": aputils.Signer.new("n/a").export() } with ctx.obj.database.session() as conn: @@ -161,19 +162,19 @@ def cli_setup(ctx: click.Context, skip_questions: bool) -> None: click.echo("Relay all setup! Start the container to run the relay.") return - if click.confirm('Relay all setup! Would you like to run it now?'): + if click.confirm("Relay all setup! Would you like to run it now?"): cli_run.callback() # type: ignore -@cli.command('run') -@click.option('--dev', '-d', is_flag=True, help='Enable developer mode') +@cli.command("run") +@click.option("--dev", "-d", is_flag=True, help="Enable developer mode") @click.pass_context def cli_run(ctx: click.Context, dev: bool = False) -> None: - 'Run the relay' + "Run the relay" - if ctx.obj.config.domain.endswith('example.com') or ctx.obj.signer is None: + if ctx.obj.config.domain.endswith("example.com") or ctx.obj.signer is None: if not IS_DOCKER: - click.echo('Relay is not set up. Please run "activityrelay setup".') + click.echo("Relay is not set up. Please run \"activityrelay setup\"") return @@ -183,17 +184,17 @@ def cli_run(ctx: click.Context, dev: bool = False) -> None: for method, path, handler in ROUTES: ctx.obj.router.add_route(method, path, handler) - ctx.obj['dev'] = dev + ctx.obj["dev"] = dev ctx.obj.run() - # todo: figure out why the relay doesn't quit properly without this + # todo: figure out why the relay doesn"t quit properly without this os._exit(0) -@cli.command('db-maintenance') +@cli.command("db-maintenance") @click.pass_context def cli_db_maintenance(ctx: click.Context) -> None: - 'Perform maintenance tasks on the database' + "Perform maintenance tasks on the database" if ctx.obj.config.db_type == "postgres": return @@ -206,17 +207,17 @@ def cli_db_maintenance(ctx: click.Context) -> None: pass -@cli.command('convert') -@click.option('--old-config', '-o', help = 'Path to the config file to convert from') +@cli.command("convert") +@click.option("--old-config", "-o", help = "Path to the config file to convert from") @click.pass_context def cli_convert(ctx: click.Context, old_config: str) -> None: - 'Convert an old config and jsonld database to the new format.' + "Convert an old config and jsonld database to the new format." old_config = Path(old_config).expanduser().resolve() if old_config else ctx.obj.config.path - backup = ctx.obj.config.path.parent.joinpath(f'{ctx.obj.config.path.stem}.backup.yaml') + backup = ctx.obj.config.path.parent.joinpath(f"{ctx.obj.config.path.stem}.backup.yaml") if str(old_config) == str(ctx.obj.config.path) and not backup.exists(): - logging.info('Created backup config @ %s', backup) + logging.info("Created backup config @ %s", backup) copyfile(ctx.obj.config.path, backup) config = RelayConfig(old_config) @@ -225,58 +226,58 @@ def cli_convert(ctx: click.Context, old_config: str) -> None: database = RelayDatabase(config) database.load() - ctx.obj.config.set('listen', config['listen']) - ctx.obj.config.set('port', config['port']) - ctx.obj.config.set('workers', config['workers']) - ctx.obj.config.set('sq_path', config['db'].replace('jsonld', 'sqlite3')) - ctx.obj.config.set('domain', config['host']) + ctx.obj.config.set("listen", config["listen"]) + ctx.obj.config.set("port", config["port"]) + ctx.obj.config.set("workers", config["workers"]) + ctx.obj.config.set("sq_path", config["db"].replace("jsonld", "sqlite3")) + ctx.obj.config.set("domain", config["host"]) ctx.obj.config.save() # fix: mypy complains about the types returned by click.progressbar when updating click to 8.1.7 with get_database(ctx.obj.config) as db: with db.session(True) as conn: - conn.put_config('private-key', database['private-key']) - conn.put_config('note', config['note']) - conn.put_config('whitelist-enabled', config['whitelist_enabled']) + conn.put_config("private-key", database["private-key"]) + conn.put_config("note", config["note"]) + conn.put_config("whitelist-enabled", config["whitelist_enabled"]) with click.progressbar( - database['relay-list'].values(), - label = 'Inboxes'.ljust(15), + database["relay-list"].values(), + label = "Inboxes".ljust(15), width = 0 ) as inboxes: for inbox in inboxes: - if inbox['software'] in {'akkoma', 'pleroma'}: - actor = f'https://{inbox["domain"]}/relay' + if inbox["software"] in {"akkoma", "pleroma"}: + actor = f"https://{inbox["domain"]}/relay" - elif inbox['software'] == 'mastodon': - actor = f'https://{inbox["domain"]}/actor' + elif inbox["software"] == "mastodon": + actor = f"https://{inbox["domain"]}/actor" else: actor = None conn.put_inbox( - inbox['domain'], - inbox['inbox'], + inbox["domain"], + inbox["inbox"], actor = actor, - followid = inbox['followid'], - software = inbox['software'] + followid = inbox["followid"], + software = inbox["software"] ) with click.progressbar( - config['blocked_software'], - label = 'Banned software'.ljust(15), + config["blocked_software"], + label = "Banned software".ljust(15), width = 0 ) as banned_software: for software in banned_software: conn.put_software_ban( software, - reason = 'relay' if software in RELAY_SOFTWARE else None + reason = "relay" if software in RELAY_SOFTWARE else None ) with click.progressbar( - config['blocked_instances'], - label = 'Banned domains'.ljust(15), + config["blocked_instances"], + label = "Banned domains".ljust(15), width = 0 ) as banned_software: @@ -284,22 +285,22 @@ def cli_convert(ctx: click.Context, old_config: str) -> None: conn.put_domain_ban(domain) with click.progressbar( - config['whitelist'], - label = 'Whitelist'.ljust(15), + config["whitelist"], + label = "Whitelist".ljust(15), width = 0 ) as whitelist: for instance in whitelist: conn.put_domain_whitelist(instance) - click.echo('Finished converting old config and database :3') + click.echo("Finished converting old config and database :3") -@cli.command('edit-config') -@click.option('--editor', '-e', help = 'Text editor to use') +@cli.command("edit-config") +@click.option("--editor", "-e", help = "Text editor to use") @click.pass_context def cli_editconfig(ctx: click.Context, editor: str) -> None: - 'Edit the config file' + "Edit the config file" click.edit( editor = editor, @@ -307,7 +308,7 @@ def cli_editconfig(ctx: click.Context, editor: str) -> None: ) -@cli.command('switch-backend') +@cli.command("switch-backend") @click.pass_context def cli_switchbackend(ctx: click.Context) -> None: """ @@ -349,17 +350,17 @@ def cli_switchbackend(ctx: click.Context) -> None: click.echo("Done!") -@cli.group('config') +@cli.group("config") def cli_config() -> None: - 'Manage the relay settings stored in the database' + "Manage the relay settings stored in the database" -@cli_config.command('list') +@cli_config.command("list") @click.pass_context def cli_config_list(ctx: click.Context) -> None: - 'List the current relay config' + "List the current relay config" - click.echo('Relay Config:') + click.echo("Relay Config:") with ctx.obj.database.session() as conn: config = conn.get_config_all() @@ -368,176 +369,176 @@ def cli_config_list(ctx: click.Context) -> None: if key in type(config).SYSTEM_KEYS(): continue - if key == 'log-level': + if key == "log-level": value = value.name - key_str = f'{key}:'.ljust(20) - click.echo(f'- {key_str} {repr(value)}') + key_str = f"{key}:".ljust(20) + click.echo(f"- {key_str} {repr(value)}") -@cli_config.command('set') -@click.argument('key') -@click.argument('value') +@cli_config.command("set") +@click.argument("key") +@click.argument("value") @click.pass_context def cli_config_set(ctx: click.Context, key: str, value: Any) -> None: - 'Set a config value' + "Set a config value" try: with ctx.obj.database.session() as conn: new_value = conn.put_config(key, value) except Exception: - click.echo(f'Invalid config name: {key}') + click.echo(f"Invalid config name: {key}") return - click.echo(f'{key}: {repr(new_value)}') + click.echo(f"{key}: {repr(new_value)}") -@cli.group('user') +@cli.group("user") def cli_user() -> None: - 'Manage local users' + "Manage local users" -@cli_user.command('list') +@cli_user.command("list") @click.pass_context def cli_user_list(ctx: click.Context) -> None: - 'List all local users' + "List all local users" - click.echo('Users:') + click.echo("Users:") with ctx.obj.database.session() as conn: for row in conn.get_users(): - click.echo(f'- {row.username}') + click.echo(f"- {row.username}") -@cli_user.command('create') -@click.argument('username') -@click.argument('handle', required = False) +@cli_user.command("create") +@click.argument("username") +@click.argument("handle", required = False) @click.pass_context def cli_user_create(ctx: click.Context, username: str, handle: str) -> None: - 'Create a new local user' + "Create a new local user" with ctx.obj.database.session() as conn: if conn.get_user(username) is not None: - click.echo(f'User already exists: {username}') + click.echo(f"User already exists: {username}") return while True: - if not (password := click.prompt('New password', hide_input = True)): - click.echo('No password provided') + if not (password := click.prompt("New password", hide_input = True)): + click.echo("No password provided") continue - if password != click.prompt('New password again', hide_input = True): - click.echo('Passwords do not match') + if password != click.prompt("New password again", hide_input = True): + click.echo("Passwords do not match") continue break conn.put_user(username, password, handle) - click.echo(f'Created user "{username}"') + click.echo(f"Created user {username}") -@cli_user.command('delete') -@click.argument('username') +@cli_user.command("delete") +@click.argument("username") @click.pass_context def cli_user_delete(ctx: click.Context, username: str) -> None: - 'Delete a local user' + "Delete a local user" with ctx.obj.database.session() as conn: if conn.get_user(username) is None: - click.echo(f'User does not exist: {username}') + click.echo(f"User does not exist: {username}") return conn.del_user(username) - click.echo(f'Deleted user "{username}"') + click.echo(f"Deleted user {username}") -@cli_user.command('list-tokens') -@click.argument('username') +@cli_user.command("list-tokens") +@click.argument("username") @click.pass_context def cli_user_list_tokens(ctx: click.Context, username: str) -> None: - 'List all API tokens for a user' + "List all API tokens for a user" - click.echo(f'Tokens for "{username}":') + click.echo(f"Tokens for {username}:") with ctx.obj.database.session() as conn: for row in conn.get_tokens(username): - click.echo(f'- {row.code}') + click.echo(f"- {row.code}") -@cli_user.command('create-token') -@click.argument('username') +@cli_user.command("create-token") +@click.argument("username") @click.pass_context def cli_user_create_token(ctx: click.Context, username: str) -> None: - 'Create a new API token for a user' + "Create a new API token for a user" with ctx.obj.database.session() as conn: if (user := conn.get_user(username)) is None: - click.echo(f'User does not exist: {username}') + click.echo(f"User does not exist: {username}") return token = conn.put_token(user.username) - click.echo(f'New token for "{username}": {token.code}') + click.echo(f"New token for {username}: {token.code}") -@cli_user.command('delete-token') -@click.argument('code') +@cli_user.command("delete-token") +@click.argument("code") @click.pass_context def cli_user_delete_token(ctx: click.Context, code: str) -> None: - 'Delete an API token' + "Delete an API token" with ctx.obj.database.session() as conn: if conn.get_token(code) is None: - click.echo('Token does not exist') + click.echo("Token does not exist") return conn.del_token(code) - click.echo('Deleted token') + click.echo("Deleted token") -@cli.group('inbox') +@cli.group("inbox") def cli_inbox() -> None: - 'Manage the inboxes in the database' + "Manage the inboxes in the database" -@cli_inbox.command('list') +@cli_inbox.command("list") @click.pass_context def cli_inbox_list(ctx: click.Context) -> None: - 'List the connected instances or relays' + "List the connected instances or relays" - click.echo('Connected to the following instances or relays:') + click.echo("Connected to the following instances or relays:") with ctx.obj.database.session() as conn: for row in conn.get_inboxes(): - click.echo(f'- {row.inbox}') + click.echo(f"- {row.inbox}") -@cli_inbox.command('follow') -@click.argument('actor') +@cli_inbox.command("follow") +@click.argument("actor") @click.pass_context def cli_inbox_follow(ctx: click.Context, actor: str) -> None: - 'Follow an actor (Relay must be running)' + "Follow an actor (Relay must be running)" instance: schema.Instance | None = None with ctx.obj.database.session() as conn: if conn.get_domain_ban(actor): - click.echo(f'Error: Refusing to follow banned actor: {actor}') + click.echo(f"Error: Refusing to follow banned actor: {actor}") return if (instance := conn.get_inbox(actor)) is not None: inbox = instance.inbox else: - if not actor.startswith('http'): - actor = f'https://{actor}/actor' + if not actor.startswith("http"): + actor = f"https://{actor}/actor" if (actor_data := asyncio.run(http.get(actor, sign_headers = True))) is None: - click.echo(f'Failed to fetch actor: {actor}') + click.echo(f"Failed to fetch actor: {actor}") return inbox = actor_data.shared_inbox @@ -548,20 +549,20 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None: ) asyncio.run(http.post(inbox, message, instance)) - click.echo(f'Sent follow message to actor: {actor}') + click.echo(f"Sent follow message to actor: {actor}") -@cli_inbox.command('unfollow') -@click.argument('actor') +@cli_inbox.command("unfollow") +@click.argument("actor") @click.pass_context def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None: - 'Unfollow an actor (Relay must be running)' + "Unfollow an actor (Relay must be running)" instance: schema.Instance | None = None with ctx.obj.database.session() as conn: if conn.get_domain_ban(actor): - click.echo(f'Error: Refusing to follow banned actor: {actor}') + click.echo(f"Error: Refusing to follow banned actor: {actor}") return if (instance := conn.get_inbox(actor)): @@ -573,8 +574,8 @@ def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None: ) else: - if not actor.startswith('http'): - actor = f'https://{actor}/actor' + if not actor.startswith("http"): + actor = f"https://{actor}/actor" actor_data = asyncio.run(http.get(actor, sign_headers = True)) @@ -587,23 +588,23 @@ def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None: host = ctx.obj.config.domain, actor = actor, follow = { - 'type': 'Follow', - 'object': actor, - 'actor': f'https://{ctx.obj.config.domain}/actor' + "type": "Follow", + "object": actor, + "actor": f"https://{ctx.obj.config.domain}/actor" } ) asyncio.run(http.post(inbox, message, instance)) - click.echo(f'Sent unfollow message to: {actor}') + click.echo(f"Sent unfollow message to: {actor}") -@cli_inbox.command('add') -@click.argument('inbox') -@click.option('--actor', '-a', help = 'Actor url for the inbox') -@click.option('--followid', '-f', help = 'Url for the follow activity') -@click.option('--software', '-s', +@cli_inbox.command("add") +@click.argument("inbox") +@click.option("--actor", "-a", help = "Actor url for the inbox") +@click.option("--followid", "-f", help = "Url for the follow activity") +@click.option("--software", "-s", type = click.Choice(SOFTWARE), - help = 'Nodeinfo software name of the instance' + help = "Nodeinfo software name of the instance" ) # noqa: E124 @click.pass_context def cli_inbox_add( @@ -612,11 +613,11 @@ def cli_inbox_add( actor: str | None = None, followid: str | None = None, software: str | None = None) -> None: - 'Add an inbox to the database' + "Add an inbox to the database" - if not inbox.startswith('http'): + if not inbox.startswith("http"): domain = inbox - inbox = f'https://{inbox}/inbox' + inbox = f"https://{inbox}/inbox" else: domain = urlparse(inbox).netloc @@ -634,62 +635,62 @@ def cli_inbox_add( with ctx.obj.database.session() as conn: if conn.get_domain_ban(domain): - click.echo(f'Refusing to add banned inbox: {inbox}') + click.echo(f"Refusing to add banned inbox: {inbox}") return if conn.get_inbox(inbox): - click.echo(f'Error: Inbox already in database: {inbox}') + click.echo(f"Error: Inbox already in database: {inbox}") return conn.put_inbox(domain, inbox, actor, followid, software) - click.echo(f'Added inbox to the database: {inbox}') + click.echo(f"Added inbox to the database: {inbox}") -@cli_inbox.command('remove') -@click.argument('inbox') +@cli_inbox.command("remove") +@click.argument("inbox") @click.pass_context def cli_inbox_remove(ctx: click.Context, inbox: str) -> None: - 'Remove an inbox from the database' + "Remove an inbox from the database" with ctx.obj.database.session() as conn: if not conn.del_inbox(inbox): - click.echo(f'Inbox not in database: {inbox}') + click.echo(f"Inbox not in database: {inbox}") return - click.echo(f'Removed inbox from the database: {inbox}') + click.echo(f"Removed inbox from the database: {inbox}") -@cli.group('request') +@cli.group("request") def cli_request() -> None: - 'Manage follow requests' + "Manage follow requests" -@cli_request.command('list') +@cli_request.command("list") @click.pass_context def cli_request_list(ctx: click.Context) -> None: - 'List all current follow requests' + "List all current follow requests" - click.echo('Follow requests:') + click.echo("Follow requests:") with ctx.obj.database.session() as conn: for row in conn.get_requests(): - date = row.created.strftime('%Y-%m-%d') - click.echo(f'- [{date}] {row.domain}') + date = row.created.strftime("%Y-%m-%d") + click.echo(f"- [{date}] {row.domain}") -@cli_request.command('accept') -@click.argument('domain') +@cli_request.command("accept") +@click.argument("domain") @click.pass_context def cli_request_accept(ctx: click.Context, domain: str) -> None: - 'Accept a follow request' + "Accept a follow request" try: with ctx.obj.database.session() as conn: instance = conn.put_request_response(domain, True) except KeyError: - click.echo('Request not found') + click.echo("Request not found") return message = Message.new_response( @@ -701,7 +702,7 @@ def cli_request_accept(ctx: click.Context, domain: str) -> None: asyncio.run(http.post(instance.inbox, message, instance)) - if instance.software != 'mastodon': + if instance.software != "mastodon": message = Message.new_follow( host = ctx.obj.config.domain, actor = instance.actor @@ -710,18 +711,18 @@ def cli_request_accept(ctx: click.Context, domain: str) -> None: asyncio.run(http.post(instance.inbox, message, instance)) -@cli_request.command('deny') -@click.argument('domain') +@cli_request.command("deny") +@click.argument("domain") @click.pass_context def cli_request_deny(ctx: click.Context, domain: str) -> None: - 'Accept a follow request' + "Accept a follow request" try: with ctx.obj.database.session() as conn: instance = conn.put_request_response(domain, False) except KeyError: - click.echo('Request not found') + click.echo("Request not found") return response = Message.new_response( @@ -734,113 +735,113 @@ def cli_request_deny(ctx: click.Context, domain: str) -> None: asyncio.run(http.post(instance.inbox, response, instance)) -@cli.group('instance') +@cli.group("instance") def cli_instance() -> None: - 'Manage instance bans' + "Manage instance bans" -@cli_instance.command('list') +@cli_instance.command("list") @click.pass_context def cli_instance_list(ctx: click.Context) -> None: - 'List all banned instances' + "List all banned instances" - click.echo('Banned domains:') + click.echo("Banned domains:") with ctx.obj.database.session() as conn: for row in conn.get_domain_bans(): if row.reason is not None: - click.echo(f'- {row.domain} ({row.reason})') + click.echo(f"- {row.domain} ({row.reason})") else: - click.echo(f'- {row.domain}') + click.echo(f"- {row.domain}") -@cli_instance.command('ban') -@click.argument('domain') -@click.option('--reason', '-r', help = 'Public note about why the domain is banned') -@click.option('--note', '-n', help = 'Internal note that will only be seen by admins and mods') +@cli_instance.command("ban") +@click.argument("domain") +@click.option("--reason", "-r", help = "Public note about why the domain is banned") +@click.option("--note", "-n", help = "Internal note that will only be seen by admins and mods") @click.pass_context def cli_instance_ban(ctx: click.Context, domain: str, reason: str, note: str) -> None: - 'Ban an instance and remove the associated inbox if it exists' + "Ban an instance and remove the associated inbox if it exists" with ctx.obj.database.session() as conn: if conn.get_domain_ban(domain) is not None: - click.echo(f'Domain already banned: {domain}') + click.echo(f"Domain already banned: {domain}") return conn.put_domain_ban(domain, reason, note) conn.del_inbox(domain) - click.echo(f'Banned instance: {domain}') + click.echo(f"Banned instance: {domain}") -@cli_instance.command('unban') -@click.argument('domain') +@cli_instance.command("unban") +@click.argument("domain") @click.pass_context def cli_instance_unban(ctx: click.Context, domain: str) -> None: - 'Unban an instance' + "Unban an instance" with ctx.obj.database.session() as conn: if conn.del_domain_ban(domain) is None: - click.echo(f'Instance wasn\'t banned: {domain}') + click.echo(f"Instance wasn\"t banned: {domain}") return - click.echo(f'Unbanned instance: {domain}') + click.echo(f"Unbanned instance: {domain}") -@cli_instance.command('update') -@click.argument('domain') -@click.option('--reason', '-r') -@click.option('--note', '-n') +@cli_instance.command("update") +@click.argument("domain") +@click.option("--reason", "-r") +@click.option("--note", "-n") @click.pass_context def cli_instance_update(ctx: click.Context, domain: str, reason: str, note: str) -> None: - 'Update the public reason or internal note for a domain ban' + "Update the public reason or internal note for a domain ban" if not (reason or note): - ctx.fail('Must pass --reason or --note') + ctx.fail("Must pass --reason or --note") with ctx.obj.database.session() as conn: if not (row := conn.update_domain_ban(domain, reason, note)): - click.echo(f'Failed to update domain ban: {domain}') + click.echo(f"Failed to update domain ban: {domain}") return - click.echo(f'Updated domain ban: {domain}') + click.echo(f"Updated domain ban: {domain}") if row.reason: - click.echo(f'- {row.domain} ({row.reason})') + click.echo(f"- {row.domain} ({row.reason})") else: - click.echo(f'- {row.domain}') + click.echo(f"- {row.domain}") -@cli.group('software') +@cli.group("software") def cli_software() -> None: - 'Manage banned software' + "Manage banned software" -@cli_software.command('list') +@cli_software.command("list") @click.pass_context def cli_software_list(ctx: click.Context) -> None: - 'List all banned software' + "List all banned software" - click.echo('Banned software:') + click.echo("Banned software:") with ctx.obj.database.session() as conn: for row in conn.get_software_bans(): if row.reason: - click.echo(f'- {row.name} ({row.reason})') + click.echo(f"- {row.name} ({row.reason})") else: - click.echo(f'- {row.name}') + click.echo(f"- {row.name}") -@cli_software.command('ban') -@click.argument('name') -@click.option('--reason', '-r') -@click.option('--note', '-n') +@cli_software.command("ban") +@click.argument("name") +@click.option("--reason", "-r") +@click.option("--note", "-n") @click.option( - '--fetch-nodeinfo', '-f', + "--fetch-nodeinfo", "-f", is_flag = True, - help = 'Treat NAME like a domain and try to fetch the software name from nodeinfo' + help = "Treat NAME like a domain and try to fetch the software name from nodeinfo" ) @click.pass_context def cli_software_ban(ctx: click.Context, @@ -848,189 +849,189 @@ def cli_software_ban(ctx: click.Context, reason: str, note: str, fetch_nodeinfo: bool) -> None: - 'Ban software. Use RELAYS for NAME to ban relays' + "Ban software. Use RELAYS for NAME to ban relays" with ctx.obj.database.session() as conn: - if name == 'RELAYS': + if name == "RELAYS": for item in RELAY_SOFTWARE: if conn.get_software_ban(item): - click.echo(f'Relay already banned: {item}') + click.echo(f"Relay already banned: {item}") continue - conn.put_software_ban(item, reason or 'relay', note) + conn.put_software_ban(item, reason or "relay", note) - click.echo('Banned all relay software') + click.echo("Banned all relay software") return if fetch_nodeinfo: if not (nodeinfo := asyncio.run(http.fetch_nodeinfo(name))): - click.echo(f'Failed to fetch software name from domain: {name}') + click.echo(f"Failed to fetch software name from domain: {name}") return name = nodeinfo.sw_name if conn.get_software_ban(name): - click.echo(f'Software already banned: {name}') + click.echo(f"Software already banned: {name}") return if not conn.put_software_ban(name, reason, note): - click.echo(f'Failed to ban software: {name}') + click.echo(f"Failed to ban software: {name}") return - click.echo(f'Banned software: {name}') + click.echo(f"Banned software: {name}") -@cli_software.command('unban') -@click.argument('name') -@click.option('--reason', '-r') -@click.option('--note', '-n') +@cli_software.command("unban") +@click.argument("name") +@click.option("--reason", "-r") +@click.option("--note", "-n") @click.option( - '--fetch-nodeinfo', '-f', + "--fetch-nodeinfo", "-f", is_flag = True, - help = 'Treat NAME like a domain and try to fetch the software name from nodeinfo' + help = "Treat NAME like a domain and try to fetch the software name from nodeinfo" ) @click.pass_context def cli_software_unban(ctx: click.Context, name: str, fetch_nodeinfo: bool) -> None: - 'Ban software. Use RELAYS for NAME to unban relays' + "Ban software. Use RELAYS for NAME to unban relays" with ctx.obj.database.session() as conn: - if name == 'RELAYS': + if name == "RELAYS": for software in RELAY_SOFTWARE: if not conn.del_software_ban(software): - click.echo(f'Relay was not banned: {software}') + click.echo(f"Relay was not banned: {software}") - click.echo('Unbanned all relay software') + click.echo("Unbanned all relay software") return if fetch_nodeinfo: if not (nodeinfo := asyncio.run(http.fetch_nodeinfo(name))): - click.echo(f'Failed to fetch software name from domain: {name}') + click.echo(f"Failed to fetch software name from domain: {name}") return name = nodeinfo.sw_name if not conn.del_software_ban(name): - click.echo(f'Software was not banned: {name}') + click.echo(f"Software was not banned: {name}") return - click.echo(f'Unbanned software: {name}') + click.echo(f"Unbanned software: {name}") -@cli_software.command('update') -@click.argument('name') -@click.option('--reason', '-r') -@click.option('--note', '-n') +@cli_software.command("update") +@click.argument("name") +@click.option("--reason", "-r") +@click.option("--note", "-n") @click.pass_context def cli_software_update(ctx: click.Context, name: str, reason: str, note: str) -> None: - 'Update the public reason or internal note for a software ban' + "Update the public reason or internal note for a software ban" if not (reason or note): - ctx.fail('Must pass --reason or --note') + ctx.fail("Must pass --reason or --note") with ctx.obj.database.session() as conn: if not (row := conn.update_software_ban(name, reason, note)): - click.echo(f'Failed to update software ban: {name}') + click.echo(f"Failed to update software ban: {name}") return - click.echo(f'Updated software ban: {name}') + click.echo(f"Updated software ban: {name}") if row.reason: - click.echo(f'- {row.name} ({row.reason})') + click.echo(f"- {row.name} ({row.reason})") else: - click.echo(f'- {row.name}') + click.echo(f"- {row.name}") -@cli.group('whitelist') +@cli.group("whitelist") def cli_whitelist() -> None: - 'Manage the instance whitelist' + "Manage the instance whitelist" -@cli_whitelist.command('list') +@cli_whitelist.command("list") @click.pass_context def cli_whitelist_list(ctx: click.Context) -> None: - 'List all the instances in the whitelist' + "List all the instances in the whitelist" - click.echo('Current whitelisted domains:') + click.echo("Current whitelisted domains:") with ctx.obj.database.session() as conn: for row in conn.get_domain_whitelist(): - click.echo(f'- {row.domain}') + click.echo(f"- {row.domain}") -@cli_whitelist.command('add') -@click.argument('domain') +@cli_whitelist.command("add") +@click.argument("domain") @click.pass_context def cli_whitelist_add(ctx: click.Context, domain: str) -> None: - 'Add a domain to the whitelist' + "Add a domain to the whitelist" with ctx.obj.database.session() as conn: if conn.get_domain_whitelist(domain): - click.echo(f'Instance already in the whitelist: {domain}') + click.echo(f"Instance already in the whitelist: {domain}") return conn.put_domain_whitelist(domain) - click.echo(f'Instance added to the whitelist: {domain}') + click.echo(f"Instance added to the whitelist: {domain}") -@cli_whitelist.command('remove') -@click.argument('domain') +@cli_whitelist.command("remove") +@click.argument("domain") @click.pass_context def cli_whitelist_remove(ctx: click.Context, domain: str) -> None: - 'Remove an instance from the whitelist' + "Remove an instance from the whitelist" with ctx.obj.database.session() as conn: if not conn.del_domain_whitelist(domain): - click.echo(f'Domain not in the whitelist: {domain}') + click.echo(f"Domain not in the whitelist: {domain}") return - if conn.get_config('whitelist-enabled'): + if conn.get_config("whitelist-enabled"): if conn.del_inbox(domain): - click.echo(f'Removed inbox for domain: {domain}') + click.echo(f"Removed inbox for domain: {domain}") - click.echo(f'Removed domain from the whitelist: {domain}') + click.echo(f"Removed domain from the whitelist: {domain}") -@cli_whitelist.command('import') +@cli_whitelist.command("import") @click.pass_context def cli_whitelist_import(ctx: click.Context) -> None: - 'Add all current instances to the whitelist' + "Add all current instances to the whitelist" with ctx.obj.database.session() as conn: for row in conn.get_inboxes(): if conn.get_domain_whitelist(row.domain) is not None: - click.echo(f'Domain already in whitelist: {row.domain}') + click.echo(f"Domain already in whitelist: {row.domain}") continue conn.put_domain_whitelist(row.domain) - click.echo('Imported whitelist from inboxes') + click.echo("Imported whitelist from inboxes") def config_postgresql(config: Config) -> None: config.pg_name = click.prompt( - 'What is the name of the database?', + "What is the name of the database?", default = config.pg_name ) config.pg_host = click.prompt( - 'What IP address, hostname, or unix socket does the server listen on?', + "What IP address, hostname, or unix socket does the server listen on?", default = config.pg_host, ) config.pg_port = click.prompt( - 'What port does the server listen on?', + "What port does the server listen on?", default = config.pg_port, type = int ) config.pg_user = click.prompt( - 'Which user will authenticate with the server?', + "Which user will authenticate with the server?", default = config.pg_user ) config.pg_pass = click.prompt( - 'User password', + "User password", hide_input = True, show_default = False, default = config.pg_pass or "" @@ -1038,4 +1039,5 @@ def config_postgresql(config: Config) -> None: def main() -> None: - cli(prog_name='activityrelay') + multiprocessing.freeze_support() + cli(prog_name="activityrelay") diff --git a/relay/misc.py b/relay/misc.py index 4d250e5..fbab355 100644 --- a/relay/misc.py +++ b/relay/misc.py @@ -16,63 +16,63 @@ if TYPE_CHECKING: from .application import Application -T = TypeVar('T') -ResponseType = TypedDict('ResponseType', { - 'status': int, - 'headers': dict[str, Any] | None, - 'content_type': str, - 'body': bytes | None, - 'text': str | None +T = TypeVar("T") +ResponseType = TypedDict("ResponseType", { + "status": int, + "headers": dict[str, Any] | None, + "content_type": str, + "body": bytes | None, + "text": str | None }) -IS_DOCKER = bool(os.environ.get('DOCKER_RUNNING')) -IS_WINDOWS = platform.system() == 'Windows' +IS_DOCKER = bool(os.environ.get("DOCKER_RUNNING")) +IS_WINDOWS = platform.system() == "Windows" MIMETYPES = { - 'activity': 'application/activity+json', - 'css': 'text/css', - 'html': 'text/html', - 'json': 'application/json', - 'text': 'text/plain', - 'webmanifest': 'application/manifest+json' + "activity": "application/activity+json", + "css": "text/css", + "html": "text/html", + "json": "application/json", + "text": "text/plain", + "webmanifest": "application/manifest+json" } ACTOR_FORMATS = { - 'mastodon': 'https://{domain}/actor', - 'akkoma': 'https://{domain}/relay', - 'pleroma': 'https://{domain}/relay' + "mastodon": "https://{domain}/actor", + "akkoma": "https://{domain}/relay", + "pleroma": "https://{domain}/relay" } SOFTWARE = ( - 'mastodon', - 'akkoma', - 'pleroma', - 'misskey', - 'friendica', - 'hubzilla', - 'firefish', - 'gotosocial' + "mastodon", + "akkoma", + "pleroma", + "misskey", + "friendica", + "hubzilla", + "firefish", + "gotosocial" ) JSON_PATHS: tuple[str, ...] = ( - '/api/v1', - '/actor', - '/inbox', - '/outbox', - '/following', - '/followers', - '/.well-known', - '/nodeinfo', - '/oauth/token', - '/oauth/revoke' + "/api/v1", + "/actor", + "/inbox", + "/outbox", + "/following", + "/followers", + "/.well-known", + "/nodeinfo", + "/oauth/token", + "/oauth/revoke" ) TOKEN_PATHS: tuple[str, ...] = ( - '/logout', - '/admin', - '/api', - '/oauth/authorize', - '/oauth/revoke' + "/logout", + "/admin", + "/api", + "/oauth/authorize", + "/oauth/revoke" ) @@ -80,7 +80,7 @@ def get_app() -> Application: from .application import Application if not Application.DEFAULT: - raise ValueError('No default application set') + raise ValueError("No default application set") return Application.DEFAULT @@ -136,23 +136,23 @@ class Message(aputils.Message): approves: bool = False) -> Self: return cls.new(aputils.ObjectType.APPLICATION, { - 'id': f'https://{host}/actor', - 'preferredUsername': 'relay', - 'name': 'ActivityRelay', - 'summary': description or 'ActivityRelay bot', - 'manuallyApprovesFollowers': approves, - 'followers': f'https://{host}/followers', - 'following': f'https://{host}/following', - 'inbox': f'https://{host}/inbox', - 'outbox': f'https://{host}/outbox', - 'url': f'https://{host}/', - 'endpoints': { - 'sharedInbox': f'https://{host}/inbox' + "id": f"https://{host}/actor", + "preferredUsername": "relay", + "name": "ActivityRelay", + "summary": description or "ActivityRelay bot", + "manuallyApprovesFollowers": approves, + "followers": f"https://{host}/followers", + "following": f"https://{host}/following", + "inbox": f"https://{host}/inbox", + "outbox": f"https://{host}/outbox", + "url": f"https://{host}/", + "endpoints": { + "sharedInbox": f"https://{host}/inbox" }, - 'publicKey': { - 'id': f'https://{host}/actor#main-key', - 'owner': f'https://{host}/actor', - 'publicKeyPem': pubkey + "publicKey": { + "id": f"https://{host}/actor#main-key", + "owner": f"https://{host}/actor", + "publicKeyPem": pubkey } }) @@ -160,44 +160,44 @@ class Message(aputils.Message): @classmethod def new_announce(cls: type[Self], host: str, obj: str | dict[str, Any]) -> Self: return cls.new(aputils.ObjectType.ANNOUNCE, { - 'id': f'https://{host}/activities/{uuid4()}', - 'to': [f'https://{host}/followers'], - 'actor': f'https://{host}/actor', - 'object': obj + "id": f"https://{host}/activities/{uuid4()}", + "to": [f"https://{host}/followers"], + "actor": f"https://{host}/actor", + "object": obj }) @classmethod def new_follow(cls: type[Self], host: str, actor: str) -> Self: return cls.new(aputils.ObjectType.FOLLOW, { - 'id': f'https://{host}/activities/{uuid4()}', - 'to': [actor], - 'object': actor, - 'actor': f'https://{host}/actor' + "id": f"https://{host}/activities/{uuid4()}", + "to": [actor], + "object": actor, + "actor": f"https://{host}/actor" }) @classmethod def new_unfollow(cls: type[Self], host: str, actor: str, follow: dict[str, str]) -> Self: return cls.new(aputils.ObjectType.UNDO, { - 'id': f'https://{host}/activities/{uuid4()}', - 'to': [actor], - 'actor': f'https://{host}/actor', - 'object': follow + "id": f"https://{host}/activities/{uuid4()}", + "to": [actor], + "actor": f"https://{host}/actor", + "object": follow }) @classmethod def new_response(cls: type[Self], host: str, actor: str, followid: str, accept: bool) -> Self: return cls.new(aputils.ObjectType.ACCEPT if accept else aputils.ObjectType.REJECT, { - 'id': f'https://{host}/activities/{uuid4()}', - 'to': [actor], - 'actor': f'https://{host}/actor', - 'object': { - 'id': followid, - 'type': 'Follow', - 'object': f'https://{host}/actor', - 'actor': actor + "id": f"https://{host}/activities/{uuid4()}", + "to": [actor], + "actor": f"https://{host}/actor", + "object": { + "id": followid, + "type": "Follow", + "object": f"https://{host}/actor", + "actor": actor } }) @@ -210,35 +210,35 @@ class Response(AiohttpResponse): @classmethod def new(cls: type[Self], - body: str | bytes | dict[str, Any] | Sequence[Any] = '', + body: str | bytes | dict[str, Any] | Sequence[Any] = "", status: int = 200, headers: dict[str, str] | None = None, - ctype: str = 'text') -> Self: + ctype: str = "text") -> Self: kwargs: ResponseType = { - 'status': status, - 'headers': headers, - 'content_type': MIMETYPES[ctype], - 'body': None, - 'text': None + "status": status, + "headers": headers, + "content_type": MIMETYPES[ctype], + "body": None, + "text": None } if isinstance(body, str): - kwargs['text'] = body + kwargs["text"] = body elif isinstance(body, bytes): - kwargs['body'] = body + kwargs["body"] = body elif isinstance(body, (dict, Sequence)): - kwargs['text'] = json.dumps(body, cls = JsonEncoder) + kwargs["text"] = json.dumps(body, cls = JsonEncoder) return cls(**kwargs) @classmethod def new_redir(cls: type[Self], path: str, status: int = 307) -> Self: - body = f'Redirect to {path}' - return cls.new(body, status, {'Location': path}, ctype = 'html') + body = f"Redirect to {path}" + return cls.new(body, status, {"Location": path}, ctype = "html") @classmethod @@ -256,9 +256,9 @@ class Response(AiohttpResponse): @property def location(self) -> str: - return self.headers.get('Location', '') + return self.headers.get("Location", "") @location.setter def location(self, value: str) -> None: - self.headers['Location'] = value + self.headers["Location"] = value diff --git a/relay/processors.py b/relay/processors.py index f56ba48..5615bee 100644 --- a/relay/processors.py +++ b/relay/processors.py @@ -12,11 +12,11 @@ if typing.TYPE_CHECKING: def actor_type_check(actor: Message, software: str | None) -> bool: - if actor.type == 'Application': + if actor.type == "Application": return True # akkoma (< 3.6.0) and pleroma use Person for the actor type - if software in {'akkoma', 'pleroma'} and actor.id == f'https://{actor.domain}/relay': + if software in {"akkoma", "pleroma"} and actor.id == f"https://{actor.domain}/relay": return True return False @@ -24,38 +24,38 @@ def actor_type_check(actor: Message, software: str | None) -> bool: async def handle_relay(app: Application, data: InboxData, conn: Connection) -> None: try: - app.cache.get('handle-relay', data.message.object_id) - logging.verbose('already relayed %s', data.message.object_id) + app.cache.get("handle-relay", data.message.object_id) + logging.verbose("already relayed %s", data.message.object_id) return except KeyError: pass message = Message.new_announce(app.config.domain, data.message.object_id) - logging.debug('>> relay: %s', message) + logging.debug(">> relay: %s", message) for instance in conn.distill_inboxes(data.message): app.push_message(instance.inbox, message, instance) - app.cache.set('handle-relay', data.message.object_id, message.id, 'str') + app.cache.set("handle-relay", data.message.object_id, message.id, "str") async def handle_forward(app: Application, data: InboxData, conn: Connection) -> None: try: - app.cache.get('handle-relay', data.message.id) - logging.verbose('already forwarded %s', data.message.id) + app.cache.get("handle-relay", data.message.id) + logging.verbose("already forwarded %s", data.message.id) return except KeyError: pass message = Message.new_announce(app.config.domain, data.message) - logging.debug('>> forward: %s', message) + logging.debug(">> forward: %s", message) for instance in conn.distill_inboxes(data.message): app.push_message(instance.inbox, data.message, instance) - app.cache.set('handle-relay', data.message.id, message.id, 'str') + app.cache.set("handle-relay", data.message.id, message.id, "str") async def handle_follow(app: Application, data: InboxData, conn: Connection) -> None: @@ -65,7 +65,7 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) -> # reject if software used by actor is banned if software and conn.get_software_ban(software): - logging.verbose('Rejected banned actor: %s', data.actor.id) + logging.verbose("Rejected banned actor: %s", data.actor.id) app.push_message( data.actor.shared_inbox, @@ -79,7 +79,7 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) -> ) logging.verbose( - 'Rejected follow from actor for using specific software: actor=%s, software=%s', + "Rejected follow from actor for using specific software: actor=%s, software=%s", data.actor.id, software ) @@ -88,7 +88,7 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) -> # reject if the actor is not an instance actor if actor_type_check(data.actor, software): - logging.verbose('Non-application actor tried to follow: %s', data.actor.id) + logging.verbose("Non-application actor tried to follow: %s", data.actor.id) app.push_message( data.actor.shared_inbox, @@ -106,7 +106,7 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) -> if not conn.get_domain_whitelist(data.actor.domain): # add request if approval-required is enabled if config.approval_required: - logging.verbose('New follow request fromm actor: %s', data.actor.id) + logging.verbose("New follow request fromm actor: %s", data.actor.id) with conn.transaction(): data.instance = conn.put_inbox( @@ -120,9 +120,9 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) -> return - # reject if the actor isn't whitelisted while the whiltelist is enabled + # reject if the actor isn"t whitelisted while the whiltelist is enabled if config.whitelist_enabled: - logging.verbose('Rejected actor for not being in the whitelist: %s', data.actor.id) + logging.verbose("Rejected actor for not being in the whitelist: %s", data.actor.id) app.push_message( data.actor.shared_inbox, @@ -160,7 +160,7 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) -> # Are Akkoma and Pleroma the only two that expect a follow back? # Ignoring only Mastodon for now - if software != 'mastodon': + if software != "mastodon": app.push_message( data.actor.shared_inbox, Message.new_follow( @@ -172,8 +172,8 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) -> async def handle_undo(app: Application, data: InboxData, conn: Connection) -> None: - if data.message.object['type'] != 'Follow': - # forwarding deletes does not work, so don't bother + if data.message.object["type"] != "Follow": + # forwarding deletes does not work, so don"t bother # await handle_forward(app, data, conn) return @@ -187,7 +187,7 @@ async def handle_undo(app: Application, data: InboxData, conn: Connection) -> No with conn.transaction(): if not conn.del_inbox(data.actor.id): logging.verbose( - 'Failed to delete "%s" with follow ID "%s"', + "Failed to delete \"%s\" with follow ID \"%s\"", data.actor.id, data.message.object_id ) @@ -204,19 +204,19 @@ async def handle_undo(app: Application, data: InboxData, conn: Connection) -> No processors = { - 'Announce': handle_relay, - 'Create': handle_relay, - 'Delete': handle_forward, - 'Follow': handle_follow, - 'Undo': handle_undo, - 'Update': handle_forward, + "Announce": handle_relay, + "Create": handle_relay, + "Delete": handle_forward, + "Follow": handle_follow, + "Undo": handle_undo, + "Update": handle_forward, } async def run_processor(data: InboxData) -> None: if data.message.type not in processors: logging.verbose( - 'Message type "%s" from actor cannot be handled: %s', + "Message type \"%s\" from actor cannot be handled: %s", data.message.type, data.actor.id ) @@ -242,5 +242,5 @@ async def run_processor(data: InboxData) -> None: actor = data.actor.id ) - logging.verbose('New "%s" from actor: %s', data.message.type, data.actor.id) + logging.verbose("New \"%s\" from actor: %s", data.message.type, data.actor.id) await processors[data.message.type](app, data, conn) diff --git a/relay/template.py b/relay/template.py index 472a691..d761ad9 100644 --- a/relay/template.py +++ b/relay/template.py @@ -34,8 +34,8 @@ class Template(Environment): MarkdownExtension ], loader = FileSystemLoader([ - File.from_resource('relay', 'frontend'), - app.config.path.parent.joinpath('template') + File.from_resource("relay", "frontend"), + app.config.path.parent.joinpath("template") ]) ) @@ -47,10 +47,10 @@ class Template(Environment): config = conn.get_config_all() new_context = { - 'request': request, - 'domain': self.app.config.domain, - 'version': __version__, - 'config': config, + "request": request, + "domain": self.app.config.domain, + "version": __version__, + "config": config, **(context or {}) } @@ -58,11 +58,11 @@ class Template(Environment): class MarkdownExtension(Extension): - tags = {'markdown'} + tags = {"markdown"} extensions = ( - 'attr_list', - 'smarty', - 'tables' + "attr_list", + "smarty", + "tables" ) @@ -77,14 +77,14 @@ class MarkdownExtension(Extension): def parse(self, parser: Parser) -> Node | list[Node]: lineno = next(parser.stream).lineno body = parser.parse_statements( - ('name:endmarkdown',), + ("name:endmarkdown",), drop_needle = True ) - output = CallBlock(self.call_method('_render_markdown'), [], [], body) + output = CallBlock(self.call_method("_render_markdown"), [], [], body) return output.set_lineno(lineno) def _render_markdown(self, caller: Callable[[], str] | str) -> str: text = caller if isinstance(caller, str) else caller() - return self._markdown.convert(textwrap.dedent(text.strip('\n'))) + return self._markdown.convert(textwrap.dedent(text.strip("\n"))) diff --git a/relay/views/activitypub.py b/relay/views/activitypub.py index 7f70901..f8f98e3 100644 --- a/relay/views/activitypub.py +++ b/relay/views/activitypub.py @@ -44,61 +44,61 @@ class InboxData: signer: Signer | None = None try: - signature = Signature.parse(request.headers['signature']) + signature = Signature.parse(request.headers["signature"]) except KeyError: - logging.verbose('Missing signature header') - raise HttpError(400, 'missing signature header') + logging.verbose("Missing signature header") + raise HttpError(400, "missing signature header") try: message = await request.json(loads = Message.parse) except Exception: traceback.print_exc() - logging.verbose('Failed to parse message from actor: %s', signature.keyid) - raise HttpError(400, 'failed to parse message') + logging.verbose("Failed to parse message from actor: %s", signature.keyid) + raise HttpError(400, "failed to parse message") if message is None: - logging.verbose('empty message') - raise HttpError(400, 'missing message') + logging.verbose("empty message") + raise HttpError(400, "missing message") - if 'actor' not in message: - logging.verbose('actor not in message') - raise HttpError(400, 'no actor in message') + if "actor" not in message: + logging.verbose("actor not in message") + raise HttpError(400, "no actor in message") try: actor = await app.client.get(signature.keyid, True, Message) except HttpError as e: - # ld signatures aren't handled atm, so just ignore it - if message.type == 'Delete': - logging.verbose('Instance sent a delete which cannot be handled') - raise HttpError(202, '') + # ld signatures aren"t handled atm, so just ignore it + if message.type == "Delete": + logging.verbose("Instance sent a delete which cannot be handled") + raise HttpError(202, "") - logging.verbose('Failed to fetch actor: %s', signature.keyid) - logging.debug('HTTP Status %i: %s', e.status, e.message) - raise HttpError(400, 'failed to fetch actor') + logging.verbose("Failed to fetch actor: %s", signature.keyid) + logging.debug("HTTP Status %i: %s", e.status, e.message) + raise HttpError(400, "failed to fetch actor") except ClientConnectorError as e: - logging.warning('Error when trying to fetch actor: %s, %s', signature.keyid, str(e)) - raise HttpError(400, 'failed to fetch actor') + logging.warning("Error when trying to fetch actor: %s, %s", signature.keyid, str(e)) + raise HttpError(400, "failed to fetch actor") except Exception: traceback.print_exc() - raise HttpError(500, 'unexpected error when fetching actor') + raise HttpError(500, "unexpected error when fetching actor") try: signer = actor.signer except KeyError: - logging.verbose('Actor missing public key: %s', signature.keyid) - raise HttpError(400, 'actor missing public key') + logging.verbose("Actor missing public key: %s", signature.keyid) + raise HttpError(400, "actor missing public key") try: await signer.validate_request_async(request) except SignatureFailureError as e: - logging.verbose('signature validation failed for "%s": %s', actor.id, e) + logging.verbose("signature validation failed for \"%s\": %s", actor.id, e) raise HttpError(401, str(e)) return cls(signature, message, actor, signer, None) @@ -128,30 +128,30 @@ async def handle_inbox(app: Application, request: Request) -> Response: # reject if actor is banned if conn.get_domain_ban(data.actor.domain): - logging.verbose('Ignored request from banned actor: %s', data.actor.id) - raise HttpError(403, 'access denied') + logging.verbose("Ignored request from banned actor: %s", data.actor.id) + raise HttpError(403, "access denied") - # reject if activity type isn't 'Follow' and the actor isn't following - if data.message.type != 'Follow' and not data.instance: + # reject if activity type isn"t "Follow" and the actor isn"t following + if data.message.type != "Follow" and not data.instance: logging.verbose( - 'Rejected actor for trying to post while not following: %s', + "Rejected actor for trying to post while not following: %s", data.actor.id ) - raise HttpError(401, 'access denied') + raise HttpError(401, "access denied") - logging.debug('>> payload %s', data.message.to_json(4)) + logging.debug(">> payload %s", data.message.to_json(4)) await run_processor(data) return Response.new(status = 202) -@register_route(HttpMethod.GET, '/outbox') +@register_route(HttpMethod.GET, "/outbox") async def handle_outbox(app: Application, request: Request) -> Response: msg = aputils.Message.new( aputils.ObjectType.ORDERED_COLLECTION, { - "id": f'https://{app.config.domain}/outbox', + "id": f"https://{app.config.domain}/outbox", "totalItems": 0, "orderedItems": [] } @@ -160,15 +160,15 @@ async def handle_outbox(app: Application, request: Request) -> Response: return Response.new(msg, ctype = "activity") -@register_route(HttpMethod.GET, '/following', '/followers') +@register_route(HttpMethod.GET, "/following", "/followers") async def handle_follow(app: Application, request: Request) -> Response: with app.database.session(False) as s: - inboxes = [row['actor'] for row in s.get_inboxes()] + inboxes = [row["actor"] for row in s.get_inboxes()] msg = aputils.Message.new( aputils.ObjectType.COLLECTION, { - "id": f'https://{app.config.domain}{request.path}', + "id": f"https://{app.config.domain}{request.path}", "totalItems": len(inboxes), "items": inboxes } @@ -177,21 +177,21 @@ async def handle_follow(app: Application, request: Request) -> Response: return Response.new(msg, ctype = "activity") -@register_route(HttpMethod.GET, '/.well-known/webfinger') +@register_route(HttpMethod.GET, "/.well-known/webfinger") async def get(app: Application, request: Request) -> Response: try: - subject = request.query['resource'] + subject = request.query["resource"] except KeyError: - raise HttpError(400, 'missing "resource" query key') + raise HttpError(400, "missing \"resource\" query key") - if subject != f'acct:relay@{app.config.domain}': - raise HttpError(404, 'user not found') + if subject != f"acct:relay@{app.config.domain}": + raise HttpError(404, "user not found") data = aputils.Webfinger.new( - handle = 'relay', + handle = "relay", domain = app.config.domain, actor = app.config.actor ) - return Response.new(data, ctype = 'json') + return Response.new(data, ctype = "json") diff --git a/relay/views/api.py b/relay/views/api.py index 95e3b73..089fcdf 100644 --- a/relay/views/api.py +++ b/relay/views/api.py @@ -343,7 +343,7 @@ async def handle_instance_add( with app.database.session(False) as s: if s.get_inbox(domain) is not None: - raise HttpError(404, 'Instance already in database') + raise HttpError(404, "Instance already in database") if inbox is None: try: @@ -396,7 +396,7 @@ async def handle_instance_update( with app.database.session(False) as s: if (instance := s.get_inbox(domain)) is None: - raise HttpError(404, 'Instance with domain not found') + raise HttpError(404, "Instance with domain not found") row = s.put_inbox( instance.domain, diff --git a/relay/views/base.py b/relay/views/base.py index 6dba873..6fe3469 100644 --- a/relay/views/base.py +++ b/relay/views/base.py @@ -31,11 +31,11 @@ if TYPE_CHECKING: METHODS: dict[str, Method] = {} ROUTES: list[tuple[str, str, HandlerCallback]] = [] -DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob' +DEFAULT_REDIRECT: str = "urn:ietf:wg:oauth:2.0:oob" ALLOWED_HEADERS: set[str] = { - 'accept', - 'authorization', - 'content-type' + "accept", + "authorization", + "content-type" } @@ -100,14 +100,14 @@ class Method: return_type = get_origin(return_type) if not issubclass(return_type, (Response, ApiObject, list)): - raise ValueError(f"Invalid return type '{return_type.__name__}' for {func.__name__}") + raise ValueError(f"Invalid return type \"{return_type.__name__}\" for {func.__name__}") args = {key: value for key, value in inspect.signature(func).parameters.items()} docstring, paramdocs = parse_docstring(func.__doc__ or "") params = [] if func.__doc__ is None: - logging.warning(f"Missing docstring for '{func.__name__}'") + logging.warning(f"Missing docstring for \"{func.__name__}\"") for key, value in args.items(): types: list[type[Any]] = [] @@ -134,7 +134,7 @@ class Method: )) if not paramdocs.get(key): - logging.warning(f"Missing docs for '{key}' parameter in '{func.__name__}'") + logging.warning(f"Missing docs for \"{key}\" parameter in \"{func.__name__}\"") rtype = annotations.get("return") or type(None) return cls(func.__name__, category, docstring, method, path, rtype, tuple(params)) @@ -222,7 +222,7 @@ class Route: if request.method != "OPTIONS" and self.require_token: if (auth := request.headers.getone("Authorization", None)) is None: - raise HttpError(401, 'Missing token') + raise HttpError(401, "Missing token") try: authtype, code = auth.split(" ", 1) @@ -245,15 +245,15 @@ class Route: request["application"] = application - if request.content_type in {'application/x-www-form-urlencoded', 'multipart/form-data'}: + if request.content_type in {"application/x-www-form-urlencoded", "multipart/form-data"}: post_data = {key: value for key, value in (await request.post()).items()} - elif request.content_type == 'application/json': + elif request.content_type == "application/json": try: post_data = await request.json() except JSONDecodeError: - raise HttpError(400, 'Invalid JSON data') + raise HttpError(400, "Invalid JSON data") else: post_data = {key: str(value) for key, value in request.query.items()} @@ -262,7 +262,7 @@ class Route: response = await self.handler(get_app(), request, **post_data) except HttpError as error: - return Response.new({'error': error.message}, error.status, ctype = "json") + return Response.new({"error": error.message}, error.status, ctype = "json") headers = { "Access-Control-Allow-Origin": "*", diff --git a/relay/views/frontend.py b/relay/views/frontend.py index 87c6424..0935974 100644 --- a/relay/views/frontend.py +++ b/relay/views/frontend.py @@ -19,7 +19,7 @@ if TYPE_CHECKING: async def handle_home(app: Application, request: Request) -> Response: with app.database.session() as conn: context: dict[str, Any] = { - 'instances': tuple(conn.get_inboxes()) + "instances": tuple(conn.get_inboxes()) } return Response.new_template(200, "page/home.haml", request, context) @@ -34,28 +34,28 @@ async def handle_api_doc(app: Application, request: Request) -> Response: return Response.new_template(200, "page/docs.haml", request, context) -@register_route(HttpMethod.GET, '/login') +@register_route(HttpMethod.GET, "/login") async def handle_login(app: Application, request: Request) -> Response: context = {"redir": unquote(request.query.get("redir", "/"))} return Response.new_template(200, "page/login.haml", request, context) -@register_route(HttpMethod.GET, '/logout') +@register_route(HttpMethod.GET, "/logout") async def handle_logout(app: Application, request: Request) -> Response: with app.database.session(True) as conn: - conn.del_app(request['token'].client_id, request['token'].client_secret) + conn.del_app(request["token"].client_id, request["token"].client_secret) - resp = Response.new_redir('/') - resp.del_cookie('user-token', domain = app.config.domain, path = '/') + resp = Response.new_redir("/") + resp.del_cookie("user-token", domain = app.config.domain, path = "/") return resp -@register_route(HttpMethod.GET, '/admin') +@register_route(HttpMethod.GET, "/admin") async def handle_admin(app: Application, request: Request) -> Response: - return Response.new_redir(f'/login?redir={request.path}', 301) + return Response.new_redir(f"/login?redir={request.path}", 301) -@register_route(HttpMethod.GET, '/admin/instances') +@register_route(HttpMethod.GET, "/admin/instances") async def handle_admin_instances( app: Application, request: Request, @@ -64,20 +64,20 @@ async def handle_admin_instances( with app.database.session() as conn: context: dict[str, Any] = { - 'instances': tuple(conn.get_inboxes()), - 'requests': tuple(conn.get_requests()) + "instances": tuple(conn.get_inboxes()), + "requests": tuple(conn.get_requests()) } if error: - context['error'] = error + context["error"] = error if message: - context['message'] = message + context["message"] = message return Response.new_template(200, "page/admin-instances.haml", request, context) -@register_route(HttpMethod.GET, '/admin/whitelist') +@register_route(HttpMethod.GET, "/admin/whitelist") async def handle_admin_whitelist( app: Application, request: Request, @@ -86,19 +86,19 @@ async def handle_admin_whitelist( with app.database.session() as conn: context: dict[str, Any] = { - 'whitelist': tuple(conn.execute('SELECT * FROM whitelist ORDER BY domain ASC')) + "whitelist": tuple(conn.execute("SELECT * FROM whitelist ORDER BY domain ASC")) } if error: - context['error'] = error + context["error"] = error if message: - context['message'] = message + context["message"] = message return Response.new_template(200, "page/admin-whitelist.haml", request, context) -@register_route(HttpMethod.GET, '/admin/domain_bans') +@register_route(HttpMethod.GET, "/admin/domain_bans") async def handle_admin_instance_bans( app: Application, request: Request, @@ -107,19 +107,19 @@ async def handle_admin_instance_bans( with app.database.session() as conn: context: dict[str, Any] = { - 'bans': tuple(conn.execute('SELECT * FROM domain_bans ORDER BY domain ASC')) + "bans": tuple(conn.execute("SELECT * FROM domain_bans ORDER BY domain ASC")) } if error: - context['error'] = error + context["error"] = error if message: - context['message'] = message + context["message"] = message return Response.new_template(200, "page/admin-domain_bans.haml", request, context) -@register_route(HttpMethod.GET, '/admin/software_bans') +@register_route(HttpMethod.GET, "/admin/software_bans") async def handle_admin_software_bans( app: Application, request: Request, @@ -128,19 +128,19 @@ async def handle_admin_software_bans( with app.database.session() as conn: context: dict[str, Any] = { - 'bans': tuple(conn.execute('SELECT * FROM software_bans ORDER BY name ASC')) + "bans": tuple(conn.execute("SELECT * FROM software_bans ORDER BY name ASC")) } if error: - context['error'] = error + context["error"] = error if message: - context['message'] = message + context["message"] = message return Response.new_template(200, "page/admin-software_bans.haml", request, context) -@register_route(HttpMethod.GET, '/admin/users') +@register_route(HttpMethod.GET, "/admin/users") async def handle_admin_users( app: Application, request: Request, @@ -149,29 +149,29 @@ async def handle_admin_users( with app.database.session() as conn: context: dict[str, Any] = { - 'users': tuple(conn.execute('SELECT * FROM users ORDER BY username ASC')) + "users": tuple(conn.execute("SELECT * FROM users ORDER BY username ASC")) } if error: - context['error'] = error + context["error"] = error if message: - context['message'] = message + context["message"] = message return Response.new_template(200, "page/admin-users.haml", request, context) -@register_route(HttpMethod.GET, '/admin/config') +@register_route(HttpMethod.GET, "/admin/config") async def handle_admin_config( app: Application, request: Request, message: str | None = None) -> Response: context: dict[str, Any] = { - 'themes': tuple(THEMES.keys()), - 'levels': tuple(level.name for level in LogLevel), - 'message': message, - 'desc': { + "themes": tuple(THEMES.keys()), + "levels": tuple(level.name for level in LogLevel), + "message": message, + "desc": { "name": "Name of the relay to be displayed in the header of the pages and in " + "the actor endpoint.", # noqa: E131 "note": "Description of the relay to be displayed on the front page and as the " + @@ -187,36 +187,36 @@ async def handle_admin_config( return Response.new_template(200, "page/admin-config.haml", request, context) -@register_route(HttpMethod.GET, '/manifest.json') +@register_route(HttpMethod.GET, "/manifest.json") async def handle_manifest(app: Application, request: Request) -> Response: with app.database.session(False) as conn: config = conn.get_config_all() theme = THEMES[config.theme] data = { - 'background_color': theme['background'], - 'categories': ['activitypub'], - 'description': 'Message relay for the ActivityPub network', - 'display': 'standalone', - 'name': config['name'], - 'orientation': 'portrait', - 'scope': f"https://{app.config.domain}/", - 'short_name': 'ActivityRelay', - 'start_url': f"https://{app.config.domain}/", - 'theme_color': theme['primary'] + "background_color": theme["background"], + "categories": ["activitypub"], + "description": "Message relay for the ActivityPub network", + "display": "standalone", + "name": config["name"], + "orientation": "portrait", + "scope": f"https://{app.config.domain}/", + "short_name": "ActivityRelay", + "start_url": f"https://{app.config.domain}/", + "theme_color": theme["primary"] } - return Response.new(data, ctype = 'webmanifest') + return Response.new(data, ctype = "webmanifest") -@register_route(HttpMethod.GET, '/theme/{theme}.css') # type: ignore[arg-type] +@register_route(HttpMethod.GET, "/theme/{theme}.css") # type: ignore[arg-type] async def handle_theme(app: Application, request: Request, theme: str) -> Response: try: context: dict[str, Any] = { - 'theme': THEMES[theme] + "theme": THEMES[theme] } except KeyError: - return Response.new('Invalid theme', 404) + return Response.new("Invalid theme", 404) return Response.new_template(200, "variables.css", request, context, ctype = "css") diff --git a/relay/views/misc.py b/relay/views/misc.py index 015c274..c9bdc08 100644 --- a/relay/views/misc.py +++ b/relay/views/misc.py @@ -21,16 +21,18 @@ VERSION = __version__ if File(__file__).join("../../../.git").resolve().exists: try: - commit_label = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode('ascii') - VERSION = f'{__version__} {commit_label}' + commit_label = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode("ascii") + VERSION = f"{__version__} {commit_label}" + + del commit_label except Exception: pass NODEINFO_PATHS = [ - '/nodeinfo/{niversion:\\d.\\d}.json', - '/nodeinfo/{niversion:\\d.\\d}' + "/nodeinfo/{niversion:\\d.\\d}.json", + "/nodeinfo/{niversion:\\d.\\d}" ] @@ -40,23 +42,23 @@ async def handle_nodeinfo(app: Application, request: Request, niversion: str) -> inboxes = conn.get_inboxes() nodeinfo = aputils.Nodeinfo.new( - name = 'activityrelay', + name = "activityrelay", version = VERSION, - protocols = ['activitypub'], - open_regs = not conn.get_config('whitelist-enabled'), + protocols = ["activitypub"], + open_regs = not conn.get_config("whitelist-enabled"), users = 1, - repo = 'https://git.pleroma.social/pleroma/relay' if niversion == '2.1' else None, + repo = "https://codeberg.org/barkshark/activityrelay" if niversion == "2.1" else None, metadata = { - 'approval_required': conn.get_config('approval-required'), - 'peers': [inbox['domain'] for inbox in inboxes] + "approval_required": conn.get_config("approval-required"), + "peers": [inbox["domain"] for inbox in inboxes] } ) - return Response.new(nodeinfo, ctype = 'json') + return Response.new(nodeinfo, ctype = "json") -@register_route(HttpMethod.GET, '/.well-known/nodeinfo') +@register_route(HttpMethod.GET, "/.well-known/nodeinfo") async def handle_wk_nodeinfo(app: Application, request: Request) -> Response: data = aputils.WellKnownNodeinfo.new_template(app.config.domain) - return Response.new(data, ctype = 'json') + return Response.new(data, ctype = "json") diff --git a/relay/workers.py b/relay/workers.py index 3a0a022..dd3c7e6 100644 --- a/relay/workers.py +++ b/relay/workers.py @@ -96,16 +96,16 @@ class PushWorker(Process): await self.client.post(item.inbox, item.message, item.instance) except HttpError as e: - logging.error('HTTP Error when pushing to %s: %i %s', item.inbox, e.status, e.message) + logging.error("HTTP Error when pushing to %s: %i %s", item.inbox, e.status, e.message) except AsyncTimeoutError: - logging.error('Timeout when pushing to %s', item.domain) + logging.error("Timeout when pushing to %s", item.domain) except ClientConnectionError as e: - logging.error('Failed to connect to %s for message push: %s', item.domain, str(e)) + logging.error("Failed to connect to %s for message push: %s", item.domain, str(e)) except ClientSSLError as e: - logging.error('SSL error when pushing to %s: %s', item.domain, str(e)) + logging.error("SSL error when pushing to %s: %s", item.domain, str(e)) class PushWorkers(list[PushWorker]):