Compare commits

..

14 commits

Author SHA1 Message Date
Izalia Mae
be21b30d61 version bump to 0.3.4 2025-02-11 13:20:21 -05:00
Izalia Mae
e4eadb69db fix linter warnings in dev script 2025-02-11 13:01:48 -05:00
Izalia Mae
f741017ef1 skip missing optional values in old config 2025-02-11 12:55:02 -05:00
Izalia Mae
760718b5dc fix actor type checking (fixes #45) 2025-02-11 12:39:18 -05:00
Izalia Mae
527afaca95 properly get actor and inbox urls
* get actor url from incoming message
* fall back to actor inbox if there is no shared inbox
2025-02-11 12:28:05 -05:00
Izalia Mae
ff275a5ba4 fix linter warnings 2025-02-11 12:26:57 -05:00
Izalia Mae
3e6a2a4f37 fix logins 2025-02-11 12:25:16 -05:00
Izalia Mae
5a1af750d5 fix spacing issue on home page 2025-02-11 12:15:21 -05:00
Izalia Mae
ade36d6e69 add buzzrelay to RELAY_SOFTWARE list 2025-01-29 06:11:25 -05:00
Izalia Mae
e47e69b14d add docs dependencies 2024-12-05 22:56:17 -05:00
Izalia Mae
f41b24406f remove unnecessary SOFTWARE const 2024-11-28 09:17:11 -05:00
Izalia Mae
338fd26688 replace single quotes with double quotes 2024-11-28 07:09:53 -05:00
Izalia Mae
29ebba7999 add more classiifiers and py.typed file 2024-11-28 06:10:35 -05:00
Izalia Mae
5131831363 sort out pyproject and replace jinja2-haml with hamlish 2024-11-28 05:43:39 -05:00
27 changed files with 1215 additions and 1189 deletions

111
dev.py
View file

@ -37,32 +37,32 @@ from watchdog.events import FileSystemEvent, PatternMatchingEventHandler
REPO = Path(__file__).parent
IGNORE_EXT = {
'.py',
'.pyc'
".py",
".pyc"
}
@click.group('cli')
@click.group("cli")
def cli() -> None:
'Useful commands for development'
"Useful commands for development"
@cli.command('install')
@click.option('--no-dev', '-d', is_flag = True, help = 'Do not install development dependencies')
@cli.command("install")
@click.option("--no-dev", "-d", is_flag = True, help = "Do not install development dependencies")
def cli_install(no_dev: bool) -> None:
with open('pyproject.toml', 'r', encoding = 'utf-8') as fd:
with open("pyproject.toml", "r", encoding = "utf-8") as fd:
data = tomllib.loads(fd.read())
deps = data['project']['dependencies']
deps.extend(data['project']['optional-dependencies']['dev'])
deps = data["project"]["dependencies"]
deps.extend(data["project"]["optional-dependencies"]["dev"])
subprocess.run([sys.executable, '-m', 'pip', 'install', '-U', *deps], check = False)
subprocess.run([sys.executable, "-m", "pip", "install", "-U", *deps], check = False)
@cli.command('lint')
@click.argument('path', required = False, type = Path, default = REPO.joinpath('relay'))
@click.option('--watch', '-w', is_flag = True,
help = 'Automatically, re-run the linters on source change')
@cli.command("lint")
@click.argument("path", required = False, type = Path, default = REPO.joinpath("relay"))
@click.option("--watch", "-w", is_flag = True,
help = "Automatically, re-run the linters on source change")
def cli_lint(path: Path, watch: bool) -> None:
path = path.expanduser().resolve()
@ -70,74 +70,73 @@ def cli_lint(path: Path, watch: bool) -> None:
handle_run_watcher([sys.executable, "dev.py", "lint", str(path)], wait = True)
return
flake8 = [sys.executable, '-m', 'flake8', "dev.py", str(path)]
mypy = [sys.executable, '-m', 'mypy', '--python-version', '3.12', 'dev.py', str(path)]
flake8 = [sys.executable, "-m", "flake8", "dev.py", str(path)]
mypy = [sys.executable, "-m", "mypy", "--python-version", "3.12", "dev.py", str(path)]
click.echo('----- flake8 -----')
click.echo("----- flake8 -----")
subprocess.run(flake8)
click.echo('\n\n----- mypy -----')
click.echo("\n\n----- mypy -----")
subprocess.run(mypy)
@cli.command('clean')
@cli.command("clean")
def cli_clean() -> None:
dirs = {
'dist',
'build',
'dist-pypi'
"dist",
"build",
"dist-pypi"
}
for directory in dirs:
shutil.rmtree(directory, ignore_errors = True)
for path in REPO.glob('*.egg-info'):
for path in REPO.glob("*.egg-info"):
shutil.rmtree(path)
for path in REPO.glob('*.spec'):
for path in REPO.glob("*.spec"):
path.unlink()
@cli.command('build')
@cli.command("build")
def cli_build() -> None:
from relay import __version__
with TemporaryDirectory() as tmp:
arch = 'amd64' if sys.maxsize >= 2**32 else 'i386'
arch = "amd64" if sys.maxsize >= 2**32 else "i386"
cmd = [
sys.executable, '-m', 'PyInstaller',
'--collect-data', 'relay',
'--collect-data', 'aiohttp_swagger',
'--hidden-import', 'pg8000',
'--hidden-import', 'sqlite3',
'--name', f'activityrelay-{__version__}-{platform.system().lower()}-{arch}',
'--workpath', tmp,
'--onefile', 'relay/__main__.py',
sys.executable, "-m", "PyInstaller",
"--collect-data", "relay",
"--hidden-import", "pg8000",
"--hidden-import", "sqlite3",
"--name", f"activityrelay-{__version__}-{platform.system().lower()}-{arch}",
"--workpath", tmp,
"--onefile", "relay/__main__.py",
]
if platform.system() == 'Windows':
cmd.append('--console')
if platform.system() == "Windows":
cmd.append("--console")
# putting the spec path on a different drive than the source dir breaks
if str(REPO)[0] == tmp[0]:
cmd.extend(['--specpath', tmp])
cmd.extend(["--specpath", tmp])
else:
cmd.append('--strip')
cmd.extend(['--specpath', tmp])
cmd.append("--strip")
cmd.extend(["--specpath", tmp])
subprocess.run(cmd, check = False)
@cli.command('run')
@click.option('--dev', '-d', is_flag = True)
@cli.command("run")
@click.option("--dev", "-d", is_flag = True)
def cli_run(dev: bool) -> None:
print('Starting process watcher')
print("Starting process watcher")
cmd = [sys.executable, '-m', 'relay', 'run']
cmd = [sys.executable, "-m", "relay", "run"]
if dev:
cmd.append('-d')
cmd.append("-d")
handle_run_watcher(cmd, watch_path = REPO.joinpath("relay"))
@ -151,8 +150,8 @@ def handle_run_watcher(
handler.run_procs()
watcher = Observer()
watcher.schedule(handler, str(watch_path), recursive=True) # type: ignore
watcher.start() # type: ignore
watcher.schedule(handler, str(watch_path), recursive=True)
watcher.start()
try:
while True:
@ -162,16 +161,16 @@ def handle_run_watcher(
pass
handler.kill_procs()
watcher.stop() # type: ignore
watcher.stop()
watcher.join()
class WatchHandler(PatternMatchingEventHandler):
patterns = ['*.py']
patterns = ["*.py"]
def __init__(self, *commands: Sequence[str], wait: bool = False) -> None:
PatternMatchingEventHandler.__init__(self) # type: ignore
PatternMatchingEventHandler.__init__(self)
self.commands: Sequence[Sequence[str]] = commands
self.wait: bool = wait
@ -184,7 +183,7 @@ class WatchHandler(PatternMatchingEventHandler):
if proc.poll() is not None:
continue
print(f'Terminating process {proc.pid}')
print(f"Terminating process {proc.pid}")
proc.terminate()
sec = 0.0
@ -193,11 +192,11 @@ class WatchHandler(PatternMatchingEventHandler):
sec += 0.1
if sec >= 5:
print('Failed to terminate. Killing process...')
print("Failed to terminate. Killing process...")
proc.kill()
break
print('Process terminated')
print("Process terminated")
def run_procs(self, restart: bool = False) -> None:
@ -213,21 +212,21 @@ class WatchHandler(PatternMatchingEventHandler):
self.procs = []
for cmd in self.commands:
print('Running command:', ' '.join(cmd))
print("Running command:", " ".join(cmd))
subprocess.run(cmd)
else:
self.procs = list(subprocess.Popen(cmd) for cmd in self.commands)
pids = (str(proc.pid) for proc in self.procs)
print('Started processes with PIDs:', ', '.join(pids))
print("Started processes with PIDs:", ", ".join(pids))
def on_any_event(self, event: FileSystemEvent) -> None:
if event.event_type not in ['modified', 'created', 'deleted']:
if event.event_type not in ["modified", "created", "deleted"]:
return
self.run_procs(restart = True)
if __name__ == '__main__':
if __name__ == "__main__":
cli()

View file

@ -1,37 +1,50 @@
[build-system]
requires = ["setuptools>=61.2"]
requires = [
"setuptools>=61.2",
]
build-backend = "setuptools.build_meta"
[project]
name = "ActivityRelay"
description = "Generic LitePub relay (works with all LitePub consumers and Mastodon)"
license = {text = "AGPLv3"}
classifiers = [
"Environment :: Console",
"License :: OSI Approved :: GNU Affero General Public License v3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12"
"Development Status :: 4 - Beta",
"Environment :: Console",
"Framework :: aiohttp",
"Framework :: AsyncIO",
"License :: OSI Approved :: GNU Affero General Public License v3",
"Programming Language :: Python",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: SQL",
"Topic :: Internet :: WWW/HTTP :: HTTP Servers",
"Typing :: Typed",
]
dependencies = [
"activitypub-utils >= 0.3.2, < 0.4",
"aiohttp >= 3.9.5",
"aiohttp-swagger[performance] == 1.0.16",
"argon2-cffi == 23.1.0",
"barkshark-lib >= 0.2.3, < 0.3.0",
"barkshark-sql >= 0.2.0, < 0.3.0",
"click == 8.1.2",
"docstring-parser == 0.16",
"hamlish == 0.4.0",
"hiredis == 2.3.2",
"idna == 3.4",
"jinja2-haml == 0.3.5",
"markdown == 3.6",
"platformdirs == 4.2.2",
"pyyaml == 6.0.1",
"redis == 5.0.7"
"redis == 5.0.7",
]
requires-python = ">=3.10"
dynamic = ["version"]
dynamic = [
"version",
]
[project.license]
file = "LICENSE"
[project.readme]
file = "README.md"
@ -40,42 +53,46 @@ content-type = "text/markdown; charset=UTF-8"
[project.urls]
Documentation = "https://git.pleroma.social/pleroma/relay/-/blob/main/docs/index.md"
Source = "https://git.pleroma.social/pleroma/relay"
Tracker = "https://git.pleroma.social/pleroma/relay/-/issues"
[project.scripts]
activityrelay = "relay.manage:main"
[project.optional-dependencies]
dev = [
"flake8 == 7.1.0",
"mypy == 1.11.1",
"build == 1.2.2.post1",
"flake8 == 7.1.1",
"mypy == 1.13.0",
"pyinstaller == 6.10.0",
"watchdog == 4.0.2"
]
docs = [
"furo == 2024.1.29",
"sphinx == 7.2.6",
"sphinx-external-toc == 1.0.1",
]
[tool.setuptools]
zip-safe = false
packages = [
"relay",
"relay.database",
"relay.views",
"relay",
"relay.database",
"relay.views",
]
include-package-data = true
license-files = ["LICENSE"]
license-files = [
"LICENSE",
]
[tool.setuptools.package-data]
relay = [
"data/*",
"frontend/*",
"frontend/page/*",
"frontend/static/*"
"py.typed",
"data/*",
"frontend/*",
"frontend/page/*",
"frontend/static/*",
]
[tool.setuptools.dynamic]
version = {attr = "relay.__version__"}
[tool.setuptools.dynamic.optional-dependencies]
dev = {file = ["dev-requirements.txt"]}
[tool.setuptools.dynamic.version]
attr = "relay.__version__"
[tool.mypy]
show_traceback = true
@ -89,15 +106,3 @@ ignore_missing_imports = true
implicit_reexport = true
strict = true
follow_imports = "silent"
[[tool.mypy.overrides]]
module = "relay.database"
implicit_reexport = true
[[tool.mypy.overrides]]
module = "aputils"
implicit_reexport = true
[[tool.mypy.overrides]]
module = "blib"
implicit_reexport = true

View file

@ -1 +1 @@
__version__ = '0.3.3'
__version__ = "0.3.4"

View file

@ -1,8 +1,5 @@
import multiprocessing
from relay.manage import main
if __name__ == '__main__':
multiprocessing.freeze_support()
if __name__ == "__main__":
main()

View file

@ -41,10 +41,10 @@ def get_csp(request: web.Request) -> str:
"img-src 'self'",
"object-src 'none'",
"frame-ancestors 'none'",
f"manifest-src 'self' https://{request.app['config'].domain}"
f"manifest-src 'self' https://{request.app["config"].domain}"
]
return '; '.join(data) + ';'
return "; ".join(data) + ";"
class Application(web.Application):
@ -61,20 +61,20 @@ class Application(web.Application):
Application.DEFAULT = self
self['running'] = False
self['signer'] = None
self['start_time'] = None
self['cleanup_thread'] = None
self['dev'] = dev
self["running"] = False
self["signer"] = None
self["start_time"] = None
self["cleanup_thread"] = None
self["dev"] = dev
self['config'] = Config(cfgpath, load = True)
self['database'] = get_database(self.config)
self['client'] = HttpClient()
self['cache'] = get_cache(self)
self['cache'].setup()
self['template'] = Template(self)
self['push_queue'] = multiprocessing.Queue()
self['workers'] = PushWorkers(self.config.workers)
self["config"] = Config(cfgpath, load = True)
self["database"] = get_database(self.config)
self["client"] = HttpClient()
self["cache"] = get_cache(self)
self["cache"].setup()
self["template"] = Template(self)
self["push_queue"] = multiprocessing.Queue()
self["workers"] = PushWorkers(self.config.workers)
self.cache.setup()
self.on_cleanup.append(handle_cleanup) # type: ignore
@ -82,69 +82,69 @@ class Application(web.Application):
@property
def cache(self) -> Cache:
return cast(Cache, self['cache'])
return cast(Cache, self["cache"])
@property
def client(self) -> HttpClient:
return cast(HttpClient, self['client'])
return cast(HttpClient, self["client"])
@property
def config(self) -> Config:
return cast(Config, self['config'])
return cast(Config, self["config"])
@property
def database(self) -> Database[Connection]:
return cast(Database[Connection], self['database'])
return cast(Database[Connection], self["database"])
@property
def signer(self) -> Signer:
return cast(Signer, self['signer'])
return cast(Signer, self["signer"])
@signer.setter
def signer(self, value: Signer | str) -> None:
if isinstance(value, Signer):
self['signer'] = value
self["signer"] = value
return
self['signer'] = Signer(value, self.config.keyid)
self["signer"] = Signer(value, self.config.keyid)
@property
def template(self) -> Template:
return cast(Template, self['template'])
return cast(Template, self["template"])
@property
def uptime(self) -> timedelta:
if not self['start_time']:
if not self["start_time"]:
return timedelta(seconds=0)
uptime = datetime.now() - self['start_time']
uptime = datetime.now() - self["start_time"]
return timedelta(seconds=uptime.seconds)
@property
def workers(self) -> PushWorkers:
return cast(PushWorkers, self['workers'])
return cast(PushWorkers, self["workers"])
def push_message(self, inbox: str, message: Message, instance: Instance) -> None:
self['workers'].push_message(inbox, message, instance)
self["workers"].push_message(inbox, message, instance)
def register_static_routes(self) -> None:
if self['dev']:
static = StaticResource('/static', File.from_resource('relay', 'frontend/static'))
if self["dev"]:
static = StaticResource("/static", File.from_resource("relay", "frontend/static"))
else:
static = CachedStaticResource(
'/static', Path(File.from_resource('relay', 'frontend/static'))
"/static", Path(File.from_resource("relay", "frontend/static"))
)
self.router.register_resource(static)
@ -158,18 +158,18 @@ class Application(web.Application):
host = self.config.listen
port = self.config.port
if port_check(port, '127.0.0.1' if host == '0.0.0.0' else host):
logging.error(f'A server is already running on {host}:{port}')
if port_check(port, "127.0.0.1" if host == "0.0.0.0" else host):
logging.error(f"A server is already running on {host}:{port}")
return
self.register_static_routes()
logging.info(f'Starting webserver at {domain} ({host}:{port})')
logging.info(f"Starting webserver at {domain} ({host}:{port})")
asyncio.run(self.handle_run())
def set_signal_handler(self, startup: bool) -> None:
for sig in ('SIGHUP', 'SIGINT', 'SIGQUIT', 'SIGTERM'):
for sig in ("SIGHUP", "SIGINT", "SIGQUIT", "SIGTERM"):
try:
signal.signal(getattr(signal, sig), self.stop if startup else signal.SIG_DFL)
@ -179,22 +179,25 @@ class Application(web.Application):
def stop(self, *_: Any) -> None:
self['running'] = False
self["running"] = False
async def handle_run(self) -> None:
self['running'] = True
self["running"] = True
self.set_signal_handler(True)
self['client'].open()
self['database'].connect()
self['cache'].setup()
self['cleanup_thread'] = CacheCleanupThread(self)
self['cleanup_thread'].start()
self['workers'].start()
self["client"].open()
self["database"].connect()
self["cache"].setup()
self["cleanup_thread"] = CacheCleanupThread(self)
self["cleanup_thread"].start()
self["workers"].start()
runner = web.AppRunner(
self, access_log_format = "%{X-Forwarded-For}i \"%r\" %s %b \"%{User-Agent}i\""
)
runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"')
await runner.setup()
site = web.TCPSite(
@ -205,22 +208,22 @@ class Application(web.Application):
)
await site.start()
self['starttime'] = datetime.now()
self["starttime"] = datetime.now()
while self['running']:
while self["running"]:
await asyncio.sleep(0.25)
await site.stop()
self['workers'].stop()
self["workers"].stop()
self.set_signal_handler(False)
self['starttime'] = None
self['running'] = False
self['cleanup_thread'].stop()
self['database'].disconnect()
self['cache'].close()
self["starttime"] = None
self["running"] = False
self["cleanup_thread"].stop()
self["database"].disconnect()
self["cache"].close()
class CachedStaticResource(StaticResource):
@ -229,19 +232,19 @@ class CachedStaticResource(StaticResource):
self.cache: dict[str, bytes] = {}
for filename in path.rglob('*'):
for filename in path.rglob("*"):
if filename.is_dir():
continue
rel_path = str(filename.relative_to(path))
with filename.open('rb') as fd:
logging.debug('Loading static resource "%s"', rel_path)
with filename.open("rb") as fd:
logging.debug("Loading static resource \"%s\"", rel_path)
self.cache[rel_path] = fd.read()
async def _handle(self, request: web.Request) -> web.StreamResponse:
rel_url = request.match_info['filename']
rel_url = request.match_info["filename"]
if Path(rel_url).anchor:
raise web.HTTPForbidden()
@ -281,8 +284,8 @@ class CacheCleanupThread(Thread):
def format_error(request: web.Request, error: HttpError) -> Response:
if request.path.startswith(JSON_PATHS) or 'json' in request.headers.get('accept', ''):
return Response.new({'error': error.message}, error.status, ctype = 'json')
if request.path.startswith(JSON_PATHS) or "json" in request.headers.get("accept", ""):
return Response.new({"error": error.message}, error.status, ctype = "json")
else:
context = {"e": error}
@ -294,27 +297,27 @@ async def handle_response_headers(
request: web.Request,
handler: Callable[[web.Request], Awaitable[Response]]) -> Response:
request['hash'] = b64encode(get_random_bytes(16)).decode('ascii')
request['token'] = None
request['user'] = None
request["hash"] = b64encode(get_random_bytes(16)).decode("ascii")
request["token"] = None
request["user"] = None
app: Application = request.app # type: ignore[assignment]
if request.path in {"/", "/docs"} or request.path.startswith(TOKEN_PATHS):
with app.database.session() as conn:
tokens = (
request.headers.get('Authorization', '').replace('Bearer', '').strip(),
request.cookies.get('user-token')
request.headers.get("Authorization", "").replace("Bearer", "").strip(),
request.cookies.get("user-token")
)
for token in tokens:
if not token:
continue
request['token'] = conn.get_app_by_token(token)
request["token"] = conn.get_app_by_token(token)
if request['token'] is not None:
request['user'] = conn.get_user(request['token'].user)
if request["token"] is not None:
request["user"] = conn.get_user(request["token"].user)
break
@ -338,21 +341,21 @@ async def handle_response_headers(
raise
except Exception:
resp = format_error(request, HttpError(500, 'Internal server error'))
resp = format_error(request, HttpError(500, "Internal server error"))
traceback.print_exc()
resp.headers['Server'] = 'ActivityRelay'
resp.headers["Server"] = "ActivityRelay"
# Still have to figure out how csp headers work
if resp.content_type == 'text/html' and not request.path.startswith("/api"):
resp.headers['Content-Security-Policy'] = get_csp(request)
if resp.content_type == "text/html" and not request.path.startswith("/api"):
resp.headers["Content-Security-Policy"] = get_csp(request)
if not request.app['dev'] and request.path.endswith(('.css', '.js', '.woff2')):
if not request.app["dev"] and request.path.endswith((".css", ".js", ".woff2")):
# cache for 2 weeks
resp.headers['Cache-Control'] = 'public,max-age=1209600,immutable'
resp.headers["Cache-Control"] = "public,max-age=1209600,immutable"
else:
resp.headers['Cache-Control'] = 'no-store'
resp.headers["Cache-Control"] = "no-store"
return resp
@ -362,25 +365,25 @@ async def handle_frontend_path(
request: web.Request,
handler: Callable[[web.Request], Awaitable[Response]]) -> Response:
if request['user'] is not None and request.path == '/login':
return Response.new_redir('/')
if request["user"] is not None and request.path == "/login":
return Response.new_redir("/")
if request.path.startswith(TOKEN_PATHS[:2]) and request['user'] is None:
if request.path == '/logout':
return Response.new_redir('/')
if request.path.startswith(TOKEN_PATHS[:2]) and request["user"] is None:
if request.path == "/logout":
return Response.new_redir("/")
response = Response.new_redir(f'/login?redir={request.path}')
response = Response.new_redir(f"/login?redir={request.path}")
if request['token'] is not None:
response.del_cookie('user-token')
if request["token"] is not None:
response.del_cookie("user-token")
return response
response = await handler(request)
if not request.path.startswith('/api'):
if request['user'] is None and request['token'] is not None:
response.del_cookie('user-token')
if not request.path.startswith("/api"):
if request["user"] is None and request["token"] is not None:
response.del_cookie("user-token")
return response

View file

@ -24,11 +24,11 @@ DeserializerCallback = Callable[[str], Any]
BACKENDS: dict[str, type[Cache]] = {}
CONVERTERS: dict[str, tuple[SerializerCallback, DeserializerCallback]] = {
'str': (str, str),
'int': (str, int),
'bool': (str, convert_to_boolean),
'json': (json.dumps, json.loads),
'message': (lambda x: x.to_json(), Message.parse)
"str": (str, str),
"int": (str, int),
"bool": (str, convert_to_boolean),
"json": (json.dumps, json.loads),
"message": (lambda x: x.to_json(), Message.parse)
}
@ -49,14 +49,14 @@ def register_cache(backend: type[Cache]) -> type[Cache]:
return backend
def serialize_value(value: Any, value_type: str = 'str') -> str:
def serialize_value(value: Any, value_type: str = "str") -> str:
if isinstance(value, str):
return value
return CONVERTERS[value_type][0](value)
def deserialize_value(value: str, value_type: str = 'str') -> Any:
def deserialize_value(value: str, value_type: str = "str") -> Any:
return CONVERTERS[value_type][1](value)
@ -93,7 +93,7 @@ class Item:
class Cache(ABC):
name: str = 'null'
name: str
def __init__(self, app: Application):
@ -116,7 +116,7 @@ class Cache(ABC):
@abstractmethod
def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> Item:
def set(self, namespace: str, key: str, value: Any, value_type: str = "key") -> Item:
...
@ -160,7 +160,7 @@ class Cache(ABC):
@register_cache
class SqlCache(Cache):
name: str = 'database'
name: str = "database"
def __init__(self, app: Application):
@ -173,16 +173,16 @@ class SqlCache(Cache):
raise RuntimeError("Database has not been setup")
params = {
'namespace': namespace,
'key': key
"namespace": namespace,
"key": key
}
with self._db.session(False) as conn:
with conn.run('get-cache-item', params) as cur:
with conn.run("get-cache-item", params) as cur:
if not (row := cur.one(Row)):
raise KeyError(f'{namespace}:{key}')
raise KeyError(f"{namespace}:{key}")
row.pop('id', None)
row.pop("id", None)
return Item.from_data(*tuple(row.values()))
@ -191,8 +191,8 @@ class SqlCache(Cache):
raise RuntimeError("Database has not been setup")
with self._db.session(False) as conn:
for row in conn.run('get-cache-keys', {'namespace': namespace}):
yield row['key']
for row in conn.run("get-cache-keys", {"namespace": namespace}):
yield row["key"]
def get_namespaces(self) -> Iterator[str]:
@ -200,28 +200,28 @@ class SqlCache(Cache):
raise RuntimeError("Database has not been setup")
with self._db.session(False) as conn:
for row in conn.run('get-cache-namespaces', None):
yield row['namespace']
for row in conn.run("get-cache-namespaces", None):
yield row["namespace"]
def set(self, namespace: str, key: str, value: Any, value_type: str = 'str') -> Item:
def set(self, namespace: str, key: str, value: Any, value_type: str = "str") -> Item:
if self._db is None:
raise RuntimeError("Database has not been setup")
params = {
'namespace': namespace,
'key': key,
'value': serialize_value(value, value_type),
'type': value_type,
'date': Date.new_utc()
"namespace": namespace,
"key": key,
"value": serialize_value(value, value_type),
"type": value_type,
"date": Date.new_utc()
}
with self._db.session(True) as conn:
with conn.run('set-cache-item', params) as cur:
with conn.run("set-cache-item", params) as cur:
if (row := cur.one(Row)) is None:
raise RuntimeError("Cache item not set")
row.pop('id', None)
row.pop("id", None)
return Item.from_data(*tuple(row.values()))
@ -230,12 +230,12 @@ class SqlCache(Cache):
raise RuntimeError("Database has not been setup")
params = {
'namespace': namespace,
'key': key
"namespace": namespace,
"key": key
}
with self._db.session(True) as conn:
with conn.run('del-cache-item', params):
with conn.run("del-cache-item", params):
pass
@ -267,7 +267,7 @@ class SqlCache(Cache):
self._db.connect()
with self._db.session(True) as conn:
with conn.run(f'create-cache-table-{self._db.backend_type.value}', None):
with conn.run(f"create-cache-table-{self._db.backend_type.value}", None):
pass
@ -281,7 +281,7 @@ class SqlCache(Cache):
@register_cache
class RedisCache(Cache):
name: str = 'redis'
name: str = "redis"
def __init__(self, app: Application):
@ -295,7 +295,7 @@ class RedisCache(Cache):
def get_key_name(self, namespace: str, key: str) -> str:
return f'{self.prefix}:{namespace}:{key}'
return f"{self.prefix}:{namespace}:{key}"
def get(self, namespace: str, key: str) -> Item:
@ -305,9 +305,9 @@ class RedisCache(Cache):
key_name = self.get_key_name(namespace, key)
if not (raw_value := self._rd.get(key_name)):
raise KeyError(f'{namespace}:{key}')
raise KeyError(f"{namespace}:{key}")
value_type, updated, value = raw_value.split(':', 2) # type: ignore[union-attr]
value_type, updated, value = raw_value.split(":", 2) # type: ignore[union-attr]
return Item.from_data(
namespace,
@ -322,8 +322,8 @@ class RedisCache(Cache):
if self._rd is None:
raise ConnectionError("Not connected")
for key in self._rd.scan_iter(self.get_key_name(namespace, '*')):
*_, key_name = key.split(':', 2)
for key in self._rd.scan_iter(self.get_key_name(namespace, "*")):
*_, key_name = key.split(":", 2)
yield key_name
@ -333,15 +333,15 @@ class RedisCache(Cache):
namespaces = []
for key in self._rd.scan_iter(f'{self.prefix}:*'):
_, namespace, _ = key.split(':', 2)
for key in self._rd.scan_iter(f"{self.prefix}:*"):
_, namespace, _ = key.split(":", 2)
if namespace not in namespaces:
namespaces.append(namespace)
yield namespace
def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> Item:
def set(self, namespace: str, key: str, value: Any, value_type: str = "key") -> Item:
if self._rd is None:
raise ConnectionError("Not connected")
@ -350,7 +350,7 @@ class RedisCache(Cache):
self._rd.set(
self.get_key_name(namespace, key),
f'{value_type}:{date}:{value}'
f"{value_type}:{date}:{value}"
)
return self.get(namespace, key)
@ -369,8 +369,8 @@ class RedisCache(Cache):
limit = Date.new_utc() - timedelta(days = days)
for full_key in self._rd.scan_iter(f'{self.prefix}:*'):
_, namespace, key = full_key.split(':', 2)
for full_key in self._rd.scan_iter(f"{self.prefix}:*"):
_, namespace, key = full_key.split(":", 2)
item = self.get(namespace, key)
if item.updated < limit:
@ -389,11 +389,11 @@ class RedisCache(Cache):
return
options: RedisConnectType = {
'client_name': f'ActivityRelay_{self.app.config.domain}',
'decode_responses': True,
'username': self.app.config.rd_user,
'password': self.app.config.rd_pass,
'db': self.app.config.rd_database
"client_name": f"ActivityRelay_{self.app.config.domain}",
"decode_responses": True,
"username": self.app.config.rd_user,
"password": self.app.config.rd_pass,
"db": self.app.config.rd_database
}
if os.path.exists(self.app.config.rd_host):

View file

@ -14,21 +14,21 @@ class RelayConfig(dict[str, Any]):
dict.__init__(self, {})
if self.is_docker:
path = '/data/config.yaml'
path = "/data/config.yaml"
self._path = Path(path).expanduser().resolve()
self.reset()
def __setitem__(self, key: str, value: Any) -> None:
if key in {'blocked_instances', 'blocked_software', 'whitelist'}:
if key in {"blocked_instances", "blocked_software", "whitelist"}:
assert isinstance(value, (list, set, tuple))
elif key in {'port', 'workers', 'json_cache', 'timeout'}:
elif key in {"port", "workers", "json_cache", "timeout"}:
if not isinstance(value, int):
value = int(value)
elif key == 'whitelist_enabled':
elif key == "whitelist_enabled":
if not isinstance(value, bool):
value = convert_to_boolean(value)
@ -37,45 +37,45 @@ class RelayConfig(dict[str, Any]):
@property
def db(self) -> Path:
return Path(self['db']).expanduser().resolve()
return Path(self["db"]).expanduser().resolve()
@property
def actor(self) -> str:
return f'https://{self["host"]}/actor'
return f"https://{self['host']}/actor"
@property
def inbox(self) -> str:
return f'https://{self["host"]}/inbox'
return f"https://{self['host']}/inbox"
@property
def keyid(self) -> str:
return f'{self.actor}#main-key'
return f"{self.actor}#main-key"
@cached_property
def is_docker(self) -> bool:
return bool(os.environ.get('DOCKER_RUNNING'))
return bool(os.environ.get("DOCKER_RUNNING"))
def reset(self) -> None:
self.clear()
self.update({
'db': str(self._path.parent.joinpath(f'{self._path.stem}.jsonld')),
'listen': '0.0.0.0',
'port': 8080,
'note': 'Make a note about your instance here.',
'push_limit': 512,
'json_cache': 1024,
'timeout': 10,
'workers': 0,
'host': 'relay.example.com',
'whitelist_enabled': False,
'blocked_software': [],
'blocked_instances': [],
'whitelist': []
"db": str(self._path.parent.joinpath(f"{self._path.stem}.jsonld")),
"listen": "0.0.0.0",
"port": 8080,
"note": "Make a note about your instance here.",
"push_limit": 512,
"json_cache": 1024,
"timeout": 10,
"workers": 0,
"host": "relay.example.com",
"whitelist_enabled": False,
"blocked_software": [],
"blocked_instances": [],
"whitelist": []
})
@ -85,13 +85,13 @@ class RelayConfig(dict[str, Any]):
options = {}
try:
options['Loader'] = yaml.FullLoader
options["Loader"] = yaml.FullLoader
except AttributeError:
pass
try:
with self._path.open('r', encoding = 'UTF-8') as fd:
with self._path.open("r", encoding = "UTF-8") as fd:
config = yaml.load(fd, **options)
except FileNotFoundError:
@ -101,7 +101,7 @@ class RelayConfig(dict[str, Any]):
return
for key, value in config.items():
if key == 'ap':
if key == "ap":
for k, v in value.items():
if k not in self:
continue
@ -119,10 +119,10 @@ class RelayConfig(dict[str, Any]):
class RelayDatabase(dict[str, Any]):
def __init__(self, config: RelayConfig):
dict.__init__(self, {
'relay-list': {},
'private-key': None,
'follow-requests': {},
'version': 1
"relay-list": {},
"private-key": None,
"follow-requests": {},
"version": 1
})
self.config = config
@ -131,12 +131,12 @@ class RelayDatabase(dict[str, Any]):
@property
def hostnames(self) -> tuple[str]:
return tuple(self['relay-list'].keys())
return tuple(self["relay-list"].keys())
@property
def inboxes(self) -> tuple[dict[str, str]]:
return tuple(data['inbox'] for data in self['relay-list'].values())
return tuple(data["inbox"] for data in self["relay-list"].values())
def load(self) -> None:
@ -144,29 +144,29 @@ class RelayDatabase(dict[str, Any]):
with self.config.db.open() as fd:
data = json.load(fd)
self['version'] = data.get('version', None)
self['private-key'] = data.get('private-key')
self["version"] = data.get("version", None)
self["private-key"] = data.get("private-key")
if self['version'] is None:
self['version'] = 1
if self["version"] is None:
self["version"] = 1
if 'actorKeys' in data:
self['private-key'] = data['actorKeys']['privateKey']
if "actorKeys" in data:
self["private-key"] = data["actorKeys"]["privateKey"]
for item in data.get('relay-list', []):
for item in data.get("relay-list", []):
domain = urlparse(item).hostname
self['relay-list'][domain] = {
'domain': domain,
'inbox': item,
'followid': None
self["relay-list"][domain] = {
"domain": domain,
"inbox": item,
"followid": None
}
else:
self['relay-list'] = data.get('relay-list', {})
self["relay-list"] = data.get("relay-list", {})
for domain, instance in self['relay-list'].items():
if not instance.get('domain'):
instance['domain'] = domain
for domain, instance in self["relay-list"].items():
if not instance.get("domain"):
instance["domain"] = domain
except FileNotFoundError:
pass

View file

@ -16,7 +16,7 @@ if TYPE_CHECKING:
from typing import Self
if platform.system() == 'Windows':
if platform.system() == "Windows":
import multiprocessing
CORE_COUNT = multiprocessing.cpu_count()
@ -25,9 +25,9 @@ else:
DOCKER_VALUES = {
'listen': '0.0.0.0',
'port': 8080,
'sq_path': '/data/relay.sqlite3'
"listen": "0.0.0.0",
"port": 8080,
"sq_path": "/data/relay.sqlite3"
}
@ -37,26 +37,26 @@ class NOVALUE:
@dataclass(init = False)
class Config:
listen: str = '0.0.0.0'
listen: str = "0.0.0.0"
port: int = 8080
domain: str = 'relay.example.com'
domain: str = "relay.example.com"
workers: int = CORE_COUNT
db_type: str = 'sqlite'
ca_type: str = 'database'
sq_path: str = 'relay.sqlite3'
db_type: str = "sqlite"
ca_type: str = "database"
sq_path: str = "relay.sqlite3"
pg_host: str = '/var/run/postgresql'
pg_host: str = "/var/run/postgresql"
pg_port: int = 5432
pg_user: str = getpass.getuser()
pg_pass: str | None = None
pg_name: str = 'activityrelay'
pg_name: str = "activityrelay"
rd_host: str = 'localhost'
rd_host: str = "localhost"
rd_port: int = 6470
rd_user: str | None = None
rd_pass: str | None = None
rd_database: int = 0
rd_prefix: str = 'activityrelay'
rd_prefix: str = "activityrelay"
def __init__(self, path: Path | None = None, load: bool = False):
@ -116,17 +116,17 @@ class Config:
@property
def actor(self) -> str:
return f'https://{self.domain}/actor'
return f"https://{self.domain}/actor"
@property
def inbox(self) -> str:
return f'https://{self.domain}/inbox'
return f"https://{self.domain}/inbox"
@property
def keyid(self) -> str:
return f'{self.actor}#main-key'
return f"{self.actor}#main-key"
def load(self) -> None:
@ -134,43 +134,43 @@ class Config:
options = {}
try:
options['Loader'] = yaml.FullLoader
options["Loader"] = yaml.FullLoader
except AttributeError:
pass
with self.path.open('r', encoding = 'UTF-8') as fd:
with self.path.open("r", encoding = "UTF-8") as fd:
config = yaml.load(fd, **options)
if not config:
raise ValueError('Config is empty')
raise ValueError("Config is empty")
pgcfg = config.get('postgres', {})
rdcfg = config.get('redis', {})
pgcfg = config.get("postgres", {})
rdcfg = config.get("redis", {})
for key in type(self).KEYS():
if IS_DOCKER and key in {'listen', 'port', 'sq_path'}:
if IS_DOCKER and key in {"listen", "port", "sq_path"}:
self.set(key, DOCKER_VALUES[key])
continue
if key.startswith('pg'):
if key.startswith("pg"):
self.set(key, pgcfg.get(key[3:], NOVALUE))
continue
elif key.startswith('rd'):
elif key.startswith("rd"):
self.set(key, rdcfg.get(key[3:], NOVALUE))
continue
cfgkey = key
if key == 'db_type':
cfgkey = 'database_type'
if key == "db_type":
cfgkey = "database_type"
elif key == 'ca_type':
cfgkey = 'cache_type'
elif key == "ca_type":
cfgkey = "cache_type"
elif key == 'sq_path':
cfgkey = 'sqlite_path'
elif key == "sq_path":
cfgkey = "sqlite_path"
self.set(key, config.get(cfgkey, NOVALUE))
@ -186,32 +186,32 @@ class Config:
data: dict[str, Any] = {}
for key, value in asdict(self).items():
if key.startswith('pg_'):
if 'postgres' not in data:
data['postgres'] = {}
if key.startswith("pg_"):
if "postgres" not in data:
data["postgres"] = {}
data['postgres'][key[3:]] = value
data["postgres"][key[3:]] = value
continue
if key.startswith('rd_'):
if 'redis' not in data:
data['redis'] = {}
if key.startswith("rd_"):
if "redis" not in data:
data["redis"] = {}
data['redis'][key[3:]] = value
data["redis"][key[3:]] = value
continue
if key == 'db_type':
key = 'database_type'
if key == "db_type":
key = "database_type"
elif key == 'ca_type':
key = 'cache_type'
elif key == "ca_type":
key = "cache_type"
elif key == 'sq_path':
key = 'sqlite_path'
elif key == "sq_path":
key = "sqlite_path"
data[key] = value
with self.path.open('w', encoding = 'utf-8') as fd:
with self.path.open("w", encoding = "utf-8") as fd:
yaml.dump(data, fd, sort_keys = False)

View file

@ -16,17 +16,17 @@ sqlite3.register_adapter(Date, Date.timestamp)
def get_database(config: Config, migrate: bool = True) -> Database[Connection]:
options = {
'connection_class': Connection,
'pool_size': 5,
'tables': TABLES
"connection_class": Connection,
"pool_size": 5,
"tables": TABLES
}
db: Database[Connection]
if config.db_type == 'sqlite':
if config.db_type == "sqlite":
db = Database.sqlite(config.sqlite_path, **options)
elif config.db_type == 'postgres':
elif config.db_type == "postgres":
db = Database.postgresql(
config.pg_name,
config.pg_host,
@ -36,30 +36,30 @@ def get_database(config: Config, migrate: bool = True) -> Database[Connection]:
**options
)
db.load_prepared_statements(File.from_resource('relay', 'data/statements.sql'))
db.load_prepared_statements(File.from_resource("relay", "data/statements.sql"))
db.connect()
if not migrate:
return db
with db.session(True) as conn:
if 'config' not in conn.get_tables():
if "config" not in conn.get_tables():
logging.info("Creating database tables")
migrate_0(conn)
return db
if (schema_ver := conn.get_config('schema-version')) < ConfigData.DEFAULT('schema-version'):
if (schema_ver := conn.get_config("schema-version")) < ConfigData.DEFAULT("schema-version"):
logging.info("Migrating database from version '%i'", schema_ver)
for ver, func in VERSIONS.items():
if schema_ver < ver:
func(conn)
conn.put_config('schema-version', ver)
conn.put_config("schema-version", ver)
logging.info("Updated database to %i", ver)
if (privkey := conn.get_config('private-key')):
if (privkey := conn.get_config("private-key")):
conn.app.signer = privkey
logging.set_level(conn.get_config('log-level'))
logging.set_level(conn.get_config("log-level"))
return db

View file

@ -15,76 +15,76 @@ if TYPE_CHECKING:
THEMES = {
'default': {
'text': '#DDD',
'background': '#222',
'primary': '#D85',
'primary-hover': '#DA8',
'section-background': '#333',
'table-background': '#444',
'border': '#444',
'message-text': '#DDD',
'message-background': '#335',
'message-border': '#446',
'error-text': '#DDD',
'error-background': '#533',
'error-border': '#644'
"default": {
"text": "#DDD",
"background": "#222",
"primary": "#D85",
"primary-hover": "#DA8",
"section-background": "#333",
"table-background": "#444",
"border": "#444",
"message-text": "#DDD",
"message-background": "#335",
"message-border": "#446",
"error-text": "#DDD",
"error-background": "#533",
"error-border": "#644"
},
'pink': {
'text': '#DDD',
'background': '#222',
'primary': '#D69',
'primary-hover': '#D36',
'section-background': '#333',
'table-background': '#444',
'border': '#444',
'message-text': '#DDD',
'message-background': '#335',
'message-border': '#446',
'error-text': '#DDD',
'error-background': '#533',
'error-border': '#644'
"pink": {
"text": "#DDD",
"background": "#222",
"primary": "#D69",
"primary-hover": "#D36",
"section-background": "#333",
"table-background": "#444",
"border": "#444",
"message-text": "#DDD",
"message-background": "#335",
"message-border": "#446",
"error-text": "#DDD",
"error-background": "#533",
"error-border": "#644"
},
'blue': {
'text': '#DDD',
'background': '#222',
'primary': '#69D',
'primary-hover': '#36D',
'section-background': '#333',
'table-background': '#444',
'border': '#444',
'message-text': '#DDD',
'message-background': '#335',
'message-border': '#446',
'error-text': '#DDD',
'error-background': '#533',
'error-border': '#644'
"blue": {
"text": "#DDD",
"background": "#222",
"primary": "#69D",
"primary-hover": "#36D",
"section-background": "#333",
"table-background": "#444",
"border": "#444",
"message-text": "#DDD",
"message-background": "#335",
"message-border": "#446",
"error-text": "#DDD",
"error-background": "#533",
"error-border": "#644"
}
}
# serializer | deserializer
CONFIG_CONVERT: dict[str, tuple[Callable[[Any], str], Callable[[str], Any]]] = {
'str': (str, str),
'int': (str, int),
'bool': (str, convert_to_boolean),
'logging.LogLevel': (lambda x: x.name, logging.LogLevel.parse)
"str": (str, str),
"int": (str, int),
"bool": (str, convert_to_boolean),
"logging.LogLevel": (lambda x: x.name, logging.LogLevel.parse)
}
@dataclass()
class ConfigData:
schema_version: int = 20240625
private_key: str = ''
private_key: str = ""
approval_required: bool = False
log_level: logging.LogLevel = logging.LogLevel.INFO
name: str = 'ActivityRelay'
note: str = ''
theme: str = 'default'
name: str = "ActivityRelay"
note: str = ""
theme: str = "default"
whitelist_enabled: bool = False
def __getitem__(self, key: str) -> Any:
if (value := getattr(self, key.replace('-', '_'), None)) is None:
if (value := getattr(self, key.replace("-", "_"), None)) is None:
raise KeyError(key)
return value
@ -101,7 +101,7 @@ class ConfigData:
@staticmethod
def SYSTEM_KEYS() -> Sequence[str]:
return ('schema-version', 'schema_version', 'private-key', 'private_key')
return ("schema-version", "schema_version", "private-key", "private_key")
@classmethod
@ -111,12 +111,12 @@ class ConfigData:
@classmethod
def DEFAULT(cls: type[Self], key: str) -> str | int | bool:
return cls.FIELD(key.replace('-', '_')).default # type: ignore[return-value]
return cls.FIELD(key.replace("-", "_")).default # type: ignore[return-value]
@classmethod
def FIELD(cls: type[Self], key: str) -> Field[str | int | bool]:
parsed_key = key.replace('-', '_')
parsed_key = key.replace("-", "_")
for field in fields(cls):
if field.name == parsed_key:
@ -131,9 +131,9 @@ class ConfigData:
set_schema_version = False
for row in rows:
data.set(row['key'], row['value'])
data.set(row["key"], row["value"])
if row['key'] == 'schema-version':
if row["key"] == "schema-version":
set_schema_version = True
if not set_schema_version:
@ -161,4 +161,4 @@ class ConfigData:
def to_dict(self) -> dict[str, Any]:
return {key.replace('_', '-'): value for key, value in asdict(self).items()}
return {key.replace("_", "-"): value for key, value in asdict(self).items()}

View file

@ -23,16 +23,17 @@ if TYPE_CHECKING:
from ..application import Application
RELAY_SOFTWARE = [
'activityrelay', # https://git.pleroma.social/pleroma/relay
'activity-relay', # https://github.com/yukimochi/Activity-Relay
'aoderelay', # https://git.asonix.dog/asonix/relay
'feditools-relay' # https://git.ptzo.gdn/feditools/relay
"activityrelay", # https://git.pleroma.social/pleroma/relay
"activity-relay", # https://github.com/yukimochi/Activity-Relay
"aoderelay", # https://git.asonix.dog/asonix/relay
"feditools-relay", # https://git.ptzo.gdn/feditools/relay
"buzzrelay" # https://github.com/astro/buzzrelay
]
class Connection(SqlConnection):
hasher = PasswordHasher(
encoding = 'utf-8'
encoding = "utf-8"
)
@property
@ -63,49 +64,49 @@ class Connection(SqlConnection):
def fix_timestamps(self) -> None:
for app in self.select('apps').all(schema.App):
data = {'created': app.created.timestamp(), 'accessed': app.accessed.timestamp()}
self.update('apps', data, client_id = app.client_id)
for app in self.select("apps").all(schema.App):
data = {"created": app.created.timestamp(), "accessed": app.accessed.timestamp()}
self.update("apps", data, client_id = app.client_id)
for item in self.select('cache'):
data = {'updated': Date.parse(item['updated']).timestamp()}
self.update('cache', data, id = item['id'])
for item in self.select("cache"):
data = {"updated": Date.parse(item["updated"]).timestamp()}
self.update("cache", data, id = item["id"])
for dban in self.select('domain_bans').all(schema.DomainBan):
data = {'created': dban.created.timestamp()}
self.update('domain_bans', data, domain = dban.domain)
for dban in self.select("domain_bans").all(schema.DomainBan):
data = {"created": dban.created.timestamp()}
self.update("domain_bans", data, domain = dban.domain)
for instance in self.select('inboxes').all(schema.Instance):
data = {'created': instance.created.timestamp()}
self.update('inboxes', data, domain = instance.domain)
for instance in self.select("inboxes").all(schema.Instance):
data = {"created": instance.created.timestamp()}
self.update("inboxes", data, domain = instance.domain)
for sban in self.select('software_bans').all(schema.SoftwareBan):
data = {'created': sban.created.timestamp()}
self.update('software_bans', data, name = sban.name)
for sban in self.select("software_bans").all(schema.SoftwareBan):
data = {"created": sban.created.timestamp()}
self.update("software_bans", data, name = sban.name)
for user in self.select('users').all(schema.User):
data = {'created': user.created.timestamp()}
self.update('users', data, username = user.username)
for user in self.select("users").all(schema.User):
data = {"created": user.created.timestamp()}
self.update("users", data, username = user.username)
for wlist in self.select('whitelist').all(schema.Whitelist):
data = {'created': wlist.created.timestamp()}
self.update('whitelist', data, domain = wlist.domain)
for wlist in self.select("whitelist").all(schema.Whitelist):
data = {"created": wlist.created.timestamp()}
self.update("whitelist", data, domain = wlist.domain)
def get_config(self, key: str) -> Any:
key = key.replace('_', '-')
key = key.replace("_", "-")
with self.run('get-config', {'key': key}) as cur:
with self.run("get-config", {"key": key}) as cur:
if (row := cur.one(Row)) is None:
return ConfigData.DEFAULT(key)
data = ConfigData()
data.set(row['key'], row['value'])
data.set(row["key"], row["value"])
return data.get(key)
def get_config_all(self) -> ConfigData:
rows = tuple(self.run('get-config-all', None).all(schema.Row))
rows = tuple(self.run("get-config-all", None).all(schema.Row))
return ConfigData.from_rows(rows)
@ -119,7 +120,7 @@ class Connection(SqlConnection):
case "log_level":
value = logging.LogLevel.parse(value)
logging.set_level(value)
self.app['workers'].set_log_level(value)
self.app["workers"].set_log_level(value)
case "approval_required":
value = convert_to_boolean(value)
@ -129,25 +130,25 @@ class Connection(SqlConnection):
case "theme":
if value not in THEMES:
raise ValueError(f'"{value}" is not a valid theme')
raise ValueError(f"\"{value}\" is not a valid theme")
data = ConfigData()
data.set(key, value)
params = {
'key': key,
'value': data.get(key, serialize = True),
'type': 'LogLevel' if field.type == 'logging.LogLevel' else field.type # type: ignore
"key": key,
"value": data.get(key, serialize = True),
"type": "LogLevel" if field.type == "logging.LogLevel" else field.type
}
with self.run('put-config', params):
with self.run("put-config", params):
pass
return data.get(key)
def get_inbox(self, value: str) -> schema.Instance | None:
with self.run('get-inbox', {'value': value}) as cur:
with self.run("get-inbox", {"value": value}) as cur:
return cur.one(schema.Instance)
@ -165,21 +166,21 @@ class Connection(SqlConnection):
accepted: bool = True) -> schema.Instance:
params: dict[str, Any] = {
'inbox': inbox,
'actor': actor,
'followid': followid,
'software': software,
'accepted': accepted
"inbox": inbox,
"actor": actor,
"followid": followid,
"software": software,
"accepted": accepted
}
if self.get_inbox(domain) is None:
if not inbox:
raise ValueError("Missing inbox")
params['domain'] = domain
params['created'] = datetime.now(tz = timezone.utc)
params["domain"] = domain
params["created"] = datetime.now(tz = timezone.utc)
with self.run('put-inbox', params) as cur:
with self.run("put-inbox", params) as cur:
if (row := cur.one(schema.Instance)) is None:
raise RuntimeError(f"Failed to insert instance: {domain}")
@ -189,7 +190,7 @@ class Connection(SqlConnection):
if value is None:
del params[key]
with self.update('inboxes', params, domain = domain) as cur:
with self.update("inboxes", params, domain = domain) as cur:
if (row := cur.one(schema.Instance)) is None:
raise RuntimeError(f"Failed to update instance: {domain}")
@ -197,20 +198,20 @@ class Connection(SqlConnection):
def del_inbox(self, value: str) -> bool:
with self.run('del-inbox', {'value': value}) as cur:
with self.run("del-inbox", {"value": value}) as cur:
if cur.row_count > 1:
raise ValueError('More than one row was modified')
raise ValueError("More than one row was modified")
return cur.row_count == 1
def get_request(self, domain: str) -> schema.Instance | None:
with self.run('get-request', {'domain': domain}) as cur:
with self.run("get-request", {"domain": domain}) as cur:
return cur.one(schema.Instance)
def get_requests(self) -> Iterator[schema.Instance]:
return self.execute('SELECT * FROM inboxes WHERE accepted = false').all(schema.Instance)
return self.execute("SELECT * FROM inboxes WHERE accepted = false").all(schema.Instance)
def put_request_response(self, domain: str, accepted: bool) -> schema.Instance:
@ -219,16 +220,16 @@ class Connection(SqlConnection):
if not accepted:
if not self.del_inbox(domain):
raise RuntimeError(f'Failed to delete request: {domain}')
raise RuntimeError(f"Failed to delete request: {domain}")
return instance
params = {
'domain': domain,
'accepted': accepted
"domain": domain,
"accepted": accepted
}
with self.run('put-inbox-accept', params) as cur:
with self.run("put-inbox-accept", params) as cur:
if (row := cur.one(schema.Instance)) is None:
raise RuntimeError(f"Failed to insert response for domain: {domain}")
@ -236,12 +237,12 @@ class Connection(SqlConnection):
def get_user(self, value: str) -> schema.User | None:
with self.run('get-user', {'value': value}) as cur:
with self.run("get-user", {"value": value}) as cur:
return cur.one(schema.User)
def get_user_by_token(self, token: str) -> schema.User | None:
with self.run('get-user-by-token', {'token': token}) as cur:
with self.run("get-user-by-token", {"token": token}) as cur:
return cur.one(schema.User)
@ -254,10 +255,10 @@ class Connection(SqlConnection):
data: dict[str, str | datetime | None] = {}
if password:
data['hash'] = self.hasher.hash(password)
data["hash"] = self.hasher.hash(password)
if handle:
data['handle'] = handle
data["handle"] = handle
stmt = Update("users", data)
stmt.set_where("username", username)
@ -269,16 +270,16 @@ class Connection(SqlConnection):
return row
if password is None:
raise ValueError('Password cannot be empty')
raise ValueError("Password cannot be empty")
data = {
'username': username,
'hash': self.hasher.hash(password),
'handle': handle,
'created': datetime.now(tz = timezone.utc)
"username": username,
"hash": self.hasher.hash(password),
"handle": handle,
"created": datetime.now(tz = timezone.utc)
}
with self.run('put-user', data) as cur:
with self.run("put-user", data) as cur:
if (row := cur.one(schema.User)) is None:
raise RuntimeError(f"Failed to insert user: {username}")
@ -289,10 +290,10 @@ class Connection(SqlConnection):
if (user := self.get_user(username)) is None:
raise KeyError(username)
with self.run('del-token-user', {'username': user.username}):
with self.run("del-token-user", {"username": user.username}):
pass
with self.run('del-user', {'username': user.username}):
with self.run("del-user", {"username": user.username}):
pass
@ -302,61 +303,61 @@ class Connection(SqlConnection):
token: str | None = None) -> schema.App | None:
params = {
'id': client_id,
'secret': client_secret
"id": client_id,
"secret": client_secret
}
if token is not None:
command = 'get-app-with-token'
params['token'] = token
command = "get-app-with-token"
params["token"] = token
else:
command = 'get-app'
command = "get-app"
with self.run(command, params) as cur:
return cur.one(schema.App)
def get_app_by_token(self, token: str) -> schema.App | None:
with self.run('get-app-by-token', {'token': token}) as cur:
with self.run("get-app-by-token", {"token": token}) as cur:
return cur.one(schema.App)
def put_app(self, name: str, redirect_uri: str, website: str | None = None) -> schema.App:
params = {
'name': name,
'redirect_uri': redirect_uri,
'website': website,
'client_id': secrets.token_hex(20),
'client_secret': secrets.token_hex(20),
'created': Date.new_utc(),
'accessed': Date.new_utc()
"name": name,
"redirect_uri": redirect_uri,
"website": website,
"client_id": secrets.token_hex(20),
"client_secret": secrets.token_hex(20),
"created": Date.new_utc(),
"accessed": Date.new_utc()
}
with self.insert('apps', params) as cur:
with self.insert("apps", params) as cur:
if (row := cur.one(schema.App)) is None:
raise RuntimeError(f'Failed to insert app: {name}')
raise RuntimeError(f"Failed to insert app: {name}")
return row
def put_app_login(self, user: schema.User) -> schema.App:
params = {
'name': 'Web',
'redirect_uri': 'urn:ietf:wg:oauth:2.0:oob',
'website': None,
'user': user.username,
'client_id': secrets.token_hex(20),
'client_secret': secrets.token_hex(20),
'auth_code': None,
'token': secrets.token_hex(20),
'created': Date.new_utc(),
'accessed': Date.new_utc()
"name": "Web",
"redirect_uri": "urn:ietf:wg:oauth:2.0:oob",
"website": None,
"user": user.username,
"client_id": secrets.token_hex(20),
"client_secret": secrets.token_hex(20),
"auth_code": None,
"token": secrets.token_hex(20),
"created": Date.new_utc(),
"accessed": Date.new_utc()
}
with self.insert('apps', params) as cur:
with self.insert("apps", params) as cur:
if (row := cur.one(schema.App)) is None:
raise RuntimeError(f'Failed to create app for "{user.username}"')
raise RuntimeError(f"Failed to create app for \"{user.username}\"")
return row
@ -365,52 +366,52 @@ class Connection(SqlConnection):
data: dict[str, str | None] = {}
if user is not None:
data['user'] = user.username
data["user"] = user.username
if set_auth:
data['auth_code'] = secrets.token_hex(20)
data["auth_code"] = secrets.token_hex(20)
else:
data['token'] = secrets.token_hex(20)
data['auth_code'] = None
data["token"] = secrets.token_hex(20)
data["auth_code"] = None
params = {
'client_id': app.client_id,
'client_secret': app.client_secret
"client_id": app.client_id,
"client_secret": app.client_secret
}
with self.update('apps', data, **params) as cur: # type: ignore[arg-type]
with self.update("apps", data, **params) as cur: # type: ignore[arg-type]
if (row := cur.one(schema.App)) is None:
raise RuntimeError('Failed to update row')
raise RuntimeError("Failed to update row")
return row
def del_app(self, client_id: str, client_secret: str, token: str | None = None) -> bool:
params = {
'id': client_id,
'secret': client_secret
"id": client_id,
"secret": client_secret
}
if token is not None:
command = 'del-app-with-token'
params['token'] = token
command = "del-app-with-token"
params["token"] = token
else:
command = 'del-app'
command = "del-app"
with self.run(command, params) as cur:
if cur.row_count > 1:
raise RuntimeError('More than 1 row was deleted')
raise RuntimeError("More than 1 row was deleted")
return cur.row_count == 0
def get_domain_ban(self, domain: str) -> schema.DomainBan | None:
if domain.startswith('http'):
if domain.startswith("http"):
domain = urlparse(domain).netloc
with self.run('get-domain-ban', {'domain': domain}) as cur:
with self.run("get-domain-ban", {"domain": domain}) as cur:
return cur.one(schema.DomainBan)
@ -424,13 +425,13 @@ class Connection(SqlConnection):
note: str | None = None) -> schema.DomainBan:
params = {
'domain': domain,
'reason': reason,
'note': note,
'created': datetime.now(tz = timezone.utc)
"domain": domain,
"reason": reason,
"note": note,
"created": datetime.now(tz = timezone.utc)
}
with self.run('put-domain-ban', params) as cur:
with self.run("put-domain-ban", params) as cur:
if (row := cur.one(schema.DomainBan)) is None:
raise RuntimeError(f"Failed to insert domain ban: {domain}")
@ -443,22 +444,22 @@ class Connection(SqlConnection):
note: str | None = None) -> schema.DomainBan:
if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified')
raise ValueError("\"reason\" and/or \"note\" must be specified")
params = {}
if reason is not None:
params['reason'] = reason
params["reason"] = reason
if note is not None:
params['note'] = note
params["note"] = note
statement = Update('domain_bans', params)
statement = Update("domain_bans", params)
statement.set_where("domain", domain)
with self.query(statement) as cur:
if cur.row_count > 1:
raise ValueError('More than one row was modified')
raise ValueError("More than one row was modified")
if (row := cur.one(schema.DomainBan)) is None:
raise RuntimeError(f"Failed to update domain ban: {domain}")
@ -467,20 +468,20 @@ class Connection(SqlConnection):
def del_domain_ban(self, domain: str) -> bool:
with self.run('del-domain-ban', {'domain': domain}) as cur:
with self.run("del-domain-ban", {"domain": domain}) as cur:
if cur.row_count > 1:
raise ValueError('More than one row was modified')
raise ValueError("More than one row was modified")
return cur.row_count == 1
def get_software_ban(self, name: str) -> schema.SoftwareBan | None:
with self.run('get-software-ban', {'name': name}) as cur:
with self.run("get-software-ban", {"name": name}) as cur:
return cur.one(schema.SoftwareBan)
def get_software_bans(self) -> Iterator[schema.SoftwareBan,]:
return self.execute('SELECT * FROM software_bans').all(schema.SoftwareBan)
return self.execute("SELECT * FROM software_bans").all(schema.SoftwareBan)
def put_software_ban(self,
@ -489,15 +490,15 @@ class Connection(SqlConnection):
note: str | None = None) -> schema.SoftwareBan:
params = {
'name': name,
'reason': reason,
'note': note,
'created': datetime.now(tz = timezone.utc)
"name": name,
"reason": reason,
"note": note,
"created": datetime.now(tz = timezone.utc)
}
with self.run('put-software-ban', params) as cur:
with self.run("put-software-ban", params) as cur:
if (row := cur.one(schema.SoftwareBan)) is None:
raise RuntimeError(f'Failed to insert software ban: {name}')
raise RuntimeError(f"Failed to insert software ban: {name}")
return row
@ -508,39 +509,39 @@ class Connection(SqlConnection):
note: str | None = None) -> schema.SoftwareBan:
if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified')
raise ValueError("\"reason\" and/or \"note\" must be specified")
params = {}
if reason is not None:
params['reason'] = reason
params["reason"] = reason
if note is not None:
params['note'] = note
params["note"] = note
statement = Update('software_bans', params)
statement = Update("software_bans", params)
statement.set_where("name", name)
with self.query(statement) as cur:
if cur.row_count > 1:
raise ValueError('More than one row was modified')
raise ValueError("More than one row was modified")
if (row := cur.one(schema.SoftwareBan)) is None:
raise RuntimeError(f'Failed to update software ban: {name}')
raise RuntimeError(f"Failed to update software ban: {name}")
return row
def del_software_ban(self, name: str) -> bool:
with self.run('del-software-ban', {'name': name}) as cur:
with self.run("del-software-ban", {"name": name}) as cur:
if cur.row_count > 1:
raise ValueError('More than one row was modified')
raise ValueError("More than one row was modified")
return cur.row_count == 1
def get_domain_whitelist(self, domain: str) -> schema.Whitelist | None:
with self.run('get-domain-whitelist', {'domain': domain}) as cur:
with self.run("get-domain-whitelist", {"domain": domain}) as cur:
return cur.one()
@ -550,20 +551,20 @@ class Connection(SqlConnection):
def put_domain_whitelist(self, domain: str) -> schema.Whitelist:
params = {
'domain': domain,
'created': datetime.now(tz = timezone.utc)
"domain": domain,
"created": datetime.now(tz = timezone.utc)
}
with self.run('put-domain-whitelist', params) as cur:
with self.run("put-domain-whitelist", params) as cur:
if (row := cur.one(schema.Whitelist)) is None:
raise RuntimeError(f'Failed to insert whitelisted domain: {domain}')
raise RuntimeError(f"Failed to insert whitelisted domain: {domain}")
return row
def del_domain_whitelist(self, domain: str) -> bool:
with self.run('del-domain-whitelist', {'domain': domain}) as cur:
with self.run("del-domain-whitelist", {"domain": domain}) as cur:
if cur.row_count > 1:
raise ValueError('More than one row was modified')
raise ValueError("More than one row was modified")
return cur.row_count == 1

View file

@ -32,113 +32,113 @@ def deserialize_timestamp(value: Any) -> Date:
@TABLES.add_row
class Config(Row):
key: Column[str] = Column('key', 'text', primary_key = True, unique = True, nullable = False)
value: Column[str] = Column('value', 'text')
type: Column[str] = Column('type', 'text', default = 'str')
key: Column[str] = Column("key", "text", primary_key = True, unique = True, nullable = False)
value: Column[str] = Column("value", "text")
type: Column[str] = Column("type", "text", default = "str")
@TABLES.add_row
class Instance(Row):
table_name: str = 'inboxes'
table_name: str = "inboxes"
domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = False)
actor: Column[str] = Column('actor', 'text', unique = True)
inbox: Column[str] = Column('inbox', 'text', unique = True, nullable = False)
followid: Column[str] = Column('followid', 'text')
software: Column[str] = Column('software', 'text')
accepted: Column[Date] = Column('accepted', 'boolean')
"domain", "text", primary_key = True, unique = True, nullable = False)
actor: Column[str] = Column("actor", "text", unique = True)
inbox: Column[str] = Column("inbox", "text", unique = True, nullable = False)
followid: Column[str] = Column("followid", "text")
software: Column[str] = Column("software", "text")
accepted: Column[Date] = Column("accepted", "boolean")
created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
"created", "timestamp", nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row
class Whitelist(Row):
domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = True)
"domain", "text", primary_key = True, unique = True, nullable = True)
created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
"created", "timestamp", nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row
class DomainBan(Row):
table_name: str = 'domain_bans'
table_name: str = "domain_bans"
domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text')
"domain", "text", primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column("reason", "text")
note: Column[str] = Column("note", "text")
created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
"created", "timestamp", nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row
class SoftwareBan(Row):
table_name: str = 'software_bans'
table_name: str = "software_bans"
name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text')
name: Column[str] = Column("name", "text", primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column("reason", "text")
note: Column[str] = Column("note", "text")
created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
"created", "timestamp", nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row
class User(Row):
table_name: str = 'users'
table_name: str = "users"
username: Column[str] = Column(
'username', 'text', primary_key = True, unique = True, nullable = False)
hash: Column[str] = Column('hash', 'text', nullable = False)
handle: Column[str] = Column('handle', 'text')
"username", "text", primary_key = True, unique = True, nullable = False)
hash: Column[str] = Column("hash", "text", nullable = False)
handle: Column[str] = Column("handle", "text")
created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
"created", "timestamp", nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row
class App(Row):
table_name: str = 'apps'
table_name: str = "apps"
client_id: Column[str] = Column(
'client_id', 'text', primary_key = True, unique = True, nullable = False)
client_secret: Column[str] = Column('client_secret', 'text', nullable = False)
name: Column[str] = Column('name', 'text')
website: Column[str] = Column('website', 'text')
redirect_uri: Column[str] = Column('redirect_uri', 'text', nullable = False)
token: Column[str | None] = Column('token', 'text')
auth_code: Column[str | None] = Column('auth_code', 'text')
user: Column[str | None] = Column('user', 'text')
"client_id", "text", primary_key = True, unique = True, nullable = False)
client_secret: Column[str] = Column("client_secret", "text", nullable = False)
name: Column[str] = Column("name", "text")
website: Column[str] = Column("website", "text")
redirect_uri: Column[str] = Column("redirect_uri", "text", nullable = False)
token: Column[str | None] = Column("token", "text")
auth_code: Column[str | None] = Column("auth_code", "text")
user: Column[str | None] = Column("user", "text")
created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
"created", "timestamp", nullable = False, deserializer = deserialize_timestamp)
accessed: Column[Date] = Column(
'accessed', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
"accessed", "timestamp", nullable = False, deserializer = deserialize_timestamp)
def get_api_data(self, include_token: bool = False) -> dict[str, Any]:
data = deepcopy(self)
data.pop('user')
data.pop('auth_code')
data.pop("user")
data.pop("auth_code")
if not include_token:
data.pop('token')
data.pop("token")
return data
def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]:
ver = int(func.__name__.replace('migrate_', ''))
ver = int(func.__name__.replace("migrate_", ""))
VERSIONS[ver] = func
return func
def migrate_0(conn: Connection) -> None:
conn.create_tables()
conn.put_config('schema-version', ConfigData.DEFAULT('schema-version'))
conn.put_config("schema-version", ConfigData.DEFAULT("schema-version"))
@migration
@ -148,11 +148,11 @@ def migrate_20240206(conn: Connection) -> None:
@migration
def migrate_20240310(conn: Connection) -> None:
conn.execute('ALTER TABLE "inboxes" ADD COLUMN "accepted" BOOLEAN').close()
conn.execute('UPDATE "inboxes" SET "accepted" = true').close()
conn.execute("ALTER TABLE \"inboxes\" ADD COLUMN \"accepted\" BOOLEAN").close()
conn.execute("UPDATE \"inboxes\" SET \"accepted\" = true").close()
@migration
def migrate_20240625(conn: Connection) -> None:
conn.create_tables()
conn.execute('DROP TABLE "tokens"').close()
conn.execute("DROP TABLE \"tokens\"").close()

