various changes

* Add oauth login support
* Add `HttpError` class
* Add custom error handling
* Use `blib.Date` class for (de)serializing db timestamp values
* Add `db-maintenance` command
* Rework middleware route checking
* Fix fetching post data in api endpoints
This commit is contained in:
Izalia Mae 2024-07-04 20:36:04 -04:00
parent b22b5bbefa
commit f98ca54ab7
19 changed files with 748 additions and 231 deletions

View file

@ -4,11 +4,14 @@ import asyncio
import multiprocessing import multiprocessing
import signal import signal
import time import time
import traceback
from Crypto.Random import get_random_bytes
from aiohttp import web from aiohttp import web
from aiohttp.web import StaticResource from aiohttp.web import HTTPException, StaticResource
from aiohttp_swagger import setup_swagger from aiohttp_swagger import setup_swagger
from aputils.signer import Signer from aputils.signer import Signer
from base64 import b64encode
from bsql import Database from bsql import Database
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -23,7 +26,8 @@ from .config import Config
from .database import Connection, get_database from .database import Connection, get_database
from .database.schema import Instance from .database.schema import Instance
from .http_client import HttpClient from .http_client import HttpClient
from .misc import Message, Response, check_open_port, get_resource from .misc import HttpError, Message, Response, check_open_port, get_resource
from .misc import JSON_PATHS, TOKEN_PATHS
from .template import Template from .template import Template
from .views import VIEWS from .views import VIEWS
from .views.api import handle_api_path from .views.api import handle_api_path
@ -53,9 +57,9 @@ class Application(web.Application):
def __init__(self, cfgpath: Path | None, dev: bool = False): def __init__(self, cfgpath: Path | None, dev: bool = False):
web.Application.__init__(self, web.Application.__init__(self,
middlewares = [ middlewares = [
handle_api_path, # type: ignore[list-item] handle_response_headers, # type: ignore[list-item]
handle_frontend_path, # type: ignore[list-item] handle_frontend_path, # type: ignore[list-item]
handle_response_headers # type: ignore[list-item] handle_api_path # type: ignore[list-item]
] ]
) )
@ -282,19 +286,70 @@ class CacheCleanupThread(Thread):
self.running.clear() self.running.clear()
def format_error(request: web.Request, error: HttpError) -> Response:
app: Application = request.app # type: ignore[assignment]
if request.path.startswith(JSON_PATHS) or 'json' in request.headers.get('accept', ''):
return Response.new({'error': error.body}, error.status, ctype = 'json')
else:
body = app.template.render('page/error.haml', request, e = error)
return Response.new(body, error.status, ctype = 'html')
@web.middleware @web.middleware
async def handle_response_headers( async def handle_response_headers(
request: web.Request, request: web.Request,
handler: Callable[[web.Request], Awaitable[Response]]) -> Response: handler: Callable[[web.Request], Awaitable[Response]]) -> Response:
request['hash'] = b64encode(get_random_bytes(16)).decode('ascii')
request['token'] = None
request['user'] = None
app: Application = request.app # type: ignore[assignment]
if request.path == "/" or request.path.startswith(TOKEN_PATHS):
with app.database.session() as conn:
if (token := request.headers.get('Authorization')) is not None:
token = token.replace('Bearer', '').strip()
request['token'] = conn.get_app_by_token(token)
request['user'] = conn.get_user_by_app_token(token)
elif (token := request.cookies.get('user-token')) is not None:
request['token'] = conn.get_token(token)
request['user'] = conn.get_user_by_token(token)
try:
resp = await handler(request) resp = await handler(request)
except HttpError as e:
resp = format_error(request, e)
except HTTPException as ae:
if ae.status == 404:
try:
text = (ae.text or "").split(":")[1].strip()
except IndexError:
text = ae.text or ""
resp = format_error(request, HttpError(ae.status, text))
else:
raise
except Exception as e:
resp = format_error(request, HttpError(500, f'{type(e).__name__}: {str(e)}'))
traceback.print_exc()
resp.headers['Server'] = 'ActivityRelay' resp.headers['Server'] = 'ActivityRelay'
# Still have to figure out how csp headers work # Still have to figure out how csp headers work
if resp.content_type == 'text/html' and not request.path.startswith("/api"): if resp.content_type == 'text/html':
resp.headers['Content-Security-Policy'] = get_csp(request) resp.headers['Content-Security-Policy'] = get_csp(request)
if not request.app['dev'] and request.path.endswith(('.css', '.js')): if not request.app['dev'] and request.path.endswith(('.css', '.js', '.woff2')):
# cache for 2 weeks # cache for 2 weeks
resp.headers['Cache-Control'] = 'public,max-age=1209600,immutable' resp.headers['Cache-Control'] = 'public,max-age=1209600,immutable'

View file

@ -56,6 +56,14 @@ WHERE username = (
); );
-- name: get-user-by-app-token
SELECT * FROM users
WHERE username = (
SELECT user FROM app
WHERE code = :code
);
-- name: put-user -- name: put-user
INSERT INTO users (username, hash, handle, created) INSERT INTO users (username, hash, handle, created)
VALUES (:username, :hash, :handle, :created) VALUES (:username, :hash, :handle, :created)
@ -67,6 +75,30 @@ DELETE FROM users
WHERE username = :value or handle = :value; WHERE username = :value or handle = :value;
-- name: get-app
SELECT * FROM app
WHERE client_id = :id and client_secret = :secret;
-- name: get-app-token
SELECT * FROM app
WHERE client_id = :id and client_secret = :secret and token = :token;
-- name: get-app-by-token
SELECT * FROM app
WHERE token = :token;
-- name: del-app
DELETE FROM users
WHERE client_id = :id and client_secret = :secret;
-- name: del-app-token
DELETE FROM users
WHERE client_id = :id and client_secret = :secret and token = :token;
-- name: get-token -- name: get-token
SELECT * FROM tokens SELECT * FROM tokens
WHERE code = :code; WHERE code = :code;

View file

