diff --git a/relay/api_objects.py b/relay/api_objects.py new file mode 100644 index 0000000..a0605a2 --- /dev/null +++ b/relay/api_objects.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +from blib import Date, JsonBase +from bsql import Row +from dataclasses import asdict, dataclass +from typing import TYPE_CHECKING, Any + +from . import logger as logging +from .database import ConfigData +from .misc import utf_to_idna + +if TYPE_CHECKING: + try: + from typing import Self + + except ImportError: + from typing_extensions import Self + + +class ApiObject: + def __str__(self) -> str: + return self.to_json() + + + @classmethod + def from_row(cls: type[Self], row: Row, *exclude: str) -> Self: + return cls(**{k: v for k, v in row.items() if k not in exclude}) + + + def to_dict(self, *exclude: str) -> dict[str, Any]: + return {k: v for k, v in asdict(self).items() if k not in exclude} # type: ignore[call-overload] + + + def to_json(self, *exclude: str, indent: int | str | None = None) -> str: + data = self.to_dict(*exclude) + return JsonBase(data).to_json(indent = indent) + + +@dataclass(slots = True) +class Message(ApiObject): + msg: str + + +@dataclass(slots = True) +class Application(ApiObject): + client_id: str + client_secret: str + name: str + website: str | None + redirect_uri: str + token: str | None + created: Date + updated: Date + + +@dataclass(slots = True) +class Config(ApiObject): + approval_required: bool + log_level: logging.LogLevel + name: str + note: str + theme: str + whitelist_enabled: bool + + + @classmethod + def from_config(cls: type[Self], cfg: ConfigData) -> Self: + return cls( + cfg.approval_required, + cfg.log_level, + cfg.name, + cfg.note, + cfg.theme, + cfg.whitelist_enabled + ) + + +@dataclass(slots = True) +class ConfigItem(ApiObject): + key: str + value: Any + type: str + + +@dataclass(slots = True) +class DomainBan(ApiObject): + domain: str + reason: str | None + note: str | None + created: Date + + +@dataclass(slots = True) +class Instance(ApiObject): + domain: str + actor: str + inbox: str + followid: str + software: str + accepted: Date + created: Date + + + def __post_init__(self) -> None: + self.domain = utf_to_idna(self.domain) + self.actor = utf_to_idna(self.actor) + self.inbox = utf_to_idna(self.inbox) + self.followid = utf_to_idna(self.followid) + + +@dataclass(slots = True) +class Relay(ApiObject): + domain: str + name: str + description: str + version: str + whitelist_enabled: bool + email: str | None + admin: str | None + icon: str | None + instances: list[str] + + +@dataclass(slots = True) +class SoftwareBan(ApiObject): + name: str + reason: str | None + note: str | None + created: Date + + +@dataclass(slots = True) +class User(ApiObject): + username: str + handle: str | None + created: Date + + +@dataclass(slots = True) +class Whitelist(ApiObject): + domain: str + created: Date diff --git a/relay/application.py b/relay/application.py index 3ba73f9..6641b1d 100644 --- a/relay/application.py +++ b/relay/application.py @@ -29,8 +29,7 @@ from .database.schema import Instance from .http_client import HttpClient from .misc import JSON_PATHS, TOKEN_PATHS, Message, Response from .template import Template -from .views import ROUTES, VIEWS -from .views.api import handle_api_path +from .views import ROUTES from .views.frontend import handle_frontend_path from .workers import PushWorkers @@ -59,8 +58,7 @@ class Application(web.Application): web.Application.__init__(self, middlewares = [ handle_response_headers, # type: ignore[list-item] - handle_frontend_path, # type: ignore[list-item] - handle_api_path # type: ignore[list-item] + handle_frontend_path # type: ignore[list-item] ] ) @@ -84,9 +82,6 @@ class Application(web.Application): self.cache.setup() self.on_cleanup.append(handle_cleanup) # type: ignore - for path, view in VIEWS: - self.router.add_view(path, view) - for method, path, handler in ROUTES: self.router.add_route(method, path, handler) diff --git a/relay/database/config.py b/relay/database/config.py index 206e757..acb6f1e 100644 --- a/relay/database/config.py +++ b/relay/database/config.py @@ -116,8 +116,10 @@ class ConfigData: @classmethod def FIELD(cls: type[Self], key: str) -> Field[str | int | bool]: + parsed_key = key.replace('-', '_') + for field in fields(cls): - if field.name == key.replace('-', '_'): + if field.name == parsed_key: return field raise KeyError(key) diff --git a/relay/database/connection.py b/relay/database/connection.py index 2d5de61..8e94627 100644 --- a/relay/database/connection.py +++ b/relay/database/connection.py @@ -111,22 +111,25 @@ class Connection(SqlConnection): def put_config(self, key: str, value: Any) -> Any: field = ConfigData.FIELD(key) - key = field.name.replace('_', '-') - if key == 'private-key': - self.app.signer = value + match field.name: + case "private_key": + self.app.signer = value - elif key == 'log-level': - value = logging.LogLevel.parse(value) - logging.set_level(value) - self.app['workers'].set_log_level(value) + case "log_level": + value = logging.LogLevel.parse(value) + logging.set_level(value) + self.app['workers'].set_log_level(value) - elif key in {'approval-required', 'whitelist-enabled'}: - value = convert_to_boolean(value) + case "approval_required": + value = convert_to_boolean(value) - elif key == 'theme': - if value not in THEMES: - raise ValueError(f'"{value}" is not a valid theme') + case "whitelist_enabled": + value = convert_to_boolean(value) + + case "theme": + if value not in THEMES: + raise ValueError(f'"{value}" is not a valid theme') data = ConfigData() data.set(key, value) diff --git a/relay/frontend/static/functions.js b/relay/frontend/static/functions.js index 3063223..3083f45 100644 --- a/relay/frontend/static/functions.js +++ b/relay/frontend/static/functions.js @@ -1,3 +1,7 @@ +let a = `; ${document.cookie}`.match(";\\s*user-token=([^;]+)"); +const token = a ? a[1] : null; + + // toast notifications const notifications = document.querySelector("#notifications") @@ -60,6 +64,7 @@ for (const elem of document.querySelectorAll("#menu-open div")) { // misc + function get_date_string(date) { var year = date.getUTCFullYear().toString(); var month = (date.getUTCMonth() + 1).toString().padStart(2, "0"); @@ -94,9 +99,13 @@ async function request(method, path, body = null) { "Accept": "application/json" } + if (token !== null) { + headers["Authorization"] = `Bearer ${token}`; + } + if (body !== null) { - headers["Content-Type"] = "application/json" - body = JSON.stringify(body) + headers["Content-Type"] = "application/json"; + body = JSON.stringify(body); } const response = await fetch("/api/" + path, { diff --git a/relay/misc.py b/relay/misc.py index f5899da..4d250e5 100644 --- a/relay/misc.py +++ b/relay/misc.py @@ -8,7 +8,7 @@ import platform from aiohttp.web import Request, Response as AiohttpResponse from collections.abc import Sequence from datetime import datetime -from typing import TYPE_CHECKING, Any, TypedDict, TypeVar +from typing import TYPE_CHECKING, Any, TypedDict, TypeVar, overload from uuid import uuid4 if TYPE_CHECKING: @@ -85,6 +85,40 @@ def get_app() -> Application: return Application.DEFAULT +@overload +def idna_to_utf(string: str) -> str: + ... + + +@overload +def idna_to_utf(string: None) -> None: + ... + + +def idna_to_utf(string: str | None) -> str | None: + if string is None: + return None + + return string.encode("idna").decode("utf-8") + + +@overload +def utf_to_idna(string: str) -> str: + ... + + +@overload +def utf_to_idna(string: None) -> None: + ... + + +def utf_to_idna(string: str | None) -> str | None: + if string is None: + return None + + return string.encode("utf-8").decode("idna") + + class JsonEncoder(json.JSONEncoder): def default(self, o: Any) -> str: if isinstance(o, datetime): diff --git a/relay/views/__init__.py b/relay/views/__init__.py index 265c7ad..671a69d 100644 --- a/relay/views/__init__.py +++ b/relay/views/__init__.py @@ -1,4 +1,4 @@ from __future__ import annotations from . import activitypub, api, frontend, misc -from .base import ROUTES, VIEWS, View +from .base import ROUTES diff --git a/relay/views/api.py b/relay/views/api.py index 8b05fc6..fe78f4f 100644 --- a/relay/views/api.py +++ b/relay/views/api.py @@ -1,624 +1,596 @@ +from __future__ import annotations + import traceback -from aiohttp.web import Request, middleware +from aiohttp.web import Request from argon2.exceptions import VerifyMismatchError -from blib import HttpError, convert_to_boolean -from collections.abc import Awaitable, Callable, Sequence +from blib import HttpError, HttpMethod, convert_to_boolean +from typing import TYPE_CHECKING, Any from urllib.parse import urlparse -from .base import View, register_view +from .base import DEFAULT_REDIRECT, Route -from .. import __version__ +from .. import api_objects as objects, __version__ from ..database import ConfigData, schema -from ..misc import Message, Response +from ..misc import Message, Response, idna_to_utf + +if TYPE_CHECKING: + from ..application import Application -DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob' -ALLOWED_HEADERS: set[str] = { - 'accept', - 'authorization', - 'content-type' -} +@Route(HttpMethod.GET, "/oauth/authorize", "Authorization", False) +async def handle_authorize_get( + app: Application, + request: Request, + response_type: str, + client_id: str, + redirect_uri: str) -> Response: -PUBLIC_API_PATHS: Sequence[tuple[str, str]] = ( - ('GET', '/api/v1/relay'), - ('POST', '/api/v1/app'), - ('POST', '/api/v1/login'), - ('POST', '/api/v1/token') -) + if response_type != "code": + raise HttpError(400, "Response type is not 'code'") + + with app.database.session(True) as s: + with s.select("apps", client_id = client_id) as cur: + if (application := cur.one(schema.App)) is None: + raise HttpError(404, "Could not find app") + + if application.token is not None: + raise HttpError(400, "Application has already been authorized") + + if application.auth_code is not None: + page = "page/authorization_show.haml" + + else: + page = "page/authorize_new.haml" + + if redirect_uri != application.redirect_uri: + raise HttpError(400, "redirect_uri does not match application") + + context = {"application": application} + return Response.new_template(200, page, request, context) -def check_api_path(method: str, path: str) -> bool: - if path.startswith('/api/doc') or (method, path) in PUBLIC_API_PATHS: - return False +@Route(HttpMethod.POST, "/oauth/authorize", "Authorization", False) +async def handle_authorize_post( + app: Application, + request: Request, + client_id: str, + client_secret: str, + redirect_uri: str, + response: str) -> Response: - return path.startswith('/api') + with app.database.session(True) as s: + if (application := s.get_app(client_id, client_secret)) is None: + raise HttpError(404, "Could not find app") + + if convert_to_boolean(response): + if application.token is not None: + raise HttpError(400, "Application has already been authorized") + + if application.auth_code is None: + application = s.update_app(application, request["user"], True) + + if application.redirect_uri == DEFAULT_REDIRECT: + context = {"application": application} + return Response.new_template(200, "page/authorize_show.haml", request, context) + + return Response.new_redir(f"{application.redirect_uri}?code={application.auth_code}") + + if not s.del_app(application.client_id, application.client_secret): + raise HttpError(404, "App not found") + + return Response.new_redir("/") -@middleware -async def handle_api_path( +@Route(HttpMethod.POST, "/oauth/token", "Auth", False) +async def handle_new_token( + app: Application, request: Request, - handler: Callable[[Request], Awaitable[Response]]) -> Response: + grant_type: str, + code: str, + client_id: str, + client_secret: str, + redirect_uri: str) -> objects.Application: - if not request.path.startswith('/api') or request.path == '/api/doc': - return await handler(request) + if grant_type != "authorization_code": + raise HttpError(400, "Invalid grant type") - if request.method != "OPTIONS" and check_api_path(request.method, request.path): - if request['token'] is None: - raise HttpError(401, 'Missing token') + with app.database.session(True) as s: + if (application := s.get_app(client_id, client_secret)) is None: + raise HttpError(404, "Application not found") - if request['user'] is None: - raise HttpError(401, 'Invalid token') + if application.auth_code != code: + raise HttpError(400, "Invalid authentication code") - response = await handler(request) - response.headers['Access-Control-Allow-Origin'] = '*' - response.headers['Access-Control-Allow-Headers'] = ', '.join(ALLOWED_HEADERS) + if application.redirect_uri != redirect_uri: + raise HttpError(400, "Invalid redirect uri") - return response + application = s.update_app(application, request["user"], False) + return objects.Application.from_row(application) -@register_view('/oauth/authorize') -@register_view('/api/oauth/authorize') -class OauthAuthorize(View): - async def get(self, request: Request) -> Response: - data = await self.get_api_data(['response_type', 'client_id', 'redirect_uri'], []) - if data['response_type'] != 'code': - raise HttpError(400, 'Response type is not "code"') +@Route(HttpMethod.POST, "/api/oauth/revoke", "Auth", True) +async def handle_token_revoke( + app: Application, + request: Request, + client_id: str, + client_secret: str, + token: str) -> objects.Message: - with self.database.session(True) as conn: - with conn.select('apps', client_id = data['client_id']) as cur: - if (app := cur.one(schema.App)) is None: - raise HttpError(404, 'Could not find app') + with app.database.session(True) as conn: + if (application := conn.get_app(client_id, client_secret, token)) is None: + raise HttpError(404, "Could not find token") - if app.token is not None: - raise HttpError(400, 'Application has already been authorized') + if application.user != request["application"].username: + raise HttpError(403, "Invalid token") - if app.auth_code is not None: - page = "page/authorization_show.haml" + if not conn.del_app(client_id, client_secret, token): + raise HttpError(400, "Token not removed") - else: - page = "page/authorize_new.haml" + return objects.Message("Token deleted") - if data['redirect_uri'] != app.redirect_uri: - raise HttpError(400, 'redirect_uri does not match application') - context = {'application': app} - return Response.new_template(200, page, request, context) +@Route(HttpMethod.POST, "/api/v1/login", "Auth", False) +async def handle_login( + app: Application, + request: Request, + username: str, + password: str) -> objects.Application: - - async def post(self, request: Request) -> Response: - data = await self.get_api_data( - ['client_id', 'client_secret', 'redirect_uri', 'response'], [] - ) - - with self.database.session(True) as conn: - if (app := conn.get_app(data['client_id'], data['client_secret'])) is None: - raise HttpError(404, 'Could not find app') - - if convert_to_boolean(data['response']): - if app.token is not None: - raise HttpError(400, 'Application has already been authorized') - - if app.auth_code is None: - app = conn.update_app(app, request['user'], True) - - if app.redirect_uri == DEFAULT_REDIRECT: - context = {'application': app} - return Response.new_template(200, "page/authorize_show.haml", request, context) - - return Response.new_redir(f'{app.redirect_uri}?code={app.auth_code}') - - if not conn.del_app(app.client_id, app.client_secret): - raise HttpError(404, 'App not found') - - return Response.new_redir('/') - - -@register_view('/oauth/token') -@register_view('/api/oauth/token') -class OauthToken(View): - async def post(self, request: Request) -> Response: - data = await self.get_api_data( - ['grant_type', 'code', 'client_id', 'client_secret', 'redirect_uri'], [] - ) - - if data['grant_type'] != 'authorization_code': - raise HttpError(400, 'Invalid grant type') - - with self.database.session(True) as conn: - if (app := conn.get_app(data['client_id'], data['client_secret'])) is None: - raise HttpError(404, 'Application not found') - - if app.auth_code != data['code']: - raise HttpError(400, 'Invalid authentication code') - - if app.redirect_uri != data['redirect_uri']: - raise HttpError(400, 'Invalid redirect uri') - - app = conn.update_app(app, request['user'], False) - - return Response.new(app.get_api_data(True), ctype = 'json') - - -@register_view('/oauth/revoke') -@register_view('/api/oauth/revoke') -class OauthRevoke(View): - async def post(self, request: Request) -> Response: - data = await self.get_api_data(['client_id', 'client_secret', 'token'], []) - - with self.database.session(True) as conn: - if (app := conn.get_app(**data)) is None: - raise HttpError(404, 'Could not find token') - - if app.user != request['token'].username: - raise HttpError(403, 'Invalid token') - - if not conn.del_app(**data): - raise HttpError(400, 'Token not removed') - - return Response.new({'msg': 'Token deleted'}, ctype = 'json') - - -@register_view('/api/v1/app') -class App(View): - async def get(self, request: Request) -> Response: - return Response.new(request['token'].get_api_data(), ctype = 'json') - - - async def post(self, request: Request) -> Response: - data = await self.get_api_data(['name', 'redirect_uri'], ['website']) - - with self.database.session(True) as conn: - app = conn.put_app( - name = data['name'], - redirect_uri = data['redirect_uri'], - website = data.get('website') - ) - - return Response.new(app.get_api_data(), ctype = 'json') - - - async def delete(self, request: Request) -> Response: - data = await self.get_api_data(['client_id', 'client_secret'], []) - - with self.database.session(True) as conn: - if not conn.del_app(data['client_id'], data['client_secret'], request['token'].code): - raise HttpError(400, 'Token not removed') - - return Response.new({'msg': 'Token deleted'}, ctype = 'json') - - -@register_view('/api/v1/login') -class Login(View): - async def post(self, request: Request) -> Response: - data = await self.get_api_data(['username', 'password'], []) - - with self.database.session(True) as conn: - if not (user := conn.get_user(data['username'])): - raise HttpError(401, 'User not found') - - try: - conn.hasher.verify(user['hash'], data['password']) - - except VerifyMismatchError: - raise HttpError(401, 'Invalid password') - - app = conn.put_app_login(user) - - resp = Response.new(app.get_api_data(True), ctype = 'json') - resp.set_cookie( - 'user-token', - app.token, # type: ignore[arg-type] - max_age = 60 * 60 * 24 * 365, - domain = self.config.domain, - path = '/', - secure = True, - httponly = False, - samesite = 'lax' - ) - - return resp - - -@register_view('/api/v1/relay') -class RelayInfo(View): - async def get(self, request: Request) -> Response: - with self.database.session() as conn: - config = conn.get_config_all() - inboxes = [row.domain for row in conn.get_inboxes()] - - data = { - 'domain': self.config.domain, - 'name': config.name, - 'description': config.note, - 'version': __version__, - 'whitelist_enabled': config.whitelist_enabled, - 'email': None, - 'admin': None, - 'icon': None, - 'instances': inboxes - } - - return Response.new(data, ctype = 'json') - - -@register_view('/api/v1/config') -class Config(View): - async def get(self, request: Request) -> Response: - data = {} - - with self.database.session() as conn: - for key, value in conn.get_config_all().to_dict().items(): - if key in ConfigData.SYSTEM_KEYS(): - continue - - if key == 'log-level': - value = value.name - - data[key] = value - - return Response.new(data, ctype = 'json') - - - async def post(self, request: Request) -> Response: - data = await self.get_api_data(['key', 'value'], []) - data['key'] = data['key'].replace('-', '_') - - if data['key'] not in ConfigData.USER_KEYS(): - raise HttpError(400, 'Invalid key') - - with self.database.session() as conn: - value = conn.put_config(data['key'], data['value']) - - if data['key'] == 'log-level': - self.app.workers.set_log_level(value) - - return Response.new({'message': 'Updated config'}, ctype = 'json') - - - async def delete(self, request: Request) -> Response: - data = await self.get_api_data(['key'], []) - - if data['key'] not in ConfigData.USER_KEYS(): - raise HttpError(400, 'Invalid key') - - with self.database.session() as conn: - value = conn.put_config(data['key'], ConfigData.DEFAULT(data['key'])) - - if data['key'] == 'log-level': - self.app.workers.set_log_level(value) - - return Response.new({'message': 'Updated config'}, ctype = 'json') - - -@register_view('/api/v1/instance') -class Inbox(View): - async def get(self, request: Request) -> Response: - with self.database.session() as conn: - data = tuple(conn.get_inboxes()) - - return Response.new(data, ctype = 'json') - - - async def post(self, request: Request) -> Response: - data = await self.get_api_data(['actor'], ['inbox', 'software', 'followid']) - data['domain'] = urlparse(data["actor"]).netloc - - with self.database.session() as conn: - if conn.get_inbox(data['domain']) is not None: - raise HttpError(404, 'Instance already in database') - - data['domain'] = data['domain'].encode('idna').decode() - - if not data.get('inbox'): - try: - actor_data = await self.client.get(data['actor'], True, Message) - - except Exception: - traceback.print_exc() - raise HttpError(500, 'Failed to fetch actor') from None - - data['inbox'] = actor_data.shared_inbox - - if not data.get('software'): - try: - nodeinfo = await self.client.fetch_nodeinfo(data['domain']) - data['software'] = nodeinfo.sw_name - - except Exception: - pass - - row = conn.put_inbox( - domain = data['domain'], - actor = data['actor'], - inbox = data.get('inbox'), - software = data.get('software'), - followid = data.get('followid') - ) - - return Response.new(row, ctype = 'json') - - - async def patch(self, request: Request) -> Response: - with self.database.session() as conn: - data = await self.get_api_data(['domain'], ['actor', 'software', 'followid']) - data['domain'] = data['domain'].encode('idna').decode() - - if (instance := conn.get_inbox(data['domain'])) is None: - raise HttpError(404, 'Instance with domain not found') - - instance = conn.put_inbox( - instance.domain, - actor = data.get('actor'), - software = data.get('software'), - followid = data.get('followid') - ) - - return Response.new(instance, ctype = 'json') - - - async def delete(self, request: Request) -> Response: - with self.database.session() as conn: - data = await self.get_api_data(['domain'], []) - data['domain'] = data['domain'].encode('idna').decode() - - if not conn.get_inbox(data['domain']): - raise HttpError(404, 'Instance with domain not found') - - conn.del_inbox(data['domain']) - - return Response.new({'message': 'Deleted instance'}, ctype = 'json') - - -@register_view('/api/v1/request') -class RequestView(View): - async def get(self, request: Request) -> Response: - with self.database.session() as conn: - instances = tuple(conn.get_requests()) - - return Response.new(instances, ctype = 'json') - - - async def post(self, request: Request) -> Response: - data = await self.get_api_data(['domain', 'accept'], []) - data['domain'] = data['domain'].encode('idna').decode() + with app.database.session(True) as s: + if not (user := s.get_user(username)): + raise HttpError(401, "User not found") try: - with self.database.session(True) as conn: - instance = conn.put_request_response( - data['domain'], - convert_to_boolean(data['accept']) - ) + s.hasher.verify(user.hash, password) - except KeyError: - raise HttpError(404, 'Request not found') from None + except VerifyMismatchError: + raise HttpError(401, "Invalid password") - message = Message.new_response( - host = self.config.domain, - actor = instance.actor, - followid = instance.followid, - accept = convert_to_boolean(data['accept']) + application = s.put_app_login(user) + + return objects.Application.from_row(application) + + +@Route(HttpMethod.GET, "/api/v1/app", "Application", True) +async def handle_get_app(app: Application, request: Request) -> objects.Application: + return objects.Application.from_row(request["application"]) + + +@Route(HttpMethod.POST, "/api/v1/app", "Application", True) +async def handle_create_app( + app: Application, + request: Request, + name: str, + redirect_uri: str, + website: str | None = None) -> objects.Application: + + with app.database.session(True) as conn: + application = conn.put_app( + name = name, + redirect_uri = redirect_uri, + website = website ) - self.app.push_message(instance.inbox, message, instance) + return objects.Application.from_row(application) - if data['accept'] and instance.software != 'mastodon': - message = Message.new_follow( - host = self.config.domain, - actor = instance.actor - ) - self.app.push_message(instance.inbox, message, instance) +@Route(HttpMethod.GET, "/api/v1/config", "Config", True) +async def handle_config_get(app: Application, request: Request) -> objects.Config: + with app.database.session(False) as conn: + return objects.Config.from_config(conn.get_config_all()) - resp_message = {'message': 'Request accepted' if data['accept'] else 'Request denied'} - return Response.new(resp_message, ctype = 'json') +@Route(HttpMethod.GET, "/api/v2/config", "Config", True) +async def handle_config_get_v2(app: Application, request: Request) -> list[objects.ConfigItem]: + data: list[objects.ConfigItem] = [] + cfg = ConfigData() + user_keys = ConfigData.USER_KEYS() -@register_view('/api/v1/domain_ban') -class DomainBan(View): - async def get(self, request: Request) -> Response: - with self.database.session() as conn: - bans = tuple(conn.get_domain_bans()) + with app.database.session(False) as s: + for row in s.execute("SELECT * FROM \"config\"").all(schema.Config): + if row.key.replace("-", "_") not in user_keys: + continue - return Response.new(bans, ctype = 'json') + cfg.set(row.key, row.value) + data.append(objects.ConfigItem(row.key, cfg.get(row.key), row.type)) + return data - async def post(self, request: Request) -> Response: - data = await self.get_api_data(['domain'], ['note', 'reason']) - data['domain'] = data['domain'].encode('idna').decode() - with self.database.session() as conn: - if conn.get_domain_ban(data['domain']) is not None: - raise HttpError(400, 'Domain already banned') +@Route(HttpMethod.POST, "/api/v1/config", "Config", True) +async def handle_config_update( + app: Application, + request: Request, + key: str, value: Any) -> objects.Message: - ban = conn.put_domain_ban( - domain = data['domain'], - reason = data.get('reason'), - note = data.get('note') - ) + if (field := ConfigData.FIELD(key)).name not in ConfigData.USER_KEYS(): + raise HttpError(400, "Invalid key") - return Response.new(ban, ctype = 'json') + with app.database.session() as conn: + value = conn.put_config(key, value) + if field.name == "log_level": + app.workers.set_log_level(value) - async def patch(self, request: Request) -> Response: - with self.database.session() as conn: - data = await self.get_api_data(['domain'], ['note', 'reason']) + return objects.Message("Updated config") - if not any([data.get('note'), data.get('reason')]): - raise HttpError(400, 'Must include note and/or reason parameters') - data['domain'] = data['domain'].encode('idna').decode() +@Route(HttpMethod.DELETE, "/api/v1/config", "Config", True) +async def handle_config_reset(app: Application, request: Request, key: str) -> objects.Message: + if (field := ConfigData.FIELD(key)).name not in ConfigData.USER_KEYS(): + raise HttpError(400, "Invalid key") - if conn.get_domain_ban(data['domain']) is None: - raise HttpError(404, 'Domain not banned') + with app.database.session() as conn: + value = conn.put_config(field.name, field.default) - ban = conn.update_domain_ban( - domain = data['domain'], - reason = data.get('reason'), - note = data.get('note') - ) + if field.name == "log_level": + app.workers.set_log_level(value) - return Response.new(ban, ctype = 'json') + return objects.Message("Updated config") - async def delete(self, request: Request) -> Response: - with self.database.session() as conn: - data = await self.get_api_data(['domain'], []) - data['domain'] = data['domain'].encode('idna').decode() +@Route(HttpMethod.GET, "/api/v1/relay", "Misc", False) +async def get(app: Application, request: Request) -> objects.Relay: + with app.database.session() as s: + config = s.get_config_all() + inboxes = [row.domain for row in s.get_inboxes()] - if conn.get_domain_ban(data['domain']) is None: - raise HttpError(404, 'Domain not banned') + return objects.Relay( + app.config.domain, + config.name, + config.note, + __version__, + config.whitelist_enabled, + None, + None, + None, + inboxes + ) - conn.del_domain_ban(data['domain']) - return Response.new({'message': 'Unbanned domain'}, ctype = 'json') +@Route(HttpMethod.GET, "/api/v1/instance", "Instance", True) +async def handle_instances_get(app: Application, request: Request) -> list[objects.Instance]: + data: list[objects.Instance] = [] + with app.database.session(False) as s: + for row in s.get_inboxes(): + data.append(objects.Instance.from_row(row)) -@register_view('/api/v1/software_ban') -class SoftwareBan(View): - async def get(self, request: Request) -> Response: - with self.database.session() as conn: - bans = tuple(conn.get_software_bans()) + return data - return Response.new(bans, ctype = 'json') +@Route(HttpMethod.POST, "/api/v1/instance", "Instance", True) +async def handle_instance_add( + app: Application, + request: Request, + actor: str, + inbox: str | None = None, + software: str | None = None, + followid: str | None = None) -> objects.Instance: - async def post(self, request: Request) -> Response: - data = await self.get_api_data(['name'], ['note', 'reason']) + domain = idna_to_utf(urlparse(actor).netloc) - with self.database.session() as conn: - if conn.get_software_ban(data['name']) is not None: - raise HttpError(400, 'Domain already banned') + with app.database.session(False) as s: + if s.get_inbox(domain) is not None: + raise HttpError(404, 'Instance already in database') - ban = conn.put_software_ban( - name = data['name'], - reason = data.get('reason'), - note = data.get('note') - ) + if inbox is None: + try: + actor_data = await app.client.get(actor, True, Message) - return Response.new(ban, ctype = 'json') + except Exception: + traceback.print_exc() + raise HttpError(500, "Failed to fetch actor") from None + inbox = actor_data.shared_inbox - async def patch(self, request: Request) -> Response: - data = await self.get_api_data(['name'], ['note', 'reason']) + if software is None: + try: + software = (await app.client.fetch_nodeinfo(domain)).sw_name - if not any([data.get('note'), data.get('reason')]): - raise HttpError(400, 'Must include note and/or reason parameters') + except Exception: + traceback.print_exc() - with self.database.session() as conn: - if conn.get_software_ban(data['name']) is None: - raise HttpError(404, 'Software not banned') + row = s.put_inbox( + domain = domain, + actor = idna_to_utf(actor), + inbox = idna_to_utf(inbox), + software = idna_to_utf(software), + followid = idna_to_utf(followid) + ) - ban = conn.update_software_ban( - name = data['name'], - reason = data.get('reason'), - note = data.get('note') - ) + return objects.Instance.from_row(row) - return Response.new(ban, ctype = 'json') +@Route(HttpMethod.PATCH, "/api/v1/instance", "Instance", True) +async def handle_instance_update( + app: Application, + request: Request, + domain: str, + actor: str | None = None, + inbox: str | None = None, + software: str | None = None, + followid: str | None = None) -> objects.Instance: - async def delete(self, request: Request) -> Response: - data = await self.get_api_data(['name'], []) + domain = idna_to_utf(domain) - with self.database.session() as conn: - if conn.get_software_ban(data['name']) is None: - raise HttpError(404, 'Software not banned') + with app.database.session(False) as s: + if (instance := s.get_inbox(domain)) is None: + raise HttpError(404, 'Instance with domain not found') - conn.del_software_ban(data['name']) + row = s.put_inbox( + instance.domain, + actor = idna_to_utf(actor) or instance.actor, + inbox = idna_to_utf(inbox) or instance.inbox, + software = idna_to_utf(software) or instance.software, + followid = idna_to_utf(followid) or instance.followid + ) - return Response.new({'message': 'Unbanned software'}, ctype = 'json') + return objects.Instance.from_row(row) -@register_view('/api/v1/user') -class User(View): - async def get(self, request: Request) -> Response: - with self.database.session() as conn: - items = [] +@Route(HttpMethod.DELETE, "/api/v1/instance", "Instance", True) +async def handle_instance_del(app: Application, request: Request, domain: str) -> objects.Message: + domain = idna_to_utf(domain) - for row in conn.get_users(): - del row['hash'] - items.append(row) + with app.database.session(False) as s: + if not s.get_inbox(domain): + raise HttpError(404, "Instance with domain not found") - return Response.new(items, ctype = 'json') + s.del_inbox(domain) + return objects.Message("Removed instance") - async def post(self, request: Request) -> Response: - data = await self.get_api_data(['username', 'password'], ['handle']) - with self.database.session() as conn: - if conn.get_user(data['username']) is not None: - raise HttpError(404, 'User already exists') +@Route(HttpMethod.GET, "/api/v1/request", "Request", True) +async def handle_requests_get(app: Application, request: Request) -> list[objects.Instance]: + data: list[objects.Instance] = [] - user = conn.put_user( - username = data['username'], - password = data['password'], - handle = data.get('handle') - ) + with app.database.session(False) as s: + for row in s.get_requests(): + data.append(objects.Instance.from_row(row)) - del user['hash'] - return Response.new(user, ctype = 'json') + return data - async def patch(self, request: Request) -> Response: - data = await self.get_api_data(['username'], ['password', 'handle']) +@Route(HttpMethod.POST, "/api/v1/request", "Request", True) +async def handle_request_response( + app: Application, + request: Request, + domain: str, + accept: bool) -> objects.Message: - with self.database.session(True) as conn: - user = conn.put_user( - username = data['username'], - password = data['password'], - handle = data.get('handle') - ) + try: + with app.database.session(True) as conn: + row = conn.put_request_response(domain, accept) - del user['hash'] - return Response.new(user, ctype = 'json') + except KeyError: + raise HttpError(404, "Request not found") from None + message = Message.new_response( + host = app.config.domain, + actor = row.actor, + followid = row.followid, + accept = accept + ) - async def delete(self, request: Request) -> Response: - data = await self.get_api_data(['username'], []) + app.push_message(row.inbox, message, row) - with self.database.session(True) as conn: - if conn.get_user(data['username']) is None: - raise HttpError(404, 'User does not exist') + if accept and row.software != "mastodon": + message = Message.new_follow( + host = app.config.domain, + actor = row.actor + ) - conn.del_user(data['username']) + app.push_message(row.inbox, message, row) - return Response.new({'message': 'Deleted user'}, ctype = 'json') + if accept: + return objects.Message("Request accepted") + return objects.Message("Request denied") -@register_view('/api/v1/whitelist') -class Whitelist(View): - async def get(self, request: Request) -> Response: - with self.database.session() as conn: - items = tuple(conn.get_domains_whitelist()) - return Response.new(items, ctype = 'json') +@Route(HttpMethod.GET, "/api/v1/domain_ban", "Domain Ban", True) +async def handle_domain_bans_get(app: Application, request: Request) -> list[objects.DomainBan]: + data: list[objects.DomainBan] = [] + with app.database.session(False) as s: + for row in s.get_domain_bans(): + data.append(objects.DomainBan.from_row(row)) - async def post(self, request: Request) -> Response: - data = await self.get_api_data(['domain'], []) + return data - domain = data['domain'].encode('idna').decode() - with self.database.session() as conn: - if conn.get_domain_whitelist(domain) is not None: - raise HttpError(400, 'Domain already added to whitelist') +@Route(HttpMethod.POST, "/api/v1/domain_ban", "Domain Ban", True) +async def handle_domain_ban_add( + app: Application, + request: Request, + domain: str, + note: str | None = None, + reason: str | None = None) -> objects.DomainBan: - item = conn.put_domain_whitelist(domain) + with app.database.session(False) as s: + if s.get_domain_ban(domain) is not None: + raise HttpError(400, "Domain already banned") - return Response.new(item, ctype = 'json') + row = s.put_domain_ban(domain, reason, note) + return objects.DomainBan.from_row(row) - async def delete(self, request: Request) -> Response: - data = await self.get_api_data(['domain'], []) +@Route(HttpMethod.PATCH, "/api/v1/domain_ban", "Domain Ban", True) +async def handle_domain_ban( + app: Application, + request: Request, + domain: str, + note: str | None = None, + reason: str | None = None) -> objects.DomainBan: - domain = data['domain'].encode('idna').decode() + with app.database.session(True) as s: + if not any([note, reason]): + raise HttpError(400, "Must include note and/or reason parameters") - with self.database.session() as conn: - if conn.get_domain_whitelist(domain) is None: - raise HttpError(404, 'Domain not in whitelist') + if s.get_domain_ban(domain) is None: + raise HttpError(404, "Domain not banned") - conn.del_domain_whitelist(domain) + row = s.update_domain_ban(domain, reason, note) + return objects.DomainBan.from_row(row) - return Response.new({'message': 'Removed domain from whitelist'}, ctype = 'json') + +@Route(HttpMethod.PATCH, "/api/v1/domain_ban", "Domain Ban", True) +async def handle_domain_unban(app: Application, request: Request, domain: str) -> objects.Message: + with app.database.session(True) as s: + if s.get_domain_ban(domain) is None: + raise HttpError(404, "Domain not banned") + + s.del_domain_ban(domain) + + return objects.Message("Unbanned domain") + + +@Route(HttpMethod.GET, "/api/v1/software_ban", "Software Ban", True) +async def handle_software_bans_get(app: Application, request: Request) -> list[objects.SoftwareBan]: + data: list[objects.SoftwareBan] = [] + + with app.database.session(False) as s: + for row in s.get_software_bans(): + data.append(objects.SoftwareBan.from_row(row)) + + return data + + +@Route(HttpMethod.POST, "/api/v1/software_ban", "Software Ban", True) +async def handle_software_ban_add( + app: Application, + request: Request, + name: str, + note: str | None = None, + reason: str | None = None) -> objects.SoftwareBan: + + with app.database.session(True) as s: + if s.get_software_ban(name) is not None: + raise HttpError(400, "Software already banned") + + row = s.put_software_ban(name, reason, note) + return objects.SoftwareBan.from_row(row) + + +@Route(HttpMethod.PATCH, "/api/v1/software_ban", "Software Ban", True) +async def handle_software_ban( + app: Application, + request: Request, + name: str, + note: str | None = None, + reason: str | None = None) -> objects.SoftwareBan: + + with app.database.session(True) as s: + if not any([note, reason]): + raise HttpError(400, "Must include note and/or reason parameters") + + if s.get_software_ban(name) is None: + raise HttpError(404, "Software not banned") + + row = s.update_software_ban(name, reason, note) + return objects.SoftwareBan.from_row(row) + + +@Route(HttpMethod.PATCH, "/api/v1/software_ban", "Software Ban", True) +async def handle_software_unban(app: Application, request: Request, name: str) -> objects.Message: + with app.database.session(True) as s: + if s.get_software_ban(name) is None: + raise HttpError(404, "Software not banned") + + s.del_software_ban(name) + + return objects.Message("Unbanned software") + + +@Route(HttpMethod.GET, "/api/v1/user", "User", True) +async def handle_users_get(app: Application, request: Request) -> list[objects.User]: + with app.database.session(False) as s: + items = [] + + for row in s.get_users(): + items.append(objects.User.from_row(row, "hash")) + + return items + + +@Route(HttpMethod.POST, "/api/v1/user", "User", True) +async def post( + app: Application, + request: Request, + username: str, + password: str, + handle: str | None = None) -> objects.User: + + with app.database.session() as s: + if s.get_user(username) is not None: + raise HttpError(404, "User already exists") + + row = s.put_user(username, password, handle) + return objects.User.from_row(row, "hash") + + +@Route(HttpMethod.PATCH, "/api/v1/user", "User", True) +async def patch( + app: Application, + request: Request, + username: str, + password: str | None = None, + handle: str | None = None) -> objects.User: + + with app.database.session(True) as s: + if s.get_user(username) is None: + raise HttpError(404, "User does not exist") + + row = s.put_user(username, password, handle) + return objects.User.from_row(row, "hash") + + +@Route(HttpMethod.DELETE, "/api/v1/user", "User", True) +async def delete(app: Application, request: Request, username: str) -> objects.Message: + with app.database.session(True) as s: + if s.get_user(username) is None: + raise HttpError(404, "User does not exist") + + s.del_user(username) + + return objects.Message("Deleted user") + + +@Route(HttpMethod.GET, "/api/v1/whitelist", "Whitelist", True) +async def handle_whitelist_get(app: Application, request: Request) -> list[objects.Whitelist]: + data: list[objects.Whitelist] = [] + + with app.database.session(False) as s: + for row in s.get_domains_whitelist(): + data.append(objects.Whitelist.from_row(row)) + + return data + + +@Route(HttpMethod.POST, "/api/v1/whitelist", "Whitelist", True) +async def handle_whitelist_add( + app: Application, + request: Request, + domain: str) -> objects.Whitelist: + + with app.database.session(True) as s: + if s.get_domain_whitelist(domain) is not None: + raise HttpError(400, "Domain already added to whitelist") + + row = s.put_domain_whitelist(domain) + return objects.Whitelist.from_row(row) + + +@Route(HttpMethod.DELETE, "/api/v1/whitelist", "Whitelist", True) +async def handle_whitelist_del(app: Application, request: Request, domain: str) -> objects.Message: + with app.database.session(True) as s: + if s.get_domain_whitelist(domain) is None: + raise HttpError(404, "Domain not in whitelist") + + s.del_domain_whitelist(domain) + + return objects.Message("Removed domain from whitelist") diff --git a/relay/views/base.py b/relay/views/base.py index 49e0bfc..b4007d0 100644 --- a/relay/views/base.py +++ b/relay/views/base.py @@ -1,47 +1,42 @@ from __future__ import annotations -from aiohttp.abc import AbstractView -from aiohttp.hdrs import METH_ALL as METHODS -from aiohttp.web import Request +from aiohttp.web import Request, StreamResponse from blib import HttpError, HttpMethod -from bsql import Database -from collections.abc import Awaitable, Callable, Generator, Sequence, Mapping -from functools import cached_property +from collections.abc import Awaitable, Callable, Mapping from json.decoder import JSONDecodeError -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, overload -from ..cache import Cache -from ..config import Config -from ..database import Connection -from ..http_client import HttpClient +from ..api_objects import ApiObject from ..misc import Response, get_app if TYPE_CHECKING: - from typing import Self from ..application import Application - from ..template import Template + try: + from typing import Self + + except ImportError: + from typing_extensions import Self + + ApiRouteHandler = Callable[..., Awaitable[ApiObject | list[Any] | StreamResponse]] RouteHandler = Callable[[Application, Request], Awaitable[Response]] HandlerCallback = Callable[[Request], Awaitable[Response]] -VIEWS: list[tuple[str, type[View]]] = [] ROUTES: list[tuple[str, str, HandlerCallback]] = [] +DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob' +ALLOWED_HEADERS: set[str] = { + 'accept', + 'authorization', + 'content-type' +} + def convert_data(data: Mapping[str, Any]) -> dict[str, str]: return {key: str(value) for key, value in data.items()} -def register_view(*paths: str) -> Callable[[type[View]], type[View]]: - def wrapper(view: type[View]) -> type[View]: - for path in paths: - VIEWS.append((path, view)) - - return view - return wrapper - - def register_route( method: HttpMethod | str, *paths: str) -> Callable[[RouteHandler], HandlerCallback]: @@ -56,108 +51,107 @@ def register_route( return wrapper -class View(AbstractView): - def __await__(self) -> Generator[Any, None, Response]: - if self.request.method not in METHODS: - raise HttpError(405, f'"{self.request.method}" method not allowed') +class Route: + handler: ApiRouteHandler - if not (handler := self.handlers.get(self.request.method)): - raise HttpError(405, f'"{self.request.method}" method not allowed') + def __init__(self, + method: HttpMethod, + path: str, + category: str, + require_token: bool) -> None: - return self._run_handler(handler).__await__() + self.method: HttpMethod = HttpMethod.parse(method) + self.path: str = path + self.category: str = category + self.require_token: bool = require_token + + ROUTES.append((self.method, self.path, self)) # type: ignore[arg-type] - @classmethod - async def run(cls: type[Self], method: str, request: Request, **kwargs: Any) -> Response: - view = cls(request) - return await view.handlers[method](request, **kwargs) + @overload + def __call__(self, obj: Request) -> Awaitable[StreamResponse]: + ... - async def _run_handler(self, handler: HandlerCallback, **kwargs: Any) -> Response: - return await handler(self.request, **self.request.match_info, **kwargs) + @overload + def __call__(self, obj: ApiRouteHandler) -> Self: + ... - async def options(self, request: Request) -> Response: - return Response.new() + def __call__(self, obj: Request | ApiRouteHandler) -> Self | Awaitable[StreamResponse]: + if isinstance(obj, Request): + return self.handle_request(obj) + + self.handler = obj + return self - @cached_property - def allowed_methods(self) -> Sequence[str]: - return tuple(self.handlers.keys()) + async def handle_request(self, request: Request) -> StreamResponse: + request["application"] = None + if request.method != "OPTIONS" and self.require_token: + if (auth := request.headers.getone("Authorization", None)) is None: + raise HttpError(401, 'Missing token') - @cached_property - def handlers(self) -> dict[str, HandlerCallback]: - data = {} - - for method in METHODS: try: - data[method] = getattr(self, method.lower()) + authtype, code = auth.split(" ", 1) - except AttributeError: - continue + except IndexError: + raise HttpError(401, "Invalid authorization heder format") - return data + if authtype != "Bearer": + raise HttpError(401, f"Invalid authorization type: {authtype}") + if not code: + raise HttpError(401, "Missing token") - @property - def app(self) -> Application: - return get_app() + with get_app().database.session(False) as s: + if (application := s.get_app_by_token(code)) is None: + raise HttpError(401, "Invalid token") + if application.auth_code is not None: + raise HttpError(401, "Invalid token") - @property - def cache(self) -> Cache: - return self.app.cache + request["application"] = application + if request.content_type in {'application/x-www-form-urlencoded', 'multipart/form-data'}: + post_data = {key: value for key, value in (await request.post()).items()} - @property - def client(self) -> HttpClient: - return self.app.client - - - @property - def config(self) -> Config: - return self.app.config - - - @property - def database(self) -> Database[Connection]: - return self.app.database - - - @property - def template(self) -> Template: - return self.app['template'] # type: ignore[no-any-return] - - - async def get_api_data(self, - required: list[str], - optional: list[str]) -> dict[str, str]: - - if self.request.content_type in {'application/x-www-form-urlencoded', 'multipart/form-data'}: - post_data = convert_data(await self.request.post()) - # post_data = {key: value for key, value in parse_qsl(await self.request.text())} - - elif self.request.content_type == 'application/json': + elif request.content_type == 'application/json': try: - post_data = convert_data(await self.request.json()) + post_data = await request.json() except JSONDecodeError: raise HttpError(400, 'Invalid JSON data') else: - post_data = convert_data(self.request.query) - - data = {} + post_data = {key: str(value) for key, value in request.query.items()} try: - for key in required: - data[key] = post_data[key] + response = await self.handler(get_app(), request, **post_data) - except KeyError as e: - raise HttpError(400, f'Missing {str(e)} pararmeter') from None + except HttpError as error: + return Response.new({'error': error.message}, error.status, ctype = "json") - for key in optional: - data[key] = post_data.get(key) # type: ignore[assignment] + headers = { + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Headers": ", ".join(ALLOWED_HEADERS) + } - return data + if isinstance(response, StreamResponse): + response.headers.update(headers) + return response + + if isinstance(response, ApiObject): + return Response.new(response.to_json(), headers = headers, ctype = "json") + + if isinstance(response, list): + data = [] + + for item in response: + if isinstance(item, ApiObject): + data.append(item.to_dict()) + + response = data + + return Response.new(response, headers = headers, ctype = "json")