diff --git a/relay/application.py b/relay/application.py
index d852f29..6ab481b 100644
--- a/relay/application.py
+++ b/relay/application.py
@@ -4,11 +4,14 @@ import asyncio
import multiprocessing
import signal
import time
+import traceback
+from Crypto.Random import get_random_bytes
from aiohttp import web
-from aiohttp.web import StaticResource
+from aiohttp.web import HTTPException, StaticResource
from aiohttp_swagger import setup_swagger
from aputils.signer import Signer
+from base64 import b64encode
from bsql import Database
from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta
@@ -23,7 +26,8 @@ from .config import Config
from .database import Connection, get_database
from .database.schema import Instance
from .http_client import HttpClient
-from .misc import Message, Response, check_open_port, get_resource
+from .misc import HttpError, Message, Response, check_open_port, get_resource
+from .misc import JSON_PATHS, TOKEN_PATHS
from .template import Template
from .views import VIEWS
from .views.api import handle_api_path
@@ -53,9 +57,9 @@ class Application(web.Application):
def __init__(self, cfgpath: Path | None, dev: bool = False):
web.Application.__init__(self,
middlewares = [
- handle_api_path, # type: ignore[list-item]
+ handle_response_headers, # type: ignore[list-item]
handle_frontend_path, # type: ignore[list-item]
- handle_response_headers # type: ignore[list-item]
+ handle_api_path # type: ignore[list-item]
]
)
@@ -282,19 +286,70 @@ class CacheCleanupThread(Thread):
self.running.clear()
+def format_error(request: web.Request, error: HttpError) -> Response:
+ app: Application = request.app # type: ignore[assignment]
+
+ if request.path.startswith(JSON_PATHS) or 'json' in request.headers.get('accept', ''):
+ return Response.new({'error': error.body}, error.status, ctype = 'json')
+
+ else:
+ body = app.template.render('page/error.haml', request, e = error)
+ return Response.new(body, error.status, ctype = 'html')
+
+
@web.middleware
async def handle_response_headers(
request: web.Request,
handler: Callable[[web.Request], Awaitable[Response]]) -> Response:
- resp = await handler(request)
+ request['hash'] = b64encode(get_random_bytes(16)).decode('ascii')
+ request['token'] = None
+ request['user'] = None
+
+ app: Application = request.app # type: ignore[assignment]
+
+ if request.path == "/" or request.path.startswith(TOKEN_PATHS):
+ with app.database.session() as conn:
+ if (token := request.headers.get('Authorization')) is not None:
+ token = token.replace('Bearer', '').strip()
+
+ request['token'] = conn.get_app_by_token(token)
+ request['user'] = conn.get_user_by_app_token(token)
+
+ elif (token := request.cookies.get('user-token')) is not None:
+ request['token'] = conn.get_token(token)
+ request['user'] = conn.get_user_by_token(token)
+
+ try:
+ resp = await handler(request)
+
+ except HttpError as e:
+ resp = format_error(request, e)
+
+ except HTTPException as ae:
+ if ae.status == 404:
+ try:
+ text = (ae.text or "").split(":")[1].strip()
+
+ except IndexError:
+ text = ae.text or ""
+
+ resp = format_error(request, HttpError(ae.status, text))
+
+ else:
+ raise
+
+ except Exception as e:
+ resp = format_error(request, HttpError(500, f'{type(e).__name__}: {str(e)}'))
+ traceback.print_exc()
+
resp.headers['Server'] = 'ActivityRelay'
# Still have to figure out how csp headers work
- if resp.content_type == 'text/html' and not request.path.startswith("/api"):
+ if resp.content_type == 'text/html':
resp.headers['Content-Security-Policy'] = get_csp(request)
- if not request.app['dev'] and request.path.endswith(('.css', '.js')):
+ if not request.app['dev'] and request.path.endswith(('.css', '.js', '.woff2')):
# cache for 2 weeks
resp.headers['Cache-Control'] = 'public,max-age=1209600,immutable'
diff --git a/relay/data/statements.sql b/relay/data/statements.sql
index f06d4b5..e8694ae 100644
--- a/relay/data/statements.sql
+++ b/relay/data/statements.sql
@@ -56,6 +56,14 @@ WHERE username = (
);
+-- name: get-user-by-app-token
+SELECT * FROM users
+WHERE username = (
+ SELECT user FROM app
+ WHERE code = :code
+);
+
+
-- name: put-user
INSERT INTO users (username, hash, handle, created)
VALUES (:username, :hash, :handle, :created)
@@ -67,6 +75,30 @@ DELETE FROM users
WHERE username = :value or handle = :value;
+-- name: get-app
+SELECT * FROM app
+WHERE client_id = :id and client_secret = :secret;
+
+
+-- name: get-app-token
+SELECT * FROM app
+WHERE client_id = :id and client_secret = :secret and token = :token;
+
+
+-- name: get-app-by-token
+SELECT * FROM app
+WHERE token = :token;
+
+-- name: del-app
+DELETE FROM users
+WHERE client_id = :id and client_secret = :secret;
+
+
+-- name: del-app-token
+DELETE FROM users
+WHERE client_id = :id and client_secret = :secret and token = :token;
+
+
-- name: get-token
SELECT * FROM tokens
WHERE code = :code;
diff --git a/relay/database/config.py b/relay/database/config.py
index 2be3ecc..3f3c7e0 100644
--- a/relay/database/config.py
+++ b/relay/database/config.py
@@ -11,11 +11,7 @@ from .. import logger as logging
from ..misc import boolean
if TYPE_CHECKING:
- try:
- from typing import Self
-
- except ImportError:
- from typing_extensions import Self
+ from typing import Self
THEMES = {
@@ -77,7 +73,7 @@ CONFIG_CONVERT: dict[str, tuple[Callable[[Any], str], Callable[[str], Any]]] = {
@dataclass()
class ConfigData:
- schema_version: int = 20240310
+ schema_version: int = 20240625
private_key: str = ''
approval_required: bool = False
log_level: logging.LogLevel = logging.LogLevel.INFO
@@ -115,11 +111,11 @@ class ConfigData:
@classmethod
def DEFAULT(cls: type[Self], key: str) -> str | int | bool:
- return cls.FIELD(key.replace('-', '_')).default # type: ignore
+ return cls.FIELD(key.replace('-', '_')).default # type: ignore[return-value]
@classmethod
- def FIELD(cls: type[Self], key: str) -> Field[Any]:
+ def FIELD(cls: type[Self], key: str) -> Field[str | int | bool]:
for field in fields(cls):
if field.name == key.replace('-', '_'):
return field
diff --git a/relay/database/connection.py b/relay/database/connection.py
index 006a907..3c973b8 100644
--- a/relay/database/connection.py
+++ b/relay/database/connection.py
@@ -1,6 +1,9 @@
from __future__ import annotations
+import secrets
+
from argon2 import PasswordHasher
+from blib import Date
from bsql import Connection as SqlConnection, Row, Update
from collections.abc import Iterator
from datetime import datetime, timezone
@@ -49,6 +52,40 @@ class Connection(SqlConnection):
yield instance
+ def fix_timestamps(self) -> None:
+ for app in self.select('apps').all(schema.App):
+ data = {'created': app.created.timestamp(), 'accessed': app.accessed.timestamp()}
+ self.update('apps', data, client_id = app.client_id)
+
+ for item in self.select('cache'):
+ data = {'updated': Date.parse(item['updated']).timestamp()}
+ self.update('cache', data, id = item['id'])
+
+ for dban in self.select('domain_bans').all(schema.DomainBan):
+ data = {'created': dban.created.timestamp()}
+ self.update('domain_bans', data, domain = dban.domain)
+
+ for instance in self.select('inboxes').all(schema.Instance):
+ data = {'created': instance.created.timestamp()}
+ self.update('inboxes', data, domain = instance.domain)
+
+ for sban in self.select('software_bans').all(schema.SoftwareBan):
+ data = {'created': sban.created.timestamp()}
+ self.update('software_bans', data, name = sban.name)
+
+ for token in self.select('tokens').all(schema.Token):
+ data = {'created': token.created.timestamp(), 'accessed': token.accessed.timestamp()}
+ self.update('tokens', data, code = token.code)
+
+ for user in self.select('users').all(schema.User):
+ data = {'created': user.created.timestamp()}
+ self.update('users', data, username = user.username)
+
+ for wlist in self.select('whitelist').all(schema.Whitelist):
+ data = {'created': wlist.created.timestamp()}
+ self.update('whitelist', data, domain = wlist.domain)
+
+
def get_config(self, key: str) -> Any:
key = key.replace('_', '-')
@@ -198,6 +235,11 @@ class Connection(SqlConnection):
return cur.one(schema.User)
+ def get_user_by_app_token(self, code: str) -> schema.User | None:
+ with self.run('get-user-by-app-token', {'code': code}) as cur:
+ return cur.one(schema.User)
+
+
def get_users(self) -> Iterator[schema.User]:
return self.execute("SELECT * FROM users").all(schema.User)
@@ -249,13 +291,102 @@ class Connection(SqlConnection):
pass
+ def get_app(self,
+ client_id: str,
+ client_secret: str,
+ token: str | None = None) -> schema.App | None:
+
+ params = {
+ 'id': client_id,
+ 'secret': client_secret
+ }
+
+ if token is not None:
+ command = 'get-app-with-token'
+ params['token'] = token
+
+ else:
+ command = 'get-app'
+
+ with self.run(command, params) as cur:
+ return cur.one(schema.App)
+
+
+ def get_app_by_token(self, token: str) -> schema.App | None:
+ with self.run('get-app-by-token', {'token': token}) as cur:
+ return cur.one(schema.App)
+
+
+ def put_app(self, name: str, redirect_uri: str, website: str | None = None) -> schema.App:
+ params = {
+ 'name': name,
+ 'redirect_uri': redirect_uri,
+ 'website': website,
+ 'client_id': secrets.token_hex(20),
+ 'client_secret': secrets.token_hex(20),
+ 'created': Date.new_utc().timestamp(),
+ 'accessed': Date.new_utc().timestamp()
+ }
+
+ with self.insert('app', params) as cur:
+ if (row := cur.one(schema.App)) is None:
+ raise RuntimeError(f'Failed to insert app: {name}')
+
+ return row
+
+
+ def update_app(self, app: schema.App, user: schema.User | None, set_auth: bool) -> schema.App:
+ data: dict[str, str | None] = {}
+
+ if user is not None:
+ data['user'] = user.username
+
+ if set_auth:
+ data['auth_code'] = secrets.token_hex(20)
+
+ else:
+ data['token'] = secrets.token_hex(20)
+ data['auth_code'] = None
+
+ params = {
+ 'client_id': app.client_id,
+ 'client_secret': app.client_secret
+ }
+
+ with self.update('app', data, **params) as cur: # type: ignore[arg-type]
+ if (row := cur.one(schema.App)) is None:
+ raise RuntimeError('Failed to update row')
+
+ return row
+
+
+ def del_app(self, client_id: str, client_secret: str, token: str | None = None) -> bool:
+ params = {
+ 'id': client_id,
+ 'secret': client_secret
+ }
+
+ if token is not None:
+ command = 'del-app-token'
+ params['token'] = token
+
+ else:
+ command = 'del-app'
+
+ with self.run(command, params) as cur:
+ if cur.row_count > 1:
+ raise RuntimeError('More than 1 row was deleted')
+
+ return cur.row_count == 0
+
+
def get_token(self, code: str) -> schema.Token | None:
with self.run('get-token', {'code': code}) as cur:
return cur.one(schema.Token)
def get_tokens(self, username: str | None = None) -> Iterator[schema.Token]:
- if username is not None:
+ if username is None:
return self.select('tokens').all(schema.Token)
return self.select('tokens', username = username).all(schema.Token)
diff --git a/relay/database/schema.py b/relay/database/schema.py
index 1fd7003..660e527 100644
--- a/relay/database/schema.py
+++ b/relay/database/schema.py
@@ -1,14 +1,14 @@
from __future__ import annotations
-import typing
-
+from blib import Date
from bsql import Column, Row, Tables
from collections.abc import Callable
-from datetime import datetime
+from copy import deepcopy
+from typing import TYPE_CHECKING, Any
from .config import ConfigData
-if typing.TYPE_CHECKING:
+if TYPE_CHECKING:
from .connection import Connection
@@ -16,6 +16,16 @@ VERSIONS: dict[int, Callable[[Connection], None]] = {}
TABLES = Tables()
+def deserialize_timestamp(value: Any) -> Date:
+ try:
+ return Date.parse(value)
+
+ except ValueError:
+ pass
+
+ return Date.fromisoformat(value)
+
+
@TABLES.add_row
class Config(Row):
key: Column[str] = Column('key', 'text', primary_key = True, unique = True, nullable = False)
@@ -27,62 +37,125 @@ class Config(Row):
class Instance(Row):
table_name: str = 'inboxes'
+
domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = False)
actor: Column[str] = Column('actor', 'text', unique = True)
inbox: Column[str] = Column('inbox', 'text', unique = True, nullable = False)
followid: Column[str] = Column('followid', 'text')
software: Column[str] = Column('software', 'text')
- accepted: Column[datetime] = Column('accepted', 'boolean')
- created: Column[datetime] = Column('created', 'timestamp', nullable = False)
+ accepted: Column[Date] = Column('accepted', 'boolean')
+ created: Column[Date] = Column(
+ 'created', 'timestamp', nullable = False,
+ deserializer = deserialize_timestamp, serializer = Date.timestamp
+ )
@TABLES.add_row
class Whitelist(Row):
domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = True)
- created: Column[datetime] = Column('created', 'timestamp')
+ created: Column[Date] = Column(
+ 'created', 'timestamp', nullable = False,
+ deserializer = deserialize_timestamp, serializer = Date.timestamp
+ )
@TABLES.add_row
class DomainBan(Row):
table_name: str = 'domain_bans'
+
domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text')
- created: Column[datetime] = Column('created', 'timestamp')
+ created: Column[Date] = Column(
+ 'created', 'timestamp', nullable = False,
+ deserializer = deserialize_timestamp, serializer = Date.timestamp
+ )
@TABLES.add_row
class SoftwareBan(Row):
table_name: str = 'software_bans'
+
name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text')
- created: Column[datetime] = Column('created', 'timestamp')
+ created: Column[Date] = Column(
+ 'created', 'timestamp', nullable = False,
+ deserializer = deserialize_timestamp, serializer = Date.timestamp
+ )
@TABLES.add_row
class User(Row):
table_name: str = 'users'
+
username: Column[str] = Column(
'username', 'text', primary_key = True, unique = True, nullable = False)
hash: Column[str] = Column('hash', 'text', nullable = False)
handle: Column[str] = Column('handle', 'text')
- created: Column[datetime] = Column('created', 'timestamp')
+ created: Column[Date] = Column(
+ 'created', 'timestamp', nullable = False,
+ deserializer = deserialize_timestamp, serializer = Date.timestamp
+ )
@TABLES.add_row
class Token(Row):
table_name: str = 'tokens'
+
code: Column[str] = Column('code', 'text', primary_key = True, unique = True, nullable = False)
user: Column[str] = Column('user', 'text', nullable = False)
- created: Column[datetime] = Column('created', 'timestamp')
+ created: Column[Date] = Column(
+ 'created', 'timestamp', nullable = False,
+ deserializer = deserialize_timestamp, serializer = Date.timestamp
+ )
+ accessed: Column[Date] = Column(
+ 'accessed', 'timestamp', nullable = False,
+ deserializer = deserialize_timestamp, serializer = Date.timestamp
+ )
+
+
+@TABLES.add_row
+class App(Row):
+ table_name: str = 'apps'
+
+
+ client_id: Column[str] = Column(
+ 'client_id', 'text', primary_key = True, unique = True, nullable = False)
+ client_secret: Column[str] = Column('client_secret', 'text', nullable = False)
+ name: Column[str] = Column('name', 'text')
+ website: Column[str] = Column('website', 'text')
+ redirect_uri: Column[str] = Column('redirect_uri', 'text', nullable = False)
+ token: Column[str | None] = Column('token', 'text')
+ auth_code: Column[str | None] = Column('auth_code', 'text')
+ user: Column[str | None] = Column('user', 'text')
+ created: Column[Date] = Column(
+ 'created', 'timestamp', nullable = False,
+ deserializer = deserialize_timestamp, serializer = Date.timestamp
+ )
+ accessed: Column[Date] = Column(
+ 'accessed', 'timestamp', nullable = False,
+ deserializer = deserialize_timestamp, serializer = Date.timestamp
+ )
+
+
+ def get_api_data(self, include_token: bool = False) -> dict[str, Any]:
+ data = deepcopy(self)
+ data.pop('auth_code')
+ data.pop('created')
+ data.pop('accessed')
+
+ if not include_token:
+ data.pop('token')
+
+ return data
def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]:
@@ -103,5 +176,15 @@ def migrate_20240206(conn: Connection) -> None:
@migration
def migrate_20240310(conn: Connection) -> None:
- conn.execute("ALTER TABLE inboxes ADD COLUMN accepted BOOLEAN")
- conn.execute("UPDATE inboxes SET accepted = 1")
+ conn.execute('ALTER TABLE "inboxes" ADD COLUMN "accepted" BOOLEAN')
+ conn.execute('UPDATE "inboxes" SET accepted = 1')
+
+
+@migration
+def migrate_20240625(conn: Connection) -> None:
+ conn.execute('ALTER TABLE "tokens" ADD "accessed" timestamp')
+
+ for token in conn.get_tokens():
+ conn.update('tokens', {'accessed': token.created}, code = token.code).one()
+
+ conn.create_tables()
diff --git a/relay/frontend/base.haml b/relay/frontend/base.haml
index d3d8bb6..dd1e3e2 100644
--- a/relay/frontend/base.haml
+++ b/relay/frontend/base.haml
@@ -1,5 +1,5 @@
-macro menu_item(name, path)
- -if view.request.path == path or (path != "/" and view.request.path.startswith(path))
+ -if request.path == path or (path != "/" and request.path.startswith(path))
%a.button(href="{{path}}" active="true") -> =name
-else
@@ -10,12 +10,12 @@
%head
%title << {{config.name}}: {{page}}
%meta(charset="UTF-8")
- %meta(name="viewport" content="width=device-width, initial-scale=1")
- %link(rel="stylesheet" type="text/css" href="/theme/{{config.theme}}.css?{{version}}" nonce="{{view.request['hash']}}" class="theme")
- %link(rel="stylesheet" type="text/css" href="/static/style.css?{{version}}" nonce="{{view.request['hash']}}")
- %link(rel="stylesheet" type="text/css" href="/static/bootstrap-icons.css?{{version}}" nonce="{{view.request['hash']}}")
+ %meta(name="ort" content="width=device-width, initial-scale=1")
+ %link(rel="stylesheet" type="text/css" href="/theme/{{config.theme}}.css?{{version}}" nonce="{{request['hash']}}" class="theme")
+ %link(rel="stylesheet" type="text/css" href="/static/style.css?{{version}}" nonce="{{request['hash']}}")
+ %link(rel="stylesheet" type="text/css" href="/static/bootstrap-icons.css?{{version}}" nonce="{{request['hash']}}")
%link(rel="manifest" href="/manifest.json?{{version}}")
- %script(type="application/javascript" src="/static/functions.js?{{version}}" nonce="{{view.request['hash']}}" defer)
+ %script(type="application/javascript" src="/static/functions.js?{{version}}" nonce="{{request['hash']}}" defer)
-block head
%body
@@ -26,7 +26,7 @@
{{menu_item("Home", "/")}}
- -if view.request["user"]
+ -if request["user"]
{{menu_item("Instances", "/admin/instances")}}
{{menu_item("Whitelist", "/admin/whitelist")}}
{{menu_item("Domain Bans", "/admin/domain_bans")}}
@@ -61,11 +61,11 @@
#footer.section
.col1
- -if not view.request["user"]
+ -if not request["user"]
%a(href="/login") << Login
-else
- =view.request["user"]["username"]
+ =request["user"]["username"]
(
%a(href="/logout") << Logout
)
diff --git a/relay/frontend/page/authorize_new.haml b/relay/frontend/page/authorize_new.haml
new file mode 100644
index 0000000..4f07df3
--- /dev/null
+++ b/relay/frontend/page/authorize_new.haml
@@ -0,0 +1,31 @@
+-extends "base.haml"
+-set page="App Authorization"
+
+-block content
+ %fieldset.section
+ %legend << App Authorization
+
+ -if application.website
+ #title << Application "{{application.name}}" wants full API access
+
+ -else
+ #title << Application "{{application.name}}" wants full API access
+
+ #buttons
+ .spacer
+
+ %form(action="/oauth/authorize" method="POST")
+ %input(type="hidden" name="client_id" value="{{application.client_id}}")
+ %input(type="hidden" name="client_secret" value="{{application.client_secret}}")
+ %input(type="hidden" name="redirect_uri" value="{{application.redirect_uri}}")
+ %input(type="hidden" name="response" value="true")
+ %input.button(type="submit" value="Allow")
+
+ %form(action="/oauth/authorize" method="POST")
+ %input(type="hidden" name="client_id" value="{{application.client_id}}")
+ %input(type="hidden" name="client_secret" value="{{application.client_secret}}")
+ %input(type="hidden" name="redirect_uri" value="{{application.redirect_uri}}")
+ %input(type="hidden" name="response" value="false")
+ %input.button(type="submit" value="Deny")
+
+ .spacer
diff --git a/relay/frontend/page/authorize_show.haml b/relay/frontend/page/authorize_show.haml
new file mode 100644
index 0000000..19cde40
--- /dev/null
+++ b/relay/frontend/page/authorize_show.haml
@@ -0,0 +1,18 @@
+-extends "base.haml"
+-set page="App Authorization"
+
+-block content
+ %fieldset.section
+ %legend << App Authorization Code
+
+ -if application.website
+ %p
+ Copy the following code into
+ %a(href="{{application.website}}" target="_main") -> %code -> =application.name
+
+ -else
+ %p
+ Copy the following code info
+ %code -> =application.name
+
+ %pre#code -> =application.auth_code
diff --git a/relay/frontend/page/error.haml b/relay/frontend/page/error.haml
new file mode 100644
index 0000000..4d4bf95
--- /dev/null
+++ b/relay/frontend/page/error.haml
@@ -0,0 +1,7 @@
+-extends "base.haml"
+-set page="Error"
+
+-block content
+ .section.error
+ .title << HTTP Error {{e.status}}
+ .body -> =e.body
diff --git a/relay/frontend/page/login.haml b/relay/frontend/page/login.haml
index c32160f..4f29746 100644
--- a/relay/frontend/page/login.haml
+++ b/relay/frontend/page/login.haml
@@ -12,4 +12,6 @@
%label(for="password") << Password
%input(id="password" name="password" placeholder="Password" type="password")
+
+ %input#redir(type="hidden" name="redir" value="{{redir}}")
%input.submit(type="button" value="Login")
diff --git a/relay/frontend/static/functions.js b/relay/frontend/static/functions.js
index b0e4db5..3063223 100644
--- a/relay/frontend/static/functions.js
+++ b/relay/frontend/static/functions.js
@@ -483,13 +483,15 @@ function page_instance() {
function page_login() {
const fields = {
username: document.querySelector("#username"),
- password: document.querySelector("#password")
- }
+ password: document.querySelector("#password"),
+ redir: document.querySelector("#redir")
+ };
async function login(event) {
const values = {
username: fields.username.value.trim(),
- password: fields.password.value.trim()
+ password: fields.password.value.trim(),
+ redir: fields.redir.value.trim()
}
if (values.username === "" | values.password === "") {
@@ -498,14 +500,14 @@ function page_login() {
}
try {
- await request("POST", "v1/token", values);
+ await request("POST", "v1/login", values);
} catch (error) {
toast(error);
return;
}
- document.location = "/";
+ document.location = values.redir;
}
@@ -848,9 +850,6 @@ if (location.pathname.startsWith("/admin/config")) {
} else if (location.pathname.startsWith("/admin/instances")) {
page_instance();
-} else if (location.pathname.startsWith("/admin/login")) {
- page_login();
-
} else if (location.pathname.startsWith("/admin/software_bans")) {
page_software_ban();
@@ -859,4 +858,7 @@ if (location.pathname.startsWith("/admin/config")) {
} else if (location.pathname.startsWith("/admin/whitelist")) {
page_whitelist();
+
+} else if (location.pathname.startsWith("/login")) {
+ page_login();
}
diff --git a/relay/frontend/static/style.css b/relay/frontend/static/style.css
index f0d72f5..c9bcd43 100644
--- a/relay/frontend/static/style.css
+++ b/relay/frontend/static/style.css
@@ -338,6 +338,44 @@ textarea {
}
+/* error */
+#content.page-error {
+ text-align: center;
+}
+
+#content.page-error .title {
+ font-size: 24px;
+ font-weight: bold;
+}
+
+
+/* auth */
+#content.page-app_authorization {
+ text-align: center;
+}
+
+#content.page-app_authorization #code {
+ background: var(--background);
+ border: 1px solid var(--border);
+ font-size: 18px;
+ margin: 0 auto;
+ width: max-content;
+ padding: 5px;
+}
+
+#content.page-app_authorization #title {
+ font-size: 24px;
+}
+
+#content.page-app_authorization #buttons {
+ display: grid;
+ grid-template-columns: auto max-content max-content auto;
+ grid-gap: var(--spacing);
+ justify-items: center;
+ margin: var(--spacing) 0;
+}
+
+
@keyframes show_toast {
0% {
transform: translateX(100%);
diff --git a/relay/manage.py b/relay/manage.py
index 81f546e..5ae8238 100644
--- a/relay/manage.py
+++ b/relay/manage.py
@@ -212,6 +212,21 @@ def cli_run(ctx: click.Context, dev: bool = False) -> None:
os._exit(0)
+@cli.command('db-maintenance')
+@click.option('--fix-timestamps', '-t', is_flag = True,
+ help = 'Make sure timestamps in the database are float values')
+@click.pass_context
+def cli_db_maintenance(ctx: click.Context, fix_timestamps: bool) -> None:
+ 'Perform maintenance tasks on the database'
+
+ if fix_timestamps:
+ with ctx.obj.database.session(True) as conn:
+ conn.fix_timestamps()
+
+ with ctx.obj.database.session(False) as conn:
+ with conn.execute("VACUUM"):
+ pass
+
@cli.command('convert')
@click.option('--old-config', '-o', help = 'Path to the config file to convert from')
@@ -239,18 +254,18 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
ctx.obj.config.set('domain', config['host'])
ctx.obj.config.save()
+ # fix: mypy complains about the types returned by click.progressbar when updating click to 8.1.7
with get_database(ctx.obj.config) as db:
with db.session(True) as conn:
conn.put_config('private-key', database['private-key'])
conn.put_config('note', config['note'])
conn.put_config('whitelist-enabled', config['whitelist_enabled'])
- with click.progressbar( # type: ignore
+ with click.progressbar(
database['relay-list'].values(),
label = 'Inboxes'.ljust(15),
width = 0
) as inboxes:
-
for inbox in inboxes:
if inbox['software'] in {'akkoma', 'pleroma'}:
actor = f'https://{inbox["domain"]}/relay'
@@ -269,7 +284,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
software = inbox['software']
)
- with click.progressbar( # type: ignore
+ with click.progressbar(
config['blocked_software'],
label = 'Banned software'.ljust(15),
width = 0
@@ -281,7 +296,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
reason = 'relay' if software in RELAY_SOFTWARE else None
)
- with click.progressbar( # type: ignore
+ with click.progressbar(
config['blocked_instances'],
label = 'Banned domains'.ljust(15),
width = 0
@@ -290,7 +305,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
for domain in banned_software:
conn.put_domain_ban(domain)
- with click.progressbar( # type: ignore
+ with click.progressbar(
config['whitelist'],
label = 'Whitelist'.ljust(15),
width = 0
diff --git a/relay/misc.py b/relay/misc.py
index 6995bc4..b27c89a 100644
--- a/relay/misc.py
+++ b/relay/misc.py
@@ -62,6 +62,28 @@ SOFTWARE = (
'gotosocial'
)
+JSON_PATHS: tuple[str, ...] = (
+ '/api/v1',
+ '/actor',
+ '/inbox',
+ '/outbox',
+ '/following',
+ '/followers',
+ '/.well-known',
+ '/nodeinfo',
+ '/oauth/token',
+ '/oauth/revoke'
+)
+
+TOKEN_PATHS: tuple[str, ...] = (
+ '/api',
+ '/login',
+ '/logout',
+ '/oauth/authorize',
+ '/oauth/revoke',
+ '/admin'
+)
+
def boolean(value: Any) -> bool:
if isinstance(value, str):
@@ -113,6 +135,17 @@ def get_resource(path: str) -> Path:
return Path(str(pkgfiles('relay'))).joinpath(path)
+class HttpError(Exception):
+ def __init__(self,
+ status: int,
+ body: str) -> None:
+
+ self.body: str = body
+ self.status: int = status
+
+ Exception.__init__(self, f"HTTP Error {status}: {body}")
+
+
class JsonEncoder(json.JSONEncoder):
def default(self, o: Any) -> str:
if isinstance(o, datetime):
@@ -242,9 +275,9 @@ class Response(AiohttpResponse):
@classmethod
- def new_redir(cls: type[Self], path: str) -> Self:
+ def new_redir(cls: type[Self], path: str, status: int = 307) -> Self:
body = f'Redirect to {path}'
- return cls.new(body, 302, {'Location': path})
+ return cls.new(body, status, {'Location': path}, ctype = 'html')
@property
diff --git a/relay/template.py b/relay/template.py
index 7e3f657..3ee2855 100644
--- a/relay/template.py
+++ b/relay/template.py
@@ -2,6 +2,7 @@ from __future__ import annotations
import textwrap
+from aiohttp.web import Request
from collections.abc import Callable
from hamlish_jinja import HamlishExtension
from jinja2 import Environment, FileSystemLoader
@@ -13,7 +14,6 @@ from typing import TYPE_CHECKING, Any
from . import __version__
from .misc import get_resource
-from .views.base import View
if TYPE_CHECKING:
from .application import Application
@@ -43,12 +43,12 @@ class Template(Environment):
self.hamlish_mode = 'indented'
- def render(self, path: str, view: View | None = None, **context: Any) -> str:
+ def render(self, path: str, request: Request, **context: Any) -> str:
with self.app.database.session(False) as conn:
config = conn.get_config_all()
new_context = {
- 'view': view,
+ 'request': request,
'domain': self.app.config.domain,
'version': __version__,
'config': config,
diff --git a/relay/views/activitypub.py b/relay/views/activitypub.py
index 74b01c6..aa672f2 100644
--- a/relay/views/activitypub.py
+++ b/relay/views/activitypub.py
@@ -7,7 +7,7 @@ from .base import View, register_route
from .. import logger as logging
from ..database import schema
-from ..misc import Message, Response
+from ..misc import HttpError, Message, Response
from ..processors import run_processor
@@ -39,8 +39,7 @@ class ActorView(View):
async def post(self, request: Request) -> Response:
- if response := await self.get_post_data():
- return response
+ await self.get_post_data()
with self.database.session() as conn:
self.instance = conn.get_inbox(self.actor.shared_inbox) # type: ignore[assignment]
@@ -65,13 +64,13 @@ class ActorView(View):
return Response.new(status = 202)
- async def get_post_data(self) -> Response | None:
+ async def get_post_data(self) -> None:
try:
self.signature = aputils.Signature.parse(self.request.headers['signature'])
except KeyError:
logging.verbose('Missing signature header')
- return Response.new_error(400, 'missing signature header', 'json')
+ raise HttpError(400, 'missing signature header')
try:
message: Message | None = await self.request.json(loads = Message.parse)
@@ -79,17 +78,17 @@ class ActorView(View):
except Exception:
traceback.print_exc()
logging.verbose('Failed to parse inbox message')
- return Response.new_error(400, 'failed to parse message', 'json')
+ raise HttpError(400, 'failed to parse message')
if message is None:
logging.verbose('empty message')
- return Response.new_error(400, 'missing message', 'json')
+ raise HttpError(400, 'missing message')
self.message = message
if 'actor' not in self.message:
logging.verbose('actor not in message')
- return Response.new_error(400, 'no actor in message', 'json')
+ raise HttpError(400, 'no actor in message')
try:
self.actor = await self.client.get(self.signature.keyid, True, Message)
@@ -98,26 +97,24 @@ class ActorView(View):
# ld signatures aren't handled atm, so just ignore it
if self.message.type == 'Delete':
logging.verbose('Instance sent a delete which cannot be handled')
- return Response.new(status=202)
+ raise HttpError(202, '')
logging.verbose(f'Failed to fetch actor: {self.signature.keyid}')
- return Response.new_error(400, 'failed to fetch actor', 'json')
+ raise HttpError(400, 'failed to fetch actor')
try:
self.signer = self.actor.signer
except KeyError:
logging.verbose('Actor missing public key: %s', self.signature.keyid)
- return Response.new_error(400, 'actor missing public key', 'json')
+ raise HttpError(400, 'actor missing public key')
try:
await self.signer.validate_request_async(self.request)
except aputils.SignatureFailureError as e:
logging.verbose('signature validation failed for "%s": %s', self.actor.id, e)
- return Response.new_error(401, str(e), 'json')
-
- return None
+ raise HttpError(401, str(e))
@register_route('/outbox')
diff --git a/relay/views/api.py b/relay/views/api.py
index 73b6a16..76cd1e5 100644
--- a/relay/views/api.py
+++ b/relay/views/api.py
@@ -1,16 +1,17 @@
+import secrets
import traceback
from aiohttp.web import Request, middleware
from argon2.exceptions import VerifyMismatchError
+from blib import convert_to_boolean
from collections.abc import Awaitable, Callable, Sequence
-from typing import Any
from urllib.parse import urlparse
from .base import View, register_route
from .. import __version__
-from ..database import ConfigData
-from ..misc import Message, Response, boolean, get_app
+from ..database import ConfigData, schema
+from ..misc import HttpError, Message, Response, boolean
DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob'
@@ -22,6 +23,8 @@ ALLOWED_HEADERS: set[str] = {
PUBLIC_API_PATHS: Sequence[tuple[str, str]] = (
('GET', '/api/v1/relay'),
+ ('POST', '/api/v1/app'),
+ ('POST', '/api/v1/login'),
('POST', '/api/v1/token')
)
@@ -37,57 +40,174 @@ def check_api_path(method: str, path: str) -> bool:
async def handle_api_path(
request: Request,
handler: Callable[[Request], Awaitable[Response]]) -> Response:
- try:
- if (token := request.cookies.get('user-token')):
- request['token'] = token
- else:
- request['token'] = request.headers['Authorization'].replace('Bearer', '').strip()
-
- with get_app().database.session() as conn:
- request['user'] = conn.get_user_by_token(request['token'])
-
- except (KeyError, ValueError):
- request['token'] = None
- request['user'] = None
+ if not request.path.startswith('/api'):
+ return await handler(request)
if request.method != "OPTIONS" and check_api_path(request.method, request.path):
- if not request['token']:
- return Response.new_error(401, 'Missing token', 'json')
+ if request['token'] is None:
+ raise HttpError(401, 'Missing token')
- if not request['user']:
- return Response.new_error(401, 'Invalid token', 'json')
+ if request['user'] is None:
+ raise HttpError(401, 'Invalid token')
response = await handler(request)
-
- if request.path.startswith('/api'):
- response.headers['Access-Control-Allow-Origin'] = '*'
- response.headers['Access-Control-Allow-Headers'] = ', '.join(ALLOWED_HEADERS)
+ response.headers['Access-Control-Allow-Origin'] = '*'
+ response.headers['Access-Control-Allow-Headers'] = ', '.join(ALLOWED_HEADERS)
return response
-@register_route('/api/v1/token')
-class Login(View):
+@register_route('/oauth/authorize')
+class OauthAuthorize(View):
async def get(self, request: Request) -> Response:
- return Response.new({'message': 'Token valid'}, ctype = 'json')
+ data = await self.get_api_data(['response_type', 'client_id', 'redirect_uri'], [])
+
+ if data['response_type'] != 'code':
+ raise HttpError(400, 'Response type is not "code"')
+
+ with self.database.session(True) as conn:
+ with conn.select('app', client_id = data['client_id']) as cur:
+ if (app := cur.one(schema.App)) is None:
+ raise HttpError(404, 'Could not find app')
+
+ if app.token is not None or app.auth_code is not None:
+ context = {'application': app}
+ html = self.template.render(
+ 'page/authorize_show.haml', self.request, **context
+ )
+
+ return Response.new(html, ctype = 'html')
+
+ if data['redirect_uri'] != app.redirect_uri:
+ raise HttpError(400, 'redirect_uri does not match application')
+
+ context = {'application': app}
+ html = self.template.render('page/authorize_new.haml', self.request, **context)
+ return Response.new(html, ctype = 'html')
async def post(self, request: Request) -> Response:
- data = await self.get_api_data(['username', 'password'], [])
+ data = await self.get_api_data(
+ ['client_id', 'client_secret', 'redirect_uri', 'response'], []
+ )
- if isinstance(data, Response):
- return data
+ with self.database.session(True) as conn:
+ if (app := conn.get_app(data['client_id'], data['client_secret'])) is None:
+ return Response.new_error(404, 'Could not find app', 'json')
+
+ if convert_to_boolean(data['response']):
+ if app.auth_code is None:
+ app = conn.update_app(app, request['user'], True)
+
+ if app.redirect_uri == DEFAULT_REDIRECT:
+ context = {'application': app}
+ html = self.template.render(
+ 'page/authorize_show.haml', self.request, **context
+ )
+
+ return Response.new(html, ctype = 'html')
+
+ return Response.new_redir(f'{app.redirect_uri}?code={app.auth_code}')
+
+ if not conn.del_app(app.client_id, app.client_secret):
+ raise HttpError(404, 'App not found')
+
+ return Response.new_redir('/')
+
+
+@register_route('/oauth/token')
+class OauthToken(View):
+ async def post(self, request: Request) -> Response:
+ data = await self.get_api_data(
+ ['grant_type', 'code', 'client_id', 'client_secret', 'redirect_uri'], []
+ )
+
+ if data['grant_type'] != 'authorization_code':
+ raise HttpError(400, 'Invalid grant type')
+
+ with self.database.session(True) as conn:
+ if (app := conn.get_app(data['client_id'], data['client_secret'])) is None:
+ raise HttpError(404, 'Application not found')
+
+ if app.auth_code != data['code']:
+ raise HttpError(400, 'Invalid authentication code')
+
+ if app.redirect_uri != data['redirect_uri']:
+ raise HttpError(400, 'Invalid redirect uri')
+
+ app = conn.update_app(app, request['user'], False)
+
+ return Response.new(app.get_api_data(True), ctype = 'json')
+
+
+@register_route('/oauth/revoke')
+class OauthRevoke(View):
+ async def post(self, request: Request) -> Response:
+ data = await self.get_api_data(['client_id', 'client_secret', 'token'], [])
+
+ with self.database.session(True) as conn:
+ if (app := conn.get_app(**data)) is None:
+ raise HttpError(404, 'Could not find token')
+
+ if app.user != request['token'].username:
+ raise HttpError(403, 'Invalid token')
+
+ if not conn.del_app(**data):
+ raise HttpError(400, 'Token not removed')
+
+ return Response.new({'msg': 'Token deleted'}, ctype = 'json')
+
+
+@register_route('/api/v1/app')
+class App(View):
+ async def get(self, request: Request) -> Response:
+ data = await self.get_api_data(['client_id', 'client_secret'], [])
+
+ with self.database.session(False) as conn:
+ if (app := conn.get_app(data['client_id'], data['client_secret'])) is None:
+ raise HttpError(404, 'Application cannot be found')
+
+ return Response.new(app.get_api_data(), ctype = 'json')
+
+
+ async def post(self, request: Request) -> Response:
+ data = await self.get_api_data(['name', 'redirect_uri'], ['website'])
+
+ with self.database.session(True) as conn:
+ app = conn.put_app(
+ name = data['name'],
+ redirect_uri = data['redirect_uri'],
+ website = data.get('website')
+ )
+
+ return Response.new(app.get_api_data(), ctype = 'json')
+
+
+ async def delete(self, request: Request) -> Response:
+ data = await self.get_api_data(['client_id', 'client_secret'], [])
+
+ with self.database.session(True) as conn:
+ if not conn.del_app(data['client_id'], data['client_secret'], request['token'].code):
+ raise HttpError(400, 'Token not removed')
+
+ return Response.new({'msg': 'Token deleted'}, ctype = 'json')
+
+
+@register_route('/api/v1/login')
+class Login(View):
+ async def post(self, request: Request) -> Response:
+ data = await self.get_api_data(['username', 'password'], [])
with self.database.session(True) as conn:
if not (user := conn.get_user(data['username'])):
- return Response.new_error(401, 'User not found', 'json')
+ raise HttpError(401, 'User not found')
try:
conn.hasher.verify(user['hash'], data['password'])
except VerifyMismatchError:
- return Response.new_error(401, 'Invalid password', 'json')
+ raise HttpError(401, 'Invalid password')
token = conn.put_token(data['username'])
@@ -106,11 +226,36 @@ class Login(View):
return resp
- async def delete(self, request: Request) -> Response:
- with self.database.session() as conn:
- conn.del_token(request['token'])
- return Response.new({'message': 'Token revoked'}, ctype = 'json')
+ async def post2(self, request: Request) -> Response:
+ data = await self.get_api_data(['username', 'password'], [])
+
+ with self.database.session(True) as conn:
+ if not (user := conn.get_user(data['username'])):
+ raise HttpError(401, 'User not found')
+
+ try:
+ conn.hasher.verify(user['hash'], data['password'])
+
+ except VerifyMismatchError:
+ raise HttpError(401, 'Invalid password')
+
+ app = conn.put_app(
+ data['app_name'],
+ DEFAULT_REDIRECT,
+ data.get('website')
+ )
+
+ params = {
+ 'code': secrets.token_hex(20),
+ 'user': user.username
+ }
+
+ with conn.update('app', params, client_id = app.client_id) as cur:
+ if (row := cur.one(schema.App)) is None:
+ raise HttpError(500, 'Failed to create app')
+
+ return Response.new(row.get_api_data(True), ctype = 'json')
@register_route('/api/v1/relay')
@@ -155,14 +300,10 @@ class Config(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['key', 'value'], [])
-
- if isinstance(data, Response):
- return data
-
data['key'] = data['key'].replace('-', '_')
if data['key'] not in ConfigData.USER_KEYS():
- return Response.new_error(400, 'Invalid key', 'json')
+ raise HttpError(400, 'Invalid key')
with self.database.session() as conn:
conn.put_config(data['key'], data['value'])
@@ -173,11 +314,8 @@ class Config(View):
async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['key'], [])
- if isinstance(data, Response):
- return data
-
if data['key'] not in ConfigData.USER_KEYS():
- return Response.new_error(400, 'Invalid key', 'json')
+ raise HttpError(400, 'Invalid key')
with self.database.session() as conn:
conn.put_config(data['key'], ConfigData.DEFAULT(data['key']))
@@ -196,15 +334,11 @@ class Inbox(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['actor'], ['inbox', 'software', 'followid'])
-
- if isinstance(data, Response):
- return data
-
data['domain'] = urlparse(data["actor"]).netloc
with self.database.session() as conn:
if conn.get_inbox(data['domain']) is not None:
- return Response.new_error(404, 'Instance already in database', 'json')
+ raise HttpError(404, 'Instance already in database')
data['domain'] = data['domain'].encode('idna').decode()
@@ -214,7 +348,7 @@ class Inbox(View):
except Exception:
traceback.print_exc()
- return Response.new_error(500, 'Failed to fetch actor', 'json')
+ raise HttpError(500, 'Failed to fetch actor')
data['inbox'] = actor_data.shared_inbox
@@ -240,14 +374,10 @@ class Inbox(View):
async def patch(self, request: Request) -> Response:
with self.database.session() as conn:
data = await self.get_api_data(['domain'], ['actor', 'software', 'followid'])
-
- if isinstance(data, Response):
- return data
-
data['domain'] = data['domain'].encode('idna').decode()
if (instance := conn.get_inbox(data['domain'])) is None:
- return Response.new_error(404, 'Instance with domain not found', 'json')
+ raise HttpError(404, 'Instance with domain not found')
instance = conn.put_inbox(
instance.domain,
@@ -262,14 +392,10 @@ class Inbox(View):
async def delete(self, request: Request) -> Response:
with self.database.session() as conn:
data = await self.get_api_data(['domain'], [])
-
- if isinstance(data, Response):
- return data
-
data['domain'] = data['domain'].encode('idna').decode()
if not conn.get_inbox(data['domain']):
- return Response.new_error(404, 'Instance with domain not found', 'json')
+ raise HttpError(404, 'Instance with domain not found')
conn.del_inbox(data['domain'])
@@ -286,26 +412,21 @@ class RequestView(View):
async def post(self, request: Request) -> Response:
- data: dict[str, Any] | Response = await self.get_api_data(['domain', 'accept'], [])
-
- if isinstance(data, Response):
- return data
-
- data['accept'] = boolean(data['accept'])
+ data = await self.get_api_data(['domain', 'accept'], [])
data['domain'] = data['domain'].encode('idna').decode()
try:
with self.database.session(True) as conn:
- instance = conn.put_request_response(data['domain'], data['accept'])
+ instance = conn.put_request_response(data['domain'], boolean(data['accept']))
except KeyError:
- return Response.new_error(404, 'Request not found', 'json')
+ raise HttpError(404, 'Request not found')
message = Message.new_response(
host = self.config.domain,
actor = instance.actor,
followid = instance.followid,
- accept = data['accept']
+ accept = boolean(data['accept'])
)
self.app.push_message(instance.inbox, message, instance)
@@ -333,15 +454,11 @@ class DomainBan(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['domain'], ['note', 'reason'])
-
- if isinstance(data, Response):
- return data
-
data['domain'] = data['domain'].encode('idna').decode()
with self.database.session() as conn:
if conn.get_domain_ban(data['domain']) is not None:
- return Response.new_error(400, 'Domain already banned', 'json')
+ raise HttpError(400, 'Domain already banned')
ban = conn.put_domain_ban(
domain = data['domain'],
@@ -356,16 +473,13 @@ class DomainBan(View):
with self.database.session() as conn:
data = await self.get_api_data(['domain'], ['note', 'reason'])
- if isinstance(data, Response):
- return data
-
if not any([data.get('note'), data.get('reason')]):
- return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
+ raise HttpError(400, 'Must include note and/or reason parameters')
data['domain'] = data['domain'].encode('idna').decode()
if conn.get_domain_ban(data['domain']) is None:
- return Response.new_error(404, 'Domain not banned', 'json')
+ raise HttpError(404, 'Domain not banned')
ban = conn.update_domain_ban(
domain = data['domain'],
@@ -379,14 +493,10 @@ class DomainBan(View):
async def delete(self, request: Request) -> Response:
with self.database.session() as conn:
data = await self.get_api_data(['domain'], [])
-
- if isinstance(data, Response):
- return data
-
data['domain'] = data['domain'].encode('idna').decode()
if conn.get_domain_ban(data['domain']) is None:
- return Response.new_error(404, 'Domain not banned', 'json')
+ raise HttpError(404, 'Domain not banned')
conn.del_domain_ban(data['domain'])
@@ -405,12 +515,9 @@ class SoftwareBan(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['name'], ['note', 'reason'])
- if isinstance(data, Response):
- return data
-
with self.database.session() as conn:
if conn.get_software_ban(data['name']) is not None:
- return Response.new_error(400, 'Domain already banned', 'json')
+ raise HttpError(400, 'Domain already banned')
ban = conn.put_software_ban(
name = data['name'],
@@ -424,15 +531,12 @@ class SoftwareBan(View):
async def patch(self, request: Request) -> Response:
data = await self.get_api_data(['name'], ['note', 'reason'])
- if isinstance(data, Response):
- return data
-
if not any([data.get('note'), data.get('reason')]):
- return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
+ raise HttpError(400, 'Must include note and/or reason parameters')
with self.database.session() as conn:
if conn.get_software_ban(data['name']) is None:
- return Response.new_error(404, 'Software not banned', 'json')
+ raise HttpError(404, 'Software not banned')
ban = conn.update_software_ban(
name = data['name'],
@@ -446,12 +550,9 @@ class SoftwareBan(View):
async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['name'], [])
- if isinstance(data, Response):
- return data
-
with self.database.session() as conn:
if conn.get_software_ban(data['name']) is None:
- return Response.new_error(404, 'Software not banned', 'json')
+ raise HttpError(404, 'Software not banned')
conn.del_software_ban(data['name'])
@@ -474,12 +575,9 @@ class User(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['username', 'password'], ['handle'])
- if isinstance(data, Response):
- return data
-
with self.database.session() as conn:
if conn.get_user(data['username']) is not None:
- return Response.new_error(404, 'User already exists', 'json')
+ raise HttpError(404, 'User already exists')
user = conn.put_user(
username = data['username'],
@@ -494,9 +592,6 @@ class User(View):
async def patch(self, request: Request) -> Response:
data = await self.get_api_data(['username'], ['password', 'handle'])
- if isinstance(data, Response):
- return data
-
with self.database.session(True) as conn:
user = conn.put_user(
username = data['username'],
@@ -511,12 +606,9 @@ class User(View):
async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['username'], [])
- if isinstance(data, Response):
- return data
-
with self.database.session(True) as conn:
if conn.get_user(data['username']) is None:
- return Response.new_error(404, 'User does not exist', 'json')
+ raise HttpError(404, 'User does not exist')
conn.del_user(data['username'])
@@ -535,14 +627,11 @@ class Whitelist(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['domain'], [])
- if isinstance(data, Response):
- return data
-
domain = data['domain'].encode('idna').decode()
with self.database.session() as conn:
if conn.get_domain_whitelist(domain) is not None:
- return Response.new_error(400, 'Domain already added to whitelist', 'json')
+ raise HttpError(400, 'Domain already added to whitelist')
item = conn.put_domain_whitelist(domain)
@@ -552,14 +641,11 @@ class Whitelist(View):
async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['domain'], [])
- if isinstance(data, Response):
- return data
-
domain = data['domain'].encode('idna').decode()
with self.database.session() as conn:
if conn.get_domain_whitelist(domain) is None:
- return Response.new_error(404, 'Domain not in whitelist', 'json')
+ raise HttpError(404, 'Domain not in whitelist')
conn.del_domain_whitelist(domain)
diff --git a/relay/views/base.py b/relay/views/base.py
index e102896..1b2d405 100644
--- a/relay/views/base.py
+++ b/relay/views/base.py
@@ -1,10 +1,8 @@
from __future__ import annotations
-from Crypto.Random import get_random_bytes
from aiohttp.abc import AbstractView
from aiohttp.hdrs import METH_ALL as METHODS
-from aiohttp.web import HTTPMethodNotAllowed, Request
-from base64 import b64encode
+from aiohttp.web import Request
from bsql import Database
from collections.abc import Awaitable, Callable, Generator, Sequence, Mapping
from functools import cached_property
@@ -15,7 +13,7 @@ from ..cache import Cache
from ..config import Config
from ..database import Connection
from ..http_client import HttpClient
-from ..misc import Response, get_app
+from ..misc import HttpError, Response, get_app
if TYPE_CHECKING:
from typing import Self
@@ -43,10 +41,10 @@ def register_route(*paths: str) -> Callable[[type[View]], type[View]]:
class View(AbstractView):
def __await__(self) -> Generator[Any, None, Response]:
if self.request.method not in METHODS:
- raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
+ raise HttpError(405, f'"{self.request.method}" method not allowed')
if not (handler := self.handlers.get(self.request.method)):
- raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
+ raise HttpError(405, f'"{self.request.method}" method not allowed')
return self._run_handler(handler).__await__()
@@ -58,7 +56,6 @@ class View(AbstractView):
async def _run_handler(self, handler: HandlerCallback, **kwargs: Any) -> Response:
- self.request['hash'] = b64encode(get_random_bytes(16)).decode('ascii')
return await handler(self.request, **self.request.match_info, **kwargs)
@@ -117,17 +114,18 @@ class View(AbstractView):
async def get_api_data(self,
required: list[str],
- optional: list[str]) -> dict[str, str] | Response:
+ optional: list[str]) -> dict[str, str]:
- if self.request.content_type in {'x-www-form-urlencoded', 'multipart/form-data'}:
+ if self.request.content_type in {'application/x-www-form-urlencoded', 'multipart/form-data'}:
post_data = convert_data(await self.request.post())
+ # post_data = {key: value for key, value in parse_qsl(await self.request.text())}
elif self.request.content_type == 'application/json':
try:
post_data = convert_data(await self.request.json())
except JSONDecodeError:
- return Response.new_error(400, 'Invalid JSON data', 'json')
+ raise HttpError(400, 'Invalid JSON data')
else:
post_data = convert_data(self.request.query)
@@ -139,9 +137,9 @@ class View(AbstractView):
data[key] = post_data[key]
except KeyError as e:
- return Response.new_error(400, f'Missing {str(e)} pararmeter', 'json')
+ raise HttpError(400, f'Missing {str(e)} pararmeter')
for key in optional:
- data[key] = post_data.get(key, '')
+ data[key] = post_data.get(key) # type: ignore[assignment]
return data
diff --git a/relay/views/frontend.py b/relay/views/frontend.py
index cf6b338..a383d20 100644
--- a/relay/views/frontend.py
+++ b/relay/views/frontend.py
@@ -1,18 +1,13 @@
from aiohttp import web
from collections.abc import Awaitable, Callable
from typing import Any
+from urllib.parse import unquote
from .base import View, register_route
from ..database import THEMES
from ..logger import LogLevel
-from ..misc import Response, get_app
-
-
-UNAUTH_ROUTES = {
- '/',
- '/login'
-}
+from ..misc import TOKEN_PATHS, Response
@web.middleware
@@ -20,28 +15,25 @@ async def handle_frontend_path(
request: web.Request,
handler: Callable[[web.Request], Awaitable[Response]]) -> Response:
- app = get_app()
+ if request['user'] is not None and request.path == '/login':
+ return Response.new_redir('/')
- if request.path in UNAUTH_ROUTES or request.path.startswith('/admin'):
- request['token'] = request.cookies.get('user-token')
- request['user'] = None
+ if request.path.startswith(TOKEN_PATHS) and request['user'] is None:
+ if request.path == '/logout':
+ return Response.new_redir('/')
- if request['token']:
- with app.database.session(False) as conn:
- request['user'] = conn.get_user_by_token(request['token'])
+ response = Response.new_redir(f'/login?redir={request.path}')
- if request['user'] and request.path == '/login':
- return Response.new('', 302, {'Location': '/'})
-
- if not request['user'] and request.path.startswith('/admin'):
- response = Response.new('', 302, {'Location': f'/login?redir={request.path}'})
+ if request['token'] is not None:
response.del_cookie('user-token')
- return response
+
+ return response
response = await handler(request)
- if not request.path.startswith('/api') and not request['user'] and request['token']:
- response.del_cookie('user-token')
+ if not request.path.startswith('/api'):
+ if request['user'] is None and request['token'] is not None:
+ response.del_cookie('user-token')
return response
@@ -54,14 +46,15 @@ class HomeView(View):
'instances': tuple(conn.get_inboxes())
}
- data = self.template.render('page/home.haml', self, **context)
+ data = self.template.render('page/home.haml', self.request, **context)
return Response.new(data, ctype='html')
@register_route('/login')
class Login(View):
async def get(self, request: web.Request) -> Response:
- data = self.template.render('page/login.haml', self)
+ redir = unquote(request.query.get('redir', '/'))
+ data = self.template.render('page/login.haml', self.request, redir = redir)
return Response.new(data, ctype = 'html')
@@ -69,7 +62,7 @@ class Login(View):
class Logout(View):
async def get(self, request: web.Request) -> Response:
with self.database.session(True) as conn:
- conn.del_token(request['token'])
+ conn.del_token(request['token'].code)
resp = Response.new_redir('/')
resp.del_cookie('user-token', domain = self.config.domain, path = '/')
@@ -79,7 +72,7 @@ class Logout(View):
@register_route('/admin')
class Admin(View):
async def get(self, request: web.Request) -> Response:
- return Response.new('', 302, {'Location': '/admin/instances'})
+ return Response.new_redir(f'/login?redir={request.path}', 301)
@register_route('/admin/instances')
@@ -101,7 +94,7 @@ class AdminInstances(View):
if message:
context['message'] = message
- data = self.template.render('page/admin-instances.haml', self, **context)
+ data = self.template.render('page/admin-instances.haml', self.request, **context)
return Response.new(data, ctype = 'html')
@@ -123,7 +116,7 @@ class AdminWhitelist(View):
if message:
context['message'] = message
- data = self.template.render('page/admin-whitelist.haml', self, **context)
+ data = self.template.render('page/admin-whitelist.haml', self.request, **context)
return Response.new(data, ctype = 'html')
@@ -145,7 +138,7 @@ class AdminDomainBans(View):
if message:
context['message'] = message
- data = self.template.render('page/admin-domain_bans.haml', self, **context)
+ data = self.template.render('page/admin-domain_bans.haml', self.request, **context)
return Response.new(data, ctype = 'html')
@@ -167,7 +160,7 @@ class AdminSoftwareBans(View):
if message:
context['message'] = message
- data = self.template.render('page/admin-software_bans.haml', self, **context)
+ data = self.template.render('page/admin-software_bans.haml', self.request, **context)
return Response.new(data, ctype = 'html')
@@ -189,7 +182,7 @@ class AdminUsers(View):
if message:
context['message'] = message
- data = self.template.render('page/admin-users.haml', self, **context)
+ data = self.template.render('page/admin-users.haml', self.request, **context)
return Response.new(data, ctype = 'html')
@@ -213,7 +206,7 @@ class AdminConfig(View):
}
}
- data = self.template.render('page/admin-config.haml', self, **context)
+ data = self.template.render('page/admin-config.haml', self.request, **context)
return Response.new(data, ctype = 'html')
@@ -251,5 +244,5 @@ class ThemeCss(View):
except KeyError:
return Response.new('Invalid theme', 404)
- data = self.template.render('variables.css', self, **context)
+ data = self.template.render('variables.css', self.request, **context)
return Response.new(data, ctype = 'css')