@ -11,12 +11,8 @@ from .. import logger as logging
from ..misc import boolean from ..misc import boolean
if TYPE_CHECKING: if TYPE_CHECKING:
try:
from typing import Self from typing import Self
except ImportError:
from typing_extensions import Self
THEMES = { THEMES = {
'default': { 'default': {
@ -77,7 +73,7 @@ CONFIG_CONVERT: dict[str, tuple[Callable[[Any], str], Callable[[str], Any]]] = {
@dataclass() @dataclass()
class ConfigData: class ConfigData:
schema_version: int = 20240310 schema_version: int = 20240625
private_key: str = '' private_key: str = ''
approval_required: bool = False approval_required: bool = False
log_level: logging.LogLevel = logging.LogLevel.INFO log_level: logging.LogLevel = logging.LogLevel.INFO
@ -115,11 +111,11 @@ class ConfigData:
@classmethod @classmethod
def DEFAULT(cls: type[Self], key: str) -> str | int | bool: def DEFAULT(cls: type[Self], key: str) -> str | int | bool:
return cls.FIELD(key.replace('-', '_')).default # type: ignore return cls.FIELD(key.replace('-', '_')).default # type: ignore[return-value]
@classmethod @classmethod
def FIELD(cls: type[Self], key: str) -> Field[Any]: def FIELD(cls: type[Self], key: str) -> Field[str | int | bool]:
for field in fields(cls): for field in fields(cls):
if field.name == key.replace('-', '_'): if field.name == key.replace('-', '_'):
return field return field

View file

@ -1,6 +1,9 @@
from __future__ import annotations from __future__ import annotations
import secrets
from argon2 import PasswordHasher from argon2 import PasswordHasher
from blib import Date
from bsql import Connection as SqlConnection, Row, Update from bsql import Connection as SqlConnection, Row, Update
from collections.abc import Iterator from collections.abc import Iterator
from datetime import datetime, timezone from datetime import datetime, timezone
@ -49,6 +52,40 @@ class Connection(SqlConnection):
yield instance yield instance
def fix_timestamps(self) -> None:
for app in self.select('apps').all(schema.App):
data = {'created': app.created.timestamp(), 'accessed': app.accessed.timestamp()}
self.update('apps', data, client_id = app.client_id)
for item in self.select('cache'):
data = {'updated': Date.parse(item['updated']).timestamp()}
self.update('cache', data, id = item['id'])
for dban in self.select('domain_bans').all(schema.DomainBan):
data = {'created': dban.created.timestamp()}
self.update('domain_bans', data, domain = dban.domain)
for instance in self.select('inboxes').all(schema.Instance):
data = {'created': instance.created.timestamp()}
self.update('inboxes', data, domain = instance.domain)
for sban in self.select('software_bans').all(schema.SoftwareBan):
data = {'created': sban.created.timestamp()}
self.update('software_bans', data, name = sban.name)
for token in self.select('tokens').all(schema.Token):
data = {'created': token.created.timestamp(), 'accessed': token.accessed.timestamp()}
self.update('tokens', data, code = token.code)
for user in self.select('users').all(schema.User):
data = {'created': user.created.timestamp()}
self.update('users', data, username = user.username)
for wlist in self.select('whitelist').all(schema.Whitelist):
data = {'created': wlist.created.timestamp()}
self.update('whitelist', data, domain = wlist.domain)
def get_config(self, key: str) -> Any: def get_config(self, key: str) -> Any:
key = key.replace('_', '-') key = key.replace('_', '-')
@ -198,6 +235,11 @@ class Connection(SqlConnection):
return cur.one(schema.User) return cur.one(schema.User)
def get_user_by_app_token(self, code: str) -> schema.User | None:
with self.run('get-user-by-app-token', {'code': code}) as cur:
return cur.one(schema.User)
def get_users(self) -> Iterator[schema.User]: def get_users(self) -> Iterator[schema.User]:
return self.execute("SELECT * FROM users").all(schema.User) return self.execute("SELECT * FROM users").all(schema.User)
@ -249,13 +291,102 @@ class Connection(SqlConnection):
pass pass
def get_app(self,
client_id: str,
client_secret: str,
token: str | None = None) -> schema.App | None:
params = {
'id': client_id,
'secret': client_secret
}
if token is not None:
command = 'get-app-with-token'
params['token'] = token
else:
command = 'get-app'
with self.run(command, params) as cur:
return cur.one(schema.App)
def get_app_by_token(self, token: str) -> schema.App | None:
with self.run('get-app-by-token', {'token': token}) as cur:
return cur.one(schema.App)
def put_app(self, name: str, redirect_uri: str, website: str | None = None) -> schema.App:
params = {
'name': name,
'redirect_uri': redirect_uri,
'website': website,
'client_id': secrets.token_hex(20),
'client_secret': secrets.token_hex(20),
'created': Date.new_utc().timestamp(),
'accessed': Date.new_utc().timestamp()
}
with self.insert('app', params) as cur:
if (row := cur.one(schema.App)) is None:
raise RuntimeError(f'Failed to insert app: {name}')
return row
def update_app(self, app: schema.App, user: schema.User | None, set_auth: bool) -> schema.App:
data: dict[str, str | None] = {}
if user is not None:
data['user'] = user.username
if set_auth:
data['auth_code'] = secrets.token_hex(20)
else:
data['token'] = secrets.token_hex(20)
data['auth_code'] = None
params = {
'client_id': app.client_id,
'client_secret': app.client_secret
}
with self.update('app', data, **params) as cur: # type: ignore[arg-type]
if (row := cur.one(schema.App)) is None:
raise RuntimeError('Failed to update row')
return row
def del_app(self, client_id: str, client_secret: str, token: str | None = None) -> bool:
params = {
'id': client_id,
'secret': client_secret
}
if token is not None:
command = 'del-app-token'
params['token'] = token
else:
command = 'del-app'
with self.run(command, params) as cur:
if cur.row_count > 1:
raise RuntimeError('More than 1 row was deleted')
return cur.row_count == 0
def get_token(self, code: str) -> schema.Token | None: def get_token(self, code: str) -> schema.Token | None:
with self.run('get-token', {'code': code}) as cur: with self.run('get-token', {'code': code}) as cur:
return cur.one(schema.Token) return cur.one(schema.Token)
def get_tokens(self, username: str | None = None) -> Iterator[schema.Token]: def get_tokens(self, username: str | None = None) -> Iterator[schema.Token]:
if username is not None: if username is None:
return self.select('tokens').all(schema.Token) return self.select('tokens').all(schema.Token)
return self.select('tokens', username = username).all(schema.Token) return self.select('tokens', username = username).all(schema.Token)

View file

@ -1,14 +1,14 @@
from __future__ import annotations from __future__ import annotations
import typing from blib import Date
from bsql import Column, Row, Tables from bsql import Column, Row, Tables
from collections.abc import Callable from collections.abc import Callable
from datetime import datetime from copy import deepcopy
from typing import TYPE_CHECKING, Any
from .config import ConfigData from .config import ConfigData
if typing.TYPE_CHECKING: if TYPE_CHECKING:
from .connection import Connection from .connection import Connection
@ -16,6 +16,16 @@ VERSIONS: dict[int, Callable[[Connection], None]] = {}
TABLES = Tables() TABLES = Tables()
def deserialize_timestamp(value: Any) -> Date:
try:
return Date.parse(value)
except ValueError:
pass
return Date.fromisoformat(value)
@TABLES.add_row @TABLES.add_row
class Config(Row): class Config(Row):
key: Column[str] = Column('key', 'text', primary_key = True, unique = True, nullable = False) key: Column[str] = Column('key', 'text', primary_key = True, unique = True, nullable = False)
@ -27,62 +37,125 @@ class Config(Row):
class Instance(Row): class Instance(Row):
table_name: str = 'inboxes' table_name: str = 'inboxes'
domain: Column[str] = Column( domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = False) 'domain', 'text', primary_key = True, unique = True, nullable = False)
actor: Column[str] = Column('actor', 'text', unique = True) actor: Column[str] = Column('actor', 'text', unique = True)
inbox: Column[str] = Column('inbox', 'text', unique = True, nullable = False) inbox: Column[str] = Column('inbox', 'text', unique = True, nullable = False)
followid: Column[str] = Column('followid', 'text') followid: Column[str] = Column('followid', 'text')
software: Column[str] = Column('software', 'text') software: Column[str] = Column('software', 'text')
accepted: Column[datetime] = Column('accepted', 'boolean') accepted: Column[Date] = Column('accepted', 'boolean')
created: Column[datetime] = Column('created', 'timestamp', nullable = False) created: Column[Date] = Column(
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
@TABLES.add_row @TABLES.add_row
class Whitelist(Row): class Whitelist(Row):
domain: Column[str] = Column( domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = True) 'domain', 'text', primary_key = True, unique = True, nullable = True)
created: Column[datetime] = Column('created', 'timestamp') created: Column[Date] = Column(
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
@TABLES.add_row @TABLES.add_row
class DomainBan(Row): class DomainBan(Row):
table_name: str = 'domain_bans' table_name: str = 'domain_bans'
domain: Column[str] = Column( domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = True) 'domain', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text') reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text') note: Column[str] = Column('note', 'text')
created: Column[datetime] = Column('created', 'timestamp') created: Column[Date] = Column(
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
@TABLES.add_row @TABLES.add_row
class SoftwareBan(Row): class SoftwareBan(Row):
table_name: str = 'software_bans' table_name: str = 'software_bans'
name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True) name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text') reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text') note: Column[str] = Column('note', 'text')
created: Column[datetime] = Column('created', 'timestamp') created: Column[Date] = Column(
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
@TABLES.add_row @TABLES.add_row
class User(Row): class User(Row):
table_name: str = 'users' table_name: str = 'users'
username: Column[str] = Column( username: Column[str] = Column(
'username', 'text', primary_key = True, unique = True, nullable = False) 'username', 'text', primary_key = True, unique = True, nullable = False)
hash: Column[str] = Column('hash', 'text', nullable = False) hash: Column[str] = Column('hash', 'text', nullable = False)
handle: Column[str] = Column('handle', 'text') handle: Column[str] = Column('handle', 'text')
created: Column[datetime] = Column('created', 'timestamp') created: Column[Date] = Column(
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
@TABLES.add_row @TABLES.add_row
class Token(Row): class Token(Row):
table_name: str = 'tokens' table_name: str = 'tokens'
code: Column[str] = Column('code', 'text', primary_key = True, unique = True, nullable = False) code: Column[str] = Column('code', 'text', primary_key = True, unique = True, nullable = False)
user: Column[str] = Column('user', 'text', nullable = False) user: Column[str] = Column('user', 'text', nullable = False)
created: Column[datetime] = Column('created', 'timestamp') created: Column[Date] = Column(
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
accessed: Column[Date] = Column(
'accessed', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
@TABLES.add_row
class App(Row):
table_name: str = 'apps'
client_id: Column[str] = Column(
'client_id', 'text', primary_key = True, unique = True, nullable = False)
client_secret: Column[str] = Column('client_secret', 'text', nullable = False)
name: Column[str] = Column('name', 'text')
website: Column[str] = Column('website', 'text')
redirect_uri: Column[str] = Column('redirect_uri', 'text', nullable = False)
token: Column[str | None] = Column('token', 'text')
auth_code: Column[str | None] = Column('auth_code', 'text')
user: Column[str | None] = Column('user', 'text')
created: Column[Date] = Column(
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
accessed: Column[Date] = Column(
'accessed', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
def get_api_data(self, include_token: bool = False) -> dict[str, Any]:
data = deepcopy(self)
data.pop('auth_code')
data.pop('created')
data.pop('accessed')
if not include_token:
data.pop('token')
return data
def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]: def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]:
@ -103,5 +176,15 @@ def migrate_20240206(conn: Connection) -> None:
@migration @migration
def migrate_20240310(conn: Connection) -> None: def migrate_20240310(conn: Connection) -> None:
conn.execute("ALTER TABLE inboxes ADD COLUMN accepted BOOLEAN") conn.execute('ALTER TABLE "inboxes" ADD COLUMN "accepted" BOOLEAN')
conn.execute("UPDATE inboxes SET accepted = 1") conn.execute('UPDATE "inboxes" SET accepted = 1')
@migration
def migrate_20240625(conn: Connection) -> None:
conn.execute('ALTER TABLE "tokens" ADD "accessed" timestamp')
for token in conn.get_tokens():
conn.update('tokens', {'accessed': token.created}, code = token.code).one()
conn.create_tables()

View file

@ -1,5 +1,5 @@
-macro menu_item(name, path) -macro menu_item(name, path)
-if view.request.path == path or (path != "/" and view.request.path.startswith(path)) -if request.path == path or (path != "/" and request.path.startswith(path))
%a.button(href="{{path}}" active="true") -> =name %a.button(href="{{path}}" active="true") -> =name
-else -else
@ -10,12 +10,12 @@
%head %head
%title << {{config.name}}: {{page}} %title << {{config.name}}: {{page}}
%meta(charset="UTF-8") %meta(charset="UTF-8")
%meta(name="viewport" content="width=device-width, initial-scale=1") %meta(name="ort" content="width=device-width, initial-scale=1")
%link(rel="stylesheet" type="text/css" href="/theme/{{config.theme}}.css?{{version}}" nonce="{{view.request['hash']}}" class="theme") %link(rel="stylesheet" type="text/css" href="/theme/{{config.theme}}.css?{{version}}" nonce="{{request['hash']}}" class="theme")
%link(rel="stylesheet" type="text/css" href="/static/style.css?{{version}}" nonce="{{view.request['hash']}}") %link(rel="stylesheet" type="text/css" href="/static/style.css?{{version}}" nonce="{{request['hash']}}")
%link(rel="stylesheet" type="text/css" href="/static/bootstrap-icons.css?{{version}}" nonce="{{view.request['hash']}}") %link(rel="stylesheet" type="text/css" href="/static/bootstrap-icons.css?{{version}}" nonce="{{request['hash']}}")
%link(rel="manifest" href="/manifest.json?{{version}}") %link(rel="manifest" href="/manifest.json?{{version}}")
%script(type="application/javascript" src="/static/functions.js?{{version}}" nonce="{{view.request['hash']}}" defer) %script(type="application/javascript" src="/static/functions.js?{{version}}" nonce="{{request['hash']}}" defer)
-block head -block head
%body %body
@ -26,7 +26,7 @@
{{menu_item("Home", "/")}} {{menu_item("Home", "/")}}
-if view.request["user"] -if request["user"]
{{menu_item("Instances", "/admin/instances")}} {{menu_item("Instances", "/admin/instances")}}
{{menu_item("Whitelist", "/admin/whitelist")}} {{menu_item("Whitelist", "/admin/whitelist")}}
{{menu_item("Domain Bans", "/admin/domain_bans")}} {{menu_item("Domain Bans", "/admin/domain_bans")}}
@ -61,11 +61,11 @@
#footer.section #footer.section
.col1 .col1
-if not view.request["user"] -if not request["user"]
%a(href="/login") << Login %a(href="/login") << Login
-else -else
=view.request["user"]["username"] =request["user"]["username"]
( (
%a(href="/logout") << Logout %a(href="/logout") << Logout
) )

View file

@ -0,0 +1,31 @@
-extends "base.haml"
-set page="App Authorization"
-block content
%fieldset.section
%legend << App Authorization
-if application.website
#title << Application "<a href="{{application.website}}" target="_new">{{application.name}}</a>" wants full API access
-else
#title << Application "{{application.name}}" wants full API access
#buttons
.spacer
%form(action="/oauth/authorize" method="POST")
%input(type="hidden" name="client_id" value="{{application.client_id}}")
%input(type="hidden" name="client_secret" value="{{application.client_secret}}")
%input(type="hidden" name="redirect_uri" value="{{application.redirect_uri}}")
%input(type="hidden" name="response" value="true")
%input.button(type="submit" value="Allow")
%form(action="/oauth/authorize" method="POST")
%input(type="hidden" name="client_id" value="{{application.client_id}}")
%input(type="hidden" name="client_secret" value="{{application.client_secret}}")
%input(type="hidden" name="redirect_uri" value="{{application.redirect_uri}}")
%input(type="hidden" name="response" value="false")
%input.button(type="submit" value="Deny")
.spacer

View file

@ -0,0 +1,18 @@
-extends "base.haml"
-set page="App Authorization"
-block content
%fieldset.section
%legend << App Authorization Code
-if application.website
%p
Copy the following code into
%a(href="{{application.website}}" target="_main") -> %code -> =application.name
-else
%p
Copy the following code info
%code -> =application.name
%pre#code -> =application.auth_code

View file

@ -0,0 +1,7 @@
-extends "base.haml"
-set page="Error"
-block content
.section.error
.title << HTTP Error {{e.status}}
.body -> =e.body

View file

@ -12,4 +12,6 @@
%label(for="password") << Password %label(for="password") << Password
%input(id="password" name="password" placeholder="Password" type="password") %input(id="password" name="password" placeholder="Password" type="password")
%input#redir(type="hidden" name="redir" value="{{redir}}")
%input.submit(type="button" value="Login") %input.submit(type="button" value="Login")

View file

@ -483,13 +483,15 @@ function page_instance() {
function page_login() { function page_login() {
const fields = { const fields = {
username: document.querySelector("#username"), username: document.querySelector("#username"),
password: document.querySelector("#password") password: document.querySelector("#password"),
} redir: document.querySelector("#redir")
};
async function login(event) { async function login(event) {
const values = { const values = {
username: fields.username.value.trim(), username: fields.username.value.trim(),
password: fields.password.value.trim() password: fields.password.value.trim(),
redir: fields.redir.value.trim()
} }
if (values.username === "" | values.password === "") { if (values.username === "" | values.password === "") {
@ -498,14 +500,14 @@ function page_login() {
} }
try { try {
await request("POST", "v1/token", values); await request("POST", "v1/login", values);
} catch (error) { } catch (error) {
toast(error); toast(error);
return; return;
} }
document.location = "/"; document.location = values.redir;
} }
@ -848,9 +850,6 @@ if (location.pathname.startsWith("/admin/config")) {
} else if (location.pathname.startsWith("/admin/instances")) { } else if (location.pathname.startsWith("/admin/instances")) {
page_instance(); page_instance();
} else if (location.pathname.startsWith("/admin/login")) {
page_login();
} else if (location.pathname.startsWith("/admin/software_bans")) { } else if (location.pathname.startsWith("/admin/software_bans")) {
page_software_ban(); page_software_ban();
@ -859,4 +858,7 @@ if (location.pathname.startsWith("/admin/config")) {
} else if (location.pathname.startsWith("/admin/whitelist")) { } else if (location.pathname.startsWith("/admin/whitelist")) {
page_whitelist(); page_whitelist();
} else if (location.pathname.startsWith("/login")) {
page_login();
} }

View file

@ -338,6 +338,44 @@ textarea {
} }
/* error */
#content.page-error {
text-align: center;
}
#content.page-error .title {
font-size: 24px;
font-weight: bold;
}
/* auth */
#content.page-app_authorization {
text-align: center;
}
#content.page-app_authorization #code {
background: var(--background);
border: 1px solid var(--border);
font-size: 18px;
margin: 0 auto;
width: max-content;
padding: 5px;
}
#content.page-app_authorization #title {
font-size: 24px;
}
#content.page-app_authorization #buttons {
display: grid;
grid-template-columns: auto max-content max-content auto;
grid-gap: var(--spacing);
justify-items: center;
margin: var(--spacing) 0;
}
@keyframes show_toast { @keyframes show_toast {
0% { 0% {
transform: translateX(100%); transform: translateX(100%);

View file

@ -212,6 +212,21 @@ def cli_run(ctx: click.Context, dev: bool = False) -> None:
os._exit(0) os._exit(0)
@cli.command('db-maintenance')
@click.option('--fix-timestamps', '-t', is_flag = True,
help = 'Make sure timestamps in the database are float values')
@click.pass_context
def cli_db_maintenance(ctx: click.Context, fix_timestamps: bool) -> None:
'Perform maintenance tasks on the database'
if fix_timestamps:
with ctx.obj.database.session(True) as conn:
conn.fix_timestamps()
with ctx.obj.database.session(False) as conn:
with conn.execute("VACUUM"):
pass
@cli.command('convert') @cli.command('convert')
@click.option('--old-config', '-o', help = 'Path to the config file to convert from') @click.option('--old-config', '-o', help = 'Path to the config file to convert from')
@ -239,18 +254,18 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
ctx.obj.config.set('domain', config['host']) ctx.obj.config.set('domain', config['host'])
ctx.obj.config.save() ctx.obj.config.save()
# fix: mypy complains about the types returned by click.progressbar when updating click to 8.1.7
with get_database(ctx.obj.config) as db: with get_database(ctx.obj.config) as db:
with db.session(True) as conn: with db.session(True) as conn:
conn.put_config('private-key', database['private-key']) conn.put_config('private-key', database['private-key'])
conn.put_config('note', config['note']) conn.put_config('note', config['note'])
conn.put_config('whitelist-enabled', config['whitelist_enabled']) conn.put_config('whitelist-enabled', config['whitelist_enabled'])
with click.progressbar( # type: ignore with click.progressbar(
database['relay-list'].values(), database['relay-list'].values(),
label = 'Inboxes'.ljust(15), label = 'Inboxes'.ljust(15),
width = 0 width = 0
) as inboxes: ) as inboxes:
for inbox in inboxes: for inbox in inboxes:
if inbox['software'] in {'akkoma', 'pleroma'}: if inbox['software'] in {'akkoma', 'pleroma'}:
actor = f'https://{inbox["domain"]}/relay' actor = f'https://{inbox["domain"]}/relay'
@ -269,7 +284,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
software = inbox['software'] software = inbox['software']
) )
with click.progressbar( # type: ignore with click.progressbar(
config['blocked_software'], config['blocked_software'],
label = 'Banned software'.ljust(15), label = 'Banned software'.ljust(15),
width = 0 width = 0
@ -281,7 +296,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
reason = 'relay' if software in RELAY_SOFTWARE else None reason = 'relay' if software in RELAY_SOFTWARE else None
) )
with click.progressbar( # type: ignore with click.progressbar(
config['blocked_instances'], config['blocked_instances'],
label = 'Banned domains'.ljust(15), label = 'Banned domains'.ljust(15),
width = 0 width = 0
@ -290,7 +305,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
for domain in banned_software: for domain in banned_software:
conn.put_domain_ban(domain) conn.put_domain_ban(domain)
with click.progressbar( # type: ignore with click.progressbar(
config['whitelist'], config['whitelist'],
label = 'Whitelist'.ljust(15), label = 'Whitelist'.ljust(15),
width = 0 width = 0

View file

@ -62,6 +62,28 @@ SOFTWARE = (
'gotosocial' 'gotosocial'
) )
JSON_PATHS: tuple[str, ...] = (
'/api/v1',
'/actor',
'/inbox',
'/outbox',
'/following',
'/followers',
'/.well-known',
'/nodeinfo',
'/oauth/token',
'/oauth/revoke'
)
TOKEN_PATHS: tuple[str, ...] = (
'/api',
'/login',
'/logout',
'/oauth/authorize',
'/oauth/revoke',
'/admin'
)
def boolean(value: Any) -> bool: def boolean(value: Any) -> bool:
if isinstance(value, str): if isinstance(value, str):
@ -113,6 +135,17 @@ def get_resource(path: str) -> Path:
return Path(str(pkgfiles('relay'))).joinpath(path) return Path(str(pkgfiles('relay'))).joinpath(path)
class HttpError(Exception):
def __init__(self,
status: int,
body: str) -> None:
self.body: str = body
self.status: int = status
Exception.__init__(self, f"HTTP Error {status}: {body}")
class JsonEncoder(json.JSONEncoder): class JsonEncoder(json.JSONEncoder):
def default(self, o: Any) -> str: def default(self, o: Any) -> str:
if isinstance(o, datetime): if isinstance(o, datetime):
@ -242,9 +275,9 @@ class Response(AiohttpResponse):
@classmethod @classmethod
def new_redir(cls: type[Self], path: str) -> Self: def new_redir(cls: type[Self], path: str, status: int = 307) -> Self:
body = f'Redirect to <a href="{path}">{path}</a>' body = f'Redirect to <a href="{path}">{path}</a>'
return cls.new(body, 302, {'Location': path}) return cls.new(body, status, {'Location': path}, ctype = 'html')
@property @property

View file

@ -2,6 +2,7 @@ from __future__ import annotations
import textwrap import textwrap
from aiohttp.web import Request
from collections.abc import Callable from collections.abc import Callable
from hamlish_jinja import HamlishExtension from hamlish_jinja import HamlishExtension
from jinja2 import Environment, FileSystemLoader from jinja2 import Environment, FileSystemLoader
@ -13,7 +14,6 @@ from typing import TYPE_CHECKING, Any
from . import __version__ from . import __version__
from .misc import get_resource from .misc import get_resource
from .views.base import View
if TYPE_CHECKING: if TYPE_CHECKING:
from .application import Application from .application import Application
@ -43,12 +43,12 @@ class Template(Environment):
self.hamlish_mode = 'indented' self.hamlish_mode = 'indented'
def render(self, path: str, view: View | None = None, **context: Any) -> str: def render(self, path: str, request: Request, **context: Any) -> str:
with self.app.database.session(False) as conn: with self.app.database.session(False) as conn:
config = conn.get_config_all() config = conn.get_config_all()
new_context = { new_context = {
'view': view, 'request': request,
'domain': self.app.config.domain, 'domain': self.app.config.domain,
'version': __version__, 'version': __version__,
'config': config, 'config': config,

View file

@ -7,7 +7,7 @@ from .base import View, register_route
from .. import logger as logging from .. import logger as logging
from ..database import schema from ..database import schema
from ..misc import Message, Response from ..misc import HttpError, Message, Response
from ..processors import run_processor from ..processors import run_processor
@ -39,8 +39,7 @@ class ActorView(View):
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
if response := await self.get_post_data(): await self.get_post_data()
return response
with self.database.session() as conn: with self.database.session() as conn:
self.instance = conn.get_inbox(self.actor.shared_inbox) # type: ignore[assignment] self.instance = conn.get_inbox(self.actor.shared_inbox) # type: ignore[assignment]
@ -65,13 +64,13 @@ class ActorView(View):
return Response.new(status = 202) return Response.new(status = 202)
async def get_post_data(self) -> Response | None: async def get_post_data(self) -> None:
try: try:
self.signature = aputils.Signature.parse(self.request.headers['signature']) self.signature = aputils.Signature.parse(self.request.headers['signature'])
except KeyError: except KeyError:
logging.verbose('Missing signature header') logging.verbose('Missing signature header')
return Response.new_error(400, 'missing signature header', 'json') raise HttpError(400, 'missing signature header')
try: try:
message: Message | None = await self.request.json(loads = Message.parse) message: Message | None = await self.request.json(loads = Message.parse)
@ -79,17 +78,17 @@ class ActorView(View):
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
logging.verbose('Failed to parse inbox message') logging.verbose('Failed to parse inbox message')
return Response.new_error(400, 'failed to parse message', 'json') raise HttpError(400, 'failed to parse message')
if message is None: if message is None:
logging.verbose('empty message') logging.verbose('empty message')
return Response.new_error(400, 'missing message', 'json') raise HttpError(400, 'missing message')
self.message = message self.message = message
if 'actor' not in self.message: if 'actor' not in self.message:
logging.verbose('actor not in message') logging.verbose('actor not in message')
return Response.new_error(400, 'no actor in message', 'json') raise HttpError(400, 'no actor in message')
try: try:
self.actor = await self.client.get(self.signature.keyid, True, Message) self.actor = await self.client.get(self.signature.keyid, True, Message)
@ -98,26 +97,24 @@ class ActorView(View):
# ld signatures aren't handled atm, so just ignore it # ld signatures aren't handled atm, so just ignore it
if self.message.type == 'Delete': if self.message.type == 'Delete':
logging.verbose('Instance sent a delete which cannot be handled') logging.verbose('Instance sent a delete which cannot be handled')
return Response.new(status=202) raise HttpError(202, '')
logging.verbose(f'Failed to fetch actor: {self.signature.keyid}') logging.verbose(f'Failed to fetch actor: {self.signature.keyid}')
return Response.new_error(400, 'failed to fetch actor', 'json') raise HttpError(400, 'failed to fetch actor')
try: try:
self.signer = self.actor.signer self.signer = self.actor.signer
except KeyError: except KeyError:
logging.verbose('Actor missing public key: %s', self.signature.keyid) logging.verbose('Actor missing public key: %s', self.signature.keyid)
return Response.new_error(400, 'actor missing public key', 'json') raise HttpError(400, 'actor missing public key')
try: try:
await self.signer.validate_request_async(self.request) await self.signer.validate_request_async(self.request)
except aputils.SignatureFailureError as e: except aputils.SignatureFailureError as e:
logging.verbose('signature validation failed for "%s": %s', self.actor.id, e) logging.verbose('signature validation failed for "%s": %s', self.actor.id, e)
return Response.new_error(401, str(e), 'json') raise HttpError(401, str(e))
return None
@register_route('/outbox') @register_route('/outbox')

View file

@ -1,16 +1,17 @@
import secrets
import traceback import traceback
from aiohttp.web import Request, middleware from aiohttp.web import Request, middleware
from argon2.exceptions import VerifyMismatchError from argon2.exceptions import VerifyMismatchError
from blib import convert_to_boolean
from collections.abc import Awaitable, Callable, Sequence from collections.abc import Awaitable, Callable, Sequence
from typing import Any
from urllib.parse import urlparse from urllib.parse import urlparse
from .base import View, register_route from .base import View, register_route
from .. import __version__ from .. import __version__
from ..database import ConfigData from ..database import ConfigData, schema
from ..misc import Message, Response, boolean, get_app from ..misc import HttpError, Message, Response, boolean
DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob' DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob'
@ -22,6 +23,8 @@ ALLOWED_HEADERS: set[str] = {
PUBLIC_API_PATHS: Sequence[tuple[str, str]] = ( PUBLIC_API_PATHS: Sequence[tuple[str, str]] = (
('GET', '/api/v1/relay'), ('GET', '/api/v1/relay'),
('POST', '/api/v1/app'),
('POST', '/api/v1/login'),
('POST', '/api/v1/token') ('POST', '/api/v1/token')
) )
@ -37,57 +40,174 @@ def check_api_path(method: str, path: str) -> bool:
async def handle_api_path( async def handle_api_path(
request: Request, request: Request,
handler: Callable[[Request], Awaitable[Response]]) -> Response: handler: Callable[[Request], Awaitable[Response]]) -> Response:
try:
if (token := request.cookies.get('user-token')):
request['token'] = token
else: if not request.path.startswith('/api'):
request['token'] = request.headers['Authorization'].replace('Bearer', '').strip() return await handler(request)
with get_app().database.session() as conn:
request['user'] = conn.get_user_by_token(request['token'])
except (KeyError, ValueError):
request['token'] = None
request['user'] = None
if request.method != "OPTIONS" and check_api_path(request.method, request.path): if request.method != "OPTIONS" and check_api_path(request.method, request.path):
if not request['token']: if request['token'] is None:
return Response.new_error(401, 'Missing token', 'json') raise HttpError(401, 'Missing token')
if not request['user']: if request['user'] is None:
return Response.new_error(401, 'Invalid token', 'json') raise HttpError(401, 'Invalid token')
response = await handler(request) response = await handler(request)
if request.path.startswith('/api'):
response.headers['Access-Control-Allow-Origin'] = '*' response.headers['Access-Control-Allow-Origin'] = '*'
response.headers['Access-Control-Allow-Headers'] = ', '.join(ALLOWED_HEADERS) response.headers['Access-Control-Allow-Headers'] = ', '.join(ALLOWED_HEADERS)
return response return response
@register_route('/api/v1/token') @register_route('/oauth/authorize')
class Login(View): class OauthAuthorize(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
return Response.new({'message': 'Token valid'}, ctype = 'json') data = await self.get_api_data(['response_type', 'client_id', 'redirect_uri'], [])
if data['response_type'] != 'code':
raise HttpError(400, 'Response type is not "code"')
with self.database.session(True) as conn:
with conn.select('app', client_id = data['client_id']) as cur:
if (app := cur.one(schema.App)) is None:
raise HttpError(404, 'Could not find app')
if app.token is not None or app.auth_code is not None:
context = {'application': app}
html = self.template.render(
'page/authorize_show.haml', self.request, **context
)
return Response.new(html, ctype = 'html')
if data['redirect_uri'] != app.redirect_uri:
raise HttpError(400, 'redirect_uri does not match application')
context = {'application': app}
html = self.template.render('page/authorize_new.haml', self.request, **context)
return Response.new(html, ctype = 'html')
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
data = await self.get_api_data(['username', 'password'], []) data = await self.get_api_data(
['client_id', 'client_secret', 'redirect_uri', 'response'], []
)
if isinstance(data, Response): with self.database.session(True) as conn:
return data if (app := conn.get_app(data['client_id'], data['client_secret'])) is None:
return Response.new_error(404, 'Could not find app', 'json')
if convert_to_boolean(data['response']):
if app.auth_code is None:
app = conn.update_app(app, request['user'], True)
if app.redirect_uri == DEFAULT_REDIRECT:
context = {'application': app}
html = self.template.render(
'page/authorize_show.haml', self.request, **context
)
return Response.new(html, ctype = 'html')
return Response.new_redir(f'{app.redirect_uri}?code={app.auth_code}')
if not conn.del_app(app.client_id, app.client_secret):
raise HttpError(404, 'App not found')
return Response.new_redir('/')
@register_route('/oauth/token')
class OauthToken(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(
['grant_type', 'code', 'client_id', 'client_secret', 'redirect_uri'], []
)
if data['grant_type'] != 'authorization_code':
raise HttpError(400, 'Invalid grant type')
with self.database.session(True) as conn:
if (app := conn.get_app(data['client_id'], data['client_secret'])) is None:
raise HttpError(404, 'Application not found')
if app.auth_code != data['code']:
raise HttpError(400, 'Invalid authentication code')
if app.redirect_uri != data['redirect_uri']:
raise HttpError(400, 'Invalid redirect uri')
app = conn.update_app(app, request['user'], False)
return Response.new(app.get_api_data(True), ctype = 'json')
@register_route('/oauth/revoke')
class OauthRevoke(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['client_id', 'client_secret', 'token'], [])
with self.database.session(True) as conn:
if (app := conn.get_app(**data)) is None:
raise HttpError(404, 'Could not find token')
if app.user != request['token'].username:
raise HttpError(403, 'Invalid token')
if not conn.del_app(**data):
raise HttpError(400, 'Token not removed')
return Response.new({'msg': 'Token deleted'}, ctype = 'json')
@register_route('/api/v1/app')
class App(View):
async def get(self, request: Request) -> Response:
data = await self.get_api_data(['client_id', 'client_secret'], [])
with self.database.session(False) as conn:
if (app := conn.get_app(data['client_id'], data['client_secret'])) is None:
raise HttpError(404, 'Application cannot be found')
return Response.new(app.get_api_data(), ctype = 'json')
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['name', 'redirect_uri'], ['website'])
with self.database.session(True) as conn:
app = conn.put_app(
name = data['name'],
redirect_uri = data['redirect_uri'],
website = data.get('website')
)
return Response.new(app.get_api_data(), ctype = 'json')
async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['client_id', 'client_secret'], [])
with self.database.session(True) as conn:
if not conn.del_app(data['client_id'], data['client_secret'], request['token'].code):
raise HttpError(400, 'Token not removed')
return Response.new({'msg': 'Token deleted'}, ctype = 'json')
@register_route('/api/v1/login')
class Login(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['username', 'password'], [])
with self.database.session(True) as conn: with self.database.session(True) as conn:
if not (user := conn.get_user(data['username'])): if not (user := conn.get_user(data['username'])):
return Response.new_error(401, 'User not found', 'json') raise HttpError(401, 'User not found')
try: try:
conn.hasher.verify(user['hash'], data['password']) conn.hasher.verify(user['hash'], data['password'])
except VerifyMismatchError: except VerifyMismatchError:
return Response.new_error(401, 'Invalid password', 'json') raise HttpError(401, 'Invalid password')
token = conn.put_token(data['username']) token = conn.put_token(data['username'])
@ -106,11 +226,36 @@ class Login(View):
return resp return resp
async def delete(self, request: Request) -> Response:
with self.database.session() as conn:
conn.del_token(request['token'])
return Response.new({'message': 'Token revoked'}, ctype = 'json') async def post2(self, request: Request) -> Response:
data = await self.get_api_data(['username', 'password'], [])
with self.database.session(True) as conn:
if not (user := conn.get_user(data['username'])):
raise HttpError(401, 'User not found')
try:
conn.hasher.verify(user['hash'], data['password'])
except VerifyMismatchError:
raise HttpError(401, 'Invalid password')
app = conn.put_app(
data['app_name'],
DEFAULT_REDIRECT,
data.get('website')
)
params = {
'code': secrets.token_hex(20),
'user': user.username
}
with conn.update('app', params, client_id = app.client_id) as cur:
if (row := cur.one(schema.App)) is None:
raise HttpError(500, 'Failed to create app')
return Response.new(row.get_api_data(True), ctype = 'json')
@register_route('/api/v1/relay') @register_route('/api/v1/relay')
@ -155,14 +300,10 @@ class Config(View):
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
data = await self.get_api_data(['key', 'value'], []) data = await self.get_api_data(['key', 'value'], [])
if isinstance(data, Response):
return data
data['key'] = data['key'].replace('-', '_') data['key'] = data['key'].replace('-', '_')
if data['key'] not in ConfigData.USER_KEYS(): if data['key'] not in ConfigData.USER_KEYS():
return Response.new_error(400, 'Invalid key', 'json') raise HttpError(400, 'Invalid key')
with self.database.session() as conn: with self.database.session() as conn:
conn.put_config(data['key'], data['value']) conn.put_config(data['key'], data['value'])
@ -173,11 +314,8 @@ class Config(View):
async def delete(self, request: Request) -> Response: async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['key'], []) data = await self.get_api_data(['key'], [])
if isinstance(data, Response):
return data
if data['key'] not in ConfigData.USER_KEYS(): if data['key'] not in ConfigData.USER_KEYS():
return Response.new_error(400, 'Invalid key', 'json') raise HttpError(400, 'Invalid key')
with self.database.session() as conn: with self.database.session() as conn:
conn.put_config(data['key'], ConfigData.DEFAULT(data['key'])) conn.put_config(data['key'], ConfigData.DEFAULT(data['key']))
@ -196,15 +334,11 @@ class Inbox(View):
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
data = await self.get_api_data(['actor'], ['inbox', 'software', 'followid']) data = await self.get_api_data(['actor'], ['inbox', 'software', 'followid'])
if isinstance(data, Response):
return data
data['domain'] = urlparse(data["actor"]).netloc data['domain'] = urlparse(data["actor"]).netloc
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_inbox(data['domain']) is not None: if conn.get_inbox(data['domain']) is not None:
return Response.new_error(404, 'Instance already in database', 'json') raise HttpError(404, 'Instance already in database')
data['domain'] = data['domain'].encode('idna').decode() data['domain'] = data['domain'].encode('idna').decode()
@ -214,7 +348,7 @@ class Inbox(View):
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
return Response.new_error(500, 'Failed to fetch actor', 'json') raise HttpError(500, 'Failed to fetch actor')
data['inbox'] = actor_data.shared_inbox data['inbox'] = actor_data.shared_inbox
@ -240,14 +374,10 @@ class Inbox(View):
async def patch(self, request: Request) -> Response: async def patch(self, request: Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
data = await self.get_api_data(['domain'], ['actor', 'software', 'followid']) data = await self.get_api_data(['domain'], ['actor', 'software', 'followid'])
if isinstance(data, Response):
return data
data['domain'] = data['domain'].encode('idna').decode() data['domain'] = data['domain'].encode('idna').decode()
if (instance := conn.get_inbox(data['domain'])) is None: if (instance := conn.get_inbox(data['domain'])) is None:
return Response.new_error(404, 'Instance with domain not found', 'json') raise HttpError(404, 'Instance with domain not found')
instance = conn.put_inbox( instance = conn.put_inbox(
instance.domain, instance.domain,
@ -262,14 +392,10 @@ class Inbox(View):
async def delete(self, request: Request) -> Response: async def delete(self, request: Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
data = await self.get_api_data(['domain'], []) data = await self.get_api_data(['domain'], [])
if isinstance(data, Response):
return data
data['domain'] = data['domain'].encode('idna').decode() data['domain'] = data['domain'].encode('idna').decode()
if not conn.get_inbox(data['domain']): if not conn.get_inbox(data['domain']):
return Response.new_error(404, 'Instance with domain not found', 'json') raise HttpError(404, 'Instance with domain not found')
conn.del_inbox(data['domain']) conn.del_inbox(data['domain'])
@ -286,26 +412,21 @@ class RequestView(View):
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
data: dict[str, Any] | Response = await self.get_api_data(['domain', 'accept'], []) data = await self.get_api_data(['domain', 'accept'], [])
if isinstance(data, Response):
return data
data['accept'] = boolean(data['accept'])
data['domain'] = data['domain'].encode('idna').decode() data['domain'] = data['domain'].encode('idna').decode()
try: try:
with self.database.session(True) as conn: with self.database.session(True) as conn:
instance = conn.put_request_response(data['domain'], data['accept']) instance = conn.put_request_response(data['domain'], boolean(data['accept']))
except KeyError: except KeyError:
return Response.new_error(404, 'Request not found', 'json') raise HttpError(404, 'Request not found')
message = Message.new_response( message = Message.new_response(
host = self.config.domain, host = self.config.domain,
actor = instance.actor, actor = instance.actor,
followid = instance.followid, followid = instance.followid,
accept = data['accept'] accept = boolean(data['accept'])
) )
self.app.push_message(instance.inbox, message, instance) self.app.push_message(instance.inbox, message, instance)
@ -333,15 +454,11 @@ class DomainBan(View):
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
data = await self.get_api_data(['domain'], ['note', 'reason']) data = await self.get_api_data(['domain'], ['note', 'reason'])
if isinstance(data, Response):
return data
data['domain'] = data['domain'].encode('idna').decode() data['domain'] = data['domain'].encode('idna').decode()
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_domain_ban(data['domain']) is not None: if conn.get_domain_ban(data['domain']) is not None:
return Response.new_error(400, 'Domain already banned', 'json') raise HttpError(400, 'Domain already banned')
ban = conn.put_domain_ban( ban = conn.put_domain_ban(
domain = data['domain'], domain = data['domain'],
@ -356,16 +473,13 @@ class DomainBan(View):
with self.database.session() as conn: with self.database.session() as conn:
data = await self.get_api_data(['domain'], ['note', 'reason']) data = await self.get_api_data(['domain'], ['note', 'reason'])
if isinstance(data, Response):
return data
if not any([data.get('note'), data.get('reason')]): if not any([data.get('note'), data.get('reason')]):
return Response.new_error(400, 'Must include note and/or reason parameters', 'json') raise HttpError(400, 'Must include note and/or reason parameters')
data['domain'] = data['domain'].encode('idna').decode() data['domain'] = data['domain'].encode('idna').decode()
if conn.get_domain_ban(data['domain']) is None: if conn.get_domain_ban(data['domain']) is None:
return Response.new_error(404, 'Domain not banned', 'json') raise HttpError(404, 'Domain not banned')
ban = conn.update_domain_ban( ban = conn.update_domain_ban(
domain = data['domain'], domain = data['domain'],
@ -379,14 +493,10 @@ class DomainBan(View):
async def delete(self, request: Request) -> Response: async def delete(self, request: Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
data = await self.get_api_data(['domain'], []) data = await self.get_api_data(['domain'], [])
if isinstance(data, Response):
return data
data['domain'] = data['domain'].encode('idna').decode() data['domain'] = data['domain'].encode('idna').decode()
if conn.get_domain_ban(data['domain']) is None: if conn.get_domain_ban(data['domain']) is None:
return Response.new_error(404, 'Domain not banned', 'json') raise HttpError(404, 'Domain not banned')
conn.del_domain_ban(data['domain']) conn.del_domain_ban(data['domain'])
@ -405,12 +515,9 @@ class SoftwareBan(View):
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
data = await self.get_api_data(['name'], ['note', 'reason']) data = await self.get_api_data(['name'], ['note', 'reason'])
if isinstance(data, Response):
return data
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_software_ban(data['name']) is not None: if conn.get_software_ban(data['name']) is not None:
return Response.new_error(400, 'Domain already banned', 'json') raise HttpError(400, 'Domain already banned')
ban = conn.put_software_ban( ban = conn.put_software_ban(
name = data['name'], name = data['name'],
@ -424,15 +531,12 @@ class SoftwareBan(View):
async def patch(self, request: Request) -> Response: async def patch(self, request: Request) -> Response:
data = await self.get_api_data(['name'], ['note', 'reason']) data = await self.get_api_data(['name'], ['note', 'reason'])
if isinstance(data, Response):
return data
if not any([data.get('note'), data.get('reason')]): if not any([data.get('note'), data.get('reason')]):
return Response.new_error(400, 'Must include note and/or reason parameters', 'json') raise HttpError(400, 'Must include note and/or reason parameters')
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_software_ban(data['name']) is None: if conn.get_software_ban(data['name']) is None:
return Response.new_error(404, 'Software not banned', 'json') raise HttpError(404, 'Software not banned')
ban = conn.update_software_ban( ban = conn.update_software_ban(
name = data['name'], name = data['name'],
@ -446,12 +550,9 @@ class SoftwareBan(View):
async def delete(self, request: Request) -> Response: async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['name'], []) data = await self.get_api_data(['name'], [])
if isinstance(data, Response):
return data
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_software_ban(data['name']) is None: if conn.get_software_ban(data['name']) is None:
return Response.new_error(404, 'Software not banned', 'json') raise HttpError(404, 'Software not banned')
conn.del_software_ban(data['name']) conn.del_software_ban(data['name'])
@ -474,12 +575,9 @@ class User(View):
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
data = await self.get_api_data(['username', 'password'], ['handle']) data = await self.get_api_data(['username', 'password'], ['handle'])
if isinstance(data, Response):
return data
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_user(data['username']) is not None: if conn.get_user(data['username']) is not None:
return Response.new_error(404, 'User already exists', 'json') raise HttpError(404, 'User already exists')
user = conn.put_user( user = conn.put_user(
username = data['username'], username = data['username'],
@ -494,9 +592,6 @@ class User(View):
async def patch(self, request: Request) -> Response: async def patch(self, request: Request) -> Response:
data = await self.get_api_data(['username'], ['password', 'handle']) data = await self.get_api_data(['username'], ['password', 'handle'])
if isinstance(data, Response):
return data
with self.database.session(True) as conn: with self.database.session(True) as conn:
user = conn.put_user( user = conn.put_user(
username = data['username'], username = data['username'],
@ -511,12 +606,9 @@ class User(View):
async def delete(self, request: Request) -> Response: async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['username'], []) data = await self.get_api_data(['username'], [])
if isinstance(data, Response):
return data
with self.database.session(True) as conn: with self.database.session(True) as conn:
if conn.get_user(data['username']) is None: if conn.get_user(data['username']) is None:
return Response.new_error(404, 'User does not exist', 'json') raise HttpError(404, 'User does not exist')
conn.del_user(data['username']) conn.del_user(data['username'])
@ -535,14 +627,11 @@ class Whitelist(View):
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
data = await self.get_api_data(['domain'], []) data = await self.get_api_data(['domain'], [])
if isinstance(data, Response):
return data
domain = data['domain'].encode('idna').decode() domain = data['domain'].encode('idna').decode()
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_domain_whitelist(domain) is not None: if conn.get_domain_whitelist(domain) is not None:
return Response.new_error(400, 'Domain already added to whitelist', 'json') raise HttpError(400, 'Domain already added to whitelist')
item = conn.put_domain_whitelist(domain) item = conn.put_domain_whitelist(domain)
@ -552,14 +641,11 @@ class Whitelist(View):
async def delete(self, request: Request) -> Response: async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['domain'], []) data = await self.get_api_data(['domain'], [])
if isinstance(data, Response):
return data
domain = data['domain'].encode('idna').decode() domain = data['domain'].encode('idna').decode()
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_domain_whitelist(domain) is None: if conn.get_domain_whitelist(domain) is None:
return Response.new_error(404, 'Domain not in whitelist', 'json') raise HttpError(404, 'Domain not in whitelist')
conn.del_domain_whitelist(domain) conn.del_domain_whitelist(domain)

View file

@ -1,10 +1,8 @@
from __future__ import annotations from __future__ import annotations
from Crypto.Random import get_random_bytes
from aiohttp.abc import AbstractView from aiohttp.abc import AbstractView
from aiohttp.hdrs import METH_ALL as METHODS from aiohttp.hdrs import METH_ALL as METHODS
from aiohttp.web import HTTPMethodNotAllowed, Request from aiohttp.web import Request
from base64 import b64encode
from bsql import Database from bsql import Database
from collections.abc import Awaitable, Callable, Generator, Sequence, Mapping from collections.abc import Awaitable, Callable, Generator, Sequence, Mapping
from functools import cached_property from functools import cached_property
@ -15,7 +13,7 @@ from ..cache import Cache
from ..config import Config from ..config import Config
from ..database import Connection from ..database import Connection
from ..http_client import HttpClient from ..http_client import HttpClient
from ..misc import Response, get_app from ..misc import HttpError, Response, get_app
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Self from typing import Self
@ -43,10 +41,10 @@ def register_route(*paths: str) -> Callable[[type[View]], type[View]]:
class View(AbstractView): class View(AbstractView):
def __await__(self) -> Generator[Any, None, Response]: def __await__(self) -> Generator[Any, None, Response]:
if self.request.method not in METHODS: if self.request.method not in METHODS:
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) raise HttpError(405, f'"{self.request.method}" method not allowed')
if not (handler := self.handlers.get(self.request.method)): if not (handler := self.handlers.get(self.request.method)):
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) raise HttpError(405, f'"{self.request.method}" method not allowed')
return self._run_handler(handler).__await__() return self._run_handler(handler).__await__()
@ -58,7 +56,6 @@ class View(AbstractView):
async def _run_handler(self, handler: HandlerCallback, **kwargs: Any) -> Response: async def _run_handler(self, handler: HandlerCallback, **kwargs: Any) -> Response:
self.request['hash'] = b64encode(get_random_bytes(16)).decode('ascii')
return await handler(self.request, **self.request.match_info, **kwargs) return await handler(self.request, **self.request.match_info, **kwargs)
@ -117,17 +114,18 @@ class View(AbstractView):
async def get_api_data(self, async def get_api_data(self,
required: list[str], required: list[str],
optional: list[str]) -> dict[str, str] | Response: optional: list[str]) -> dict[str, str]:
if self.request.content_type in {'x-www-form-urlencoded', 'multipart/form-data'}: if self.request.content_type in {'application/x-www-form-urlencoded', 'multipart/form-data'}:
post_data = convert_data(await self.request.post()) post_data = convert_data(await self.request.post())
# post_data = {key: value for key, value in parse_qsl(await self.request.text())}
elif self.request.content_type == 'application/json': elif self.request.content_type == 'application/json':
try: try:
post_data = convert_data(await self.request.json()) post_data = convert_data(await self.request.json())
except JSONDecodeError: except JSONDecodeError:
return Response.new_error(400, 'Invalid JSON data', 'json') raise HttpError(400, 'Invalid JSON data')
else: else:
post_data = convert_data(self.request.query) post_data = convert_data(self.request.query)
@ -139,9 +137,9 @@ class View(AbstractView):
data[key] = post_data[key] data[key] = post_data[key]
except KeyError as e: except KeyError as e:
return Response.new_error(400, f'Missing {str(e)} pararmeter', 'json') raise HttpError(400, f'Missing {str(e)} pararmeter')
for key in optional: for key in optional:
data[key] = post_data.get(key, '') data[key] = post_data.get(key) # type: ignore[assignment]
return data return data

View file

@ -1,18 +1,13 @@
from aiohttp import web from aiohttp import web
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import Any from typing import Any
from urllib.parse import unquote
from .base import View, register_route from .base import View, register_route
from ..database import THEMES from ..database import THEMES
from ..logger import LogLevel from ..logger import LogLevel
from ..misc import Response, get_app from ..misc import TOKEN_PATHS, Response
UNAUTH_ROUTES = {
'/',
'/login'
}
@web.middleware @web.middleware
@ -20,27 +15,24 @@ async def handle_frontend_path(
request: web.Request, request: web.Request,
handler: Callable[[web.Request], Awaitable[Response]]) -> Response: handler: Callable[[web.Request], Awaitable[Response]]) -> Response:
app = get_app() if request['user'] is not None and request.path == '/login':
return Response.new_redir('/')
if request.path in UNAUTH_ROUTES or request.path.startswith('/admin'): if request.path.startswith(TOKEN_PATHS) and request['user'] is None:
request['token'] = request.cookies.get('user-token') if request.path == '/logout':
request['user'] = None return Response.new_redir('/')
if request['token']: response = Response.new_redir(f'/login?redir={request.path}')
with app.database.session(False) as conn:
request['user'] = conn.get_user_by_token(request['token'])
if request['user'] and request.path == '/login': if request['token'] is not None:
return Response.new('', 302, {'Location': '/'})
if not request['user'] and request.path.startswith('/admin'):
response = Response.new('', 302, {'Location': f'/login?redir={request.path}'})
response.del_cookie('user-token') response.del_cookie('user-token')
return response return response
response = await handler(request) response = await handler(request)
if not request.path.startswith('/api') and not request['user'] and request['token']: if not request.path.startswith('/api'):
if request['user'] is None and request['token'] is not None:
response.del_cookie('user-token') response.del_cookie('user-token')
return response return response
@ -54,14 +46,15 @@ class HomeView(View):
'instances': tuple(conn.get_inboxes()) 'instances': tuple(conn.get_inboxes())
} }
data = self.template.render('page/home.haml', self, **context) data = self.template.render('page/home.haml', self.request, **context)
return Response.new(data, ctype='html') return Response.new(data, ctype='html')
@register_route('/login') @register_route('/login')
class Login(View): class Login(View):
async def get(self, request: web.Request) -> Response: async def get(self, request: web.Request) -> Response:
data = self.template.render('page/login.haml', self) redir = unquote(request.query.get('redir', '/'))
data = self.template.render('page/login.haml', self.request, redir = redir)
return Response.new(data, ctype = 'html') return Response.new(data, ctype = 'html')
@ -69,7 +62,7 @@ class Login(View):
class Logout(View): class Logout(View):
async def get(self, request: web.Request) -> Response: async def get(self, request: web.Request) -> Response:
with self.database.session(True) as conn: with self.database.session(True) as conn:
conn.del_token(request['token']) conn.del_token(request['token'].code)
resp = Response.new_redir('/') resp = Response.new_redir('/')
resp.del_cookie('user-token', domain = self.config.domain, path = '/') resp.del_cookie('user-token', domain = self.config.domain, path = '/')
@ -79,7 +72,7 @@ class Logout(View):
@register_route('/admin') @register_route('/admin')
class Admin(View): class Admin(View):
async def get(self, request: web.Request) -> Response: async def get(self, request: web.Request) -> Response:
return Response.new('', 302, {'Location': '/admin/instances'}) return Response.new_redir(f'/login?redir={request.path}', 301)
@register_route('/admin/instances') @register_route('/admin/instances')
@ -101,7 +94,7 @@ class AdminInstances(View):
if message: if message:
context['message'] = message context['message'] = message
data = self.template.render('page/admin-instances.haml', self, **context) data = self.template.render('page/admin-instances.haml', self.request, **context)
return Response.new(data, ctype = 'html') return Response.new(data, ctype = 'html')
@ -123,7 +116,7 @@ class AdminWhitelist(View):
if message: if message:
context['message'] = message context['message'] = message
data = self.template.render('page/admin-whitelist.haml', self, **context) data = self.template.render('page/admin-whitelist.haml', self.request, **context)
return Response.new(data, ctype = 'html') return Response.new(data, ctype = 'html')
@ -145,7 +138,7 @@ class AdminDomainBans(View):
if message: if message:
context['message'] = message context['message'] = message
data = self.template.render('page/admin-domain_bans.haml', self, **context) data = self.template.render('page/admin-domain_bans.haml', self.request, **context)
return Response.new(data, ctype = 'html') return Response.new(data, ctype = 'html')
@ -167,7 +160,7 @@ class AdminSoftwareBans(View):
if message: if message:
context['message'] = message context['message'] = message
data = self.template.render('page/admin-software_bans.haml', self, **context) data = self.template.render('page/admin-software_bans.haml', self.request, **context)
return Response.new(data, ctype = 'html') return Response.new(data, ctype = 'html')
@ -189,7 +182,7 @@ class AdminUsers(View):
if message: if message:
context['message'] = message context['message'] = message
data = self.template.render('page/admin-users.haml', self, **context) data = self.template.render('page/admin-users.haml', self.request, **context)
return Response.new(data, ctype = 'html') return Response.new(data, ctype = 'html')
@ -213,7 +206,7 @@ class AdminConfig(View):
} }
} }
data = self.template.render('page/admin-config.haml', self, **context) data = self.template.render('page/admin-config.haml', self.request, **context)
return Response.new(data, ctype = 'html') return Response.new(data, ctype = 'html')
@ -251,5 +244,5 @@ class ThemeCss(View):
except KeyError: except KeyError:
return Response.new('Invalid theme', 404) return Response.new('Invalid theme', 404)
data = self.template.render('variables.css', self, **context) data = self.template.render('variables.css', self.request, **context)
return Response.new(data, ctype = 'css') return Response.new(data, ctype = 'css')