View file

@ -16,15 +16,13 @@
-if config.approval_required
%div.section.message
Follow requests require approval. You will need to wait for an admin to accept or deny
your request.
Follow requests require approval. You will need to wait for an admin to accept or deny your request.
-elif config.whitelist_enabled
%fieldset.section.message
%legend << Whitelist Enabled
The whitelist is enabled on this instance. Ask the admin to add your instance before
joining.
The whitelist is enabled on this instance. Ask the admin to add your instance before joining.
%fieldset.section
%legend << Instances

View file

@ -499,8 +499,7 @@ function page_login() {
async function login(event) {
const values = {
username: fields.username.value.trim(),
password: fields.password.value.trim(),
redir: fields.redir.value.trim()
password: fields.password.value.trim()
}
if (values.username === "" | values.password === "") {
@ -509,14 +508,16 @@ function page_login() {
}
try {
await request("POST", "v1/login", values);
application = await request("POST", "v1/login", values);
} catch (error) {
toast(error);
return;
}
document.location = values.redir;
const max_age = 60 * 60 * 24 * 30;
document.cookie = `user-token=${application.token};Secure;SameSite=Strict;Domain=${document.location.host};MaxAge=${max_age}`;
document.location = fields.redir.value.trim();
}

View file

@ -17,6 +17,13 @@ if TYPE_CHECKING:
from .application import Application
T = TypeVar("T", bound = JsonBase[Any])
HEADERS = {
"Accept": f"{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9",
"User-Agent": f"ActivityRelay/{__version__}"
}
SUPPORTS_HS2019 = {
'friendica',
'gotosocial',
@ -32,12 +39,6 @@ SUPPORTS_HS2019 = {
'sharkey'
}
T = TypeVar('T', bound = JsonBase[Any])
HEADERS = {
'Accept': f'{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9',
'User-Agent': f'ActivityRelay/{__version__}'
}
class HttpClient:
def __init__(self, limit: int = 100, timeout: int = 10):
@ -106,25 +107,25 @@ class HttpClient:
old_algo: bool) -> str | None:
if not self._session:
raise RuntimeError('Client not open')
raise RuntimeError("Client not open")
url = url.split("#", 1)[0]
if not force:
try:
if not (item := self.cache.get('request', url)).older_than(48):
if not (item := self.cache.get("request", url)).older_than(48):
return item.value # type: ignore [no-any-return]
except KeyError:
logging.verbose('No cached data for url: %s', url)
logging.verbose("No cached data for url: %s", url)
headers = {}
if sign_headers:
algo = AlgorithmType.RSASHA256 if old_algo else AlgorithmType.HS2019
headers = self.signer.sign_headers('GET', url, algorithm = algo)
headers = self.signer.sign_headers("GET", url, algorithm = algo)
logging.debug('Fetching resource: %s', url)
logging.debug("Fetching resource: %s", url)
async with self._session.get(url, headers = headers) as resp:
# Not expecting a response with 202s, so just return
@ -142,7 +143,7 @@ class HttpClient:
raise HttpError(resp.status, error)
self.cache.set('request', url, data, 'str')
self.cache.set("request", url, data, "str")
return data
@ -172,13 +173,13 @@ class HttpClient:
old_algo: bool = True) -> T | str | None:
if cls is not None and not issubclass(cls, JsonBase):
raise TypeError('cls must be a sub-class of "blib.JsonBase"')
raise TypeError("cls must be a sub-class of \"blib.JsonBase\"")
data = await self._get(url, sign_headers, force, old_algo)
if cls is not None:
if data is None:
# this shouldn't actually get raised, but keeping just in case
# this shouldn"t actually get raised, but keeping just in case
raise EmptyBodyError(f"GET {url}")
return cls.parse(data)
@ -188,7 +189,7 @@ class HttpClient:
async def post(self, url: str, data: Message | bytes, instance: Instance | None = None) -> None:
if not self._session:
raise RuntimeError('Client not open')
raise RuntimeError("Client not open")
# akkoma and pleroma do not support HS2019 and other software still needs to be tested
if instance is not None and instance.software in SUPPORTS_HS2019:
@ -210,14 +211,14 @@ class HttpClient:
mtype = message.type.value if isinstance(message.type, ObjectType) else message.type
headers = self.signer.sign_headers(
'POST',
"POST",
url,
body,
headers = {'Content-Type': 'application/activity+json'},
headers = {"Content-Type": "application/activity+json"},
algorithm = algorithm
)
logging.verbose('Sending "%s" to %s', mtype, url)
logging.verbose("Sending \"%s\" to %s", mtype, url)
async with self._session.post(url, headers = headers, data = body) as resp:
if resp.status not in (200, 202):
@ -231,10 +232,10 @@ class HttpClient:
async def fetch_nodeinfo(self, domain: str, force: bool = False) -> Nodeinfo:
nodeinfo_url = None
wk_nodeinfo = await self.get(
f'https://{domain}/.well-known/nodeinfo', False, WellKnownNodeinfo, force
f"https://{domain}/.well-known/nodeinfo", False, WellKnownNodeinfo, force
)
for version in ('20', '21'):
for version in ("20", "21"):
try:
nodeinfo_url = wk_nodeinfo.get_url(version)
@ -242,7 +243,7 @@ class HttpClient:
pass
if nodeinfo_url is None:
raise ValueError(f'Failed to fetch nodeinfo url for {domain}')
raise ValueError(f"Failed to fetch nodeinfo url for {domain}")
return await self.get(nodeinfo_url, False, Nodeinfo, force)

View file

@ -54,7 +54,7 @@ class LogLevel(IntEnum):
except ValueError:
pass
raise AttributeError(f'Invalid enum property for {cls.__name__}: {data}')
raise AttributeError(f"Invalid enum property for {cls.__name__}: {data}")
def get_level() -> LogLevel:
@ -80,7 +80,7 @@ critical: LoggingMethod = logging.critical
try:
env_log_file: Path | None = Path(os.environ['LOG_FILE']).expanduser().resolve()
env_log_file: Path | None = Path(os.environ["LOG_FILE"]).expanduser().resolve()
except KeyError:
env_log_file = None
@ -90,16 +90,16 @@ handlers: list[Any] = [logging.StreamHandler()]
if env_log_file:
handlers.append(logging.FileHandler(env_log_file))
if os.environ.get('IS_SYSTEMD'):
logging_format = '%(levelname)s: %(message)s'
if os.environ.get("IS_SYSTEMD"):
logging_format = "%(levelname)s: %(message)s"
else:
logging_format = '[%(asctime)s] %(levelname)s: %(message)s'
logging_format = "[%(asctime)s] %(levelname)s: %(message)s"
logging.addLevelName(LogLevel.VERBOSE, 'VERBOSE')
logging.addLevelName(LogLevel.VERBOSE, "VERBOSE")
logging.basicConfig(
level = LogLevel.INFO,
format = logging_format,
datefmt = '%Y-%m-%d %H:%M:%S',
datefmt = "%Y-%m-%d %H:%M:%S",
handlers = handlers
)

File diff suppressed because it is too large Load diff

View file

@ -16,63 +16,53 @@ if TYPE_CHECKING:
from .application import Application
T = TypeVar('T')
ResponseType = TypedDict('ResponseType', {
'status': int,
'headers': dict[str, Any] | None,
'content_type': str,
'body': bytes | None,
'text': str | None
T = TypeVar("T")
IS_DOCKER = bool(os.environ.get("DOCKER_RUNNING"))
IS_WINDOWS = platform.system() == "Windows"
ResponseType = TypedDict("ResponseType", {
"status": int,
"headers": dict[str, Any] | None,
"content_type": str,
"body": bytes | None,
"text": str | None
})
IS_DOCKER = bool(os.environ.get('DOCKER_RUNNING'))
IS_WINDOWS = platform.system() == 'Windows'
MIMETYPES = {
'activity': 'application/activity+json',
'css': 'text/css',
'html': 'text/html',
'json': 'application/json',
'text': 'text/plain',
'webmanifest': 'application/manifest+json'
"activity": "application/activity+json",
"css": "text/css",
"html": "text/html",
"json": "application/json",
"text": "text/plain",
"webmanifest": "application/manifest+json"
}
ACTOR_FORMATS = {
'mastodon': 'https://{domain}/actor',
'akkoma': 'https://{domain}/relay',
'pleroma': 'https://{domain}/relay'
"mastodon": "https://{domain}/actor",
"akkoma": "https://{domain}/relay",
"pleroma": "https://{domain}/relay"
}
SOFTWARE = (
'mastodon',
'akkoma',
'pleroma',
'misskey',
'friendica',
'hubzilla',
'firefish',
'gotosocial'
)
JSON_PATHS: tuple[str, ...] = (
'/api/v1',
'/actor',
'/inbox',
'/outbox',
'/following',
'/followers',
'/.well-known',
'/nodeinfo',
'/oauth/token',
'/oauth/revoke'
"/api/v1",
"/actor",
"/inbox",
"/outbox",
"/following",
"/followers",
"/.well-known",
"/nodeinfo",
"/oauth/token",
"/oauth/revoke"
)
TOKEN_PATHS: tuple[str, ...] = (
'/logout',
'/admin',
'/api',
'/oauth/authorize',
'/oauth/revoke'
"/logout",
"/admin",
"/api",
"/oauth/authorize",
"/oauth/revoke"
)
@ -80,7 +70,7 @@ def get_app() -> Application:
from .application import Application
if not Application.DEFAULT:
raise ValueError('No default application set')
raise ValueError("No default application set")
return Application.DEFAULT
@ -136,23 +126,23 @@ class Message(aputils.Message):
approves: bool = False) -> Self:
return cls.new(aputils.ObjectType.APPLICATION, {
'id': f'https://{host}/actor',
'preferredUsername': 'relay',
'name': 'ActivityRelay',
'summary': description or 'ActivityRelay bot',
'manuallyApprovesFollowers': approves,
'followers': f'https://{host}/followers',
'following': f'https://{host}/following',
'inbox': f'https://{host}/inbox',
'outbox': f'https://{host}/outbox',
'url': f'https://{host}/',
'endpoints': {
'sharedInbox': f'https://{host}/inbox'
"id": f"https://{host}/actor",
"preferredUsername": "relay",
"name": "ActivityRelay",
"summary": description or "ActivityRelay bot",
"manuallyApprovesFollowers": approves,
"followers": f"https://{host}/followers",
"following": f"https://{host}/following",
"inbox": f"https://{host}/inbox",
"outbox": f"https://{host}/outbox",
"url": f"https://{host}/",
"endpoints": {
"sharedInbox": f"https://{host}/inbox"
},
'publicKey': {
'id': f'https://{host}/actor#main-key',
'owner': f'https://{host}/actor',
'publicKeyPem': pubkey
"publicKey": {
"id": f"https://{host}/actor#main-key",
"owner": f"https://{host}/actor",
"publicKeyPem": pubkey
}
})
@ -160,44 +150,44 @@ class Message(aputils.Message):
@classmethod
def new_announce(cls: type[Self], host: str, obj: str | dict[str, Any]) -> Self:
return cls.new(aputils.ObjectType.ANNOUNCE, {
'id': f'https://{host}/activities/{uuid4()}',
'to': [f'https://{host}/followers'],
'actor': f'https://{host}/actor',
'object': obj
"id": f"https://{host}/activities/{uuid4()}",
"to": [f"https://{host}/followers"],
"actor": f"https://{host}/actor",
"object": obj
})
@classmethod
def new_follow(cls: type[Self], host: str, actor: str) -> Self:
return cls.new(aputils.ObjectType.FOLLOW, {
'id': f'https://{host}/activities/{uuid4()}',
'to': [actor],
'object': actor,
'actor': f'https://{host}/actor'
"id": f"https://{host}/activities/{uuid4()}",
"to": [actor],
"object": actor,
"actor": f"https://{host}/actor"
})
@classmethod
def new_unfollow(cls: type[Self], host: str, actor: str, follow: dict[str, str]) -> Self:
return cls.new(aputils.ObjectType.UNDO, {
'id': f'https://{host}/activities/{uuid4()}',
'to': [actor],
'actor': f'https://{host}/actor',
'object': follow
"id": f"https://{host}/activities/{uuid4()}",
"to": [actor],
"actor": f"https://{host}/actor",
"object": follow
})
@classmethod
def new_response(cls: type[Self], host: str, actor: str, followid: str, accept: bool) -> Self:
return cls.new(aputils.ObjectType.ACCEPT if accept else aputils.ObjectType.REJECT, {
'id': f'https://{host}/activities/{uuid4()}',
'to': [actor],
'actor': f'https://{host}/actor',
'object': {
'id': followid,
'type': 'Follow',
'object': f'https://{host}/actor',
'actor': actor
"id": f"https://{host}/activities/{uuid4()}",
"to": [actor],
"actor": f"https://{host}/actor",
"object": {
"id": followid,
"type": "Follow",
"object": f"https://{host}/actor",
"actor": actor
}
})
@ -210,35 +200,35 @@ class Response(AiohttpResponse):
@classmethod
def new(cls: type[Self],
body: str | bytes | dict[str, Any] | Sequence[Any] = '',
body: str | bytes | dict[str, Any] | Sequence[Any] = "",
status: int = 200,
headers: dict[str, str] | None = None,
ctype: str = 'text') -> Self:
ctype: str = "text") -> Self:
kwargs: ResponseType = {
'status': status,
'headers': headers,
'content_type': MIMETYPES[ctype],
'body': None,
'text': None
"status": status,
"headers": headers,
"content_type": MIMETYPES[ctype],
"body": None,
"text": None
}
if isinstance(body, str):
kwargs['text'] = body
kwargs["text"] = body
elif isinstance(body, bytes):
kwargs['body'] = body
kwargs["body"] = body
elif isinstance(body, (dict, Sequence)):
kwargs['text'] = json.dumps(body, cls = JsonEncoder)
kwargs["text"] = json.dumps(body, cls = JsonEncoder)
return cls(**kwargs)
@classmethod
def new_redir(cls: type[Self], path: str, status: int = 307) -> Self:
body = f'Redirect to <a href="{path}">{path}</a>'
return cls.new(body, status, {'Location': path}, ctype = 'html')
body = f"Redirect to <a href=\"{path}\">{path}</a>"
return cls.new(body, status, {"Location": path}, ctype = "html")
@classmethod
@ -256,9 +246,9 @@ class Response(AiohttpResponse):
@property
def location(self) -> str:
return self.headers.get('Location', '')
return self.headers.get("Location", "")
@location.setter
def location(self, value: str) -> None:
self.headers['Location'] = value
self.headers["Location"] = value

View file

@ -11,12 +11,12 @@ if typing.TYPE_CHECKING:
from .views.activitypub import InboxData
def actor_type_check(actor: Message, software: str | None) -> bool:
if actor.type == 'Application':
def is_application(actor: Message, software: str | None) -> bool:
if actor.type == "Application":
return True
# akkoma (< 3.6.0) and pleroma use Person for the actor type
if software in {'akkoma', 'pleroma'} and actor.id == f'https://{actor.domain}/relay':
if software in {"akkoma", "pleroma"} and actor.id == f"https://{actor.domain}/relay":
return True
return False
@ -24,38 +24,38 @@ def actor_type_check(actor: Message, software: str | None) -> bool:
async def handle_relay(app: Application, data: InboxData, conn: Connection) -> None:
try:
app.cache.get('handle-relay', data.message.object_id)
logging.verbose('already relayed %s', data.message.object_id)
app.cache.get("handle-relay", data.message.object_id)
logging.verbose("already relayed %s", data.message.object_id)
return
except KeyError:
pass
message = Message.new_announce(app.config.domain, data.message.object_id)
logging.debug('>> relay: %s', message)
logging.debug(">> relay: %s", message)
for instance in conn.distill_inboxes(data.message):
app.push_message(instance.inbox, message, instance)
app.cache.set('handle-relay', data.message.object_id, message.id, 'str')
app.cache.set("handle-relay", data.message.object_id, message.id, "str")
async def handle_forward(app: Application, data: InboxData, conn: Connection) -> None:
try:
app.cache.get('handle-relay', data.message.id)
logging.verbose('already forwarded %s', data.message.id)
app.cache.get("handle-relay", data.message.id)
logging.verbose("already forwarded %s", data.message.id)
return
except KeyError:
pass
message = Message.new_announce(app.config.domain, data.message)
logging.debug('>> forward: %s', message)
logging.debug(">> forward: %s", message)
for instance in conn.distill_inboxes(data.message):
app.push_message(instance.inbox, data.message, instance)
app.cache.set('handle-relay', data.message.id, message.id, 'str')
app.cache.set("handle-relay", data.message.id, message.id, "str")
async def handle_follow(app: Application, data: InboxData, conn: Connection) -> None:
@ -65,10 +65,8 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) ->
# reject if software used by actor is banned
if software and conn.get_software_ban(software):
logging.verbose('Rejected banned actor: %s', data.actor.id)
app.push_message(
data.actor.shared_inbox,
data.shared_inbox,
Message.new_response(
host = app.config.domain,
actor = data.actor.id,
@ -79,7 +77,7 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) ->
)
logging.verbose(
'Rejected follow from actor for using specific software: actor=%s, software=%s',
"Rejected follow from actor for using specific software: actor=%s, software=%s",
data.actor.id,
software
)
@ -87,11 +85,11 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) ->
return
# reject if the actor is not an instance actor
if actor_type_check(data.actor, software):
logging.verbose('Non-application actor tried to follow: %s', data.actor.id)
if not is_application(data.actor, software):
logging.verbose("Non-application actor tried to follow: %s", data.actor.id)
app.push_message(
data.actor.shared_inbox,
data.shared_inbox,
Message.new_response(
host = app.config.domain,
actor = data.actor.id,
@ -106,12 +104,12 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) ->
if not conn.get_domain_whitelist(data.actor.domain):
# add request if approval-required is enabled
if config.approval_required:
logging.verbose('New follow request fromm actor: %s', data.actor.id)
logging.verbose("New follow request fromm actor: %s", data.actor.id)
with conn.transaction():
data.instance = conn.put_inbox(
domain = data.actor.domain,
inbox = data.actor.shared_inbox,
inbox = data.shared_inbox,
actor = data.actor.id,
followid = data.message.id,
software = software,
@ -120,12 +118,12 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) ->
return
# reject if the actor isn't whitelisted while the whiltelist is enabled
# reject if the actor isn"t whitelisted while the whiltelist is enabled
if config.whitelist_enabled:
logging.verbose('Rejected actor for not being in the whitelist: %s', data.actor.id)
logging.verbose("Rejected actor for not being in the whitelist: %s", data.actor.id)
app.push_message(
data.actor.shared_inbox,
data.shared_inbox,
Message.new_response(
host = app.config.domain,
actor = data.actor.id,
@ -140,7 +138,7 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) ->
with conn.transaction():
data.instance = conn.put_inbox(
domain = data.actor.domain,
inbox = data.actor.shared_inbox,
inbox = data.shared_inbox,
actor = data.actor.id,
followid = data.message.id,
software = software,
@ -148,7 +146,7 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) ->
)
app.push_message(
data.actor.shared_inbox,
data.shared_inbox,
Message.new_response(
host = app.config.domain,
actor = data.actor.id,
@ -160,9 +158,9 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) ->
# Are Akkoma and Pleroma the only two that expect a follow back?
# Ignoring only Mastodon for now
if software != 'mastodon':
if software != "mastodon":
app.push_message(
data.actor.shared_inbox,
data.shared_inbox,
Message.new_follow(
host = app.config.domain,
actor = data.actor.id
@ -172,8 +170,8 @@ async def handle_follow(app: Application, data: InboxData, conn: Connection) ->
async def handle_undo(app: Application, data: InboxData, conn: Connection) -> None:
if data.message.object['type'] != 'Follow':
# forwarding deletes does not work, so don't bother
if data.message.object["type"] != "Follow":
# forwarding deletes does not work, so don"t bother
# await handle_forward(app, data, conn)
return
@ -187,13 +185,13 @@ async def handle_undo(app: Application, data: InboxData, conn: Connection) -> No
with conn.transaction():
if not conn.del_inbox(data.actor.id):
logging.verbose(
'Failed to delete "%s" with follow ID "%s"',
"Failed to delete \"%s\" with follow ID \"%s\"",
data.actor.id,
data.message.object_id
)
app.push_message(
data.actor.shared_inbox,
data.shared_inbox,
Message.new_unfollow(
host = app.config.domain,
actor = data.actor.id,
@ -204,19 +202,19 @@ async def handle_undo(app: Application, data: InboxData, conn: Connection) -> No
processors = {
'Announce': handle_relay,
'Create': handle_relay,
'Delete': handle_forward,
'Follow': handle_follow,
'Undo': handle_undo,
'Update': handle_forward,
"Announce": handle_relay,
"Create": handle_relay,
"Delete": handle_forward,
"Follow": handle_follow,
"Undo": handle_undo,
"Update": handle_forward,
}
async def run_processor(data: InboxData) -> None:
if data.message.type not in processors:
logging.verbose(
'Message type "%s" from actor cannot be handled: %s',
"Message type \"%s\" from actor cannot be handled: %s",
data.message.type,
data.actor.id
)
@ -242,5 +240,5 @@ async def run_processor(data: InboxData) -> None:
actor = data.actor.id
)
logging.verbose('New "%s" from actor: %s', data.message.type, data.actor.id)
logging.verbose("New \"%s\" from actor: %s", data.message.type, data.actor.id)
await processors[data.message.type](app, data, conn)

0
relay/py.typed Normal file
View file

View file

@ -5,7 +5,7 @@ import textwrap
from aiohttp.web import Request
from blib import File
from collections.abc import Callable
from hamlish_jinja import HamlishExtension
from hamlish import HamlishExtension, HamlishSettings
from jinja2 import Environment, FileSystemLoader
from jinja2.ext import Extension
from jinja2.nodes import CallBlock, Node
@ -21,6 +21,7 @@ if TYPE_CHECKING:
class Template(Environment):
render_markdown: Callable[[str], str]
hamlish: HamlishSettings
def __init__(self, app: Application):
@ -33,14 +34,12 @@ class Template(Environment):
MarkdownExtension
],
loader = FileSystemLoader([
File.from_resource('relay', 'frontend'),
app.config.path.parent.joinpath('template')
File.from_resource("relay", "frontend"),
app.config.path.parent.joinpath("template")
])
)
self.app = app
self.hamlish_enable_div_shortcut = True
self.hamlish_mode = 'indented'
def render(self, path: str, request: Request, **context: Any) -> str:
@ -48,10 +47,10 @@ class Template(Environment):
config = conn.get_config_all()
new_context = {
'request': request,
'domain': self.app.config.domain,
'version': __version__,
'config': config,
"request": request,
"domain": self.app.config.domain,
"version": __version__,
"config": config,
**(context or {})
}
@ -59,11 +58,11 @@ class Template(Environment):
class MarkdownExtension(Extension):
tags = {'markdown'}
tags = {"markdown"}
extensions = (
'attr_list',
'smarty',
'tables'
"attr_list",
"smarty",
"tables"
)
@ -78,14 +77,14 @@ class MarkdownExtension(Extension):
def parse(self, parser: Parser) -> Node | list[Node]:
lineno = next(parser.stream).lineno
body = parser.parse_statements(
('name:endmarkdown',),
("name:endmarkdown",),
drop_needle = True
)
output = CallBlock(self.call_method('_render_markdown'), [], [], body)
output = CallBlock(self.call_method("_render_markdown"), [], [], body)
return output.set_lineno(lineno)
def _render_markdown(self, caller: Callable[[], str] | str) -> str:
text = caller if isinstance(caller, str) else caller()
return self._markdown.convert(textwrap.dedent(text.strip('\n')))
return self._markdown.convert(textwrap.dedent(text.strip("\n")))

View file

@ -44,66 +44,86 @@ class InboxData:
signer: Signer | None = None
try:
signature = Signature.parse(request.headers['signature'])
signature = Signature.parse(request.headers["signature"])
except KeyError:
logging.verbose('Missing signature header')
raise HttpError(400, 'missing signature header')
logging.verbose("Missing signature header")
raise HttpError(400, "missing signature header")
try:
message = await request.json(loads = Message.parse)
except Exception:
traceback.print_exc()
logging.verbose('Failed to parse message from actor: %s', signature.keyid)
raise HttpError(400, 'failed to parse message')
logging.verbose("Failed to parse message from actor: %s", signature.keyid)
raise HttpError(400, "failed to parse message")
if message is None:
logging.verbose('empty message')
raise HttpError(400, 'missing message')
logging.verbose("empty message")
raise HttpError(400, "missing message")
if 'actor' not in message:
logging.verbose('actor not in message')
raise HttpError(400, 'no actor in message')
if "actor" not in message:
logging.verbose("actor not in message")
raise HttpError(400, "no actor in message")
actor_id: str
try:
actor = await app.client.get(signature.keyid, True, Message)
actor_id = message.actor_id
except AttributeError:
actor_id = signature.keyid
try:
actor = await app.client.get(actor_id, True, Message)
except HttpError as e:
# ld signatures aren't handled atm, so just ignore it
if message.type == 'Delete':
logging.verbose('Instance sent a delete which cannot be handled')
raise HttpError(202, '')
# ld signatures aren"t handled atm, so just ignore it
if message.type == "Delete":
logging.verbose("Instance sent a delete which cannot be handled")
raise HttpError(202, "")
logging.verbose('Failed to fetch actor: %s', signature.keyid)
logging.debug('HTTP Status %i: %s', e.status, e.message)
raise HttpError(400, 'failed to fetch actor')
logging.verbose("Failed to fetch actor: %s", signature.keyid)
logging.debug("HTTP Status %i: %s", e.status, e.message)
raise HttpError(400, "failed to fetch actor")
except ClientConnectorError as e:
logging.warning('Error when trying to fetch actor: %s, %s', signature.keyid, str(e))
raise HttpError(400, 'failed to fetch actor')
logging.warning("Error when trying to fetch actor: %s, %s", signature.keyid, str(e))
raise HttpError(400, "failed to fetch actor")
except Exception:
traceback.print_exc()
raise HttpError(500, 'unexpected error when fetching actor')
raise HttpError(500, "unexpected error when fetching actor")
try:
signer = actor.signer
except KeyError:
logging.verbose('Actor missing public key: %s', signature.keyid)
raise HttpError(400, 'actor missing public key')
logging.verbose("Actor missing public key: %s", signature.keyid)
raise HttpError(400, "actor missing public key")
try:
await signer.validate_request_async(request)
except SignatureFailureError as e:
logging.verbose('signature validation failed for "%s": %s', actor.id, e)
logging.verbose("signature validation failed for \"%s\": %s", actor.id, e)
raise HttpError(401, str(e))
return cls(signature, message, actor, signer, None)
@property
def shared_inbox(self) -> str:
if self.actor is None:
raise AttributeError("Actor not set yet")
try:
return self.actor.shared_inbox
except KeyError:
return self.actor.inbox # type: ignore[no-any-return]
@register_route(HttpMethod.GET, "/actor", "/inbox")
async def handle_actor(app: Application, request: Request) -> Response:
with app.database.session(False) as conn:
@ -124,34 +144,34 @@ async def handle_inbox(app: Application, request: Request) -> Response:
data = await InboxData.parse(app, request)
with app.database.session() as conn:
data.instance = conn.get_inbox(data.actor.shared_inbox)
data.instance = conn.get_inbox(data.shared_inbox)
# reject if actor is banned
if conn.get_domain_ban(data.actor.domain):
logging.verbose('Ignored request from banned actor: %s', data.actor.id)
raise HttpError(403, 'access denied')
logging.verbose("Ignored request from banned actor: %s", data.actor.id)
raise HttpError(403, "access denied")
# reject if activity type isn't 'Follow' and the actor isn't following
if data.message.type != 'Follow' and not data.instance:
# reject if activity type isn"t "Follow" and the actor isn"t following
if data.message.type != "Follow" and not data.instance:
logging.verbose(
'Rejected actor for trying to post while not following: %s',
"Rejected actor for trying to post while not following: %s",
data.actor.id
)
raise HttpError(401, 'access denied')
raise HttpError(401, "access denied")
logging.debug('>> payload %s', data.message.to_json(4))
logging.debug(">> payload %s", data.message.to_json(4))
await run_processor(data)
return Response.new(status = 202)
@register_route(HttpMethod.GET, '/outbox')
@register_route(HttpMethod.GET, "/outbox")
async def handle_outbox(app: Application, request: Request) -> Response:
msg = aputils.Message.new(
aputils.ObjectType.ORDERED_COLLECTION,
{
"id": f'https://{app.config.domain}/outbox',
"id": f"https://{app.config.domain}/outbox",
"totalItems": 0,
"orderedItems": []
}
@ -160,15 +180,15 @@ async def handle_outbox(app: Application, request: Request) -> Response:
return Response.new(msg, ctype = "activity")
@register_route(HttpMethod.GET, '/following', '/followers')
@register_route(HttpMethod.GET, "/following", "/followers")
async def handle_follow(app: Application, request: Request) -> Response:
with app.database.session(False) as s:
inboxes = [row['actor'] for row in s.get_inboxes()]
inboxes = [row["actor"] for row in s.get_inboxes()]
msg = aputils.Message.new(
aputils.ObjectType.COLLECTION,
{
"id": f'https://{app.config.domain}{request.path}',
"id": f"https://{app.config.domain}{request.path}",
"totalItems": len(inboxes),
"items": inboxes
}
@ -177,21 +197,21 @@ async def handle_follow(app: Application, request: Request) -> Response:
return Response.new(msg, ctype = "activity")
@register_route(HttpMethod.GET, '/.well-known/webfinger')
@register_route(HttpMethod.GET, "/.well-known/webfinger")
async def get(app: Application, request: Request) -> Response:
try:
subject = request.query['resource']
subject = request.query["resource"]
except KeyError:
raise HttpError(400, 'missing "resource" query key')
raise HttpError(400, "missing \"resource\" query key")
if subject != f'acct:relay@{app.config.domain}':
raise HttpError(404, 'user not found')
if subject != f"acct:relay@{app.config.domain}":
raise HttpError(404, "user not found")
data = aputils.Webfinger.new(
handle = 'relay',
handle = "relay",
domain = app.config.domain,
actor = app.config.actor
)
return Response.new(data, ctype = 'json')
return Response.new(data, ctype = "json")

View file

@ -181,7 +181,16 @@ async def handle_login(
application = s.put_app_login(user)
return objects.Application.from_row(application)
return objects.Application(
application.client_id,
application.client_secret,
application.name,
application.website,
application.redirect_uri,
application.token,
application.created,
application.accessed
)
@Route(HttpMethod.GET, "/api/v1/app", "Application", True)
@ -343,7 +352,7 @@ async def handle_instance_add(
with app.database.session(False) as s:
if s.get_inbox(domain) is not None:
raise HttpError(404, 'Instance already in database')
raise HttpError(404, "Instance already in database")
if inbox is None:
try:
@ -396,7 +405,7 @@ async def handle_instance_update(
with app.database.session(False) as s:
if (instance := s.get_inbox(domain)) is None:
raise HttpError(404, 'Instance with domain not found')
raise HttpError(404, "Instance with domain not found")
row = s.put_inbox(
instance.domain,

View file

@ -31,11 +31,11 @@ if TYPE_CHECKING:
METHODS: dict[str, Method] = {}
ROUTES: list[tuple[str, str, HandlerCallback]] = []
DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob'
DEFAULT_REDIRECT: str = "urn:ietf:wg:oauth:2.0:oob"
ALLOWED_HEADERS: set[str] = {
'accept',
'authorization',
'content-type'
"accept",
"authorization",
"content-type"
}
@ -100,14 +100,14 @@ class Method:
return_type = get_origin(return_type)
if not issubclass(return_type, (Response, ApiObject, list)):
raise ValueError(f"Invalid return type '{return_type.__name__}' for {func.__name__}")
raise ValueError(f"Invalid return type \"{return_type.__name__}\" for {func.__name__}")
args = {key: value for key, value in inspect.signature(func).parameters.items()}
docstring, paramdocs = parse_docstring(func.__doc__ or "")
params = []
if func.__doc__ is None:
logging.warning(f"Missing docstring for '{func.__name__}'")
logging.warning(f"Missing docstring for \"{func.__name__}\"")
for key, value in args.items():
types: list[type[Any]] = []
@ -134,7 +134,7 @@ class Method:
))
if not paramdocs.get(key):
logging.warning(f"Missing docs for '{key}' parameter in '{func.__name__}'")
logging.warning(f"Missing docs for \"{key}\" parameter in \"{func.__name__}\"")
rtype = annotations.get("return") or type(None)
return cls(func.__name__, category, docstring, method, path, rtype, tuple(params))
@ -222,7 +222,7 @@ class Route:
if request.method != "OPTIONS" and self.require_token:
if (auth := request.headers.getone("Authorization", None)) is None:
raise HttpError(401, 'Missing token')
raise HttpError(401, "Missing token")
try:
authtype, code = auth.split(" ", 1)
@ -245,15 +245,15 @@ class Route:
request["application"] = application
if request.content_type in {'application/x-www-form-urlencoded', 'multipart/form-data'}:
if request.content_type in {"application/x-www-form-urlencoded", "multipart/form-data"}:
post_data = {key: value for key, value in (await request.post()).items()}
elif request.content_type == 'application/json':
elif request.content_type == "application/json":
try:
post_data = await request.json()
except JSONDecodeError:
raise HttpError(400, 'Invalid JSON data')
raise HttpError(400, "Invalid JSON data")
else:
post_data = {key: str(value) for key, value in request.query.items()}
@ -262,7 +262,7 @@ class Route:
response = await self.handler(get_app(), request, **post_data)
except HttpError as error:
return Response.new({'error': error.message}, error.status, ctype = "json")
return Response.new({"error": error.message}, error.status, ctype = "json")
headers = {
"Access-Control-Allow-Origin": "*",

View file

@ -19,7 +19,7 @@ if TYPE_CHECKING:
async def handle_home(app: Application, request: Request) -> Response:
with app.database.session() as conn:
context: dict[str, Any] = {
'instances': tuple(conn.get_inboxes())
"instances": tuple(conn.get_inboxes())
}
return Response.new_template(200, "page/home.haml", request, context)
@ -34,28 +34,28 @@ async def handle_api_doc(app: Application, request: Request) -> Response:
return Response.new_template(200, "page/docs.haml", request, context)
@register_route(HttpMethod.GET, '/login')
@register_route(HttpMethod.GET, "/login")
async def handle_login(app: Application, request: Request) -> Response:
context = {"redir": unquote(request.query.get("redir", "/"))}
return Response.new_template(200, "page/login.haml", request, context)
@register_route(HttpMethod.GET, '/logout')
@register_route(HttpMethod.GET, "/logout")
async def handle_logout(app: Application, request: Request) -> Response:
with app.database.session(True) as conn:
conn.del_app(request['token'].client_id, request['token'].client_secret)
conn.del_app(request["token"].client_id, request["token"].client_secret)
resp = Response.new_redir('/')
resp.del_cookie('user-token', domain = app.config.domain, path = '/')
resp = Response.new_redir("/")
resp.del_cookie("user-token", domain = app.config.domain, path = "/")
return resp
@register_route(HttpMethod.GET, '/admin')
@register_route(HttpMethod.GET, "/admin")
async def handle_admin(app: Application, request: Request) -> Response:
return Response.new_redir(f'/login?redir={request.path}', 301)
return Response.new_redir(f"/login?redir={request.path}", 301)
@register_route(HttpMethod.GET, '/admin/instances')
@register_route(HttpMethod.GET, "/admin/instances")
async def handle_admin_instances(
app: Application,
request: Request,
@ -64,20 +64,20 @@ async def handle_admin_instances(
with app.database.session() as conn:
context: dict[str, Any] = {
'instances': tuple(conn.get_inboxes()),
'requests': tuple(conn.get_requests())
"instances": tuple(conn.get_inboxes()),
"requests": tuple(conn.get_requests())
}
if error:
context['error'] = error
context["error"] = error
if message:
context['message'] = message
context["message"] = message
return Response.new_template(200, "page/admin-instances.haml", request, context)
@register_route(HttpMethod.GET, '/admin/whitelist')
@register_route(HttpMethod.GET, "/admin/whitelist")
async def handle_admin_whitelist(
app: Application,
request: Request,
@ -86,19 +86,19 @@ async def handle_admin_whitelist(
with app.database.session() as conn:
context: dict[str, Any] = {
'whitelist': tuple(conn.execute('SELECT * FROM whitelist ORDER BY domain ASC'))
"whitelist": tuple(conn.execute("SELECT * FROM whitelist ORDER BY domain ASC"))
}
if error:
context['error'] = error
context["error"] = error
if message:
context['message'] = message
context["message"] = message
return Response.new_template(200, "page/admin-whitelist.haml", request, context)
@register_route(HttpMethod.GET, '/admin/domain_bans')
@register_route(HttpMethod.GET, "/admin/domain_bans")
async def handle_admin_instance_bans(
app: Application,
request: Request,
@ -107,19 +107,19 @@ async def handle_admin_instance_bans(
with app.database.session() as conn:
context: dict[str, Any] = {
'bans': tuple(conn.execute('SELECT * FROM domain_bans ORDER BY domain ASC'))
"bans": tuple(conn.execute("SELECT * FROM domain_bans ORDER BY domain ASC"))
}
if error:
context['error'] = error
context["error"] = error
if message:
context['message'] = message
context["message"] = message
return Response.new_template(200, "page/admin-domain_bans.haml", request, context)
@register_route(HttpMethod.GET, '/admin/software_bans')
@register_route(HttpMethod.GET, "/admin/software_bans")
async def handle_admin_software_bans(
app: Application,
request: Request,
@ -128,19 +128,19 @@ async def handle_admin_software_bans(
with app.database.session() as conn:
context: dict[str, Any] = {
'bans': tuple(conn.execute('SELECT * FROM software_bans ORDER BY name ASC'))
"bans": tuple(conn.execute("SELECT * FROM software_bans ORDER BY name ASC"))
}
if error:
context['error'] = error
context["error"] = error
if message:
context['message'] = message
context["message"] = message
return Response.new_template(200, "page/admin-software_bans.haml", request, context)
@register_route(HttpMethod.GET, '/admin/users')
@register_route(HttpMethod.GET, "/admin/users")
async def handle_admin_users(
app: Application,
request: Request,
@ -149,29 +149,29 @@ async def handle_admin_users(
with app.database.session() as conn:
context: dict[str, Any] = {
'users': tuple(conn.execute('SELECT * FROM users ORDER BY username ASC'))
"users": tuple(conn.execute("SELECT * FROM users ORDER BY username ASC"))
}
if error:
context['error'] = error
context["error"] = error
if message:
context['message'] = message
context["message"] = message
return Response.new_template(200, "page/admin-users.haml", request, context)
@register_route(HttpMethod.GET, '/admin/config')
@register_route(HttpMethod.GET, "/admin/config")
async def handle_admin_config(
app: Application,
request: Request,
message: str | None = None) -> Response:
context: dict[str, Any] = {
'themes': tuple(THEMES.keys()),
'levels': tuple(level.name for level in LogLevel),
'message': message,
'desc': {
"themes": tuple(THEMES.keys()),
"levels": tuple(level.name for level in LogLevel),
"message": message,
"desc": {
"name": "Name of the relay to be displayed in the header of the pages and in " +
"the actor endpoint.", # noqa: E131
"note": "Description of the relay to be displayed on the front page and as the " +
@ -187,36 +187,36 @@ async def handle_admin_config(
return Response.new_template(200, "page/admin-config.haml", request, context)
@register_route(HttpMethod.GET, '/manifest.json')
@register_route(HttpMethod.GET, "/manifest.json")
async def handle_manifest(app: Application, request: Request) -> Response:
with app.database.session(False) as conn:
config = conn.get_config_all()
theme = THEMES[config.theme]
data = {
'background_color': theme['background'],
'categories': ['activitypub'],
'description': 'Message relay for the ActivityPub network',
'display': 'standalone',
'name': config['name'],
'orientation': 'portrait',
'scope': f"https://{app.config.domain}/",
'short_name': 'ActivityRelay',
'start_url': f"https://{app.config.domain}/",
'theme_color': theme['primary']
"background_color": theme["background"],
"categories": ["activitypub"],
"description": "Message relay for the ActivityPub network",
"display": "standalone",
"name": config["name"],
"orientation": "portrait",
"scope": f"https://{app.config.domain}/",
"short_name": "ActivityRelay",
"start_url": f"https://{app.config.domain}/",
"theme_color": theme["primary"]
}
return Response.new(data, ctype = 'webmanifest')
return Response.new(data, ctype = "webmanifest")
@register_route(HttpMethod.GET, '/theme/{theme}.css') # type: ignore[arg-type]
@register_route(HttpMethod.GET, "/theme/{theme}.css") # type: ignore[arg-type]
async def handle_theme(app: Application, request: Request, theme: str) -> Response:
try:
context: dict[str, Any] = {
'theme': THEMES[theme]
"theme": THEMES[theme]
}
except KeyError:
return Response.new('Invalid theme', 404)
return Response.new("Invalid theme", 404)
return Response.new_template(200, "variables.css", request, context, ctype = "css")

View file

@ -21,16 +21,18 @@ VERSION = __version__
if File(__file__).join("../../../.git").resolve().exists:
try:
commit_label = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode('ascii')
VERSION = f'{__version__} {commit_label}'
commit_label = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode("ascii")
VERSION = f"{__version__} {commit_label}"
del commit_label
except Exception:
pass
NODEINFO_PATHS = [
'/nodeinfo/{niversion:\\d.\\d}.json',
'/nodeinfo/{niversion:\\d.\\d}'
"/nodeinfo/{niversion:\\d.\\d}.json",
"/nodeinfo/{niversion:\\d.\\d}"
]
@ -40,23 +42,23 @@ async def handle_nodeinfo(app: Application, request: Request, niversion: str) ->
inboxes = conn.get_inboxes()
nodeinfo = aputils.Nodeinfo.new(
name = 'activityrelay',
name = "activityrelay",
version = VERSION,
protocols = ['activitypub'],
open_regs = not conn.get_config('whitelist-enabled'),
protocols = ["activitypub"],
open_regs = not conn.get_config("whitelist-enabled"),
users = 1,
repo = 'https://git.pleroma.social/pleroma/relay' if niversion == '2.1' else None,
repo = "https://codeberg.org/barkshark/activityrelay" if niversion == "2.1" else None,
metadata = {
'approval_required': conn.get_config('approval-required'),
'peers': [inbox['domain'] for inbox in inboxes]
"approval_required": conn.get_config("approval-required"),
"peers": [inbox["domain"] for inbox in inboxes]
}
)
return Response.new(nodeinfo, ctype = 'json')
return Response.new(nodeinfo, ctype = "json")
@register_route(HttpMethod.GET, '/.well-known/nodeinfo')
@register_route(HttpMethod.GET, "/.well-known/nodeinfo")
async def handle_wk_nodeinfo(app: Application, request: Request) -> Response:
data = aputils.WellKnownNodeinfo.new_template(app.config.domain)
return Response.new(data, ctype = 'json')
return Response.new(data, ctype = "json")

View file

@ -96,16 +96,16 @@ class PushWorker(Process):
await self.client.post(item.inbox, item.message, item.instance)
except HttpError as e:
logging.error('HTTP Error when pushing to %s: %i %s', item.inbox, e.status, e.message)
logging.error("HTTP Error when pushing to %s: %i %s", item.inbox, e.status, e.message)
except AsyncTimeoutError:
logging.error('Timeout when pushing to %s', item.domain)
logging.error("Timeout when pushing to %s", item.domain)
except ClientConnectionError as e:
logging.error('Failed to connect to %s for message push: %s', item.domain, str(e))
logging.error("Failed to connect to %s for message push: %s", item.domain, str(e))
except ClientSSLError as e:
logging.error('SSL error when pushing to %s: %s', item.domain, str(e))
logging.error("SSL error when pushing to %s: %s", item.domain, str(e))
class PushWorkers(list[PushWorker]):