Merge branch 'dev' into 'master'

0.3.0

See merge request pleroma/relay!57
This commit is contained in:
Izalia Mae 2024-03-10 23:02:59 +00:00
commit a4158c8d7e
35 changed files with 1724 additions and 346 deletions

View file

@ -1,2 +0,0 @@
include data/statements.sql
include data/swagger.yaml

View file

@ -1,3 +1,4 @@
flake8 == 7.0.0 flake8 == 7.0.0
pyinstaller == 6.3.0 pyinstaller == 6.3.0
pylint == 3.0 pylint == 3.0
watchdog == 4.0.0

View file

@ -9,7 +9,7 @@ use `python3 -m relay` if installed via pip or `~/.local/bin/activityrelay` if i
## Run ## Run
Run the relay. Optionally add `-d` or `--dev` to enable auto-reloading on code changes. Run the relay.
activityrelay run activityrelay run

View file

@ -19,10 +19,10 @@ proxy is on the same host.
port: 8080 port: 8080
### Web Workers ### Push Workers
The number of processes to spawn for handling web requests. Leave it at 0 to automatically detect The number of processes to spawn for pushing messages to subscribed instances. Leave it at 0 to
how many processes should be spawned. automatically detect how many processes should be spawned.
workers: 0 workers: 0

View file

