From d10c864a00a46ee4e0ad626ae6730da285e9e619 Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Fri, 9 Feb 2024 16:25:34 -0500 Subject: [PATCH] add public api views --- relay/application.py | 9 +- relay/data/statements.sql | 63 ++++++- relay/database/__init__.py | 6 +- relay/database/config.py | 3 +- relay/database/connection.py | 59 ++++++ relay/database/schema.py | 24 ++- relay/manage.py | 110 ++++++++++++ relay/misc.py | 2 +- relay/views/__init__.py | 2 +- relay/views/api.py | 338 +++++++++++++++++++++++++++++++++++ relay/views/base.py | 33 +++- requirements.txt | 1 + 12 files changed, 627 insertions(+), 23 deletions(-) create mode 100644 relay/views/api.py diff --git a/relay/application.py b/relay/application.py index 80efed9..20695d7 100644 --- a/relay/application.py +++ b/relay/application.py @@ -20,6 +20,7 @@ from .database import get_database from .http_client import HttpClient from .misc import check_open_port from .views import VIEWS +from .views.api import handle_api_path if typing.TYPE_CHECKING: from collections.abc import Awaitable @@ -35,7 +36,11 @@ class Application(web.Application): DEFAULT: Application = None def __init__(self, cfgpath: str, gunicorn: bool = False): - web.Application.__init__(self) + web.Application.__init__(self, + middlewares = [ + handle_api_path + ] + ) Application.DEFAULT = self @@ -219,6 +224,6 @@ async def main_gunicorn(): except KeyError: logging.error('Failed to set "CONFIG_FILE" environment. Trying to run without gunicorn?') - raise + raise RuntimeError from None return app diff --git a/relay/data/statements.sql b/relay/data/statements.sql index 9bcea41..0f98f64 100644 --- a/relay/data/statements.sql +++ b/relay/data/statements.sql @@ -1,47 +1,92 @@ -- name: get-config -SELECT * FROM config WHERE key = :key +SELECT * FROM config WHERE key = :key; -- name: get-config-all -SELECT * FROM config +SELECT * FROM config; -- name: put-config INSERT INTO config (key, value, type) VALUES (:key, :value, :type) ON CONFLICT (key) DO UPDATE SET value = :value -RETURNING * +RETURNING *; -- name: del-config DELETE FROM config -WHERE key = :key +WHERE key = :key; -- name: get-inbox -SELECT * FROM inboxes WHERE domain = :value or inbox = :value or actor = :value +SELECT * FROM inboxes WHERE domain = :value or inbox = :value or actor = :value; -- name: put-inbox INSERT INTO inboxes (domain, actor, inbox, followid, software, created) VALUES (:domain, :actor, :inbox, :followid, :software, :created) ON CONFLICT (domain) DO UPDATE SET followid = :followid -RETURNING * +RETURNING *; -- name: del-inbox DELETE FROM inboxes -WHERE domain = :value or inbox = :value or actor = :value +WHERE domain = :value or inbox = :value or actor = :value; + + +-- name: get-user +SELECT * FROM users +WHERE username = :value or handle = :value; + + +-- name: get-user-by-token +SELECT * FROM users +WHERE username = ( + SELECT user FROM tokens + WHERE code = :code +); + + +-- name: put-user +INSERT INTO users (username, hash, handle, created) +VALUES (:username, :hash, :handle, :created) +RETURNING *; + + +-- name: del-user +DELETE FROM users +WHERE username = :value or handle = :value; + + +-- name: get-token +SELECT * FROM tokens +WHERE code = :code; + + +-- name: put-token +INSERT INTO tokens (code, user, created) +VALUES (:code, :user, :created) +RETURNING *; + + +-- name: del-token +DELETE FROM tokens +WHERE code = :code; + + +-- name: del-token-user +DELETE FROM tokens +WHERE user = :username -- name: get-software-ban -SELECT * FROM software_bans WHERE name = :name +SELECT * FROM software_bans WHERE name = :name; -- name: put-software-ban INSERT INTO software_bans (name, reason, note, created) VALUES (:name, :reason, :note, :created) -RETURNING * +RETURNING *; -- name: del-software-ban diff --git a/relay/database/__init__.py b/relay/database/__init__.py index c403092..facea97 100644 --- a/relay/database/__init__.py +++ b/relay/database/__init__.py @@ -52,14 +52,10 @@ def get_database(config: Config, migrate: bool = True) -> tinysql.Database: if (schema_ver := conn.get_config('schema-version')) < get_default_value('schema-version'): logging.info("Migrating database from version '%i'", schema_ver) - for ver, func in VERSIONS: + for ver, func in VERSIONS.items(): if schema_ver < ver: - conn.begin() - func(conn) - conn.put_config('schema-version', ver) - conn.commit() if (privkey := conn.get_config('private-key')): conn.app.signer = privkey diff --git a/relay/database/config.py b/relay/database/config.py index 8fe3b4c..b69f13e 100644 --- a/relay/database/config.py +++ b/relay/database/config.py @@ -11,8 +11,9 @@ if typing.TYPE_CHECKING: CONFIG_DEFAULTS: dict[str, tuple[str, Any]] = { - 'schema-version': ('int', 20240119), + 'schema-version': ('int', 20240206), 'log-level': ('loglevel', logging.LogLevel.INFO), + 'name': ('str', 'ActivityRelay'), 'note': ('str', 'Make a note about your instance here.'), 'private-key': ('str', None), 'whitelist-enabled': ('bool', False) diff --git a/relay/database/connection.py b/relay/database/connection.py index ac5a364..718861d 100644 --- a/relay/database/connection.py +++ b/relay/database/connection.py @@ -3,8 +3,10 @@ from __future__ import annotations import tinysql import typing +from argon2 import PasswordHasher from datetime import datetime, timezone from urllib.parse import urlparse +from uuid import uuid4 from .config import CONFIG_DEFAULTS, get_default_type, get_default_value, serialize, deserialize @@ -28,6 +30,10 @@ RELAY_SOFTWARE = [ class Connection(tinysql.Connection): + hasher = PasswordHasher( + encoding = 'utf-8' + ) + @property def app(self) -> Application: return get_app() @@ -162,6 +168,59 @@ class Connection(tinysql.Connection): return cur.modified_row_count == 1 + def get_user(self, value: str) -> Row: + with self.exec_statement('get-user', {'value': value}) as cur: + return cur.one() + + + def get_user_by_token(self, code: str) -> Row: + with self.exec_statement('get-user-by-token', {'code': code}) as cur: + return cur.one() + + + def put_user(self, username: str, password: str, handle: str | None = None) -> Row: + data = { + 'username': username, + 'hash': self.hasher.hash(password), + 'handle': handle, + 'created': datetime.now(tz = timezone.utc) + } + + with self.exec_statement('put-user', data) as cur: + return cur.one() + + + def del_user(self, username: str) -> None: + user = self.get_user(username) + + with self.exec_statement('del-user', {'value': user['username']}): + pass + + with self.exec_statement('del-token-user', {'username': user['username']}): + pass + + + def get_token(self, code: str) -> Row: + with self.exec_statement('get-token', {'code': code}) as cur: + return cur.one() + + + def put_token(self, username: str) -> Row: + data = { + 'code': uuid4().hex, + 'user': username, + 'created': datetime.now(tz = timezone.utc) + } + + with self.exec_statement('put-token', data) as cur: + return cur.one() + + + def del_token(self, code: str) -> None: + with self.exec_statement('del-token', {'code': code}): + pass + + def get_domain_ban(self, domain: str) -> Row: if domain.startswith('http'): domain = urlparse(domain).netloc diff --git a/relay/database/schema.py b/relay/database/schema.py index 7b6d927..d4c51a4 100644 --- a/relay/database/schema.py +++ b/relay/database/schema.py @@ -10,7 +10,7 @@ if typing.TYPE_CHECKING: from collections.abc import Callable -VERSIONS: list[Callable] = [] +VERSIONS: dict[int, Callable] = {} TABLES: list[Table] = [ Table( 'config', @@ -45,12 +45,25 @@ TABLES: list[Table] = [ Column('reason', 'text'), Column('note', 'text'), Column('created', 'timestamp', nullable = False) + ), + Table( + 'users', + Column('username', 'text', primary_key = True, unique = True, nullable = False), + Column('hash', 'text', nullable = False), + Column('handle', 'text'), + Column('created', 'timestamp', nullable = False) + ), + Table( + 'tokens', + Column('code', 'text', primary_key = True, unique = True, nullable = False), + Column('user', 'text', nullable = False), + Column('created', 'timestmap', nullable = False) ) ] -def version(func: Callable) -> Callable: - ver = int(func.replace('migrate_', '')) +def migration(func: Callable) -> Callable: + ver = int(func.__name__.replace('migrate_', '')) VERSIONS[ver] = func return func @@ -58,3 +71,8 @@ def version(func: Callable) -> Callable: def migrate_0(conn: Connection) -> None: conn.create_tables(TABLES) conn.put_config('schema-version', get_default_value('schema-version')) + + +@migration +def migrate_20240206(conn: Connection) -> None: + conn.create_tables(TABLES) diff --git a/relay/manage.py b/relay/manage.py index 2a78b6f..039a6f5 100644 --- a/relay/manage.py +++ b/relay/manage.py @@ -370,6 +370,116 @@ def cli_config_set(ctx: click.Context, key: str, value: Any) -> None: print(f'{key}: {repr(new_value)}') +@cli.group('user') +def cli_user() -> None: + 'Manage local users' + + +@cli_user.command('list') +@click.pass_context +def cli_user_list(ctx: click.Context) -> None: + 'List all local users' + + click.echo('Users:') + + with ctx.obj.database.connection() as conn: + for user in conn.execute('SELECT * FROM users'): + click.echo(f'- {user["username"]}') + + +@cli_user.command('create') +@click.argument('username') +@click.argument('handle', required = False) +@click.pass_context +def cli_user_create(ctx: click.Context, username: str, handle: str) -> None: + 'Create a new local user' + + with ctx.obj.database.connection() as conn: + if conn.get_user(username): + click.echo(f'User already exists: {username}') + return + + while True: + password = click.prompt('New password', hide_input = True) + + if not password: + click.echo('No password provided') + continue + + password2 = click.prompt('New password again', hide_input = True) + + if password != password2: + click.echo('Passwords do not match') + continue + + break + + conn.put_user(username, password, handle) + + click.echo(f'Created user "{username}"') + + +@cli_user.command('delete') +@click.argument('username') +@click.pass_context +def cli_user_delete(ctx: click.Context, username: str) -> None: + 'Delete a local user' + + with ctx.obj.database.connection() as conn: + if not conn.get_user(username): + click.echo(f'User does not exist: {username}') + return + + conn.del_user(username) + + click.echo(f'Deleted user "{username}"') + + +@cli_user.command('list-tokens') +@click.argument('username') +@click.pass_context +def cli_user_list_tokens(ctx: click.Context, username: str) -> None: + 'List all API tokens for a user' + + click.echo(f'Tokens for "{username}":') + + with ctx.obj.database.connection() as conn: + for token in conn.execute('SELECT * FROM tokens WHERE user = :user', {'user': username}): + click.echo(f'- {token["code"]}') + + +@cli_user.command('create-token') +@click.argument('username') +@click.pass_context +def cli_user_create_token(ctx: click.Context, username: str) -> None: + 'Create a new API token for a user' + + with ctx.obj.database.connection() as conn: + if not (user := conn.get_user(username)): + click.echo(f'User does not exist: {username}') + return + + token = conn.put_token(user['username']) + + click.echo(f'New token for "{username}": {token["code"]}') + + +@cli_user.command('delete-token') +@click.argument('code') +@click.pass_context +def cli_user_delete_token(ctx: click.Context, code: str) -> None: + 'Delete an API token' + + with ctx.obj.database.connection() as conn: + if not (conn.get_token(code)): + click.echo('Token does not exist') + return + + conn.del_token(code) + + click.echo('Deleted token') + + @cli.group('inbox') def cli_inbox() -> None: 'Manage the inboxes in the database' diff --git a/relay/misc.py b/relay/misc.py index 296082b..bb78957 100644 --- a/relay/misc.py +++ b/relay/misc.py @@ -199,7 +199,7 @@ class Response(AiohttpResponse): if isinstance(body, bytes): kwargs['body'] = body - elif isinstance(body, dict) and ctype in {'json', 'activity'}: + elif isinstance(body, (dict, list, tuple, set)) and ctype in {'json', 'activity'}: kwargs['text'] = json.dumps(body) else: diff --git a/relay/views/__init__.py b/relay/views/__init__.py index 85c2bd3..6366592 100644 --- a/relay/views/__init__.py +++ b/relay/views/__init__.py @@ -1,4 +1,4 @@ from __future__ import annotations -from . import activitypub, frontend, misc +from . import activitypub, api, frontend, misc from .base import VIEWS diff --git a/relay/views/api.py b/relay/views/api.py new file mode 100644 index 0000000..f75caf4 --- /dev/null +++ b/relay/views/api.py @@ -0,0 +1,338 @@ +from __future__ import annotations + +import typing + +from aiohttp import web +from argon2.exceptions import VerifyMismatchError +from datetime import datetime, timezone + +from .base import View, register_route + +from .. import __version__ +from ..database.config import CONFIG_DEFAULTS +from ..misc import Response + +if typing.TYPE_CHECKING: + from aiohttp.web import Request + from collections.abc import Coroutine + from ..database.connection import Connection + + +CONFIG_IGNORE = ( + 'schema-version', + 'private-key' +) + +CONFIG_VALID = {key for key in CONFIG_DEFAULTS if key not in CONFIG_IGNORE} + +PUBLIC_API_PATHS: tuple[tuple[str, str]] = ( + ('GET', '/api/v1/relay'), + ('POST', '/api/v1/token') +) + + +def check_api_path(method: str, path: str) -> bool: + for m, p in PUBLIC_API_PATHS: + if m == method and p == path: + return False + + return path.startswith('/api') + + +@web.middleware +async def handle_api_path(request: web.Request, handler: Coroutine) -> web.Response: + try: + request['token'] = request.headers['Authorization'].replace('Bearer', '').strip() + + with request.app.database.connection() as conn: + request['user'] = conn.get_user_by_token(request['token']) + + except (KeyError, ValueError): + request['token'] = None + request['user'] = None + + if check_api_path(request.method, request.path): + if not request['token']: + return Response.new_error(401, 'Missing token', 'json') + + if not request['user']: + return Response.new_error(401, 'Invalid token', 'json') + + return await handler(request) + + +# pylint: disable=no-self-use,unused-argument + +@register_route('/api/v1/token') +class Login(View): + async def get(self, request: Request, conn: Connection) -> Response: + return Response.new({'message': 'Token valid :3'}) + + + async def post(self, request: Request, conn: Connection) -> Response: + data = await self.get_data(['username', 'password'], []) + + if isinstance(data, Response): + return data + + if not (user := conn.get_user(data['username'])): + return Response.new_error(401, 'User not found', 'json') + + try: + conn.hasher.verify(user['hash'], data['password']) + + except VerifyMismatchError: + return Response.new_error(401, 'Invalid password', 'json') + + token = conn.put_token(data['username']) + return Response.new({'token': token['code']}, ctype = 'json') + + + async def delete(self, request: Request, conn: Connection) -> Response: + conn.del_token(request['token']) + return Response.new({'message': 'Token revoked'}, ctype = 'json') + + +@register_route('/api/v1/relay') +class RelayInfo(View): + async def get(self, request: Request, conn: Connection) -> Response: + config = conn.get_config_all() + inboxes = [row['domain'] for row in conn.execute('SELECT * FROM inboxes')] + + data = { + 'domain': self.config.domain, + 'name': config['name'], + 'description': config['note'], + 'version': __version__, + 'whitelist_enabled': config['whitelist-enabled'], + 'email': None, + 'admin': None, + 'icon': None, + 'instances': inboxes + } + + return Response.new(data, ctype = 'json') + + +@register_route('/api/v1/config') +class Config(View): + async def get(self, request: Request, conn: Connection) -> Response: + data = conn.get_config_all() + data['log-level'] = data['log-level'].name + + for key in CONFIG_IGNORE: + del data[key] + + return Response.new(data, ctype = 'json') + + + async def post(self, request: Request, conn: Connection) -> Response: + data = await self.get_data(['key', 'value'], []) + + if isinstance(data, Response): + return data + + if data['key'] not in CONFIG_VALID: + return Response.new_error(400, 'Invalid key', 'json') + + conn.put_config(data['key'], data['value']) + return Response.new({'message': 'Updated config'}, ctype = 'json') + + + async def delete(self, request: Request, conn: Connection) -> Response: + data = await self.get_data(['key'], []) + + if isinstance(data, Response): + return data + + if data['key'] not in CONFIG_VALID: + return Response.new_error(400, 'Invalid key', 'json') + + conn.put_config(data['key'], CONFIG_DEFAULTS[data['key']][1]) + return Response.new({'message': 'Updated config'}, ctype = 'json') + + +@register_route('/api/v1/inbox') +class Inbox(View): + async def get(self, request: Request, conn: Connection) -> Response: + data = [] + + for inbox in conn.execute('SELECT * FROM inboxes'): + try: + created = datetime.fromtimestamp(inbox['created'], tz = timezone.utc) + + except TypeError: + created = datetime.fromisoformat(inbox['created']) + + inbox['created'] = created.isoformat() + data.append(inbox) + + return Response.new(data, ctype = 'json') + + + async def post(self, request: Request, conn: Connection) -> Response: + data = await self.get_data(['domain', 'inbox', 'actor'], ['software']) + + if isinstance(data, Response): + return data + + if conn.get_inbox(data['domain']): + return Response.new_error(404, 'Inbox already in database', 'json') + + row = conn.put_inbox(**data) + return Response.new(row.to_json(), ctype = 'json') + + +@register_route('/api/v1/inbox/{domain}') +class InboxSingle(View): + async def get(self, request: Request, conn: Connection, domain: str) -> Response: + if not (row := conn.get_inbox(domain)): + return Response.new_error(404, 'Inbox with domain not found', 'json') + + row['created'] = datetime.fromtimestamp(row['created'], tz = timezone.utc).isoformat() + return Response.new(row, ctype = 'json') + + + async def delete(self, request: Request, conn: Connection, domain: str) -> Response: + if not conn.get_inbox(domain): + return Response.new_error(404, 'Inbox with domain not found', 'json') + + conn.del_inbox(domain) + return Response.new({'message': 'Deleted inbox'}, ctype = 'json') + + +@register_route('/api/v1/domain_ban') +class DomainBan(View): + async def get(self, request: Request, conn: Connection) -> Response: + bans = conn.execute('SELECT * FROM domain_bans').all() + return Response.new(bans, ctype = 'json') + + + async def post(self, request: Request, conn: Connection) -> Response: + data = await self.get_data(['domain'], ['note', 'reason']) + + if isinstance(data, Response): + return data + + if conn.get_domain_ban(data['domain']): + return Response.new_error(400, 'Domain already banned', 'json') + + ban = conn.put_domain_ban(**data) + return Response.new(ban, ctype = 'json') + + +@register_route('/api/v1/domain_ban/{domain}') +class DomainBanSingle(View): + async def get(self, request: Request, conn: Connection, domain: str) -> Response: + if not (ban := conn.get_domain_ban(domain)): + return Response.new_error(404, 'Domain ban not found', 'json') + + return Response.new(ban, ctype = 'json') + + + async def patch(self, request: Request, conn: Connection, domain: str) -> Response: + if not conn.get_domain_ban(domain): + return Response.new_error(404, 'Domain not banned', 'json') + + data = await self.get_data(['domain'], ['note', 'reason']) + + if isinstance(data, Response): + return data + + ban = conn.update_domain_ban(**data) + return Response.new(ban, ctype = 'json') + + + async def delete(self, request: Request, conn: Connection, domain: str) -> Response: + if not conn.get_domain_ban(domain): + return Response.new_error(404, 'Domain not banned', 'json') + + conn.del_domain_ban(domain) + return Response.new({'message': 'Unbanned domain'}, ctype = 'json') + + +@register_route('/api/v1/software_ban') +class SoftwareBan(View): + async def get(self, request: Request, conn: Connection) -> Response: + bans = conn.execute('SELECT * FROM software_bans').all() + return Response.new(bans, ctype = 'json') + + + async def post(self, request: Request, conn: Connection) -> Response: + data = await self.get_data(['name'], ['note', 'reason']) + + if isinstance(data, Response): + return data + + if conn.get_software_ban(data['name']): + return Response.new_error(400, 'Domain already banned', 'json') + + ban = conn.put_software_ban(**data) + return Response.new(ban, ctype = 'json') + + +@register_route('/api/v1/software_ban/{name}') +class SoftwareBanSingle(View): + async def get(self, request: Request, conn: Connection, name: str) -> Response: + if not (ban := conn.get_software_ban(name)): + return Response.new_error(404, 'Software ban not found', 'json') + + return Response.new(ban, ctype = 'json') + + + async def post(self, request: Request, conn: Connection, name: str) -> Response: + if not conn.get_software_ban(name): + return Response.new_error(404, 'Software not banned', 'json') + + data = await self.get_data(['name'], ['note', 'reason']) + + if isinstance(data, Response): + return data + + ban = conn.update_software_ban(**data) + return Response.new(ban, ctype = 'json') + + + async def delete(self, request: Request, conn: Connection, domain: str) -> Response: + if not conn.get_software_ban(domain): + return Response.new_error(404, 'Software not banned', 'json') + + conn.del_software_ban(domain) + return Response.new({'message': 'Unbanned software'}, ctype = 'json') + + +@register_route('/api/v1/whitelist') +class Whitelist(View): + async def get(self, request: Request, conn: Connection) -> Response: + items = conn.execute('SELECT * FROM whitelist').all() + return Response.new(items, ctype = 'json') + + + async def post(self, request: Request, conn: Connection) -> Response: + data = await self.get_data(['domain']) + + if isinstance(data, Response): + return data + + if conn.get_domain_whitelist(data['domain']): + return Response.new_error(400, 'Domain already added to whitelist', 'json') + + item = conn.put_domain_whitelist(**data) + return Response.new(item, ctype = 'json') + + +@register_route('/api/v1/domain/{domain}') +class WhitelistSingle(View): + async def get(self, request: Request, conn: Connection, domain: str) -> Response: + if not (item := conn.get_domain_whitelist(domain)): + return Response.new_error(404, 'Domain not in whitelist', 'json') + + return Response.new(item, ctype = 'json') + + + async def delete(self, request: Request, conn: Connection, domain: str) -> Response: + if not conn.get_domain_whitelist(domain): + return Response.new_error(404, 'Domain not in whitelist', 'json') + + conn.del_domain_whitelist(domain) + return Response.new({'message': 'Removed domain from whitelist'}, ctype = 'json') diff --git a/relay/views/base.py b/relay/views/base.py index 95b6562..53d6fd5 100644 --- a/relay/views/base.py +++ b/relay/views/base.py @@ -6,6 +6,9 @@ from aiohttp.abc import AbstractView from aiohttp.hdrs import METH_ALL as METHODS from aiohttp.web import HTTPMethodNotAllowed from functools import cached_property +from json.decoder import JSONDecodeError + +from ..misc import Response if typing.TYPE_CHECKING: from collections.abc import Callable, Coroutine, Generator @@ -14,7 +17,6 @@ if typing.TYPE_CHECKING: from ..cache import Cache from ..config import Config from ..http_client import HttpClient - from ..misc import Response VIEWS = [] @@ -91,3 +93,32 @@ class View(AbstractView): @property def database(self) -> Database: return self.app.database + + + async def get_data(self, required: list[str], optional: list[str]) -> dict[str, str] | Response: + if self.request.content_type in {'x-www-form-urlencoded', 'multipart/form-data'}: + post_data = await self.request.post() + + elif self.request.content_type == 'application/json': + try: + post_data = await self.request.json() + + except JSONDecodeError: + return Response.new_error(400, 'Invalid JSON data', 'json') + + else: + post_data = self.request.query + + data = {} + + try: + for key in required: + data[key] = post_data[key] + + except KeyError as e: + return Response.new_error(400, f'Missing {str(e)} pararmeter', 'json') + + for key in optional: + data[key] = post_data.get(key) + + return data diff --git a/requirements.txt b/requirements.txt index 5239cb8..e7c73fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ aiohttp>=3.9.1 aputils@https://git.barkshark.xyz/barkshark/aputils/archive/0.1.6a.tar.gz +argon2-cffi==23.1.0 click>=8.1.2 gunicorn==21.1.0 hiredis==2.3.2