From f98ca54ab7c80e0a5ca0a96ae158f4c1258dd402 Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Thu, 4 Jul 2024 20:36:04 -0400 Subject: [PATCH] various changes * Add oauth login support * Add `HttpError` class * Add custom error handling * Use `blib.Date` class for (de)serializing db timestamp values * Add `db-maintenance` command * Rework middleware route checking * Fix fetching post data in api endpoints --- relay/application.py | 69 ++++- relay/data/statements.sql | 32 +++ relay/database/config.py | 12 +- relay/database/connection.py | 133 +++++++++- relay/database/schema.py | 109 +++++++- relay/frontend/base.haml | 18 +- relay/frontend/page/authorize_new.haml | 31 +++ relay/frontend/page/authorize_show.haml | 18 ++ relay/frontend/page/error.haml | 7 + relay/frontend/page/login.haml | 2 + relay/frontend/static/functions.js | 18 +- relay/frontend/static/style.css | 38 +++ relay/manage.py | 25 +- relay/misc.py | 37 ++- relay/template.py | 6 +- relay/views/activitypub.py | 25 +- relay/views/api.py | 318 +++++++++++++++--------- relay/views/base.py | 22 +- relay/views/frontend.py | 59 ++--- 19 files changed, 748 insertions(+), 231 deletions(-) create mode 100644 relay/frontend/page/authorize_new.haml create mode 100644 relay/frontend/page/authorize_show.haml create mode 100644 relay/frontend/page/error.haml diff --git a/relay/application.py b/relay/application.py index d852f29..6ab481b 100644 --- a/relay/application.py +++ b/relay/application.py @@ -4,11 +4,14 @@ import asyncio import multiprocessing import signal import time +import traceback +from Crypto.Random import get_random_bytes from aiohttp import web -from aiohttp.web import StaticResource +from aiohttp.web import HTTPException, StaticResource from aiohttp_swagger import setup_swagger from aputils.signer import Signer +from base64 import b64encode from bsql import Database from collections.abc import Awaitable, Callable from datetime import datetime, timedelta @@ -23,7 +26,8 @@ from .config import Config from .database import Connection, get_database from .database.schema import Instance from .http_client import HttpClient -from .misc import Message, Response, check_open_port, get_resource +from .misc import HttpError, Message, Response, check_open_port, get_resource +from .misc import JSON_PATHS, TOKEN_PATHS from .template import Template from .views import VIEWS from .views.api import handle_api_path @@ -53,9 +57,9 @@ class Application(web.Application): def __init__(self, cfgpath: Path | None, dev: bool = False): web.Application.__init__(self, middlewares = [ - handle_api_path, # type: ignore[list-item] + handle_response_headers, # type: ignore[list-item] handle_frontend_path, # type: ignore[list-item] - handle_response_headers # type: ignore[list-item] + handle_api_path # type: ignore[list-item] ] ) @@ -282,19 +286,70 @@ class CacheCleanupThread(Thread): self.running.clear() +def format_error(request: web.Request, error: HttpError) -> Response: + app: Application = request.app # type: ignore[assignment] + + if request.path.startswith(JSON_PATHS) or 'json' in request.headers.get('accept', ''): + return Response.new({'error': error.body}, error.status, ctype = 'json') + + else: + body = app.template.render('page/error.haml', request, e = error) + return Response.new(body, error.status, ctype = 'html') + + @web.middleware async def handle_response_headers( request: web.Request, handler: Callable[[web.Request], Awaitable[Response]]) -> Response: - resp = await handler(request) + request['hash'] = b64encode(get_random_bytes(16)).decode('ascii') + request['token'] = None + request['user'] = None + + app: Application = request.app # type: ignore[assignment] + + if request.path == "/" or request.path.startswith(TOKEN_PATHS): + with app.database.session() as conn: + if (token := request.headers.get('Authorization')) is not None: + token = token.replace('Bearer', '').strip() + + request['token'] = conn.get_app_by_token(token) + request['user'] = conn.get_user_by_app_token(token) + + elif (token := request.cookies.get('user-token')) is not None: + request['token'] = conn.get_token(token) + request['user'] = conn.get_user_by_token(token) + + try: + resp = await handler(request) + + except HttpError as e: + resp = format_error(request, e) + + except HTTPException as ae: + if ae.status == 404: + try: + text = (ae.text or "").split(":")[1].strip() + + except IndexError: + text = ae.text or "" + + resp = format_error(request, HttpError(ae.status, text)) + + else: + raise + + except Exception as e: + resp = format_error(request, HttpError(500, f'{type(e).__name__}: {str(e)}')) + traceback.print_exc() + resp.headers['Server'] = 'ActivityRelay' # Still have to figure out how csp headers work - if resp.content_type == 'text/html' and not request.path.startswith("/api"): + if resp.content_type == 'text/html': resp.headers['Content-Security-Policy'] = get_csp(request) - if not request.app['dev'] and request.path.endswith(('.css', '.js')): + if not request.app['dev'] and request.path.endswith(('.css', '.js', '.woff2')): # cache for 2 weeks resp.headers['Cache-Control'] = 'public,max-age=1209600,immutable' diff --git a/relay/data/statements.sql b/relay/data/statements.sql index f06d4b5..e8694ae 100644 --- a/relay/data/statements.sql +++ b/relay/data/statements.sql @@ -56,6 +56,14 @@ WHERE username = ( ); +-- name: get-user-by-app-token +SELECT * FROM users +WHERE username = ( + SELECT user FROM app + WHERE code = :code +); + + -- name: put-user INSERT INTO users (username, hash, handle, created) VALUES (:username, :hash, :handle, :created) @@ -67,6 +75,30 @@ DELETE FROM users WHERE username = :value or handle = :value; +-- name: get-app +SELECT * FROM app +WHERE client_id = :id and client_secret = :secret; + + +-- name: get-app-token +SELECT * FROM app +WHERE client_id = :id and client_secret = :secret and token = :token; + + +-- name: get-app-by-token +SELECT * FROM app +WHERE token = :token; + +-- name: del-app +DELETE FROM users +WHERE client_id = :id and client_secret = :secret; + + +-- name: del-app-token +DELETE FROM users +WHERE client_id = :id and client_secret = :secret and token = :token; + + -- name: get-token SELECT * FROM tokens WHERE code = :code; diff --git a/relay/database/config.py b/relay/database/config.py index 2be3ecc..3f3c7e0 100644 --- a/relay/database/config.py +++ b/relay/database/config.py @@ -11,11 +11,7 @@ from .. import logger as logging from ..misc import boolean if TYPE_CHECKING: - try: - from typing import Self - - except ImportError: - from typing_extensions import Self + from typing import Self THEMES = { @@ -77,7 +73,7 @@ CONFIG_CONVERT: dict[str, tuple[Callable[[Any], str], Callable[[str], Any]]] = { @dataclass() class ConfigData: - schema_version: int = 20240310 + schema_version: int = 20240625 private_key: str = '' approval_required: bool = False log_level: logging.LogLevel = logging.LogLevel.INFO @@ -115,11 +111,11 @@ class ConfigData: @classmethod def DEFAULT(cls: type[Self], key: str) -> str | int | bool: - return cls.FIELD(key.replace('-', '_')).default # type: ignore + return cls.FIELD(key.replace('-', '_')).default # type: ignore[return-value] @classmethod - def FIELD(cls: type[Self], key: str) -> Field[Any]: + def FIELD(cls: type[Self], key: str) -> Field[str | int | bool]: for field in fields(cls): if field.name == key.replace('-', '_'): return field diff --git a/relay/database/connection.py b/relay/database/connection.py index 006a907..3c973b8 100644 --- a/relay/database/connection.py +++ b/relay/database/connection.py @@ -1,6 +1,9 @@ from __future__ import annotations +import secrets + from argon2 import PasswordHasher +from blib import Date from bsql import Connection as SqlConnection, Row, Update from collections.abc import Iterator from datetime import datetime, timezone @@ -49,6 +52,40 @@ class Connection(SqlConnection): yield instance + def fix_timestamps(self) -> None: + for app in self.select('apps').all(schema.App): + data = {'created': app.created.timestamp(), 'accessed': app.accessed.timestamp()} + self.update('apps', data, client_id = app.client_id) + + for item in self.select('cache'): + data = {'updated': Date.parse(item['updated']).timestamp()} + self.update('cache', data, id = item['id']) + + for dban in self.select('domain_bans').all(schema.DomainBan): + data = {'created': dban.created.timestamp()} + self.update('domain_bans', data, domain = dban.domain) + + for instance in self.select('inboxes').all(schema.Instance): + data = {'created': instance.created.timestamp()} + self.update('inboxes', data, domain = instance.domain) + + for sban in self.select('software_bans').all(schema.SoftwareBan): + data = {'created': sban.created.timestamp()} + self.update('software_bans', data, name = sban.name) + + for token in self.select('tokens').all(schema.Token): + data = {'created': token.created.timestamp(), 'accessed': token.accessed.timestamp()} + self.update('tokens', data, code = token.code) + + for user in self.select('users').all(schema.User): + data = {'created': user.created.timestamp()} + self.update('users', data, username = user.username) + + for wlist in self.select('whitelist').all(schema.Whitelist): + data = {'created': wlist.created.timestamp()} + self.update('whitelist', data, domain = wlist.domain) + + def get_config(self, key: str) -> Any: key = key.replace('_', '-') @@ -198,6 +235,11 @@ class Connection(SqlConnection): return cur.one(schema.User) + def get_user_by_app_token(self, code: str) -> schema.User | None: + with self.run('get-user-by-app-token', {'code': code}) as cur: + return cur.one(schema.User) + + def get_users(self) -> Iterator[schema.User]: return self.execute("SELECT * FROM users").all(schema.User) @@ -249,13 +291,102 @@ class Connection(SqlConnection): pass + def get_app(self, + client_id: str, + client_secret: str, + token: str | None = None) -> schema.App | None: + + params = { + 'id': client_id, + 'secret': client_secret + } + + if token is not None: + command = 'get-app-with-token' + params['token'] = token + + else: + command = 'get-app' + + with self.run(command, params) as cur: + return cur.one(schema.App) + + + def get_app_by_token(self, token: str) -> schema.App | None: + with self.run('get-app-by-token', {'token': token}) as cur: + return cur.one(schema.App) + + + def put_app(self, name: str, redirect_uri: str, website: str | None = None) -> schema.App: + params = { + 'name': name, + 'redirect_uri': redirect_uri, + 'website': website, + 'client_id': secrets.token_hex(20), + 'client_secret': secrets.token_hex(20), + 'created': Date.new_utc().timestamp(), + 'accessed': Date.new_utc().timestamp() + } + + with self.insert('app', params) as cur: + if (row := cur.one(schema.App)) is None: + raise RuntimeError(f'Failed to insert app: {name}') + + return row + + + def update_app(self, app: schema.App, user: schema.User | None, set_auth: bool) -> schema.App: + data: dict[str, str | None] = {} + + if user is not None: + data['user'] = user.username + + if set_auth: + data['auth_code'] = secrets.token_hex(20) + + else: + data['token'] = secrets.token_hex(20) + data['auth_code'] = None + + params = { + 'client_id': app.client_id, + 'client_secret': app.client_secret + } + + with self.update('app', data, **params) as cur: # type: ignore[arg-type] + if (row := cur.one(schema.App)) is None: + raise RuntimeError('Failed to update row') + + return row + + + def del_app(self, client_id: str, client_secret: str, token: str | None = None) -> bool: + params = { + 'id': client_id, + 'secret': client_secret + } + + if token is not None: + command = 'del-app-token' + params['token'] = token + + else: + command = 'del-app' + + with self.run(command, params) as cur: + if cur.row_count > 1: + raise RuntimeError('More than 1 row was deleted') + + return cur.row_count == 0 + + def get_token(self, code: str) -> schema.Token | None: with self.run('get-token', {'code': code}) as cur: return cur.one(schema.Token) def get_tokens(self, username: str | None = None) -> Iterator[schema.Token]: - if username is not None: + if username is None: return self.select('tokens').all(schema.Token) return self.select('tokens', username = username).all(schema.Token) diff --git a/relay/database/schema.py b/relay/database/schema.py index 1fd7003..660e527 100644 --- a/relay/database/schema.py +++ b/relay/database/schema.py @@ -1,14 +1,14 @@ from __future__ import annotations -import typing - +from blib import Date from bsql import Column, Row, Tables from collections.abc import Callable -from datetime import datetime +from copy import deepcopy +from typing import TYPE_CHECKING, Any from .config import ConfigData -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from .connection import Connection @@ -16,6 +16,16 @@ VERSIONS: dict[int, Callable[[Connection], None]] = {} TABLES = Tables() +def deserialize_timestamp(value: Any) -> Date: + try: + return Date.parse(value) + + except ValueError: + pass + + return Date.fromisoformat(value) + + @TABLES.add_row class Config(Row): key: Column[str] = Column('key', 'text', primary_key = True, unique = True, nullable = False) @@ -27,62 +37,125 @@ class Config(Row): class Instance(Row): table_name: str = 'inboxes' + domain: Column[str] = Column( 'domain', 'text', primary_key = True, unique = True, nullable = False) actor: Column[str] = Column('actor', 'text', unique = True) inbox: Column[str] = Column('inbox', 'text', unique = True, nullable = False) followid: Column[str] = Column('followid', 'text') software: Column[str] = Column('software', 'text') - accepted: Column[datetime] = Column('accepted', 'boolean') - created: Column[datetime] = Column('created', 'timestamp', nullable = False) + accepted: Column[Date] = Column('accepted', 'boolean') + created: Column[Date] = Column( + 'created', 'timestamp', nullable = False, + deserializer = deserialize_timestamp, serializer = Date.timestamp + ) @TABLES.add_row class Whitelist(Row): domain: Column[str] = Column( 'domain', 'text', primary_key = True, unique = True, nullable = True) - created: Column[datetime] = Column('created', 'timestamp') + created: Column[Date] = Column( + 'created', 'timestamp', nullable = False, + deserializer = deserialize_timestamp, serializer = Date.timestamp + ) @TABLES.add_row class DomainBan(Row): table_name: str = 'domain_bans' + domain: Column[str] = Column( 'domain', 'text', primary_key = True, unique = True, nullable = True) reason: Column[str] = Column('reason', 'text') note: Column[str] = Column('note', 'text') - created: Column[datetime] = Column('created', 'timestamp') + created: Column[Date] = Column( + 'created', 'timestamp', nullable = False, + deserializer = deserialize_timestamp, serializer = Date.timestamp + ) @TABLES.add_row class SoftwareBan(Row): table_name: str = 'software_bans' + name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True) reason: Column[str] = Column('reason', 'text') note: Column[str] = Column('note', 'text') - created: Column[datetime] = Column('created', 'timestamp') + created: Column[Date] = Column( + 'created', 'timestamp', nullable = False, + deserializer = deserialize_timestamp, serializer = Date.timestamp + ) @TABLES.add_row class User(Row): table_name: str = 'users' + username: Column[str] = Column( 'username', 'text', primary_key = True, unique = True, nullable = False) hash: Column[str] = Column('hash', 'text', nullable = False) handle: Column[str] = Column('handle', 'text') - created: Column[datetime] = Column('created', 'timestamp') + created: Column[Date] = Column( + 'created', 'timestamp', nullable = False, + deserializer = deserialize_timestamp, serializer = Date.timestamp + ) @TABLES.add_row class Token(Row): table_name: str = 'tokens' + code: Column[str] = Column('code', 'text', primary_key = True, unique = True, nullable = False) user: Column[str] = Column('user', 'text', nullable = False) - created: Column[datetime] = Column('created', 'timestamp') + created: Column[Date] = Column( + 'created', 'timestamp', nullable = False, + deserializer = deserialize_timestamp, serializer = Date.timestamp + ) + accessed: Column[Date] = Column( + 'accessed', 'timestamp', nullable = False, + deserializer = deserialize_timestamp, serializer = Date.timestamp + ) + + +@TABLES.add_row +class App(Row): + table_name: str = 'apps' + + + client_id: Column[str] = Column( + 'client_id', 'text', primary_key = True, unique = True, nullable = False) + client_secret: Column[str] = Column('client_secret', 'text', nullable = False) + name: Column[str] = Column('name', 'text') + website: Column[str] = Column('website', 'text') + redirect_uri: Column[str] = Column('redirect_uri', 'text', nullable = False) + token: Column[str | None] = Column('token', 'text') + auth_code: Column[str | None] = Column('auth_code', 'text') + user: Column[str | None] = Column('user', 'text') + created: Column[Date] = Column( + 'created', 'timestamp', nullable = False, + deserializer = deserialize_timestamp, serializer = Date.timestamp + ) + accessed: Column[Date] = Column( + 'accessed', 'timestamp', nullable = False, + deserializer = deserialize_timestamp, serializer = Date.timestamp + ) + + + def get_api_data(self, include_token: bool = False) -> dict[str, Any]: + data = deepcopy(self) + data.pop('auth_code') + data.pop('created') + data.pop('accessed') + + if not include_token: + data.pop('token') + + return data def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]: @@ -103,5 +176,15 @@ def migrate_20240206(conn: Connection) -> None: @migration def migrate_20240310(conn: Connection) -> None: - conn.execute("ALTER TABLE inboxes ADD COLUMN accepted BOOLEAN") - conn.execute("UPDATE inboxes SET accepted = 1") + conn.execute('ALTER TABLE "inboxes" ADD COLUMN "accepted" BOOLEAN') + conn.execute('UPDATE "inboxes" SET accepted = 1') + + +@migration +def migrate_20240625(conn: Connection) -> None: + conn.execute('ALTER TABLE "tokens" ADD "accessed" timestamp') + + for token in conn.get_tokens(): + conn.update('tokens', {'accessed': token.created}, code = token.code).one() + + conn.create_tables() diff --git a/relay/frontend/base.haml b/relay/frontend/base.haml index d3d8bb6..dd1e3e2 100644 --- a/relay/frontend/base.haml +++ b/relay/frontend/base.haml @@ -1,5 +1,5 @@ -macro menu_item(name, path) - -if view.request.path == path or (path != "/" and view.request.path.startswith(path)) + -if request.path == path or (path != "/" and request.path.startswith(path)) %a.button(href="{{path}}" active="true") -> =name -else @@ -10,12 +10,12 @@ %head %title << {{config.name}}: {{page}} %meta(charset="UTF-8") - %meta(name="viewport" content="width=device-width, initial-scale=1") - %link(rel="stylesheet" type="text/css" href="/theme/{{config.theme}}.css?{{version}}" nonce="{{view.request['hash']}}" class="theme") - %link(rel="stylesheet" type="text/css" href="/static/style.css?{{version}}" nonce="{{view.request['hash']}}") - %link(rel="stylesheet" type="text/css" href="/static/bootstrap-icons.css?{{version}}" nonce="{{view.request['hash']}}") + %meta(name="ort" content="width=device-width, initial-scale=1") + %link(rel="stylesheet" type="text/css" href="/theme/{{config.theme}}.css?{{version}}" nonce="{{request['hash']}}" class="theme") + %link(rel="stylesheet" type="text/css" href="/static/style.css?{{version}}" nonce="{{request['hash']}}") + %link(rel="stylesheet" type="text/css" href="/static/bootstrap-icons.css?{{version}}" nonce="{{request['hash']}}") %link(rel="manifest" href="/manifest.json?{{version}}") - %script(type="application/javascript" src="/static/functions.js?{{version}}" nonce="{{view.request['hash']}}" defer) + %script(type="application/javascript" src="/static/functions.js?{{version}}" nonce="{{request['hash']}}" defer) -block head %body @@ -26,7 +26,7 @@ {{menu_item("Home", "/")}} - -if view.request["user"] + -if request["user"] {{menu_item("Instances", "/admin/instances")}} {{menu_item("Whitelist", "/admin/whitelist")}} {{menu_item("Domain Bans", "/admin/domain_bans")}} @@ -61,11 +61,11 @@ #footer.section .col1 - -if not view.request["user"] + -if not request["user"] %a(href="/login") << Login -else - =view.request["user"]["username"] + =request["user"]["username"] ( %a(href="/logout") << Logout ) diff --git a/relay/frontend/page/authorize_new.haml b/relay/frontend/page/authorize_new.haml new file mode 100644 index 0000000..4f07df3 --- /dev/null +++ b/relay/frontend/page/authorize_new.haml @@ -0,0 +1,31 @@ +-extends "base.haml" +-set page="App Authorization" + +-block content + %fieldset.section + %legend << App Authorization + + -if application.website + #title << Application "{{application.name}}" wants full API access + + -else + #title << Application "{{application.name}}" wants full API access + + #buttons + .spacer + + %form(action="/oauth/authorize" method="POST") + %input(type="hidden" name="client_id" value="{{application.client_id}}") + %input(type="hidden" name="client_secret" value="{{application.client_secret}}") + %input(type="hidden" name="redirect_uri" value="{{application.redirect_uri}}") + %input(type="hidden" name="response" value="true") + %input.button(type="submit" value="Allow") + + %form(action="/oauth/authorize" method="POST") + %input(type="hidden" name="client_id" value="{{application.client_id}}") + %input(type="hidden" name="client_secret" value="{{application.client_secret}}") + %input(type="hidden" name="redirect_uri" value="{{application.redirect_uri}}") + %input(type="hidden" name="response" value="false") + %input.button(type="submit" value="Deny") + + .spacer diff --git a/relay/frontend/page/authorize_show.haml b/relay/frontend/page/authorize_show.haml new file mode 100644 index 0000000..19cde40 --- /dev/null +++ b/relay/frontend/page/authorize_show.haml @@ -0,0 +1,18 @@ +-extends "base.haml" +-set page="App Authorization" + +-block content + %fieldset.section + %legend << App Authorization Code + + -if application.website + %p + Copy the following code into + %a(href="{{application.website}}" target="_main") -> %code -> =application.name + + -else + %p + Copy the following code info + %code -> =application.name + + %pre#code -> =application.auth_code diff --git a/relay/frontend/page/error.haml b/relay/frontend/page/error.haml new file mode 100644 index 0000000..4d4bf95 --- /dev/null +++ b/relay/frontend/page/error.haml @@ -0,0 +1,7 @@ +-extends "base.haml" +-set page="Error" + +-block content + .section.error + .title << HTTP Error {{e.status}} + .body -> =e.body diff --git a/relay/frontend/page/login.haml b/relay/frontend/page/login.haml index c32160f..4f29746 100644 --- a/relay/frontend/page/login.haml +++ b/relay/frontend/page/login.haml @@ -12,4 +12,6 @@ %label(for="password") << Password %input(id="password" name="password" placeholder="Password" type="password") + + %input#redir(type="hidden" name="redir" value="{{redir}}") %input.submit(type="button" value="Login") diff --git a/relay/frontend/static/functions.js b/relay/frontend/static/functions.js index b0e4db5..3063223 100644 --- a/relay/frontend/static/functions.js +++ b/relay/frontend/static/functions.js @@ -483,13 +483,15 @@ function page_instance() { function page_login() { const fields = { username: document.querySelector("#username"), - password: document.querySelector("#password") - } + password: document.querySelector("#password"), + redir: document.querySelector("#redir") + }; async function login(event) { const values = { username: fields.username.value.trim(), - password: fields.password.value.trim() + password: fields.password.value.trim(), + redir: fields.redir.value.trim() } if (values.username === "" | values.password === "") { @@ -498,14 +500,14 @@ function page_login() { } try { - await request("POST", "v1/token", values); + await request("POST", "v1/login", values); } catch (error) { toast(error); return; } - document.location = "/"; + document.location = values.redir; } @@ -848,9 +850,6 @@ if (location.pathname.startsWith("/admin/config")) { } else if (location.pathname.startsWith("/admin/instances")) { page_instance(); -} else if (location.pathname.startsWith("/admin/login")) { - page_login(); - } else if (location.pathname.startsWith("/admin/software_bans")) { page_software_ban(); @@ -859,4 +858,7 @@ if (location.pathname.startsWith("/admin/config")) { } else if (location.pathname.startsWith("/admin/whitelist")) { page_whitelist(); + +} else if (location.pathname.startsWith("/login")) { + page_login(); } diff --git a/relay/frontend/static/style.css b/relay/frontend/static/style.css index f0d72f5..c9bcd43 100644 --- a/relay/frontend/static/style.css +++ b/relay/frontend/static/style.css @@ -338,6 +338,44 @@ textarea { } +/* error */ +#content.page-error { + text-align: center; +} + +#content.page-error .title { + font-size: 24px; + font-weight: bold; +} + + +/* auth */ +#content.page-app_authorization { + text-align: center; +} + +#content.page-app_authorization #code { + background: var(--background); + border: 1px solid var(--border); + font-size: 18px; + margin: 0 auto; + width: max-content; + padding: 5px; +} + +#content.page-app_authorization #title { + font-size: 24px; +} + +#content.page-app_authorization #buttons { + display: grid; + grid-template-columns: auto max-content max-content auto; + grid-gap: var(--spacing); + justify-items: center; + margin: var(--spacing) 0; +} + + @keyframes show_toast { 0% { transform: translateX(100%); diff --git a/relay/manage.py b/relay/manage.py index 81f546e..5ae8238 100644 --- a/relay/manage.py +++ b/relay/manage.py @@ -212,6 +212,21 @@ def cli_run(ctx: click.Context, dev: bool = False) -> None: os._exit(0) +@cli.command('db-maintenance') +@click.option('--fix-timestamps', '-t', is_flag = True, + help = 'Make sure timestamps in the database are float values') +@click.pass_context +def cli_db_maintenance(ctx: click.Context, fix_timestamps: bool) -> None: + 'Perform maintenance tasks on the database' + + if fix_timestamps: + with ctx.obj.database.session(True) as conn: + conn.fix_timestamps() + + with ctx.obj.database.session(False) as conn: + with conn.execute("VACUUM"): + pass + @cli.command('convert') @click.option('--old-config', '-o', help = 'Path to the config file to convert from') @@ -239,18 +254,18 @@ def cli_convert(ctx: click.Context, old_config: str) -> None: ctx.obj.config.set('domain', config['host']) ctx.obj.config.save() + # fix: mypy complains about the types returned by click.progressbar when updating click to 8.1.7 with get_database(ctx.obj.config) as db: with db.session(True) as conn: conn.put_config('private-key', database['private-key']) conn.put_config('note', config['note']) conn.put_config('whitelist-enabled', config['whitelist_enabled']) - with click.progressbar( # type: ignore + with click.progressbar( database['relay-list'].values(), label = 'Inboxes'.ljust(15), width = 0 ) as inboxes: - for inbox in inboxes: if inbox['software'] in {'akkoma', 'pleroma'}: actor = f'https://{inbox["domain"]}/relay' @@ -269,7 +284,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None: software = inbox['software'] ) - with click.progressbar( # type: ignore + with click.progressbar( config['blocked_software'], label = 'Banned software'.ljust(15), width = 0 @@ -281,7 +296,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None: reason = 'relay' if software in RELAY_SOFTWARE else None ) - with click.progressbar( # type: ignore + with click.progressbar( config['blocked_instances'], label = 'Banned domains'.ljust(15), width = 0 @@ -290,7 +305,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None: for domain in banned_software: conn.put_domain_ban(domain) - with click.progressbar( # type: ignore + with click.progressbar( config['whitelist'], label = 'Whitelist'.ljust(15), width = 0 diff --git a/relay/misc.py b/relay/misc.py index 6995bc4..b27c89a 100644 --- a/relay/misc.py +++ b/relay/misc.py @@ -62,6 +62,28 @@ SOFTWARE = ( 'gotosocial' ) +JSON_PATHS: tuple[str, ...] = ( + '/api/v1', + '/actor', + '/inbox', + '/outbox', + '/following', + '/followers', + '/.well-known', + '/nodeinfo', + '/oauth/token', + '/oauth/revoke' +) + +TOKEN_PATHS: tuple[str, ...] = ( + '/api', + '/login', + '/logout', + '/oauth/authorize', + '/oauth/revoke', + '/admin' +) + def boolean(value: Any) -> bool: if isinstance(value, str): @@ -113,6 +135,17 @@ def get_resource(path: str) -> Path: return Path(str(pkgfiles('relay'))).joinpath(path) +class HttpError(Exception): + def __init__(self, + status: int, + body: str) -> None: + + self.body: str = body + self.status: int = status + + Exception.__init__(self, f"HTTP Error {status}: {body}") + + class JsonEncoder(json.JSONEncoder): def default(self, o: Any) -> str: if isinstance(o, datetime): @@ -242,9 +275,9 @@ class Response(AiohttpResponse): @classmethod - def new_redir(cls: type[Self], path: str) -> Self: + def new_redir(cls: type[Self], path: str, status: int = 307) -> Self: body = f'Redirect to {path}' - return cls.new(body, 302, {'Location': path}) + return cls.new(body, status, {'Location': path}, ctype = 'html') @property diff --git a/relay/template.py b/relay/template.py index 7e3f657..3ee2855 100644 --- a/relay/template.py +++ b/relay/template.py @@ -2,6 +2,7 @@ from __future__ import annotations import textwrap +from aiohttp.web import Request from collections.abc import Callable from hamlish_jinja import HamlishExtension from jinja2 import Environment, FileSystemLoader @@ -13,7 +14,6 @@ from typing import TYPE_CHECKING, Any from . import __version__ from .misc import get_resource -from .views.base import View if TYPE_CHECKING: from .application import Application @@ -43,12 +43,12 @@ class Template(Environment): self.hamlish_mode = 'indented' - def render(self, path: str, view: View | None = None, **context: Any) -> str: + def render(self, path: str, request: Request, **context: Any) -> str: with self.app.database.session(False) as conn: config = conn.get_config_all() new_context = { - 'view': view, + 'request': request, 'domain': self.app.config.domain, 'version': __version__, 'config': config, diff --git a/relay/views/activitypub.py b/relay/views/activitypub.py index 74b01c6..aa672f2 100644 --- a/relay/views/activitypub.py +++ b/relay/views/activitypub.py @@ -7,7 +7,7 @@ from .base import View, register_route from .. import logger as logging from ..database import schema -from ..misc import Message, Response +from ..misc import HttpError, Message, Response from ..processors import run_processor @@ -39,8 +39,7 @@ class ActorView(View): async def post(self, request: Request) -> Response: - if response := await self.get_post_data(): - return response + await self.get_post_data() with self.database.session() as conn: self.instance = conn.get_inbox(self.actor.shared_inbox) # type: ignore[assignment] @@ -65,13 +64,13 @@ class ActorView(View): return Response.new(status = 202) - async def get_post_data(self) -> Response | None: + async def get_post_data(self) -> None: try: self.signature = aputils.Signature.parse(self.request.headers['signature']) except KeyError: logging.verbose('Missing signature header') - return Response.new_error(400, 'missing signature header', 'json') + raise HttpError(400, 'missing signature header') try: message: Message | None = await self.request.json(loads = Message.parse) @@ -79,17 +78,17 @@ class ActorView(View): except Exception: traceback.print_exc() logging.verbose('Failed to parse inbox message') - return Response.new_error(400, 'failed to parse message', 'json') + raise HttpError(400, 'failed to parse message') if message is None: logging.verbose('empty message') - return Response.new_error(400, 'missing message', 'json') + raise HttpError(400, 'missing message') self.message = message if 'actor' not in self.message: logging.verbose('actor not in message') - return Response.new_error(400, 'no actor in message', 'json') + raise HttpError(400, 'no actor in message') try: self.actor = await self.client.get(self.signature.keyid, True, Message) @@ -98,26 +97,24 @@ class ActorView(View): # ld signatures aren't handled atm, so just ignore it if self.message.type == 'Delete': logging.verbose('Instance sent a delete which cannot be handled') - return Response.new(status=202) + raise HttpError(202, '') logging.verbose(f'Failed to fetch actor: {self.signature.keyid}') - return Response.new_error(400, 'failed to fetch actor', 'json') + raise HttpError(400, 'failed to fetch actor') try: self.signer = self.actor.signer except KeyError: logging.verbose('Actor missing public key: %s', self.signature.keyid) - return Response.new_error(400, 'actor missing public key', 'json') + raise HttpError(400, 'actor missing public key') try: await self.signer.validate_request_async(self.request) except aputils.SignatureFailureError as e: logging.verbose('signature validation failed for "%s": %s', self.actor.id, e) - return Response.new_error(401, str(e), 'json') - - return None + raise HttpError(401, str(e)) @register_route('/outbox') diff --git a/relay/views/api.py b/relay/views/api.py index 73b6a16..76cd1e5 100644 --- a/relay/views/api.py +++ b/relay/views/api.py @@ -1,16 +1,17 @@ +import secrets import traceback from aiohttp.web import Request, middleware from argon2.exceptions import VerifyMismatchError +from blib import convert_to_boolean from collections.abc import Awaitable, Callable, Sequence -from typing import Any from urllib.parse import urlparse from .base import View, register_route from .. import __version__ -from ..database import ConfigData -from ..misc import Message, Response, boolean, get_app +from ..database import ConfigData, schema +from ..misc import HttpError, Message, Response, boolean DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob' @@ -22,6 +23,8 @@ ALLOWED_HEADERS: set[str] = { PUBLIC_API_PATHS: Sequence[tuple[str, str]] = ( ('GET', '/api/v1/relay'), + ('POST', '/api/v1/app'), + ('POST', '/api/v1/login'), ('POST', '/api/v1/token') ) @@ -37,57 +40,174 @@ def check_api_path(method: str, path: str) -> bool: async def handle_api_path( request: Request, handler: Callable[[Request], Awaitable[Response]]) -> Response: - try: - if (token := request.cookies.get('user-token')): - request['token'] = token - else: - request['token'] = request.headers['Authorization'].replace('Bearer', '').strip() - - with get_app().database.session() as conn: - request['user'] = conn.get_user_by_token(request['token']) - - except (KeyError, ValueError): - request['token'] = None - request['user'] = None + if not request.path.startswith('/api'): + return await handler(request) if request.method != "OPTIONS" and check_api_path(request.method, request.path): - if not request['token']: - return Response.new_error(401, 'Missing token', 'json') + if request['token'] is None: + raise HttpError(401, 'Missing token') - if not request['user']: - return Response.new_error(401, 'Invalid token', 'json') + if request['user'] is None: + raise HttpError(401, 'Invalid token') response = await handler(request) - - if request.path.startswith('/api'): - response.headers['Access-Control-Allow-Origin'] = '*' - response.headers['Access-Control-Allow-Headers'] = ', '.join(ALLOWED_HEADERS) + response.headers['Access-Control-Allow-Origin'] = '*' + response.headers['Access-Control-Allow-Headers'] = ', '.join(ALLOWED_HEADERS) return response -@register_route('/api/v1/token') -class Login(View): +@register_route('/oauth/authorize') +class OauthAuthorize(View): async def get(self, request: Request) -> Response: - return Response.new({'message': 'Token valid'}, ctype = 'json') + data = await self.get_api_data(['response_type', 'client_id', 'redirect_uri'], []) + + if data['response_type'] != 'code': + raise HttpError(400, 'Response type is not "code"') + + with self.database.session(True) as conn: + with conn.select('app', client_id = data['client_id']) as cur: + if (app := cur.one(schema.App)) is None: + raise HttpError(404, 'Could not find app') + + if app.token is not None or app.auth_code is not None: + context = {'application': app} + html = self.template.render( + 'page/authorize_show.haml', self.request, **context + ) + + return Response.new(html, ctype = 'html') + + if data['redirect_uri'] != app.redirect_uri: + raise HttpError(400, 'redirect_uri does not match application') + + context = {'application': app} + html = self.template.render('page/authorize_new.haml', self.request, **context) + return Response.new(html, ctype = 'html') async def post(self, request: Request) -> Response: - data = await self.get_api_data(['username', 'password'], []) + data = await self.get_api_data( + ['client_id', 'client_secret', 'redirect_uri', 'response'], [] + ) - if isinstance(data, Response): - return data + with self.database.session(True) as conn: + if (app := conn.get_app(data['client_id'], data['client_secret'])) is None: + return Response.new_error(404, 'Could not find app', 'json') + + if convert_to_boolean(data['response']): + if app.auth_code is None: + app = conn.update_app(app, request['user'], True) + + if app.redirect_uri == DEFAULT_REDIRECT: + context = {'application': app} + html = self.template.render( + 'page/authorize_show.haml', self.request, **context + ) + + return Response.new(html, ctype = 'html') + + return Response.new_redir(f'{app.redirect_uri}?code={app.auth_code}') + + if not conn.del_app(app.client_id, app.client_secret): + raise HttpError(404, 'App not found') + + return Response.new_redir('/') + + +@register_route('/oauth/token') +class OauthToken(View): + async def post(self, request: Request) -> Response: + data = await self.get_api_data( + ['grant_type', 'code', 'client_id', 'client_secret', 'redirect_uri'], [] + ) + + if data['grant_type'] != 'authorization_code': + raise HttpError(400, 'Invalid grant type') + + with self.database.session(True) as conn: + if (app := conn.get_app(data['client_id'], data['client_secret'])) is None: + raise HttpError(404, 'Application not found') + + if app.auth_code != data['code']: + raise HttpError(400, 'Invalid authentication code') + + if app.redirect_uri != data['redirect_uri']: + raise HttpError(400, 'Invalid redirect uri') + + app = conn.update_app(app, request['user'], False) + + return Response.new(app.get_api_data(True), ctype = 'json') + + +@register_route('/oauth/revoke') +class OauthRevoke(View): + async def post(self, request: Request) -> Response: + data = await self.get_api_data(['client_id', 'client_secret', 'token'], []) + + with self.database.session(True) as conn: + if (app := conn.get_app(**data)) is None: + raise HttpError(404, 'Could not find token') + + if app.user != request['token'].username: + raise HttpError(403, 'Invalid token') + + if not conn.del_app(**data): + raise HttpError(400, 'Token not removed') + + return Response.new({'msg': 'Token deleted'}, ctype = 'json') + + +@register_route('/api/v1/app') +class App(View): + async def get(self, request: Request) -> Response: + data = await self.get_api_data(['client_id', 'client_secret'], []) + + with self.database.session(False) as conn: + if (app := conn.get_app(data['client_id'], data['client_secret'])) is None: + raise HttpError(404, 'Application cannot be found') + + return Response.new(app.get_api_data(), ctype = 'json') + + + async def post(self, request: Request) -> Response: + data = await self.get_api_data(['name', 'redirect_uri'], ['website']) + + with self.database.session(True) as conn: + app = conn.put_app( + name = data['name'], + redirect_uri = data['redirect_uri'], + website = data.get('website') + ) + + return Response.new(app.get_api_data(), ctype = 'json') + + + async def delete(self, request: Request) -> Response: + data = await self.get_api_data(['client_id', 'client_secret'], []) + + with self.database.session(True) as conn: + if not conn.del_app(data['client_id'], data['client_secret'], request['token'].code): + raise HttpError(400, 'Token not removed') + + return Response.new({'msg': 'Token deleted'}, ctype = 'json') + + +@register_route('/api/v1/login') +class Login(View): + async def post(self, request: Request) -> Response: + data = await self.get_api_data(['username', 'password'], []) with self.database.session(True) as conn: if not (user := conn.get_user(data['username'])): - return Response.new_error(401, 'User not found', 'json') + raise HttpError(401, 'User not found') try: conn.hasher.verify(user['hash'], data['password']) except VerifyMismatchError: - return Response.new_error(401, 'Invalid password', 'json') + raise HttpError(401, 'Invalid password') token = conn.put_token(data['username']) @@ -106,11 +226,36 @@ class Login(View): return resp - async def delete(self, request: Request) -> Response: - with self.database.session() as conn: - conn.del_token(request['token']) - return Response.new({'message': 'Token revoked'}, ctype = 'json') + async def post2(self, request: Request) -> Response: + data = await self.get_api_data(['username', 'password'], []) + + with self.database.session(True) as conn: + if not (user := conn.get_user(data['username'])): + raise HttpError(401, 'User not found') + + try: + conn.hasher.verify(user['hash'], data['password']) + + except VerifyMismatchError: + raise HttpError(401, 'Invalid password') + + app = conn.put_app( + data['app_name'], + DEFAULT_REDIRECT, + data.get('website') + ) + + params = { + 'code': secrets.token_hex(20), + 'user': user.username + } + + with conn.update('app', params, client_id = app.client_id) as cur: + if (row := cur.one(schema.App)) is None: + raise HttpError(500, 'Failed to create app') + + return Response.new(row.get_api_data(True), ctype = 'json') @register_route('/api/v1/relay') @@ -155,14 +300,10 @@ class Config(View): async def post(self, request: Request) -> Response: data = await self.get_api_data(['key', 'value'], []) - - if isinstance(data, Response): - return data - data['key'] = data['key'].replace('-', '_') if data['key'] not in ConfigData.USER_KEYS(): - return Response.new_error(400, 'Invalid key', 'json') + raise HttpError(400, 'Invalid key') with self.database.session() as conn: conn.put_config(data['key'], data['value']) @@ -173,11 +314,8 @@ class Config(View): async def delete(self, request: Request) -> Response: data = await self.get_api_data(['key'], []) - if isinstance(data, Response): - return data - if data['key'] not in ConfigData.USER_KEYS(): - return Response.new_error(400, 'Invalid key', 'json') + raise HttpError(400, 'Invalid key') with self.database.session() as conn: conn.put_config(data['key'], ConfigData.DEFAULT(data['key'])) @@ -196,15 +334,11 @@ class Inbox(View): async def post(self, request: Request) -> Response: data = await self.get_api_data(['actor'], ['inbox', 'software', 'followid']) - - if isinstance(data, Response): - return data - data['domain'] = urlparse(data["actor"]).netloc with self.database.session() as conn: if conn.get_inbox(data['domain']) is not None: - return Response.new_error(404, 'Instance already in database', 'json') + raise HttpError(404, 'Instance already in database') data['domain'] = data['domain'].encode('idna').decode() @@ -214,7 +348,7 @@ class Inbox(View): except Exception: traceback.print_exc() - return Response.new_error(500, 'Failed to fetch actor', 'json') + raise HttpError(500, 'Failed to fetch actor') data['inbox'] = actor_data.shared_inbox @@ -240,14 +374,10 @@ class Inbox(View): async def patch(self, request: Request) -> Response: with self.database.session() as conn: data = await self.get_api_data(['domain'], ['actor', 'software', 'followid']) - - if isinstance(data, Response): - return data - data['domain'] = data['domain'].encode('idna').decode() if (instance := conn.get_inbox(data['domain'])) is None: - return Response.new_error(404, 'Instance with domain not found', 'json') + raise HttpError(404, 'Instance with domain not found') instance = conn.put_inbox( instance.domain, @@ -262,14 +392,10 @@ class Inbox(View): async def delete(self, request: Request) -> Response: with self.database.session() as conn: data = await self.get_api_data(['domain'], []) - - if isinstance(data, Response): - return data - data['domain'] = data['domain'].encode('idna').decode() if not conn.get_inbox(data['domain']): - return Response.new_error(404, 'Instance with domain not found', 'json') + raise HttpError(404, 'Instance with domain not found') conn.del_inbox(data['domain']) @@ -286,26 +412,21 @@ class RequestView(View): async def post(self, request: Request) -> Response: - data: dict[str, Any] | Response = await self.get_api_data(['domain', 'accept'], []) - - if isinstance(data, Response): - return data - - data['accept'] = boolean(data['accept']) + data = await self.get_api_data(['domain', 'accept'], []) data['domain'] = data['domain'].encode('idna').decode() try: with self.database.session(True) as conn: - instance = conn.put_request_response(data['domain'], data['accept']) + instance = conn.put_request_response(data['domain'], boolean(data['accept'])) except KeyError: - return Response.new_error(404, 'Request not found', 'json') + raise HttpError(404, 'Request not found') message = Message.new_response( host = self.config.domain, actor = instance.actor, followid = instance.followid, - accept = data['accept'] + accept = boolean(data['accept']) ) self.app.push_message(instance.inbox, message, instance) @@ -333,15 +454,11 @@ class DomainBan(View): async def post(self, request: Request) -> Response: data = await self.get_api_data(['domain'], ['note', 'reason']) - - if isinstance(data, Response): - return data - data['domain'] = data['domain'].encode('idna').decode() with self.database.session() as conn: if conn.get_domain_ban(data['domain']) is not None: - return Response.new_error(400, 'Domain already banned', 'json') + raise HttpError(400, 'Domain already banned') ban = conn.put_domain_ban( domain = data['domain'], @@ -356,16 +473,13 @@ class DomainBan(View): with self.database.session() as conn: data = await self.get_api_data(['domain'], ['note', 'reason']) - if isinstance(data, Response): - return data - if not any([data.get('note'), data.get('reason')]): - return Response.new_error(400, 'Must include note and/or reason parameters', 'json') + raise HttpError(400, 'Must include note and/or reason parameters') data['domain'] = data['domain'].encode('idna').decode() if conn.get_domain_ban(data['domain']) is None: - return Response.new_error(404, 'Domain not banned', 'json') + raise HttpError(404, 'Domain not banned') ban = conn.update_domain_ban( domain = data['domain'], @@ -379,14 +493,10 @@ class DomainBan(View): async def delete(self, request: Request) -> Response: with self.database.session() as conn: data = await self.get_api_data(['domain'], []) - - if isinstance(data, Response): - return data - data['domain'] = data['domain'].encode('idna').decode() if conn.get_domain_ban(data['domain']) is None: - return Response.new_error(404, 'Domain not banned', 'json') + raise HttpError(404, 'Domain not banned') conn.del_domain_ban(data['domain']) @@ -405,12 +515,9 @@ class SoftwareBan(View): async def post(self, request: Request) -> Response: data = await self.get_api_data(['name'], ['note', 'reason']) - if isinstance(data, Response): - return data - with self.database.session() as conn: if conn.get_software_ban(data['name']) is not None: - return Response.new_error(400, 'Domain already banned', 'json') + raise HttpError(400, 'Domain already banned') ban = conn.put_software_ban( name = data['name'], @@ -424,15 +531,12 @@ class SoftwareBan(View): async def patch(self, request: Request) -> Response: data = await self.get_api_data(['name'], ['note', 'reason']) - if isinstance(data, Response): - return data - if not any([data.get('note'), data.get('reason')]): - return Response.new_error(400, 'Must include note and/or reason parameters', 'json') + raise HttpError(400, 'Must include note and/or reason parameters') with self.database.session() as conn: if conn.get_software_ban(data['name']) is None: - return Response.new_error(404, 'Software not banned', 'json') + raise HttpError(404, 'Software not banned') ban = conn.update_software_ban( name = data['name'], @@ -446,12 +550,9 @@ class SoftwareBan(View): async def delete(self, request: Request) -> Response: data = await self.get_api_data(['name'], []) - if isinstance(data, Response): - return data - with self.database.session() as conn: if conn.get_software_ban(data['name']) is None: - return Response.new_error(404, 'Software not banned', 'json') + raise HttpError(404, 'Software not banned') conn.del_software_ban(data['name']) @@ -474,12 +575,9 @@ class User(View): async def post(self, request: Request) -> Response: data = await self.get_api_data(['username', 'password'], ['handle']) - if isinstance(data, Response): - return data - with self.database.session() as conn: if conn.get_user(data['username']) is not None: - return Response.new_error(404, 'User already exists', 'json') + raise HttpError(404, 'User already exists') user = conn.put_user( username = data['username'], @@ -494,9 +592,6 @@ class User(View): async def patch(self, request: Request) -> Response: data = await self.get_api_data(['username'], ['password', 'handle']) - if isinstance(data, Response): - return data - with self.database.session(True) as conn: user = conn.put_user( username = data['username'], @@ -511,12 +606,9 @@ class User(View): async def delete(self, request: Request) -> Response: data = await self.get_api_data(['username'], []) - if isinstance(data, Response): - return data - with self.database.session(True) as conn: if conn.get_user(data['username']) is None: - return Response.new_error(404, 'User does not exist', 'json') + raise HttpError(404, 'User does not exist') conn.del_user(data['username']) @@ -535,14 +627,11 @@ class Whitelist(View): async def post(self, request: Request) -> Response: data = await self.get_api_data(['domain'], []) - if isinstance(data, Response): - return data - domain = data['domain'].encode('idna').decode() with self.database.session() as conn: if conn.get_domain_whitelist(domain) is not None: - return Response.new_error(400, 'Domain already added to whitelist', 'json') + raise HttpError(400, 'Domain already added to whitelist') item = conn.put_domain_whitelist(domain) @@ -552,14 +641,11 @@ class Whitelist(View): async def delete(self, request: Request) -> Response: data = await self.get_api_data(['domain'], []) - if isinstance(data, Response): - return data - domain = data['domain'].encode('idna').decode() with self.database.session() as conn: if conn.get_domain_whitelist(domain) is None: - return Response.new_error(404, 'Domain not in whitelist', 'json') + raise HttpError(404, 'Domain not in whitelist') conn.del_domain_whitelist(domain) diff --git a/relay/views/base.py b/relay/views/base.py index e102896..1b2d405 100644 --- a/relay/views/base.py +++ b/relay/views/base.py @@ -1,10 +1,8 @@ from __future__ import annotations -from Crypto.Random import get_random_bytes from aiohttp.abc import AbstractView from aiohttp.hdrs import METH_ALL as METHODS -from aiohttp.web import HTTPMethodNotAllowed, Request -from base64 import b64encode +from aiohttp.web import Request from bsql import Database from collections.abc import Awaitable, Callable, Generator, Sequence, Mapping from functools import cached_property @@ -15,7 +13,7 @@ from ..cache import Cache from ..config import Config from ..database import Connection from ..http_client import HttpClient -from ..misc import Response, get_app +from ..misc import HttpError, Response, get_app if TYPE_CHECKING: from typing import Self @@ -43,10 +41,10 @@ def register_route(*paths: str) -> Callable[[type[View]], type[View]]: class View(AbstractView): def __await__(self) -> Generator[Any, None, Response]: if self.request.method not in METHODS: - raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) + raise HttpError(405, f'"{self.request.method}" method not allowed') if not (handler := self.handlers.get(self.request.method)): - raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) + raise HttpError(405, f'"{self.request.method}" method not allowed') return self._run_handler(handler).__await__() @@ -58,7 +56,6 @@ class View(AbstractView): async def _run_handler(self, handler: HandlerCallback, **kwargs: Any) -> Response: - self.request['hash'] = b64encode(get_random_bytes(16)).decode('ascii') return await handler(self.request, **self.request.match_info, **kwargs) @@ -117,17 +114,18 @@ class View(AbstractView): async def get_api_data(self, required: list[str], - optional: list[str]) -> dict[str, str] | Response: + optional: list[str]) -> dict[str, str]: - if self.request.content_type in {'x-www-form-urlencoded', 'multipart/form-data'}: + if self.request.content_type in {'application/x-www-form-urlencoded', 'multipart/form-data'}: post_data = convert_data(await self.request.post()) + # post_data = {key: value for key, value in parse_qsl(await self.request.text())} elif self.request.content_type == 'application/json': try: post_data = convert_data(await self.request.json()) except JSONDecodeError: - return Response.new_error(400, 'Invalid JSON data', 'json') + raise HttpError(400, 'Invalid JSON data') else: post_data = convert_data(self.request.query) @@ -139,9 +137,9 @@ class View(AbstractView): data[key] = post_data[key] except KeyError as e: - return Response.new_error(400, f'Missing {str(e)} pararmeter', 'json') + raise HttpError(400, f'Missing {str(e)} pararmeter') for key in optional: - data[key] = post_data.get(key, '') + data[key] = post_data.get(key) # type: ignore[assignment] return data diff --git a/relay/views/frontend.py b/relay/views/frontend.py index cf6b338..a383d20 100644 --- a/relay/views/frontend.py +++ b/relay/views/frontend.py @@ -1,18 +1,13 @@ from aiohttp import web from collections.abc import Awaitable, Callable from typing import Any +from urllib.parse import unquote from .base import View, register_route from ..database import THEMES from ..logger import LogLevel -from ..misc import Response, get_app - - -UNAUTH_ROUTES = { - '/', - '/login' -} +from ..misc import TOKEN_PATHS, Response @web.middleware @@ -20,28 +15,25 @@ async def handle_frontend_path( request: web.Request, handler: Callable[[web.Request], Awaitable[Response]]) -> Response: - app = get_app() + if request['user'] is not None and request.path == '/login': + return Response.new_redir('/') - if request.path in UNAUTH_ROUTES or request.path.startswith('/admin'): - request['token'] = request.cookies.get('user-token') - request['user'] = None + if request.path.startswith(TOKEN_PATHS) and request['user'] is None: + if request.path == '/logout': + return Response.new_redir('/') - if request['token']: - with app.database.session(False) as conn: - request['user'] = conn.get_user_by_token(request['token']) + response = Response.new_redir(f'/login?redir={request.path}') - if request['user'] and request.path == '/login': - return Response.new('', 302, {'Location': '/'}) - - if not request['user'] and request.path.startswith('/admin'): - response = Response.new('', 302, {'Location': f'/login?redir={request.path}'}) + if request['token'] is not None: response.del_cookie('user-token') - return response + + return response response = await handler(request) - if not request.path.startswith('/api') and not request['user'] and request['token']: - response.del_cookie('user-token') + if not request.path.startswith('/api'): + if request['user'] is None and request['token'] is not None: + response.del_cookie('user-token') return response @@ -54,14 +46,15 @@ class HomeView(View): 'instances': tuple(conn.get_inboxes()) } - data = self.template.render('page/home.haml', self, **context) + data = self.template.render('page/home.haml', self.request, **context) return Response.new(data, ctype='html') @register_route('/login') class Login(View): async def get(self, request: web.Request) -> Response: - data = self.template.render('page/login.haml', self) + redir = unquote(request.query.get('redir', '/')) + data = self.template.render('page/login.haml', self.request, redir = redir) return Response.new(data, ctype = 'html') @@ -69,7 +62,7 @@ class Login(View): class Logout(View): async def get(self, request: web.Request) -> Response: with self.database.session(True) as conn: - conn.del_token(request['token']) + conn.del_token(request['token'].code) resp = Response.new_redir('/') resp.del_cookie('user-token', domain = self.config.domain, path = '/') @@ -79,7 +72,7 @@ class Logout(View): @register_route('/admin') class Admin(View): async def get(self, request: web.Request) -> Response: - return Response.new('', 302, {'Location': '/admin/instances'}) + return Response.new_redir(f'/login?redir={request.path}', 301) @register_route('/admin/instances') @@ -101,7 +94,7 @@ class AdminInstances(View): if message: context['message'] = message - data = self.template.render('page/admin-instances.haml', self, **context) + data = self.template.render('page/admin-instances.haml', self.request, **context) return Response.new(data, ctype = 'html') @@ -123,7 +116,7 @@ class AdminWhitelist(View): if message: context['message'] = message - data = self.template.render('page/admin-whitelist.haml', self, **context) + data = self.template.render('page/admin-whitelist.haml', self.request, **context) return Response.new(data, ctype = 'html') @@ -145,7 +138,7 @@ class AdminDomainBans(View): if message: context['message'] = message - data = self.template.render('page/admin-domain_bans.haml', self, **context) + data = self.template.render('page/admin-domain_bans.haml', self.request, **context) return Response.new(data, ctype = 'html') @@ -167,7 +160,7 @@ class AdminSoftwareBans(View): if message: context['message'] = message - data = self.template.render('page/admin-software_bans.haml', self, **context) + data = self.template.render('page/admin-software_bans.haml', self.request, **context) return Response.new(data, ctype = 'html') @@ -189,7 +182,7 @@ class AdminUsers(View): if message: context['message'] = message - data = self.template.render('page/admin-users.haml', self, **context) + data = self.template.render('page/admin-users.haml', self.request, **context) return Response.new(data, ctype = 'html') @@ -213,7 +206,7 @@ class AdminConfig(View): } } - data = self.template.render('page/admin-config.haml', self, **context) + data = self.template.render('page/admin-config.haml', self.request, **context) return Response.new(data, ctype = 'html') @@ -251,5 +244,5 @@ class ThemeCss(View): except KeyError: return Response.new('Invalid theme', 404) - data = self.template.render('variables.css', self, **context) + data = self.template.render('variables.css', self.request, **context) return Response.new(data, ctype = 'css')