@ -15,7 +15,7 @@ the [official pipx docs](https://pypa.github.io/pipx/installation/) for more in-
Now simply install ActivityRelay directly from git Now simply install ActivityRelay directly from git
pipx install git+https://git.pleroma.social/pleroma/relay@0.2.5 pipx install git+https://git.pleroma.social/pleroma/relay@0.3.0
Or from a cloned git repo. Or from a cloned git repo.
@ -39,7 +39,7 @@ be installed via [pyenv](https://github.com/pyenv/pyenv).
The instructions for installation via pip are very similar to pipx. Installation can be done from The instructions for installation via pip are very similar to pipx. Installation can be done from
git git
python3 -m pip install git+https://git.pleroma.social/pleroma/relay@0.2.5 python3 -m pip install git+https://git.pleroma.social/pleroma/relay@0.3.0
or a cloned git repo. or a cloned git repo.

View file

@ -1,55 +0,0 @@
# -*- mode: python ; coding: utf-8 -*-
import importlib
from pathlib import Path
block_cipher = None
aiohttp_swagger_path = Path(importlib.import_module('aiohttp_swagger').__file__).parent
a = Analysis(
['relay/__main__.py'],
pathex=[],
binaries=[],
datas=[
('relay/data', 'relay/data'),
(aiohttp_swagger_path, 'aiohttp_swagger')
],
hiddenimports=[
'gunicorn',
'gunicorn.glogging'
],
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
a.binaries,
a.zipfiles,
a.datas,
[],
name='activityrelay',
icon=None,
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)

View file

@ -1 +1 @@
__version__ = '0.2.6' __version__ = '0.3.0'

View file

@ -1,5 +1,8 @@
import multiprocessing
from relay.manage import main from relay.manage import main
if __name__ == '__main__': if __name__ == '__main__':
multiprocessing.freeze_support()
main() main()

View file

@ -1,17 +1,17 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import os import multiprocessing
import signal import signal
import subprocess
import sys
import time import time
import traceback
import typing import typing
from aiohttp import web from aiohttp import web
from aiohttp_swagger import setup_swagger from aiohttp_swagger import setup_swagger
from aputils.signer import Signer from aputils.signer import Signer
from datetime import datetime, timedelta from datetime import datetime, timedelta
from queue import Empty
from threading import Event, Thread from threading import Event, Thread
from . import logger as logging from . import logger as logging
@ -19,20 +19,17 @@ from .cache import get_cache
from .config import Config from .config import Config
from .database import get_database from .database import get_database
from .http_client import HttpClient from .http_client import HttpClient
from .misc import check_open_port from .misc import check_open_port, get_resource
from .template import Template
from .views import VIEWS from .views import VIEWS
from .views.api import handle_api_path from .views.api import handle_api_path
from .views.frontend import handle_frontend_path
try:
from importlib.resources import files as pkgfiles
except ImportError:
from importlib_resources import files as pkgfiles
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from collections.abc import Coroutine
from tinysql import Database, Row from tinysql import Database, Row
from .cache import Cache from .cache import Cache
from .misc import Message from .misc import Message, Response
# pylint: disable=unsubscriptable-object # pylint: disable=unsubscriptable-object
@ -40,29 +37,35 @@ if typing.TYPE_CHECKING:
class Application(web.Application): class Application(web.Application):
DEFAULT: Application = None DEFAULT: Application = None
def __init__(self, cfgpath: str, gunicorn: bool = False): def __init__(self, cfgpath: str | None, dev: bool = False):
web.Application.__init__(self, web.Application.__init__(self,
middlewares = [ middlewares = [
handle_api_path handle_api_path,
handle_frontend_path,
handle_response_headers
] ]
) )
Application.DEFAULT = self Application.DEFAULT = self
self['proc'] = None self['running'] = None
self['signer'] = None self['signer'] = None
self['start_time'] = None self['start_time'] = None
self['cleanup_thread'] = None self['cleanup_thread'] = None
self['dev'] = dev
self['config'] = Config(cfgpath, load = True) self['config'] = Config(cfgpath, load = True)
self['database'] = get_database(self.config) self['database'] = get_database(self.config)
self['client'] = HttpClient() self['client'] = HttpClient()
self['cache'] = get_cache(self) self['cache'] = get_cache(self)
self['cache'].setup()
self['template'] = Template(self)
self['push_queue'] = multiprocessing.Queue()
self['workers'] = []
if not gunicorn: self.cache.setup()
return
self.on_response_prepare.append(handle_access_log) # self.on_response_prepare.append(handle_access_log)
self.on_cleanup.append(handle_cleanup) self.on_cleanup.append(handle_cleanup)
for path, view in VIEWS: for path, view in VIEWS:
@ -70,7 +73,7 @@ class Application(web.Application):
setup_swagger(self, setup_swagger(self,
ui_version = 3, ui_version = 3,
swagger_from_file = pkgfiles('relay').joinpath('data', 'swagger.yaml') swagger_from_file = get_resource('data/swagger.yaml')
) )
@ -119,16 +122,23 @@ class Application(web.Application):
def push_message(self, inbox: str, message: Message, instance: Row) -> None: def push_message(self, inbox: str, message: Message, instance: Row) -> None:
asyncio.ensure_future(self.client.post(inbox, message, instance)) self['push_queue'].put((inbox, message, instance))
def run(self, dev: bool = False) -> None: def run(self) -> None:
self.start(dev) if self["running"]:
return
while self['proc'] and self['proc'].poll() is None: domain = self.config.domain
time.sleep(0.1) host = self.config.listen
port = self.config.port
self.stop() if not check_open_port(host, port):
logging.error(f'A server is already running on {host}:{port}')
return
logging.info(f'Starting webserver at {domain} ({host}:{port})')
asyncio.run(self.handle_run())
def set_signal_handler(self, startup: bool) -> None: def set_signal_handler(self, startup: bool) -> None:
@ -141,56 +151,54 @@ class Application(web.Application):
pass pass
def stop(self, *_):
self['running'] = False
def start(self, dev: bool = False) -> None:
if self['proc']:
return
if not check_open_port(self.config.listen, self.config.port): async def handle_run(self):
logging.error('Server already running on %s:%s', self.config.listen, self.config.port) self['running'] = True
return
cmd = [
sys.executable, '-m', 'gunicorn',
'relay.application:main_gunicorn',
'--bind', f'{self.config.listen}:{self.config.port}',
'--worker-class', 'aiohttp.GunicornWebWorker',
'--workers', str(self.config.workers),
'--env', f'CONFIG_FILE={self.config.path}',
'--reload-extra-file', pkgfiles('relay').joinpath('data', 'swagger.yaml'),
'--reload-extra-file', pkgfiles('relay').joinpath('data', 'statements.sql')
]
if dev:
cmd.append('--reload')
self.set_signal_handler(True) self.set_signal_handler(True)
self['proc'] = subprocess.Popen(cmd) # pylint: disable=consider-using-with
self['database'].connect()
self['cache'].setup()
self['cleanup_thread'] = CacheCleanupThread(self) self['cleanup_thread'] = CacheCleanupThread(self)
self['cleanup_thread'].start() self['cleanup_thread'].start()
for _ in range(self.config.workers):
worker = PushWorker(self['push_queue'])
worker.start()
def stop(self, *_) -> None: self['workers'].append(worker)
if not self['proc']:
return
self['cleanup_thread'].stop() runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"')
self['proc'].terminate() await runner.setup()
time_wait = 0.0
while self['proc'].poll() is None: site = web.TCPSite(runner,
time.sleep(0.1) host = self.config.listen,
time_wait += 0.1 port = self.config.port,
reuse_address = True
)
if time_wait >= 5.0: await site.start()
self['proc'].kill() self['starttime'] = datetime.now()
break
while self['running']:
await asyncio.sleep(0.25)
await site.stop()
for worker in self['workers']: # pylint: disable=not-an-iterable
worker.stop()
self.set_signal_handler(False) self.set_signal_handler(False)
self['proc'] = None
self.cache.close() self['starttime'] = None
self.database.disconnect() self['running'] = False
self['cleanup_thread'].stop()
self['workers'].clear()
self['database'].disconnect()
self['cache'].close()
class CacheCleanupThread(Thread): class CacheCleanupThread(Thread):
@ -217,38 +225,55 @@ class CacheCleanupThread(Thread):
self.running.clear() self.running.clear()
async def handle_access_log(request: web.Request, response: web.Response) -> None: class PushWorker(multiprocessing.Process):
address = request.headers.get( def __init__(self, queue: multiprocessing.Queue):
'X-Forwarded-For', multiprocessing.Process.__init__(self)
request.headers.get( self.queue = queue
'X-Real-Ip', self.shutdown = multiprocessing.Event()
request.remote
)
)
logging.info(
'%s "%s %s" %i %i "%s"', def stop(self) -> None:
address, self.shutdown.set()
request.method,
request.path,
response.status, def run(self) -> None:
response.content_length or 0, asyncio.run(self.handle_queue())
request.headers.get('User-Agent', 'n/a')
)
async def handle_queue(self) -> None:
client = HttpClient()
while not self.shutdown.is_set():
try:
inbox, message, instance = self.queue.get(block=True, timeout=0.25)
await client.post(inbox, message, instance)
except Empty:
pass
## make sure an exception doesn't bring down the worker
except Exception:
traceback.print_exc()
await client.close()
@web.middleware
async def handle_response_headers(request: web.Request, handler: Coroutine) -> Response:
resp = await handler(request)
resp.headers['Server'] = 'ActivityRelay'
if not request.app['dev'] and request.path.endswith(('.css', '.js')):
# cache for 2 weeks
resp.headers['Cache-Control'] = 'public,max-age=1209600,immutable'
else:
resp.headers['Cache-Control'] = 'no-store'
return resp
async def handle_cleanup(app: Application) -> None: async def handle_cleanup(app: Application) -> None:
await app.client.close() await app.client.close()
app.cache.close() app.cache.close()
app.database.disconnect() app.database.disconnect()
async def main_gunicorn():
try:
app = Application(os.environ['CONFIG_FILE'], gunicorn = True)
except KeyError:
logging.error('Failed to set "CONFIG_FILE" environment. Trying to run without gunicorn?')
raise RuntimeError from None
return app

View file

@ -91,7 +91,6 @@ class Cache(ABC):
def __init__(self, app: Application): def __init__(self, app: Application):
self.app = app self.app = app
self.setup()
@abstractmethod @abstractmethod
@ -158,8 +157,8 @@ class SqlCache(Cache):
def __init__(self, app: Application): def __init__(self, app: Application):
self._db = get_database(app.config)
Cache.__init__(self, app) Cache.__init__(self, app)
self._db = None
def get(self, namespace: str, key: str) -> Item: def get(self, namespace: str, key: str) -> Item:
@ -232,6 +231,10 @@ class SqlCache(Cache):
def setup(self) -> None: def setup(self) -> None:
if self._db and self._db.connected:
return
self._db = get_database(self.app.config)
self._db.connect() self._db.connect()
with self._db.session(True) as conn: with self._db.session(True) as conn:
@ -240,6 +243,9 @@ class SqlCache(Cache):
def close(self) -> None: def close(self) -> None:
if not self._db:
return
self._db.disconnect() self._db.disconnect()
self._db = None self._db = None
@ -247,7 +253,11 @@ class SqlCache(Cache):
@register_cache @register_cache
class RedisCache(Cache): class RedisCache(Cache):
name: str = 'redis' name: str = 'redis'
_rd: Redis
def __init__(self, app: Application):
Cache.__init__(self, app)
self._rd = None
@property @property
@ -322,6 +332,9 @@ class RedisCache(Cache):
def setup(self) -> None: def setup(self) -> None:
if self._rd:
return
options = { options = {
'client_name': f'ActivityRelay_{self.app.config.domain}', 'client_name': f'ActivityRelay_{self.app.config.domain}',
'decode_responses': True, 'decode_responses': True,
@ -341,5 +354,8 @@ class RedisCache(Cache):
def close(self) -> None: def close(self) -> None:
if not self._rd:
return
self._rd.close() self._rd.close()
self._rd = None self._rd = None

View file

@ -2,10 +2,12 @@ from __future__ import annotations
import getpass import getpass
import os import os
import platform
import typing import typing
import yaml import yaml
from pathlib import Path from pathlib import Path
from platformdirs import user_config_dir
from .misc import IS_DOCKER from .misc import IS_DOCKER
@ -13,11 +15,19 @@ if typing.TYPE_CHECKING:
from typing import Any from typing import Any
if platform.system() == 'Windows':
import multiprocessing
CORE_COUNT = multiprocessing.cpu_count()
else:
CORE_COUNT = len(os.sched_getaffinity(0))
DEFAULTS: dict[str, Any] = { DEFAULTS: dict[str, Any] = {
'listen': '0.0.0.0', 'listen': '0.0.0.0',
'port': 8080, 'port': 8080,
'domain': 'relay.example.com', 'domain': 'relay.example.com',
'workers': len(os.sched_getaffinity(0)), 'workers': CORE_COUNT,
'db_type': 'sqlite', 'db_type': 'sqlite',
'ca_type': 'database', 'ca_type': 'database',
'sq_path': 'relay.sqlite3', 'sq_path': 'relay.sqlite3',
@ -42,7 +52,11 @@ if IS_DOCKER:
class Config: class Config:
def __init__(self, path: str, load: bool = False): def __init__(self, path: str, load: bool = False):
self.path = Path(path).expanduser().resolve() if path:
self.path = Path(path).expanduser().resolve()
else:
self.path = Config.get_config_dir()
self.listen = None self.listen = None
self.port = None self.port = None
@ -73,6 +87,24 @@ class Config:
self.save() self.save()
@staticmethod
def get_config_dir(path: str | None = None) -> Path:
if path:
return Path(path).expanduser().resolve()
dirs = (
Path("relay.yaml").resolve(),
Path(user_config_dir("activityrelay"), "relay.yaml"),
Path("/etc/activityrelay/relay.yaml")
)
for directory in dirs:
if directory.exists():
return directory
return dirs[0]
@property @property
def sqlite_path(self) -> Path: def sqlite_path(self) -> Path:
if not os.path.isabs(self.sq_path): if not os.path.isabs(self.sq_path):
@ -154,7 +186,30 @@ class Config:
def save(self) -> None: def save(self) -> None:
self.path.parent.mkdir(exist_ok = True, parents = True) self.path.parent.mkdir(exist_ok = True, parents = True)
config = { with self.path.open('w', encoding = 'utf-8') as fd:
yaml.dump(self.to_dict(), fd, sort_keys = False)
def set(self, key: str, value: Any) -> None:
if key not in DEFAULTS:
raise KeyError(key)
if key in {'port', 'pg_port', 'workers'} and not isinstance(value, int):
if (value := int(value)) < 1:
if key == 'port':
value = 8080
elif key == 'pg_port':
value = 5432
elif key == 'workers':
value = len(os.sched_getaffinity(0))
setattr(self, key, value)
def to_dict(self) -> dict[str, Any]:
return {
'listen': self.listen, 'listen': self.listen,
'port': self.port, 'port': self.port,
'domain': self.domain, 'domain': self.domain,
@ -178,24 +233,3 @@ class Config:
'refix': self.rd_prefix 'refix': self.rd_prefix
} }
} }
with self.path.open('w', encoding = 'utf-8') as fd:
yaml.dump(config, fd, sort_keys = False)
def set(self, key: str, value: Any) -> None:
if key not in DEFAULTS:
raise KeyError(key)
if key in {'port', 'pg_port', 'workers'} and not isinstance(value, int):
if (value := int(value)) < 1:
if key == 'port':
value = 8080
elif key == 'pg_port':
value = 5432
elif key == 'workers':
value = len(os.sched_getaffinity(0))
setattr(self, key, value)

View file

@ -3,17 +3,12 @@ from __future__ import annotations
import bsql import bsql
import typing import typing
from .config import get_default_value from .config import CONFIG_DEFAULTS, THEMES, get_default_value
from .connection import RELAY_SOFTWARE, Connection from .connection import RELAY_SOFTWARE, Connection
from .schema import TABLES, VERSIONS, migrate_0 from .schema import TABLES, VERSIONS, migrate_0
from .. import logger as logging from .. import logger as logging
from ..misc import get_resource
try:
from importlib.resources import files as pkgfiles
except ImportError: # pylint: disable=duplicate-code
from importlib_resources import files as pkgfiles
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from .config import Config from .config import Config
@ -21,15 +16,15 @@ if typing.TYPE_CHECKING:
def get_database(config: Config, migrate: bool = True) -> bsql.Database: def get_database(config: Config, migrate: bool = True) -> bsql.Database:
options = { options = {
"connection_class": Connection, 'connection_class': Connection,
"pool_size": 5, 'pool_size': 5,
"tables": TABLES 'tables': TABLES
} }
if config.db_type == "sqlite": if config.db_type == 'sqlite':
db = bsql.Database.sqlite(config.sqlite_path, **options) db = bsql.Database.sqlite(config.sqlite_path, **options)
elif config.db_type == "postgres": elif config.db_type == 'postgres':
db = bsql.Database.postgresql( db = bsql.Database.postgresql(
config.pg_name, config.pg_name,
config.pg_host, config.pg_host,
@ -39,7 +34,7 @@ def get_database(config: Config, migrate: bool = True) -> bsql.Database:
**options **options
) )
db.load_prepared_statements(pkgfiles("relay").joinpath("data", "statements.sql")) db.load_prepared_statements(get_resource('data/statements.sql'))
db.connect() db.connect()
if not migrate: if not migrate:

View file

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import json
import typing import typing
from .. import logger as logging from .. import logger as logging
@ -10,12 +11,61 @@ if typing.TYPE_CHECKING:
from typing import Any from typing import Any
THEMES = {
'default': {
'text': '#DDD',
'background': '#222',
'primary': '#D85',
'primary-hover': '#DA8',
'section-background': '#333',
'table-background': '#444',
'border': '#444',
'message-text': '#DDD',
'message-background': '#335',
'message-border': '#446',
'error-text': '#DDD',
'error-background': '#533',
'error-border': '#644'
},
'pink': {
'text': '#DDD',
'background': '#222',
'primary': '#D69',
'primary-hover': '#D36',
'section-background': '#333',
'table-background': '#444',
'border': '#444',
'message-text': '#DDD',
'message-background': '#335',
'message-border': '#446',
'error-text': '#DDD',
'error-background': '#533',
'error-border': '#644'
},
'blue': {
'text': '#DDD',
'background': '#222',
'primary': '#69D',
'primary-hover': '#36D',
'section-background': '#333',
'table-background': '#444',
'border': '#444',
'message-text': '#DDD',
'message-background': '#335',
'message-border': '#446',
'error-text': '#DDD',
'error-background': '#533',
'error-border': '#644'
}
}
CONFIG_DEFAULTS: dict[str, tuple[str, Any]] = { CONFIG_DEFAULTS: dict[str, tuple[str, Any]] = {
'schema-version': ('int', 20240206), 'schema-version': ('int', 20240206),
'private-key': ('str', None),
'log-level': ('loglevel', logging.LogLevel.INFO), 'log-level': ('loglevel', logging.LogLevel.INFO),
'name': ('str', 'ActivityRelay'), 'name': ('str', 'ActivityRelay'),
'note': ('str', 'Make a note about your instance here.'), 'note': ('str', 'Make a note about your instance here.'),
'private-key': ('str', None), 'theme': ('str', 'default'),
'whitelist-enabled': ('bool', False) 'whitelist-enabled': ('bool', False)
} }
@ -24,6 +74,7 @@ CONFIG_CONVERT: dict[str, tuple[Callable, Callable]] = {
'str': (str, str), 'str': (str, str),
'int': (str, int), 'int': (str, int),
'bool': (str, boolean), 'bool': (str, boolean),
'json': (json.dumps, json.loads),
'loglevel': (lambda x: x.name, logging.LogLevel.parse) 'loglevel': (lambda x: x.name, logging.LogLevel.parse)
} }

View file

@ -8,10 +8,17 @@ from datetime import datetime, timezone
from urllib.parse import urlparse from urllib.parse import urlparse
from uuid import uuid4 from uuid import uuid4
from .config import CONFIG_DEFAULTS, get_default_type, get_default_value, serialize, deserialize from .config import (
CONFIG_DEFAULTS,
THEMES,
get_default_type,
get_default_value,
serialize,
deserialize
)
from .. import logger as logging from .. import logger as logging
from ..misc import get_app from ..misc import boolean, get_app
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from collections.abc import Iterator from collections.abc import Iterator
@ -95,6 +102,13 @@ class Connection(SqlConnection):
value = logging.LogLevel.parse(value) value = logging.LogLevel.parse(value)
logging.set_level(value) logging.set_level(value)
elif key == 'whitelist-enabled':
value = boolean(value)
elif key == 'theme':
if value not in THEMES:
raise ValueError(f'"{value}" is not a valid theme')
params = { params = {
'key': key, 'key': key,
'value': serialize(key, value) if value is not None else None, 'value': serialize(key, value) if value is not None else None,
@ -252,10 +266,10 @@ class Connection(SqlConnection):
params = {} params = {}
if reason: if reason is not None:
params['reason'] = reason params['reason'] = reason
if note: if note is not None:
params['note'] = note params['note'] = note
statement = Update('domain_bans', params) statement = Update('domain_bans', params)
@ -307,10 +321,10 @@ class Connection(SqlConnection):
params = {} params = {}
if reason: if reason is not None:
params['reason'] = reason params['reason'] = reason
if note: if note is not None:
params['note'] = note params['note'] = note
statement = Update('software_bans', params) statement = Update('software_bans', params)

164
relay/dev.py Normal file
View file

@ -0,0 +1,164 @@
import click
import platform
import subprocess
import sys
import time
from datetime import datetime
from pathlib import Path
from tempfile import TemporaryDirectory
from . import __version__
try:
from watchdog.observers import Observer
from watchdog.events import PatternMatchingEventHandler
except ImportError:
class PatternMatchingEventHandler:
pass
SCRIPT = Path(__file__).parent
REPO = SCRIPT.parent
IGNORE_EXT = {
'.py',
'.pyc'
}
@click.group('cli')
def cli():
'Useful commands for development'
@cli.command('install')
def cli_install():
cmd = [
sys.executable, '-m', 'pip', 'install',
'-r', 'requirements.txt',
'-r', 'dev-requirements.txt'
]
subprocess.run(cmd, check = False)
@cli.command('lint')
@click.argument('path', required = False, default = 'relay')
def cli_lint(path):
subprocess.run([sys.executable, '-m', 'flake8', path], check = False)
subprocess.run([sys.executable, '-m', 'pylint', path], check = False)
@cli.command('build')
def cli_build():
with TemporaryDirectory() as tmp:
arch = 'amd64' if sys.maxsize >= 2**32 else 'i386'
cmd = [
sys.executable, '-m', 'PyInstaller',
'--collect-data', 'relay',
'--collect-data', 'aiohttp_swagger',
'--hidden-import', 'pg8000',
'--hidden-import', 'sqlite3',
'--name', f'activityrelay-{__version__}-{platform.system().lower()}-{arch}',
'--workpath', tmp,
'--onefile', 'relay/__main__.py',
]
if platform.system() == 'Windows':
cmd.append('--console')
# putting the spec path on a different drive than the source dir breaks
if str(SCRIPT)[0] == tmp[0]:
cmd.extend(['--specpath', tmp])
else:
cmd.append('--strip')
cmd.extend(['--specpath', tmp])
subprocess.run(cmd, check = False)
@cli.command('run')
def cli_run():
print('Starting process watcher')
handler = WatchHandler()
handler.run_proc()
watcher = Observer()
watcher.schedule(handler, str(SCRIPT), recursive=True)
watcher.start()
try:
while True:
handler.proc.stdin.write(sys.stdin.read().encode('UTF-8'))
handler.proc.stdin.flush()
except KeyboardInterrupt:
pass
handler.kill_proc()
watcher.stop()
watcher.join()
class WatchHandler(PatternMatchingEventHandler):
patterns = ['*.py']
cmd = [sys.executable, '-m', 'relay', 'run', '-d']
def __init__(self):
PatternMatchingEventHandler.__init__(self)
self.proc = None
self.last_restart = None
def kill_proc(self):
if self.proc.poll() is not None:
return
print(f'Terminating process {self.proc.pid}')
self.proc.terminate()
sec = 0.0
while self.proc.poll() is None:
time.sleep(0.1)
sec += 0.1
if sec >= 5:
print('Failed to terminate. Killing process...')
self.proc.kill()
break
print('Process terminated')
def run_proc(self, restart=False):
timestamp = datetime.timestamp(datetime.now())
self.last_restart = timestamp if not self.last_restart else 0
if restart and self.proc.pid != '':
if timestamp - 3 < self.last_restart:
return
self.kill_proc()
# pylint: disable=consider-using-with
self.proc = subprocess.Popen(self.cmd, stdin = subprocess.PIPE)
self.last_restart = timestamp
print(f'Started process with PID {self.proc.pid}')
def on_any_event(self, event):
if event.event_type not in ['modified', 'created', 'deleted']:
return
self.run_proc(restart = True)
if __name__ == '__main__':
cli()

94
relay/frontend/base.haml Normal file
View file

@ -0,0 +1,94 @@
-macro menu_item(name, path)
-if view.request.path == path or (path != "/" and view.request.path.startswith(path))
%a.button(href="{{path}}" active="true") -> =name
-else
%a.button(href="{{path}}") -> =name
!!!
%html
%head
%title << {{config.name}}: {{page}}
%meta(charset="UTF-8")
%meta(name="viewport" content="width=device-width, initial-scale=1")
%link(rel="stylesheet" type="text/css" href="/theme/{{theme_name}}.css")
%link(rel="stylesheet" type="text/css" href="/style.css")
-block head
%body
#menu.section(visible="false")
.menu-head
%span.menu-title << Menu
%span#menu-close.button << &#10006;
{{menu_item("Home", "/")}}
-if view.request["user"]
{{menu_item("Instances", "/admin/instances")}}
{{menu_item("Whitelist", "/admin/whitelist")}}
{{menu_item("Domain Bans", "/admin/domain_bans")}}
{{menu_item("Software Bans", "/admin/software_bans")}}
{{menu_item("Users", "/admin/users")}}
{{menu_item("Config", "/admin/config")}}
{{menu_item("Logout", "/logout")}}
-else
{{menu_item("Login", "/login")}}
#container
#header.section
%span#menu-open << &#8286;
%span.title-container
%a.title(href="/") -> =config.name
-if view.request.path not in ["/", "/login"]
.page -> =page
.empty
-if error
.error.section -> =error
-if message
.message.section -> =message
#content(class="page-{{page.lower().replace(' ', '_')}}")
-block content
#footer.section
.col1
-if not view.request["user"]
%a(href="/login") << Login
-else
=view.request["user"]["username"]
(
%a(href="/logout") << Logout
)
.version
%a(href="https://git.pleroma.social/pleroma/relay")
ActivityRelay/{{version}}
%script(type="application/javascript")
const body = document.getElementById("container")
const menu = document.getElementById("menu");
const menu_open = document.getElementById("menu-open");
const menu_close = document.getElementById("menu-close");
menu_open.addEventListener("click", (event) => {
var new_value = menu.attributes.visible.nodeValue === "true" ? "false" : "true";
menu.attributes.visible.nodeValue = new_value;
});
menu_close.addEventListener("click", (event) => {
menu.attributes.visible.nodeValue = "false"
});
body.addEventListener("click", (event) => {
if (event.target === menu_open) {
return;
}
menu.attributes.visible.nodeValue = "false";
});

View file

@ -0,0 +1,37 @@
-extends "base.haml"
-set page="Config"
-block content
%form.section(action="/admin/config" method="POST")
.grid-2col
%label(for="name") << Name
%input(id = "name" name="name" placeholder="Relay Name" value="{{config.name or ''}}")
%label(for="description") << Description
%textarea(id="description" name="note" value="{{config.note}}") << {{config.note}}
%label(for="theme") << Color Theme
%select(id="theme" name="theme")
-for theme in themes
-if theme == config.theme
%option(value="{{theme}}" selected) -> =theme.title()
-else
%option(value="{{theme}}") -> =theme.title()
%label(for="log-level") << Log Level
%select(id="log-level" name="log-level")
-for level in LogLevel
-if level == config["log-level"]
%option(value="{{level.name}}" selected) -> =level.name.title()
-else
%option(value="{{level.name}}") -> =level.name.title()
%label(for="whitelist-enabled") << Whitelist
-if config["whitelist-enabled"]
%input(id="whitelist-enabled" name="whitelist-enabled" type="checkbox" checked)
-else
%input(id="whitelist-enabled" name="whitelist-enabled" type="checkbox")
%input(type="submit" value="Save")

View file

@ -0,0 +1,48 @@
-extends "base.haml"
-set page="Domain Bans"
-block content
%details.section
%summary << Ban Domain
%form(action="/admin/domain_bans" method="POST")
#add-item
%label(for="domain") << Domain
%input(type="domain" id="domain" name="domain" placeholder="Domain")
%label(for="reason") << Ban Reason
%textarea(id="reason" name="reason") << {{""}}
%label(for="note") << Admin Note
%textarea(id="note" name="note") << {{""}}
%input(type="submit" value="Ban Domain")
#data-table.section
%table
%thead
%tr
%td.domain << Instance
%td << Date
%td.remove
%tbody
-for ban in bans
%tr
%td.domain
%details
%summary -> =ban.domain
%form(action="/admin/domain_bans" method="POST")
.grid-2col
.reason << Reason
%textarea.reason(id="reason" name="reason") << {{ban.reason or ""}}
.note << Note
%textarea.note(id="note" name="note") << {{ban.note or ""}}
%input(type="hidden" name="domain" value="{{ban.domain}}")
%input(type="submit" value="Update")
%td.date
=ban.created.strftime("%Y-%m-%d")
%td.remove
%a(href="/admin/domain_bans/delete/{{ban.domain}}" title="Unban domain") << &#10006;

View file

@ -0,0 +1,44 @@
-extends "base.haml"
-set page="Instances"
-block content
%details.section
%summary << Add Instance
%form(action="/admin/instances" method="POST")
#add-item
%label(for="domain") << Domain
%input(type="domain" id="domain" name="domain" placeholder="Domain")
%label(for="actor") << Actor URL
%input(type="url" id="actor" name="actor" placeholder="Actor URL")
%label(for="inbox") << Inbox URL
%input(type="url" id="inbox" name="inbox" placeholder="Inbox URL")
%label(for="software") << Software
%input(name="software" id="software" placeholder="software")
%input(type="submit" value="Add Instance")
#data-table.section
%table
%thead
%tr
%td.instance << Instance
%td.software << Software
%td.date << Joined
%td.remove
%tbody
-for instance in instances
%tr
%td.instance
%a(href="https://{{instance.domain}}/" target="_new") -> =instance.domain
%td.software
=instance.software or "n/a"
%td.date
=instance.created.strftime("%Y-%m-%d")
%td.remove
%a(href="/admin/instances/delete/{{instance.domain}}" title="Remove Instance") << &#10006;

View file

@ -0,0 +1,48 @@
-extends "base.haml"
-set page="Software Bans"
-block content
%details.section
%summary << Ban Software
%form(action="/admin/software_bans" method="POST")
#add-item
%label(for="name") << Name
%input(id="name" name="name" placeholder="Name")
%label(for="reason") << Ban Reason
%textarea(id="reason" name="reason") << {{""}}
%label(for="note") << Admin Note
%textarea(id="note" name="note") << {{""}}
%input(type="submit" value="Ban Software")
#data-table.section
%table
%thead
%tr
%td.name << Instance
%td << Date
%td.remove
%tbody
-for ban in bans
%tr
%td.name
%details
%summary -> =ban.name
%form(action="/admin/software_bans" method="POST")
.grid-2col
.reason << Reason
%textarea.reason(id="reason" name="reason") << {{ban.reason or ""}}
.note << Note
%textarea.note(id="note" name="note") << {{ban.note or ""}}
%input(type="hidden" name="name" value="{{ban.name}}")
%input(type="submit" value="Update")
%td.date
=ban.created.strftime("%Y-%m-%d")
%td.remove
%a(href="/admin/software_bans/delete/{{ban.name}}" title="Unban software") << &#10006;

View file

@ -0,0 +1,44 @@
-extends "base.haml"
-set page="Users"
-block content
%details.section
%summary << Add User
%form(action="/admin/users", method="POST")
#add-item
%label(for="username") << Username
%input(id="username" name="username" placeholder="Username")
%label(for="password") << Password
%input(type="password" id="password" name="password" placeholder="Password")
%label(for="password2") << Password Again
%input(type="password" id="password2" name="password2" placeholder="Password Again")
%label(for="handle") << Handle
%input(type="email" name="handle" id="handle" placeholder="handle")
%input(type="submit" value="Add User")
#data-table.section
%table
%thead
%tr
%td.username << Username
%td.handle << Handle
%td.date << Joined
%td.remove
%tbody
-for user in users
%tr
%td.username
=user.username
%td.handle
=user.handle or "n/a"
%td.date
=user.created.strftime("%Y-%m-%d")
%td.remove
%a(href="/admin/users/delete/{{user.username}}" title="Remove User") << &#10006;

View file

@ -0,0 +1,31 @@
-extends "base.haml"
-set page="Whitelist"
-block content
%details.section
%summary << Add Domain
%form(action="/admin/whitelist" method="POST")
#add-item
%label(for="domain") << Domain
%input(type="domain" id="domain" name="domain" placeholder="Domain")
%input(type="submit" value="Add Domain")
#data-table.section
%table
%thead
%tr
%td.domain << Domain
%td.date << Added
%td.remove
%tbody
-for item in whitelist
%tr
%td.domain
=item.domain
%td.date
=item.created.strftime("%Y-%m-%d")
%td.remove
%a(href="/admin/whitelist/delete/{{item.domain}}" title="Remove whitlisted domain") << &#10006;

View file

@ -0,0 +1,36 @@
-extends "base.haml"
-set page = "Home"
-block content
.section
-for line in config.note.splitlines()
-if line
%p -> =line
.section
%p
This is an Activity Relay for fediverse instances.
%p
You may subscribe to this relay with the address:
%a(href="https://{{domain}}/actor") << https://{{domain}}/actor</a>
-if config["whitelist-enabled"]
%p.section.message
Note: The whitelist is enabled on this instance. Ask the admin to add your instance
before joining.
#data-table.section
%table
%thead
%tr
%td.instance << Instance
%td.date << Joined
%tbody
-for instance in instances
%tr
%td.instance -> %a(href="https://{{instance.domain}}/" target="_new")
=instance.domain
%td.date
=instance.created.strftime("%Y-%m-%d")

View file

@ -0,0 +1,12 @@
-extends "base.haml"
-set page="Login"
-block content
%form.section(action="/login" method="POST")
.grid-2col
%label(for="username") << Username
%input(id="username" name="username" placeholder="Username" value="{{username or ''}}")
%label(for="password") << Password
%input(id="password" name="password" placeholder="Password" type="password")
%input(type="submit" value="Login")

301
relay/frontend/style.css Normal file
View file

@ -0,0 +1,301 @@
a {
color: var(--primary);
text-decoration: none;
}
a:hover {
color: var(--primary-hover);
text-decoration: underline;
}
body {
color: var(--text);
background-color: #222;
margin: var(--spacing);
font-family: sans serif;
}
details *:nth-child(2) {
margin-top: 5px;
}
details summary {
cursor: pointer;
}
form input[type="submit"] {
display: block;
margin: 0 auto;
}
p {
line-height: 1em;
margin: 0px;
}
p:not(:first-child) {
margin-top: 1em;
}
p:not(:last-child) {
margin-bottom: 1em;
}
table {
border: 1px solid var(--primary);
border-radius: 5px;
border-spacing: 0px;
width: 100%;
}
table tbody tr:nth-child(even) td {
background-color: var(--section-background);
}
thead td:first-child {
border-top-left-radius: 3px;
}
thead td:last-child {
border-top-right-radius: 3px;
}
tbody tr:last-child td:first-child {
border-bottom-left-radius: 5px;
}
tbody tr:last-child td:last-child {
border-bottom-right-radius: 5px;
}
table td {
padding: 5px;
}
table thead td {
background-color: var(--primary);
color: var(--background);
text-align: center;
}
table tbody td {
background-color: var(--table-background);
}
textarea {
height: calc(5em);
}
#container {
width: 1024px;
margin: 0px auto;
}
#header {
display: grid;
grid-template-columns: 50px auto 50px;
justify-items: center;
align-items: center;
}
#header > * {
font-size: 2em;
}
#header .title-container {
text-align: center;
}
#header .title {
font-weight: bold;
line-height: 0.75em;
}
#header .page {
font-size: 0.5em;
}
#menu {
padding: 0px;
position: fixed;
top: 0px;
left: 0px;
margin: 0px;
height: 100%;
width: max-content;
z-index: 1;
font-size: 1.5em;
min-width: 300px;
}
#menu[visible="false"] {
visibility: hidden;
}
#menu > a {
margin: var(--spacing);
display: block;
border-radius: 5px;
padding: 5px;
}
#menu > a[active="true"]:not(:hover) {
background-color: var(--primary-hover);
border-color: transparent;
}
#menu .menu-head {
display: grid;
grid-template-columns: auto max-content;
margin: var(--spacing);
}
#menu .menu-head > * {
font-weight: bold;
}
#menu .menu-title {
color: var(--primary);
}
#menu-open {
color: var(--primary);
}
#menu-open:hover {
color: var(--primary-hover);
}
#menu-open, #menu-close {
cursor: pointer;
}
#menu-open, #menu-close {
min-width: 35px;
text-align: center;
}
#footer {
display: grid;
grid-template-columns: auto auto;
}
#footer .version {
text-align: right
}
#add-item {
display: grid;
grid-template-columns: max-content auto;
grid-gap: var(--spacing);
margin-top: var(--spacing);
margin-bottom: var(--spacing);
align-items: center;
}
#data-table td:first-child {
width: 100%;
}
#data-table .date {
width: max-content;
text-align: right;
}
.button {
background-color: var(--primary);
border: 1px solid var(--primary);
color: var(--background);
border-radius: 5px;
}
.button:hover {
background-color: var(--background);
color: var(--primary);
text-decoration: None;
}
.container {
display: grid;
grid-template-columns: max-content auto;
}
.error, .message {
text-align: center;
}
.error{
color: var(--error-text) !important;
background-color: var(--error-background) !important;
border: 1px solid var(--error-border) !important;
}
.grid-2col {
display: grid;
grid-template-columns: max-content auto;
grid-gap: var(--spacing);
margin-bottom: var(--spacing);
align-items: center;
}
.message {
color: var(--message-text) !important;
background-color: var(--message-background) !important;
border: 1px solid var(--message-border) !important;
}
.section {
background-color: var(--section-background);
padding: var(--spacing);
border: 1px solid var(--border);
border-radius: 5px;
}
.section:not(:first-child) {
margin-top: var(--spacing);
}
.section:not(:last-child) {
margin-bottom: var(--spacing);
}
/* config */
#content.page-config input[type="checkbox"] {
justify-self: left;
}
@media (max-width: 1026px) {
body {
margin: 0px;
}
#menu {
width: 100%;
}
#menu > a {
text-align: center;
}
#container {
width: unset;
margin: unset;
}
.container {
grid-template-columns: auto;
}
.content {
grid-template-columns: auto;
}
.section {
border-left-width: 0px;
border-right-width: 0px;
border-radius: 0px;
}
}

