diff --git a/relay/application.py b/relay/application.py index 80efed9..20695d7 100644 --- a/relay/application.py +++ b/relay/application.py @@ -20,6 +20,7 @@ from .database import get_database from .http_client import HttpClient from .misc import check_open_port from .views import VIEWS +from .views.api import handle_api_path if typing.TYPE_CHECKING: from collections.abc import Awaitable @@ -35,7 +36,11 @@ class Application(web.Application): DEFAULT: Application = None def __init__(self, cfgpath: str, gunicorn: bool = False): - web.Application.__init__(self) + web.Application.__init__(self, + middlewares = [ + handle_api_path + ] + ) Application.DEFAULT = self @@ -219,6 +224,6 @@ async def main_gunicorn(): except KeyError: logging.error('Failed to set "CONFIG_FILE" environment. Trying to run without gunicorn?') - raise + raise RuntimeError from None return app diff --git a/relay/data/statements.sql b/relay/data/statements.sql index 9bcea41..2ddef35 100644 --- a/relay/data/statements.sql +++ b/relay/data/statements.sql @@ -1,82 +1,127 @@ -- name: get-config -SELECT * FROM config WHERE key = :key +SELECT * FROM config WHERE key = :key; -- name: get-config-all -SELECT * FROM config +SELECT * FROM config; -- name: put-config INSERT INTO config (key, value, type) VALUES (:key, :value, :type) ON CONFLICT (key) DO UPDATE SET value = :value -RETURNING * +RETURNING *; -- name: del-config DELETE FROM config -WHERE key = :key +WHERE key = :key; -- name: get-inbox -SELECT * FROM inboxes WHERE domain = :value or inbox = :value or actor = :value +SELECT * FROM inboxes WHERE domain = :value or inbox = :value or actor = :value; -- name: put-inbox INSERT INTO inboxes (domain, actor, inbox, followid, software, created) VALUES (:domain, :actor, :inbox, :followid, :software, :created) ON CONFLICT (domain) DO UPDATE SET followid = :followid -RETURNING * +RETURNING *; -- name: del-inbox DELETE FROM inboxes -WHERE domain = :value or inbox = :value or actor = :value +WHERE domain = :value or inbox = :value or actor = :value; + + +-- name: get-user +SELECT * FROM users +WHERE username = :value or handle = :value; + + +-- name: get-user-by-token +SELECT * FROM users +WHERE username = ( + SELECT user FROM tokens + WHERE code = :code +); + + +-- name: put-user +INSERT INTO users (username, hash, handle, created) +VALUES (:username, :hash, :handle, :created) +RETURNING *; + + +-- name: del-user +DELETE FROM users +WHERE username = :value or handle = :value; + + +-- name: get-token +SELECT * FROM tokens +WHERE code = :code; + + +-- name: put-token +INSERT INTO tokens (code, user, created) +VALUES (:code, :user, :created) +RETURNING *; + + +-- name: del-token +DELETE FROM tokens +WHERE code = :code; + + +-- name: del-token-user +DELETE FROM tokens +WHERE user = :username; -- name: get-software-ban -SELECT * FROM software_bans WHERE name = :name +SELECT * FROM software_bans WHERE name = :name; -- name: put-software-ban INSERT INTO software_bans (name, reason, note, created) VALUES (:name, :reason, :note, :created) -RETURNING * +RETURNING *; -- name: del-software-ban DELETE FROM software_bans -WHERE name = :name +WHERE name = :name; -- name: get-domain-ban -SELECT * FROM domain_bans WHERE domain = :domain +SELECT * FROM domain_bans WHERE domain = :domain; -- name: put-domain-ban INSERT INTO domain_bans (domain, reason, note, created) VALUES (:domain, :reason, :note, :created) -RETURNING * +RETURNING *; -- name: del-domain-ban DELETE FROM domain_bans -WHERE domain = :domain +WHERE domain = :domain; -- name: get-domain-whitelist -SELECT * FROM whitelist WHERE domain = :domain +SELECT * FROM whitelist WHERE domain = :domain; -- name: put-domain-whitelist INSERT INTO whitelist (domain, created) VALUES (:domain, :created) -RETURNING * +RETURNING *; -- name: del-domain-whitelist DELETE FROM whitelist -WHERE domain = :domain +WHERE domain = :domain; -- cache functions -- @@ -90,7 +135,7 @@ CREATE TABLE IF NOT EXISTS cache ( type TEXT DEFAULT 'str', updated TIMESTAMP NOT NULL, UNIQUE(namespace, key) -) +); -- name: create-cache-table-postgres CREATE TABLE IF NOT EXISTS cache ( @@ -101,21 +146,21 @@ CREATE TABLE IF NOT EXISTS cache ( type TEXT DEFAULT 'str', updated TIMESTAMP NOT NULL, UNIQUE(namespace, key) -) +); -- name: get-cache-item SELECT * FROM cache -WHERE namespace = :namespace and key = :key +WHERE namespace = :namespace and key = :key; -- name: get-cache-keys SELECT key FROM cache -WHERE namespace = :namespace +WHERE namespace = :namespace; -- name: get-cache-namespaces -SELECT DISTINCT namespace FROM cache +SELECT DISTINCT namespace FROM cache; -- name: set-cache-item @@ -123,18 +168,18 @@ INSERT INTO cache (namespace, key, value, type, updated) VALUES (:namespace, :key, :value, :type, :date) ON CONFLICT (namespace, key) DO UPDATE SET value = :value, type = :type, updated = :date -RETURNING * +RETURNING *; -- name: del-cache-item DELETE FROM cache -WHERE namespace = :namespace and key = :key +WHERE namespace = :namespace and key = :key; -- name: del-cache-namespace DELETE FROM cache -WHERE namespace = :namespace +WHERE namespace = :namespace; -- name: del-cache-all -DELETE FROM cache +DELETE FROM cache; diff --git a/relay/database/__init__.py b/relay/database/__init__.py index c403092..facea97 100644 --- a/relay/database/__init__.py +++ b/relay/database/__init__.py @@ -52,14 +52,10 @@ def get_database(config: Config, migrate: bool = True) -> tinysql.Database: if (schema_ver := conn.get_config('schema-version')) < get_default_value('schema-version'): logging.info("Migrating database from version '%i'", schema_ver) - for ver, func in VERSIONS: + for ver, func in VERSIONS.items(): if schema_ver < ver: - conn.begin() - func(conn) - conn.put_config('schema-version', ver) - conn.commit() if (privkey := conn.get_config('private-key')): conn.app.signer = privkey diff --git a/relay/database/config.py b/relay/database/config.py index 8fe3b4c..b69f13e 100644 --- a/relay/database/config.py +++ b/relay/database/config.py @@ -11,8 +11,9 @@ if typing.TYPE_CHECKING: CONFIG_DEFAULTS: dict[str, tuple[str, Any]] = { - 'schema-version': ('int', 20240119), + 'schema-version': ('int', 20240206), 'log-level': ('loglevel', logging.LogLevel.INFO), + 'name': ('str', 'ActivityRelay'), 'note': ('str', 'Make a note about your instance here.'), 'private-key': ('str', None), 'whitelist-enabled': ('bool', False) diff --git a/relay/database/connection.py b/relay/database/connection.py index ac5a364..718861d 100644 --- a/relay/database/connection.py +++ b/relay/database/connection.py @@ -3,8 +3,10 @@ from __future__ import annotations import tinysql import typing +from argon2 import PasswordHasher from datetime import datetime, timezone from urllib.parse import urlparse +from uuid import uuid4 from .config import CONFIG_DEFAULTS, get_default_type, get_default_value, serialize, deserialize @@ -28,6 +30,10 @@ RELAY_SOFTWARE = [ class Connection(tinysql.Connection): + hasher = PasswordHasher( + encoding = 'utf-8' + ) + @property def app(self) -> Application: return get_app() @@ -162,6 +168,59 @@ class Connection(tinysql.Connection): return cur.modified_row_count == 1 + def get_user(self, value: str) -> Row: + with self.exec_statement('get-user', {'value': value}) as cur: + return cur.one() + + + def get_user_by_token(self, code: str) -> Row: + with self.exec_statement('get-user-by-token', {'code': code}) as cur: + return cur.one() + + + def put_user(self, username: str, password: str, handle: str | None = None) -> Row: + data = { + 'username': username, + 'hash': self.hasher.hash(password), + 'handle': handle, + 'created': datetime.now(tz = timezone.utc) + } + + with self.exec_statement('put-user', data) as cur: + return cur.one() + + + def del_user(self, username: str) -> None: + user = self.get_user(username) + + with self.exec_statement('del-user', {'value': user['username']}): + pass + + with self.exec_statement('del-token-user', {'username': user['username']}): + pass + + + def get_token(self, code: str) -> Row: + with self.exec_statement('get-token', {'code': code}) as cur: + return cur.one() + + + def put_token(self, username: str) -> Row: + data = { + 'code': uuid4().hex, + 'user': username, + 'created': datetime.now(tz = timezone.utc) + } + + with self.exec_statement('put-token', data) as cur: + return cur.one() + + + def del_token(self, code: str) -> None: + with self.exec_statement('del-token', {'code': code}): + pass + + def get_domain_ban(self, domain: str) -> Row: if domain.startswith('http'): domain = urlparse(domain).netloc diff --git a/relay/database/schema.py b/relay/database/schema.py index 7b6d927..d4c51a4 100644 --- a/relay/database/schema.py +++ b/relay/database/schema.py @@ -10,7 +10,7 @@ if typing.TYPE_CHECKING: from collections.abc import Callable -VERSIONS: list[Callable] = [] +VERSIONS: dict[int, Callable] = {} TABLES: list[Table] = [ Table( 'config', @@ -45,12 +45,25 @@ TABLES: list[Table] = [ Column('reason', 'text'), Column('note', 'text'), Column('created', 'timestamp', nullable = False) + ), + Table( + 'users', + Column('username', 'text', primary_key = True, unique = True, nullable = False), + Column('hash', 'text', nullable = False), + Column('handle', 'text'), + Column('created', 'timestamp', nullable = False) + ), + Table( + 'tokens', + Column('code', 'text', primary_key = True, unique = True, nullable = False), + Column('user', 'text', nullable = False), + Column('created', 'timestmap', nullable = False) ) ] -def version(func: Callable) -> Callable: - ver = int(func.replace('migrate_', '')) +def migration(func: Callable) -> Callable: + ver = int(func.__name__.replace('migrate_', '')) VERSIONS[ver] = func return func @@ -58,3 +71,8 @@ def version(func: Callable) -> Callable: def migrate_0(conn: Connection) -> None: conn.create_tables(TABLES) conn.put_config('schema-version', get_default_value('schema-version')) + + +@migration +def migrate_20240206(conn: Connection) -> None: + conn.create_tables(TABLES) diff --git a/relay/manage.py b/relay/manage.py index 2a78b6f..e8aab1b 100644 --- a/relay/manage.py +++ b/relay/manage.py @@ -370,6 +370,112 @@ def cli_config_set(ctx: click.Context, key: str, value: Any) -> None: print(f'{key}: {repr(new_value)}') +@cli.group('user') +def cli_user() -> None: + 'Manage local users' + + +@cli_user.command('list') +@click.pass_context +def cli_user_list(ctx: click.Context) -> None: + 'List all local users' + + click.echo('Users:') + + with ctx.obj.database.connection() as conn: + for user in conn.execute('SELECT * FROM users'): + click.echo(f'- {user["username"]}') + + +@cli_user.command('create') +@click.argument('username') +@click.argument('handle', required = False) +@click.pass_context +def cli_user_create(ctx: click.Context, username: str, handle: str) -> None: + 'Create a new local user' + + with ctx.obj.database.connection() as conn: + if conn.get_user(username): + click.echo(f'User already exists: {username}') + return + + while True: + if not (password := click.prompt('New password', hide_input = True)): + click.echo('No password provided') + continue + + if password != click.prompt('New password again', hide_input = True): + click.echo('Passwords do not match') + continue + + break + + conn.put_user(username, password, handle) + + click.echo(f'Created user "{username}"') + + +@cli_user.command('delete') +@click.argument('username') +@click.pass_context +def cli_user_delete(ctx: click.Context, username: str) -> None: + 'Delete a local user' + + with ctx.obj.database.connection() as conn: + if not conn.get_user(username): + click.echo(f'User does not exist: {username}') + return + + conn.del_user(username) + + click.echo(f'Deleted user "{username}"') + + +@cli_user.command('list-tokens') +@click.argument('username') +@click.pass_context +def cli_user_list_tokens(ctx: click.Context, username: str) -> None: + 'List all API tokens for a user' + + click.echo(f'Tokens for "{username}":') + + with ctx.obj.database.connection() as conn: + for token in conn.execute('SELECT * FROM tokens WHERE user = :user', {'user': username}): + click.echo(f'- {token["code"]}') + + +@cli_user.command('create-token') +@click.argument('username') +@click.pass_context +def cli_user_create_token(ctx: click.Context, username: str) -> None: + 'Create a new API token for a user' + + with ctx.obj.database.connection() as conn: + if not (user := conn.get_user(username)): + click.echo(f'User does not exist: {username}') + return + + token = conn.put_token(user['username']) + + click.echo(f'New token for "{username}": {token["code"]}') + + +@cli_user.command('delete-token') +@click.argument('code') +@click.pass_context +def cli_user_delete_token(ctx: click.Context, code: str) -> None: + 'Delete an API token' + + with ctx.obj.database.connection() as conn: + if not conn.get_token(code): + click.echo('Token does not exist') + return + + conn.del_token(code) + + click.echo('Deleted token') + + @cli.group('inbox') def cli_inbox() -> None: 'Manage the inboxes in the database' diff --git a/relay/misc.py b/relay/misc.py index e71845d..bb78957 100644 --- a/relay/misc.py +++ b/relay/misc.py @@ -5,12 +5,8 @@ import os import socket import typing -from aiohttp.abc import AbstractView -from aiohttp.hdrs import METH_ALL as METHODS from aiohttp.web import Response as AiohttpResponse -from aiohttp.web_exceptions import HTTPMethodNotAllowed from aputils.message import Message as ApMessage -from functools import cached_property from uuid import uuid4 if typing.TYPE_CHECKING: @@ -203,7 +199,7 @@ class Response(AiohttpResponse): if isinstance(body, bytes): kwargs['body'] = body - elif isinstance(body, dict) and ctype in {'json', 'activity'}: + elif isinstance(body, (dict, list, tuple, set)) and ctype in {'json', 'activity'}: kwargs['text'] = json.dumps(body) else: @@ -232,67 +228,3 @@ class Response(AiohttpResponse): @location.setter def location(self, value: str) -> None: self.headers['Location'] = value - - -class View(AbstractView): - def __await__(self) -> Generator[Response]: - if self.request.method not in METHODS: - raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) - - if not (handler := self.handlers.get(self.request.method)): - raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) from None - - return self._run_handler(handler).__await__() - - - async def _run_handler(self, handler: Awaitable) -> Response: - with self.database.config.connection_class(self.database) as conn: - # todo: remove on next tinysql release - conn.open() - - return await handler(self.request, conn, **self.request.match_info) - - - @cached_property - def allowed_methods(self) -> tuple[str]: - return tuple(self.handlers.keys()) - - - @cached_property - def handlers(self) -> dict[str, Coroutine]: - data = {} - - for method in METHODS: - try: - data[method] = getattr(self, method.lower()) - - except AttributeError: - continue - - return data - - - # app components - @property - def app(self) -> Application: - return self.request.app - - - @property - def cache(self) -> Cache: - return self.app.cache - - - @property - def client(self) -> HttpClient: - return self.app.client - - - @property - def config(self) -> Config: - return self.app.config - - - @property - def database(self) -> Database: - return self.app.database diff --git a/relay/processors.py b/relay/processors.py index d9780d1..7fc3423 100644 --- a/relay/processors.py +++ b/relay/processors.py @@ -170,7 +170,7 @@ processors = { } -async def run_processor(view: ActorView, conn: Connection) -> None: +async def run_processor(view: ActorView) -> None: if view.message.type not in processors: logging.verbose( 'Message type "%s" from actor cannot be handled: %s', @@ -180,8 +180,8 @@ async def run_processor(view: ActorView, conn: Connection) -> None: return - if view.instance: - with conn.transaction(): + with view.database.connection(False) as conn: + if view.instance: if not view.instance['software']: if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])): view.instance = conn.update_inbox( @@ -195,5 +195,5 @@ async def run_processor(view: ActorView, conn: Connection) -> None: actor = view.actor.id ) - logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id) - await processors[view.message.type](view, conn) + logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id) + await processors[view.message.type](view, conn) diff --git a/relay/views/__init__.py b/relay/views/__init__.py new file mode 100644 index 0000000..6366592 --- /dev/null +++ b/relay/views/__init__.py @@ -0,0 +1,4 @@ +from __future__ import annotations + +from . import activitypub, api, frontend, misc +from .base import VIEWS diff --git a/relay/views.py b/relay/views/activitypub.py similarity index 51% rename from relay/views.py rename to relay/views/activitypub.py index cb648a2..70f759a 100644 --- a/relay/views.py +++ b/relay/views/activitypub.py @@ -1,95 +1,27 @@ from __future__ import annotations -import subprocess import traceback import typing from aputils.errors import SignatureFailureError from aputils.misc import Digest, HttpDate, Signature -from aputils.objects import Nodeinfo, Webfinger, WellKnownNodeinfo -from pathlib import Path +from aputils.objects import Webfinger -from . import __version__ -from . import logger as logging -from .database.connection import Connection -from .misc import Message, Response, View -from .processors import run_processor +from .base import View, register_route + +from .. import logger as logging +from ..misc import Message, Response +from ..processors import run_processor if typing.TYPE_CHECKING: from aiohttp.web import Request from aputils.signer import Signer - from collections.abc import Callable from tinysql import Row - - -VIEWS = [] -VERSION = __version__ -HOME_TEMPLATE = """ - - ActivityPub Relay at {host} - - - -

This is an Activity Relay for fediverse instances.

-

{note}

-

- You may subscribe to this relay with the address: - https://{host}/actor -

-

- To host your own relay, you may download the code at this address: - - https://git.pleroma.social/pleroma/relay - -

-

List of {count} registered instances:
{targets}

- -""" - - -if Path(__file__).parent.parent.joinpath('.git').exists(): - try: - commit_label = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode('ascii') - VERSION = f'{__version__} {commit_label}' - - except Exception: - pass - - -def register_route(*paths: str) -> Callable: - def wrapper(view: View) -> View: - for path in paths: - VIEWS.append([path, view]) - - return View - return wrapper + from ..database.connection import Connection # pylint: disable=unused-argument -@register_route('/') -class HomeView(View): - async def get(self, request: Request, conn: Connection) -> Response: - config = conn.get_config_all() - inboxes = conn.execute('SELECT * FROM inboxes').all() - - text = HOME_TEMPLATE.format( - host = self.config.domain, - note = config['note'], - count = len(inboxes), - targets = '
'.join(inbox['domain'] for inbox in inboxes) - ) - - return Response.new(text, ctype='html') - - - @register_route('/actor', '/inbox') class ActorView(View): def __init__(self, request: Request): @@ -102,7 +34,7 @@ class ActorView(View): self.signer: Signer = None - async def get(self, request: Request, conn: Connection) -> Response: + async def get(self, request: Request) -> Response: data = Message.new_actor( host = self.config.domain, pubkey = self.app.signer.pubkey @@ -111,35 +43,36 @@ class ActorView(View): return Response.new(data, ctype='activity') - async def post(self, request: Request, conn: Connection) -> Response: + async def post(self, request: Request) -> Response: if response := await self.get_post_data(): return response - self.instance = conn.get_inbox(self.actor.shared_inbox) - config = conn.get_config_all() + with self.database.connection(False) as conn: + self.instance = conn.get_inbox(self.actor.shared_inbox) + config = conn.get_config_all() - ## reject if the actor isn't whitelisted while the whiltelist is enabled - if config['whitelist-enabled'] and not conn.get_domain_whitelist(self.actor.domain): - logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id) - return Response.new_error(403, 'access denied', 'json') + ## reject if the actor isn't whitelisted while the whiltelist is enabled + if config['whitelist-enabled'] and not conn.get_domain_whitelist(self.actor.domain): + logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id) + return Response.new_error(403, 'access denied', 'json') - ## reject if actor is banned - if conn.get_domain_ban(self.actor.domain): - logging.verbose('Ignored request from banned actor: %s', self.actor.id) - return Response.new_error(403, 'access denied', 'json') + ## reject if actor is banned + if conn.get_domain_ban(self.actor.domain): + logging.verbose('Ignored request from banned actor: %s', self.actor.id) + return Response.new_error(403, 'access denied', 'json') - ## reject if activity type isn't 'Follow' and the actor isn't following - if self.message.type != 'Follow' and not self.instance: - logging.verbose( - 'Rejected actor for trying to post while not following: %s', - self.actor.id - ) + ## reject if activity type isn't 'Follow' and the actor isn't following + if self.message.type != 'Follow' and not self.instance: + logging.verbose( + 'Rejected actor for trying to post while not following: %s', + self.actor.id + ) - return Response.new_error(401, 'access denied', 'json') + return Response.new_error(401, 'access denied', 'json') logging.debug('>> payload %s', self.message.to_json(4)) - await run_processor(self, conn) + await run_processor(self) return Response.new(status = 202) @@ -230,7 +163,7 @@ class ActorView(View): @register_route('/.well-known/webfinger') class WebfingerView(View): - async def get(self, request: Request, conn: Connection) -> Response: + async def get(self, request: Request) -> Response: try: subject = request.query['resource'] @@ -247,31 +180,3 @@ class WebfingerView(View): ) return Response.new(data, ctype = 'json') - - -@register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}') -class NodeinfoView(View): - # pylint: disable=no-self-use - async def get(self, request: Request, conn: Connection, niversion: str) -> Response: - inboxes = conn.execute('SELECT * FROM inboxes').all() - - data = { - 'name': 'activityrelay', - 'version': VERSION, - 'protocols': ['activitypub'], - 'open_regs': not conn.get_config('whitelist-enabled'), - 'users': 1, - 'metadata': {'peers': [inbox['domain'] for inbox in inboxes]} - } - - if niversion == '2.1': - data['repo'] = 'https://git.pleroma.social/pleroma/relay' - - return Response.new(Nodeinfo.new(**data), ctype = 'json') - - -@register_route('/.well-known/nodeinfo') -class WellknownNodeinfoView(View): - async def get(self, request: Request, conn: Connection) -> Response: - data = WellKnownNodeinfo.new_template(self.config.domain) - return Response.new(data, ctype = 'json') diff --git a/relay/views/api.py b/relay/views/api.py new file mode 100644 index 0000000..ade1d3b --- /dev/null +++ b/relay/views/api.py @@ -0,0 +1,421 @@ +from __future__ import annotations + +import typing + +from aiohttp import web +from argon2.exceptions import VerifyMismatchError +from datetime import datetime, timezone +from urllib.parse import urlparse + +from .base import View, register_route + +from .. import __version__ +from .. import logger as logging +from ..database.config import CONFIG_DEFAULTS +from ..misc import Message, Response + +if typing.TYPE_CHECKING: + from aiohttp.web import Request + from collections.abc import Coroutine + from ..database.connection import Connection + + +CONFIG_IGNORE = ( + 'schema-version', + 'private-key' +) + +CONFIG_VALID = {key for key in CONFIG_DEFAULTS if key not in CONFIG_IGNORE} + +PUBLIC_API_PATHS: tuple[tuple[str, str]] = ( + ('GET', '/api/v1/relay'), + ('GET', '/api/v1/instance'), + ('POST', '/api/v1/token') +) + + +def check_api_path(method: str, path: str) -> bool: + if (method, path) in PUBLIC_API_PATHS: + return False + + return path.startswith('/api') + + +@web.middleware +async def handle_api_path(request: web.Request, handler: Coroutine) -> web.Response: + try: + request['token'] = request.headers['Authorization'].replace('Bearer', '').strip() + + with request.app.database.connection() as conn: + request['user'] = conn.get_user_by_token(request['token']) + + except (KeyError, ValueError): + request['token'] = None + request['user'] = None + + if check_api_path(request.method, request.path): + if not request['token']: + return Response.new_error(401, 'Missing token', 'json') + + if not request['user']: + return Response.new_error(401, 'Invalid token', 'json') + + return await handler(request) + + +# pylint: disable=no-self-use,unused-argument + +@register_route('/api/v1/token') +class Login(View): + async def get(self, request: Request) -> Response: + return Response.new({'message': 'Token valid :3'}) + + + async def post(self, request: Request) -> Response: + data = await self.get_api_data(['username', 'password'], []) + + if isinstance(data, Response): + return data + + with self.database.connction(True) as conn: + if not (user := conn.get_user(data['username'])): + return Response.new_error(401, 'User not found', 'json') + + try: + conn.hasher.verify(user['hash'], data['password']) + + except VerifyMismatchError: + return Response.new_error(401, 'Invalid password', 'json') + + token = conn.put_token(data['username']) + + return Response.new({'token': token['code']}, ctype = 'json') + + + async def delete(self, request: Request) -> Response: + with self.database.connection(True) as conn: + conn.del_token(request['token']) + + return Response.new({'message': 'Token revoked'}, ctype = 'json') + + +@register_route('/api/v1/relay') +class RelayInfo(View): + async def get(self, request: Request) -> Response: + with self.database.connection(False) as conn: + config = conn.get_config_all() + inboxes = [row['domain'] for row in conn.execute('SELECT * FROM inboxes')] + + data = { + 'domain': self.config.domain, + 'name': config['name'], + 'description': config['note'], + 'version': __version__, + 'whitelist_enabled': config['whitelist-enabled'], + 'email': None, + 'admin': None, + 'icon': None, + 'instances': inboxes + } + + return Response.new(data, ctype = 'json') + + +@register_route('/api/v1/config') +class Config(View): + async def get(self, request: Request) -> Response: + with self.database.connection(False) as conn: + data = conn.get_config_all() + data['log-level'] = data['log-level'].name + + for key in CONFIG_IGNORE: + del data[key] + + return Response.new(data, ctype = 'json') + + + async def post(self, request: Request) -> Response: + data = await self.get_api_data(['key', 'value'], []) + + if isinstance(data, Response): + return data + + if data['key'] not in CONFIG_VALID: + return Response.new_error(400, 'Invalid key', 'json') + + with self.database.connection(True) as conn: + conn.put_config(data['key'], data['value']) + + return Response.new({'message': 'Updated config'}, ctype = 'json') + + + async def delete(self, request: Request) -> Response: + data = await self.get_api_data(['key'], []) + + if isinstance(data, Response): + return data + + if data['key'] not in CONFIG_VALID: + return Response.new_error(400, 'Invalid key', 'json') + + with self.database.connection(True) as conn: + conn.put_config(data['key'], CONFIG_DEFAULTS[data['key']][1]) + + return Response.new({'message': 'Updated config'}, ctype = 'json') + + +@register_route('/api/v1/instance') +class Inbox(View): + async def get(self, request: Request) -> Response: + data = [] + + with self.database.connection(False) as conn: + for inbox in conn.execute('SELECT * FROM inboxes'): + try: + created = datetime.fromtimestamp(inbox['created'], tz = timezone.utc) + + except TypeError: + created = datetime.fromisoformat(inbox['created']) + + inbox['created'] = created.isoformat() + data.append(inbox) + + return Response.new(data, ctype = 'json') + + + async def post(self, request: Request) -> Response: + data = await self.get_api_data(['actor'], ['inbox', 'software', 'followid']) + + if isinstance(data, Response): + return data + + data['domain'] = urlparse(data["actor"]).netloc + + with self.database.connection(True) as conn: + if conn.get_inbox(data['domain']): + return Response.new_error(404, 'Instance already in database', 'json') + + if not data.get('inbox'): + try: + actor_data = await self.client.get( + data['actor'], + sign_headers = True, + loads = Message.parse + ) + + data['inbox'] = actor_data.shared_inbox + + except Exception as e: + logging.error('Failed to fetch actor: %s', str(e)) + return Response.new_error(500, 'Failed to fetch actor', 'json') + + row = conn.put_inbox(**data) + + return Response.new(row, ctype = 'json') + + +@register_route('/api/v1/instance/{domain}') +class InboxSingle(View): + async def get(self, request: Request, domain: str) -> Response: + with self.database.connection(False) as conn: + if not (row := conn.get_inbox(domain)): + return Response.new_error(404, 'Instance with domain not found', 'json') + + row['created'] = datetime.fromtimestamp(row['created'], tz = timezone.utc).isoformat() + return Response.new(row, ctype = 'json') + + + async def patch(self, request: Request, domain: str) -> Response: + with self.database.connection(True) as conn: + if not conn.get_inbox(domain): + return Response.new_error(404, 'Instance with domain not found', 'json') + + data = await self.get_api_data([], ['actor', 'software', 'followid']) + + if isinstance(data, Response): + return data + + if not (instance := conn.get_inbox(domain)): + return Response.new_error(404, 'Instance with domain not found', 'json') + + instance = conn.update_inbox(instance['inbox'], **data) + + return Response.new(instance, ctype = 'json') + + + async def delete(self, request: Request, domain: str) -> Response: + with self.database.connection(True) as conn: + if not conn.get_inbox(domain): + return Response.new_error(404, 'Instance with domain not found', 'json') + + conn.del_inbox(domain) + + return Response.new({'message': 'Deleted instance'}, ctype = 'json') + + +@register_route('/api/v1/domain_ban') +class DomainBan(View): + async def get(self, request: Request) -> Response: + with self.database.connection(False) as conn: + bans = conn.execute('SELECT * FROM domain_bans').all() + + return Response.new(bans, ctype = 'json') + + + async def post(self, request: Request) -> Response: + data = await self.get_api_data(['domain'], ['note', 'reason']) + + if isinstance(data, Response): + return data + + with self.database.connection(True) as conn: + if conn.get_domain_ban(data['domain']): + return Response.new_error(400, 'Domain already banned', 'json') + + ban = conn.put_domain_ban(**data) + + return Response.new(ban, ctype = 'json') + + +@register_route('/api/v1/domain_ban/{domain}') +class DomainBanSingle(View): + async def get(self, request: Request, domain: str) -> Response: + with self.database.connection(False) as conn: + if not (ban := conn.get_domain_ban(domain)): + return Response.new_error(404, 'Domain ban not found', 'json') + + return Response.new(ban, ctype = 'json') + + + async def patch(self, request: Request, domain: str) -> Response: + with self.database.connection(True) as conn: + if not conn.get_domain_ban(domain): + return Response.new_error(404, 'Domain not banned', 'json') + + data = await self.get_api_data([], ['note', 'reason']) + + if isinstance(data, Response): + return data + + if not any([data.get('note'), data.get('reason')]): + return Response.new_error(400, 'Must include note and/or reason parameters', 'json') + + ban = conn.update_domain_ban(domain, **data) + + return Response.new(ban, ctype = 'json') + + + async def delete(self, request: Request, domain: str) -> Response: + with self.database.connection(True) as conn: + if not conn.get_domain_ban(domain): + return Response.new_error(404, 'Domain not banned', 'json') + + conn.del_domain_ban(domain) + + return Response.new({'message': 'Unbanned domain'}, ctype = 'json') + + +@register_route('/api/v1/software_ban') +class SoftwareBan(View): + async def get(self, request: Request) -> Response: + with self.database.connection(False) as conn: + bans = conn.execute('SELECT * FROM software_bans').all() + + return Response.new(bans, ctype = 'json') + + + async def post(self, request: Request) -> Response: + data = await self.get_api_data(['name'], ['note', 'reason']) + + if isinstance(data, Response): + return data + + with self.database.connection(True) as conn: + if conn.get_software_ban(data['name']): + return Response.new_error(400, 'Domain already banned', 'json') + + ban = conn.put_software_ban(**data) + + return Response.new(ban, ctype = 'json') + + +@register_route('/api/v1/software_ban/{name}') +class SoftwareBanSingle(View): + async def get(self, request: Request, name: str) -> Response: + with self.database.connection(False) as conn: + if not (ban := conn.get_software_ban(name)): + return Response.new_error(404, 'Software ban not found', 'json') + + return Response.new(ban, ctype = 'json') + + + async def patch(self, request: Request, name: str) -> Response: + with self.database.connection(True) as conn: + if not conn.get_software_ban(name): + return Response.new_error(404, 'Software not banned', 'json') + + data = await self.get_api_data([], ['note', 'reason']) + + if isinstance(data, Response): + return data + + if not any([data.get('note'), data.get('reason')]): + return Response.new_error(400, 'Must include note and/or reason parameters', 'json') + + ban = conn.update_software_ban(name, **data) + + return Response.new(ban, ctype = 'json') + + + async def delete(self, request: Request, name: str) -> Response: + with self.database.connection(True) as conn: + if not conn.get_software_ban(name): + return Response.new_error(404, 'Software not banned', 'json') + + conn.del_software_ban(name) + + return Response.new({'message': 'Unbanned software'}, ctype = 'json') + + +@register_route('/api/v1/whitelist') +class Whitelist(View): + async def get(self, request: Request) -> Response: + with self.database.connection(False) as conn: + items = conn.execute('SELECT * FROM whitelist').all() + + return Response.new(items, ctype = 'json') + + + async def post(self, request: Request) -> Response: + data = await self.get_api_data(['domain'], []) + + if isinstance(data, Response): + return data + + with self.database.connection(True) as conn: + if conn.get_domain_whitelist(data['domain']): + return Response.new_error(400, 'Domain already added to whitelist', 'json') + + item = conn.put_domain_whitelist(**data) + + return Response.new(item, ctype = 'json') + + +@register_route('/api/v1/domain/{domain}') +class WhitelistSingle(View): + async def get(self, request: Request, domain: str) -> Response: + with self.database.connection(False) as conn: + if not (item := conn.get_domain_whitelist(domain)): + return Response.new_error(404, 'Domain not in whitelist', 'json') + + return Response.new(item, ctype = 'json') + + + async def delete(self, request: Request, domain: str) -> Response: + with self.database.connection(False) as conn: + if not conn.get_domain_whitelist(domain): + return Response.new_error(404, 'Domain not in whitelist', 'json') + + conn.del_domain_whitelist(domain) + + return Response.new({'message': 'Removed domain from whitelist'}, ctype = 'json') diff --git a/relay/views/base.py b/relay/views/base.py new file mode 100644 index 0000000..ce72e4b --- /dev/null +++ b/relay/views/base.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +import typing + +from aiohttp.abc import AbstractView +from aiohttp.hdrs import METH_ALL as METHODS +from aiohttp.web import HTTPMethodNotAllowed +from functools import cached_property +from json.decoder import JSONDecodeError + +from ..misc import Response + +if typing.TYPE_CHECKING: + from collections.abc import Callable, Coroutine, Generator + from tinysql import Database + from ..application import Application + from ..cache import Cache + from ..config import Config + from ..http_client import HttpClient + + +VIEWS = [] + + +def register_route(*paths: str) -> Callable: + def wrapper(view: View) -> View: + for path in paths: + VIEWS.append([path, view]) + + return View + return wrapper + + +class View(AbstractView): + def __await__(self) -> Generator[Response]: + if self.request.method not in METHODS: + raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) + + if not (handler := self.handlers.get(self.request.method)): + raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) + + return self._run_handler(handler).__await__() + + + async def _run_handler(self, handler: Coroutine) -> Response: + return await handler(self.request, **self.request.match_info) + + + @cached_property + def allowed_methods(self) -> tuple[str]: + return tuple(self.handlers.keys()) + + + @cached_property + def handlers(self) -> dict[str, Coroutine]: + data = {} + + for method in METHODS: + try: + data[method] = getattr(self, method.lower()) + + except AttributeError: + continue + + return data + + + # app components + @property + def app(self) -> Application: + return self.request.app + + + @property + def cache(self) -> Cache: + return self.app.cache + + + @property + def client(self) -> HttpClient: + return self.app.client + + + @property + def config(self) -> Config: + return self.app.config + + + @property + def database(self) -> Database: + return self.app.database + + + async def get_api_data(self, + required: list[str], + optional: list[str]) -> dict[str, str] | Response: + + if self.request.content_type in {'x-www-form-urlencoded', 'multipart/form-data'}: + post_data = await self.request.post() + + elif self.request.content_type == 'application/json': + try: + post_data = await self.request.json() + + except JSONDecodeError: + return Response.new_error(400, 'Invalid JSON data', 'json') + + else: + post_data = self.request.query + + data = {} + + try: + for key in required: + data[key] = post_data[key] + + except KeyError as e: + return Response.new_error(400, f'Missing {str(e)} pararmeter', 'json') + + for key in optional: + data[key] = post_data.get(key) + + return data diff --git a/relay/views/frontend.py b/relay/views/frontend.py new file mode 100644 index 0000000..663edd4 --- /dev/null +++ b/relay/views/frontend.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import typing + +from .base import View, register_route + +from .. import __version__ +from ..misc import Response + +if typing.TYPE_CHECKING: + from aiohttp.web import Request + from aputils.signer import Signer + from collections.abc import Callable + from tinysql import Row + from ..database.connection import Connection + + +HOME_TEMPLATE = """ + + ActivityPub Relay at {host} + + + +

This is an Activity Relay for fediverse instances.

+

{note}

+

+ You may subscribe to this relay with the address: + https://{host}/actor +

+

+ To host your own relay, you may download the code at this address: + + https://git.pleroma.social/pleroma/relay + +

+

List of {count} registered instances:
{targets}

+ +""" + + +# pylint: disable=unused-argument + +@register_route('/') +class HomeView(View): + async def get(self, request: Request) -> Response: + with self.database.connection(False) as conn: + config = conn.get_config_all() + inboxes = conn.execute('SELECT * FROM inboxes').all() + + text = HOME_TEMPLATE.format( + host = self.config.domain, + note = config['note'], + count = len(inboxes), + targets = '
'.join(inbox['domain'] for inbox in inboxes) + ) + + return Response.new(text, ctype='html') diff --git a/relay/views/misc.py b/relay/views/misc.py new file mode 100644 index 0000000..7c2e65c --- /dev/null +++ b/relay/views/misc.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import subprocess +import typing + +from aputils.objects import Nodeinfo, WellKnownNodeinfo +from pathlib import Path + +from .base import View, register_route + +from .. import __version__ +from ..misc import Response + +if typing.TYPE_CHECKING: + from aiohttp.web import Request + from ..database.connection import Connection + + +VERSION = __version__ + + +if Path(__file__).parent.parent.joinpath('.git').exists(): + try: + commit_label = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode('ascii') + VERSION = f'{__version__} {commit_label}' + + except Exception: + pass + + +# pylint: disable=unused-argument + +@register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}') +class NodeinfoView(View): + # pylint: disable=no-self-use + async def get(self, request: Request, niversion: str) -> Response: + with self.database.connection(False) as conn: + inboxes = conn.execute('SELECT * FROM inboxes').all() + + data = { + 'name': 'activityrelay', + 'version': VERSION, + 'protocols': ['activitypub'], + 'open_regs': not conn.get_config('whitelist-enabled'), + 'users': 1, + 'metadata': {'peers': [inbox['domain'] for inbox in inboxes]} + } + + if niversion == '2.1': + data['repo'] = 'https://git.pleroma.social/pleroma/relay' + + return Response.new(Nodeinfo.new(**data), ctype = 'json') + + +@register_route('/.well-known/nodeinfo') +class WellknownNodeinfoView(View): + async def get(self, request: Request) -> Response: + data = WellKnownNodeinfo.new_template(self.config.domain) + + return Response.new(data, ctype = 'json') diff --git a/requirements.txt b/requirements.txt index 5239cb8..16bd9dc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,11 @@ aiohttp>=3.9.1 aputils@https://git.barkshark.xyz/barkshark/aputils/archive/0.1.6a.tar.gz +argon2-cffi==23.1.0 click>=8.1.2 gunicorn==21.1.0 hiredis==2.3.2 pyyaml>=6.0 redis==5.0.1 -tinysql[postgres]@https://git.barkshark.xyz/barkshark/tinysql/archive/0.2.4.tar.gz +tinysql[postgres]@https://git.barkshark.xyz/barkshark/tinysql/archive/f8db814084dded0a46bd3a9576e09fca860f2166.tar.gz importlib_resources==6.1.1;python_version<'3.9'