View file

@ -1,16 +1,13 @@
from __future__ import annotations from __future__ import annotations
import Crypto import Crypto
import aputils
import asyncio import asyncio
import click import click
import os import os
import platform import platform
import subprocess
import sys
import typing import typing
from aputils.signer import Signer
from gunicorn.app.wsgiapp import WSGIApplication
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from urllib.parse import urlparse from urllib.parse import urlparse
@ -21,7 +18,7 @@ from . import logger as logging
from .application import Application from .application import Application
from .compat import RelayConfig, RelayDatabase from .compat import RelayConfig, RelayDatabase
from .database import RELAY_SOFTWARE, get_database from .database import RELAY_SOFTWARE, get_database
from .misc import IS_DOCKER, Message from .misc import ACTOR_FORMATS, SOFTWARE, IS_DOCKER, Message
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from tinysql import Row from tinysql import Row
@ -36,23 +33,6 @@ CONFIG_IGNORE = (
'private-key' 'private-key'
) )
ACTOR_FORMATS = {
'mastodon': 'https://{domain}/actor',
'akkoma': 'https://{domain}/relay',
'pleroma': 'https://{domain}/relay'
}
SOFTWARE = (
'mastodon',
'akkoma',
'pleroma',
'misskey',
'friendica',
'hubzilla',
'firefish',
'gotosocial'
)
def check_alphanumeric(text: str) -> str: def check_alphanumeric(text: str) -> str:
if not text.isalnum(): if not text.isalnum():
@ -62,10 +42,10 @@ def check_alphanumeric(text: str) -> str:
@click.group('cli', context_settings={'show_default': True}, invoke_without_command=True) @click.group('cli', context_settings={'show_default': True}, invoke_without_command=True)
@click.option('--config', '-c', default='relay.yaml', help='path to the relay\'s config') @click.option('--config', '-c', help='path to the relay\'s config')
@click.version_option(version=__version__, prog_name='ActivityRelay') @click.version_option(version=__version__, prog_name='ActivityRelay')
@click.pass_context @click.pass_context
def cli(ctx: click.Context, config: str) -> None: def cli(ctx: click.Context, config: str | None) -> None:
ctx.obj = Application(config) ctx.obj = Application(config)
if not ctx.invoked_subcommand: if not ctx.invoked_subcommand:
@ -196,7 +176,7 @@ def cli_setup(ctx: click.Context) -> None:
ctx.obj.config.save() ctx.obj.config.save()
config = { config = {
'private-key': Signer.new('n/a').export() 'private-key': aputils.Signer.new('n/a').export()
} }
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
@ -208,7 +188,7 @@ def cli_setup(ctx: click.Context) -> None:
@cli.command('run') @cli.command('run')
@click.option('--dev', '-d', is_flag = True, help = 'Enable worker reloading on code change') @click.option('--dev', '-d', is_flag=True, help='Enable developer mode')
@click.pass_context @click.pass_context
def cli_run(ctx: click.Context, dev: bool = False) -> None: def cli_run(ctx: click.Context, dev: bool = False) -> None:
'Run the relay' 'Run the relay'
@ -237,23 +217,13 @@ def cli_run(ctx: click.Context, dev: bool = False) -> None:
click.echo(pip_command) click.echo(pip_command)
return return
if getattr(sys, 'frozen', False): ctx.obj['dev'] = dev
subprocess.run([sys.executable, 'run-gunicorn'], check = False) ctx.obj.run()
else:
ctx.obj.run(dev)
# todo: figure out why the relay doesn't quit properly without this # todo: figure out why the relay doesn't quit properly without this
os._exit(0) os._exit(0)
@cli.command('run-gunicorn')
@click.pass_context
def cli_run_gunicorn(ctx: click.Context) -> None:
runner = GunicornRunner(ctx.obj)
runner.run()
@cli.command('convert') @cli.command('convert')
@click.option('--old-config', '-o', help = 'Path to the config file to convert from') @click.option('--old-config', '-o', help = 'Path to the config file to convert from')
@ -921,30 +891,6 @@ def cli_whitelist_import(ctx: click.Context) -> None:
click.echo('Imported whitelist from inboxes') click.echo('Imported whitelist from inboxes')
class GunicornRunner(WSGIApplication):
def __init__(self, app: Application):
self.app = app
self.app_uri = 'relay.application:main_gunicorn'
self.options = {
'bind': f'{app.config.listen}:{app.config.port}',
'worker_class': 'aiohttp.GunicornWebWorker',
'workers': app.config.workers,
'raw_env': f'CONFIG_FILE={app.config.path}'
}
WSGIApplication.__init__(self)
def load_config(self):
for key, value in self.options.items():
self.cfg.set(key, value)
def run(self):
logging.info('Starting webserver for %s', self.app.config.domain)
WSGIApplication.run(self)
def main() -> None: def main() -> None:
# pylint: disable=no-value-for-parameter # pylint: disable=no-value-for-parameter

View file

@ -1,16 +1,23 @@
from __future__ import annotations from __future__ import annotations
import aputils
import json import json
import os import os
import socket import socket
import typing import typing
from aiohttp.web import Response as AiohttpResponse from aiohttp.web import Response as AiohttpResponse
from aputils.message import Message as ApMessage
from datetime import datetime from datetime import datetime
from uuid import uuid4 from uuid import uuid4
try:
from importlib.resources import files as pkgfiles
except ImportError:
from importlib_resources import files as pkgfiles
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from pathlib import Path
from typing import Any from typing import Any
from .application import Application from .application import Application
@ -18,6 +25,7 @@ if typing.TYPE_CHECKING:
IS_DOCKER = bool(os.environ.get('DOCKER_RUNNING')) IS_DOCKER = bool(os.environ.get('DOCKER_RUNNING'))
MIMETYPES = { MIMETYPES = {
'activity': 'application/activity+json', 'activity': 'application/activity+json',
'css': 'text/css',
'html': 'text/html', 'html': 'text/html',
'json': 'application/json', 'json': 'application/json',
'text': 'text/plain' 'text': 'text/plain'
@ -28,6 +36,23 @@ NODEINFO_NS = {
'21': 'http://nodeinfo.diaspora.software/ns/schema/2.1' '21': 'http://nodeinfo.diaspora.software/ns/schema/2.1'
} }
ACTOR_FORMATS = {
'mastodon': 'https://{domain}/actor',
'akkoma': 'https://{domain}/relay',
'pleroma': 'https://{domain}/relay'
}
SOFTWARE = (
'mastodon',
'akkoma',
'pleroma',
'misskey',
'friendica',
'hubzilla',
'firefish',
'gotosocial'
)
def boolean(value: Any) -> bool: def boolean(value: Any) -> bool:
if isinstance(value, str): if isinstance(value, str):
@ -75,15 +100,19 @@ def get_app() -> Application:
return Application.DEFAULT return Application.DEFAULT
def get_resource(path: str) -> Path:
return pkgfiles('relay').joinpath(path)
class JsonEncoder(json.JSONEncoder): class JsonEncoder(json.JSONEncoder):
def default(self, obj: Any) -> str: def default(self, o: Any) -> str:
if isinstance(obj, datetime): if isinstance(o, datetime):
return obj.isoformat() return o.isoformat()
return JSONEncoder.default(self, obj) return json.JSONEncoder.default(self, o)
class Message(ApMessage): class Message(aputils.Message):
@classmethod @classmethod
def new_actor(cls: type[Message], # pylint: disable=arguments-differ def new_actor(cls: type[Message], # pylint: disable=arguments-differ
host: str, host: str,
@ -170,16 +199,6 @@ class Message(ApMessage):
}) })
# todo: remove when fixed in aputils
@property
def object_id(self) -> str:
try:
return self["object"]["id"]
except (KeyError, TypeError):
return self["object"]
class Response(AiohttpResponse): class Response(AiohttpResponse):
# AiohttpResponse.__len__ method returns 0, so bool(response) always returns False # AiohttpResponse.__len__ method returns 0, so bool(response) always returns False
def __bool__(self) -> bool: def __bool__(self) -> bool:
@ -223,6 +242,12 @@ class Response(AiohttpResponse):
return cls.new(body=body, status=status, ctype=ctype) return cls.new(body=body, status=status, ctype=ctype)
@classmethod
def new_redir(cls: type[Response], path: str) -> Response:
body = f'Redirect to <a href="{path}">{path}</a>'
return cls.new(body, 302, {'Location': path})
@property @property
def location(self) -> str: def location(self) -> str:
return self.headers.get('Location') return self.headers.get('Location')

51
relay/template.py Normal file
View file

@ -0,0 +1,51 @@
from __future__ import annotations
import typing
from hamlish_jinja.extension import HamlishExtension
from jinja2 import Environment, FileSystemLoader
from . import __version__
from .database.config import THEMES
from .misc import get_resource
if typing.TYPE_CHECKING:
from typing import Any
from .application import Application
from .views.base import View
class Template(Environment):
def __init__(self, app: Application):
Environment.__init__(self,
autoescape = True,
trim_blocks = True,
lstrip_blocks = True,
extensions = [
HamlishExtension
],
loader = FileSystemLoader([
get_resource('frontend'),
app.config.path.parent.joinpath('template')
])
)
self.app = app
self.hamlish_enable_div_shortcut = True
self.hamlish_mode = 'indented'
def render(self, path: str, view: View | None = None, **context: Any) -> str:
with self.app.database.session(False) as s:
config = s.get_config_all()
new_context = {
'view': view,
'domain': self.app.config.domain,
'version': __version__,
'config': config,
'theme_name': config['theme'] or 'Default',
**(context or {})
}
return self.get_template(path).render(new_context)

View file

@ -1,12 +1,9 @@
from __future__ import annotations from __future__ import annotations
import aputils
import traceback import traceback
import typing import typing
from aputils.errors import SignatureFailureError
from aputils.misc import Digest, HttpDate, Signature
from aputils.objects import Webfinger
from .base import View, register_route from .base import View, register_route
from .. import logger as logging from .. import logger as logging
@ -15,7 +12,6 @@ from ..processors import run_processor
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from aiohttp.web import Request from aiohttp.web import Request
from aputils.signer import Signer
from tinysql import Row from tinysql import Row
@ -26,11 +22,11 @@ class ActorView(View):
def __init__(self, request: Request): def __init__(self, request: Request):
View.__init__(self, request) View.__init__(self, request)
self.signature: Signature = None self.signature: aputils.Signature = None
self.message: Message = None self.message: Message = None
self.actor: Message = None self.actor: Message = None
self.instance: Row = None self.instance: Row = None
self.signer: Signer = None self.signer: aputils.Signer = None
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
@ -77,7 +73,7 @@ class ActorView(View):
async def get_post_data(self) -> Response | None: async def get_post_data(self) -> Response | None:
try: try:
self.signature = Signature.new_from_signature(self.request.headers['signature']) self.signature = aputils.Signature.new_from_signature(self.request.headers['signature'])
except KeyError: except KeyError:
logging.verbose('Missing signature header') logging.verbose('Missing signature header')
@ -124,7 +120,7 @@ class ActorView(View):
try: try:
self.validate_signature(await self.request.read()) self.validate_signature(await self.request.read())
except SignatureFailureError as e: except aputils.SignatureFailureError as e:
logging.verbose('signature validation failed for "%s": %s', self.actor.id, e) logging.verbose('signature validation failed for "%s": %s', self.actor.id, e)
return Response.new_error(401, str(e), 'json') return Response.new_error(401, str(e), 'json')
@ -133,31 +129,31 @@ class ActorView(View):
headers = {key.lower(): value for key, value in self.request.headers.items()} headers = {key.lower(): value for key, value in self.request.headers.items()}
headers["(request-target)"] = " ".join([self.request.method.lower(), self.request.path]) headers["(request-target)"] = " ".join([self.request.method.lower(), self.request.path])
if (digest := Digest.new_from_digest(headers.get("digest"))): if (digest := aputils.Digest.new_from_digest(headers.get("digest"))):
if not body: if not body:
raise SignatureFailureError("Missing body for digest verification") raise aputils.SignatureFailureError("Missing body for digest verification")
if not digest.validate(body): if not digest.validate(body):
raise SignatureFailureError("Body digest does not match") raise aputils.SignatureFailureError("Body digest does not match")
if self.signature.algorithm_type == "hs2019": if self.signature.algorithm_type == "hs2019":
if "(created)" not in self.signature.headers: if "(created)" not in self.signature.headers:
raise SignatureFailureError("'(created)' header not used") raise aputils.SignatureFailureError("'(created)' header not used")
current_timestamp = HttpDate.new_utc().timestamp() current_timestamp = aputils.HttpDate.new_utc().timestamp()
if self.signature.created > current_timestamp: if self.signature.created > current_timestamp:
raise SignatureFailureError("Creation date after current date") raise aputils.SignatureFailureError("Creation date after current date")
if current_timestamp > self.signature.expires: if current_timestamp > self.signature.expires:
raise SignatureFailureError("Expiration date before current date") raise aputils.SignatureFailureError("Expiration date before current date")
headers["(created)"] = self.signature.created headers["(created)"] = self.signature.created
headers["(expires)"] = self.signature.expires headers["(expires)"] = self.signature.expires
# pylint: disable=protected-access # pylint: disable=protected-access
if not self.signer._validate_signature(headers, self.signature): if not self.signer._validate_signature(headers, self.signature):
raise SignatureFailureError("Signature does not match") raise aputils.SignatureFailureError("Signature does not match")
@register_route('/.well-known/webfinger') @register_route('/.well-known/webfinger')
@ -172,7 +168,7 @@ class WebfingerView(View):
if subject != f'acct:relay@{self.config.domain}': if subject != f'acct:relay@{self.config.domain}':
return Response.new_error(404, 'user not found', 'json') return Response.new_error(404, 'user not found', 'json')
data = Webfinger.new( data = aputils.Webfinger.new(
handle = 'relay', handle = 'relay',
domain = self.config.domain, domain = self.config.domain,
actor = self.config.actor actor = self.config.actor

View file

@ -4,7 +4,6 @@ import typing
from aiohttp import web from aiohttp import web
from argon2.exceptions import VerifyMismatchError from argon2.exceptions import VerifyMismatchError
from datetime import datetime, timezone
from urllib.parse import urlparse from urllib.parse import urlparse
from .base import View, register_route from .base import View, register_route
@ -76,7 +75,7 @@ class Login(View):
if isinstance(data, Response): if isinstance(data, Response):
return data return data
with self.database.connction(True) as conn: with self.database.session(True) as conn:
if not (user := conn.get_user(data['username'])): if not (user := conn.get_user(data['username'])):
return Response.new_error(401, 'User not found', 'json') return Response.new_error(401, 'User not found', 'json')

View file

@ -11,12 +11,15 @@ from json.decoder import JSONDecodeError
from ..misc import Response from ..misc import Response
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from aiohttp.web import Request
from collections.abc import Callable, Coroutine, Generator from collections.abc import Callable, Coroutine, Generator
from tinysql import Database from bsql import Database
from typing import Any, Self
from ..application import Application from ..application import Application
from ..cache import Cache from ..cache import Cache
from ..config import Config from ..config import Config
from ..http_client import HttpClient from ..http_client import HttpClient
from ..template import Template
VIEWS = [] VIEWS = []
@ -27,7 +30,7 @@ def register_route(*paths: str) -> Callable:
for path in paths: for path in paths:
VIEWS.append([path, view]) VIEWS.append([path, view])
return View return view
return wrapper return wrapper
@ -42,8 +45,14 @@ class View(AbstractView):
return self._run_handler(handler).__await__() return self._run_handler(handler).__await__()
async def _run_handler(self, handler: Coroutine) -> Response: @classmethod
return await handler(self.request, **self.request.match_info) async def run(cls: type[Self], method: str, request: Request, **kwargs: Any) -> Self:
view = cls(request)
return await view.handlers[method](request, **kwargs)
async def _run_handler(self, handler: Coroutine, **kwargs: Any) -> Response:
return await handler(self.request, **self.request.match_info, **kwargs)
@cached_property @cached_property
@ -91,6 +100,11 @@ class View(AbstractView):
return self.app.database return self.app.database
@property
def template(self) -> Template:
return self.app['template']
async def get_api_data(self, async def get_api_data(self,
required: list[str], required: list[str],
optional: list[str]) -> dict[str, str] | Response: optional: list[str]) -> dict[str, str] | Response:

View file

@ -2,41 +2,51 @@ from __future__ import annotations
import typing import typing
from aiohttp import web
from argon2.exceptions import VerifyMismatchError
from urllib.parse import urlparse
from .base import View, register_route from .base import View, register_route
from ..misc import Response from ..database import CONFIG_DEFAULTS, THEMES
from ..logger import LogLevel
from ..misc import ACTOR_FORMATS, Message, Response
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from aiohttp.web import Request from aiohttp.web import Request
from collections.abc import Coroutine
HOME_TEMPLATE = """ # pylint: disable=no-self-use
<html><head>
<title>ActivityPub Relay at {host}</title> UNAUTH_ROUTES = {
<style> '/',
p {{ color: #FFFFFF; font-family: monospace, arial; font-size: 100%; }} '/login'
body {{ background-color: #000000; }} }
a {{ color: #26F; }}
a:visited {{ color: #46C; }} CONFIG_IGNORE = (
a:hover {{ color: #8AF; }} 'schema-version',
</style> 'private-key'
</head> )
<body>
<p>This is an Activity Relay for fediverse instances.</p>
<p>{note}</p> @web.middleware
<p> async def handle_frontend_path(request: web.Request, handler: Coroutine) -> Response:
You may subscribe to this relay with the address: if request.path in UNAUTH_ROUTES or request.path.startswith('/admin'):
<a href="https://{host}/actor">https://{host}/actor</a> request['token'] = request.cookies.get('user-token')
</p> request['user'] = None
<p>
To host your own relay, you may download the code at this address: if request['token']:
<a href="https://git.pleroma.social/pleroma/relay"> with request.app.database.session(False) as conn:
https://git.pleroma.social/pleroma/relay request['user'] = conn.get_user_by_token(request['token'])
</a>
</p> if request['user'] and request.path == '/login':
<br><p>List of {count} registered instances:<br>{targets}</p> return Response.new('', 302, {'Location': '/'})
</body></html>
""" if not request['user'] and request.path.startswith('/admin'):
return Response.new('', 302, {'Location': f'/login?redir={request.path}'})
return await handler(request)
# pylint: disable=unused-argument # pylint: disable=unused-argument
@ -45,14 +55,408 @@ HOME_TEMPLATE = """
class HomeView(View): class HomeView(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
config = conn.get_config_all() context = {
inboxes = tuple(conn.execute('SELECT * FROM inboxes').all()) 'instances': tuple(conn.execute('SELECT * FROM inboxes').all())
}
text = HOME_TEMPLATE.format( data = self.template.render('page/home.haml', self, **context)
host = self.config.domain, return Response.new(data, ctype='html')
note = config['note'],
count = len(inboxes),
targets = '<br>'.join(inbox['domain'] for inbox in inboxes)
)
return Response.new(text, ctype='html')
@register_route('/login')
class Login(View):
async def get(self, request: Request) -> Response:
data = self.template.render('page/login.haml', self)
return Response.new(data, ctype = 'html')
async def post(self, request: Request) -> Response:
form = await request.post()
params = {}
with self.database.session(True) as conn:
if not (user := conn.get_user(form['username'])):
params = {
'username': form['username'],
'error': 'User not found'
}
else:
try:
conn.hasher.verify(user['hash'], form['password'])
except VerifyMismatchError:
params = {
'username': form['username'],
'error': 'Invalid password'
}
if params:
data = self.template.render('page/login.haml', self, **params)
return Response.new(data, ctype = 'html')
token = conn.put_token(user['username'])
resp = Response.new_redir(request.query.getone('redir', '/'))
resp.set_cookie(
'user-token',
token['code'],
max_age = 60 * 60 * 24 * 365,
domain = self.config.domain,
path = '/',
secure = True,
httponly = True,
samesite = 'Strict'
)
return resp
@register_route('/logout')
class Logout(View):
async def get(self, request: Request) -> Response:
with self.database.session(True) as conn:
conn.del_token(request['token'])
resp = Response.new_redir('/')
resp.del_cookie('user-token', domain = self.config.domain, path = '/')
return resp
@register_route('/admin')
class Admin(View):
async def get(self, request: Request) -> Response:
return Response.new('', 302, {'Location': '/admin/instances'})
@register_route('/admin/instances')
class AdminInstances(View):
async def get(self,
request: Request,
error: str | None = None,
message: str | None = None) -> Response:
with self.database.session() as conn:
context = {
'instances': tuple(conn.execute('SELECT * FROM inboxes').all())
}
if error:
context['error'] = error
if message:
context['message'] = message
data = self.template.render('page/admin-instances.haml', self, **context)
return Response.new(data, ctype = 'html')
async def post(self, request: Request) -> Response:
data = await request.post()
if not data.get('actor') and not data.get('domain'):
return await self.get(request, error = 'Missing actor and/or domain')
if not data.get('domain'):
data['domain'] = urlparse(data['actor']).netloc
if not data.get('software'):
nodeinfo = await self.client.fetch_nodeinfo(data['domain'])
data['software'] = nodeinfo.sw_name
if not data.get('actor') and data['software'] in ACTOR_FORMATS:
data['actor'] = ACTOR_FORMATS[data['software']].format(domain = data['domain'])
if not data.get('inbox') and data['actor']:
actor = await self.client.get(data['actor'], sign_headers = True, loads = Message.parse)
data['inbox'] = actor.shared_inbox
with self.database.session(True) as conn:
conn.put_inbox(**data)
return await self.get(request, message = "Added new inbox")
@register_route('/admin/instances/delete/{domain}')
class AdminInstancesDelete(View):
async def get(self, request: Request, domain: str) -> Response:
with self.database.session() as conn:
if not conn.get_inbox(domain):
return await AdminInstances(request).get(request, message = 'Instance not found')
conn.del_inbox(domain)
return await AdminInstances(request).get(request, message = 'Removed instance')
@register_route('/admin/whitelist')
class AdminWhitelist(View):
async def get(self,
request: Request,
error: str | None = None,
message: str | None = None) -> Response:
with self.database.session() as conn:
context = {
'whitelist': tuple(conn.execute('SELECT * FROM whitelist').all())
}
if error:
context['error'] = error
if message:
context['message'] = message
data = self.template.render('page/admin-whitelist.haml', self, **context)
return Response.new(data, ctype = 'html')
async def post(self, request: Request) -> Response:
data = await request.post()
if not data['domain']:
return await self.get(request, error = 'Missing domain')
with self.database.session(True) as conn:
if conn.get_domain_whitelist(data['domain']):
return await self.get(request, message = "Domain already in whitelist")
conn.put_domain_whitelist(data['domain'])
return await self.get(request, message = "Added/updated domain ban")
@register_route('/admin/whitelist/delete/{domain}')
class AdminWhitlistDelete(View):
async def get(self, request: Request, domain: str) -> Response:
with self.database.session() as conn:
if not conn.get_domain_whitelist(domain):
msg = 'Whitelisted domain not found'
return await AdminWhitelist.run("GET", request, message = msg)
conn.del_domain_whitelist(domain)
return await AdminWhitelist.run("GET", request, message = 'Removed domain from whitelist')
@register_route('/admin/domain_bans')
class AdminDomainBans(View):
async def get(self,
request: Request,
error: str | None = None,
message: str | None = None) -> Response:
with self.database.session() as conn:
context = {
'bans': tuple(conn.execute('SELECT * FROM domain_bans ORDER BY domain ASC').all())
}
if error:
context['error'] = error
if message:
context['message'] = message
data = self.template.render('page/admin-domain_bans.haml', self, **context)
return Response.new(data, ctype = 'html')
async def post(self, request: Request) -> Response:
data = await request.post()
if not data['domain']:
return await self.get(request, error = 'Missing domain')
with self.database.session(True) as conn:
if conn.get_domain_ban(data['domain']):
conn.update_domain_ban(
data['domain'],
data.get('reason'),
data.get('note')
)
else:
conn.put_domain_ban(
data['domain'],
data.get('reason'),
data.get('note')
)
return await self.get(request, message = "Added/updated domain ban")
@register_route('/admin/domain_bans/delete/{domain}')
class AdminDomainBansDelete(View):
async def get(self, request: Request, domain: str) -> Response:
with self.database.session() as conn:
if not conn.get_domain_ban(domain):
return await AdminDomainBans.run("GET", request, message = 'Domain ban not found')
conn.del_domain_ban(domain)
return await AdminDomainBans.run("GET", request, message = 'Unbanned domain')
@register_route('/admin/software_bans')
class AdminSoftwareBans(View):
async def get(self,
request: Request,
error: str | None = None,
message: str | None = None) -> Response:
with self.database.session() as conn:
context = {
'bans': tuple(conn.execute('SELECT * FROM software_bans ORDER BY name ASC').all())
}
if error:
context['error'] = error
if message:
context['message'] = message
data = self.template.render('page/admin-software_bans.haml', self, **context)
return Response.new(data, ctype = 'html')
async def post(self, request: Request) -> Response:
data = await request.post()
if not data['name']:
return await self.get(request, error = 'Missing name')
with self.database.session(True) as conn:
if conn.get_software_ban(data['name']):
conn.update_software_ban(
data['name'],
data.get('reason'),
data.get('note')
)
else:
conn.put_software_ban(
data['name'],
data.get('reason'),
data.get('note')
)
return await self.get(request, message = "Added/updated software ban")
@register_route('/admin/software_bans/delete/{name}')
class AdminSoftwareBansDelete(View):
async def get(self, request: Request, name: str) -> Response:
with self.database.session() as conn:
if not conn.get_software_ban(name):
return await AdminSoftwareBans.run("GET", request, message = 'Software ban not found')
conn.del_software_ban(name)
return await AdminSoftwareBans.run("GET", request, message = 'Unbanned software')
@register_route('/admin/users')
class AdminUsers(View):
async def get(self,
request: Request,
error: str | None = None,
message: str | None = None) -> Response:
with self.database.session() as conn:
context = {
'users': tuple(conn.execute('SELECT * FROM users').all())
}
if error:
context['error'] = error
if message:
context['message'] = message
data = self.template.render('page/admin-users.haml', self, **context)
return Response.new(data, ctype = 'html')
async def post(self, request: Request) -> Response:
data = await request.post()
required_fields = {'username', 'password', 'password2'}
if not all(data.get(field) for field in required_fields):
return await self.get(request, error = 'Missing username and/or password')
if data['password'] != data['password2']:
return await self.get(request, error = 'Passwords do not match')
with self.database.session(True) as conn:
if conn.get_user(data['username']):
return await self.get(request, message = "User already exists")
conn.put_user(data['username'], data['password'], data['handle'])
return await self.get(request, message = "Added user")
@register_route('/admin/users/delete/{name}')
class AdminUsersDelete(View):
async def get(self, request: Request, name: str) -> Response:
with self.database.session() as conn:
if not conn.get_user(name):
return await AdminUsers.run("GET", request, message = 'User not found')
conn.del_user(name)
return await AdminUsers.run("GET", request, message = 'User deleted')
@register_route('/admin/config')
class AdminConfig(View):
async def get(self, request: Request, message: str | None = None) -> Response:
context = {
'themes': tuple(THEMES.keys()),
'LogLevel': LogLevel,
'message': message
}
data = self.template.render('page/admin-config.haml', self, **context)
return Response.new(data, ctype = 'html')
async def post(self, request: Request) -> Response:
form = dict(await request.post())
with self.database.session(True) as conn:
for key in CONFIG_DEFAULTS:
value = form.get(key)
if key == 'whitelist-enabled':
value = bool(value)
elif key.lower() in CONFIG_IGNORE:
continue
if value is None:
continue
conn.put_config(key, value)
return await self.get(request, message = 'Updated config')
@register_route('/style.css')
class StyleCss(View):
async def get(self, request: Request) -> Response:
data = self.template.render('style.css', self)
return Response.new(data, ctype = 'css')
@register_route('/theme/{theme}.css')
class ThemeCss(View):
async def get(self, request: Request, theme: str) -> Response:
try:
context = {
'theme': THEMES[theme]
}
except KeyError:
return Response.new('Invalid theme', 404)
data = self.template.render('variables.css', self, **context)
return Response.new(data, ctype = 'css')

View file

@ -1,9 +1,9 @@
from __future__ import annotations from __future__ import annotations
import aputils
import subprocess import subprocess
import typing import typing
from aputils.objects import Nodeinfo, WellKnownNodeinfo
from pathlib import Path from pathlib import Path
from .base import View, register_route from .base import View, register_route
@ -48,12 +48,12 @@ class NodeinfoView(View):
if niversion == '2.1': if niversion == '2.1':
data['repo'] = 'https://git.pleroma.social/pleroma/relay' data['repo'] = 'https://git.pleroma.social/pleroma/relay'
return Response.new(Nodeinfo.new(**data), ctype = 'json') return Response.new(aputils.Nodeinfo.new(**data), ctype = 'json')
@register_route('/.well-known/nodeinfo') @register_route('/.well-known/nodeinfo')
class WellknownNodeinfoView(View): class WellknownNodeinfoView(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
data = WellKnownNodeinfo.new_template(self.config.domain) data = aputils.WellKnownNodeinfo.new_template(self.config.domain)
return Response.new(data, ctype = 'json') return Response.new(data, ctype = 'json')

View file

@ -1,12 +1,13 @@
aiohttp>=3.9.1 aiohttp>=3.9.1
aiohttp-swagger[performance]==1.0.16 aiohttp-swagger[performance]==1.0.16
aputils@https://git.barkshark.xyz/barkshark/aputils/archive/0.1.6a.tar.gz aputils@https://git.barkshark.xyz/barkshark/aputils/archive/0.1.7.tar.gz
argon2-cffi==23.1.0 argon2-cffi==23.1.0
barkshark-sql@https://git.barkshark.xyz/barkshark/bsql/archive/0.1.2.tar.gz
click>=8.1.2 click>=8.1.2
gunicorn==21.1.0 hamlish-jinja@https://git.barkshark.xyz/barkshark/hamlish-jinja/archive/0.3.5.tar.gz
hiredis==2.3.2 hiredis==2.3.2
platformdirs==4.2.0
pyyaml>=6.0 pyyaml>=6.0
redis==5.0.1 redis==5.0.1
barkshark-sql@https://git.barkshark.xyz/barkshark/bsql/archive/0.1.1.tar.gz
importlib_resources==6.1.1;python_version<'3.9' importlib_resources==6.1.1;python_version<'3.9'

View file

@ -34,8 +34,9 @@ dev = file: dev-requirements.txt
[options.package_data] [options.package_data]
relay = relay =
data/statements.sql data/*
data/swagger.yaml frontend/*
frontend/page/*
[options.entry_points] [options.entry_points]
console_scripts = console_scripts =