Compare commits

..

No commits in common. "773922e2630ad355340f5497ac5a86f899817985" and "b22b5bbefaa1b6cf13deaeb65396b135dc3fb192" have entirely different histories.

19 changed files with 283 additions and 701 deletions

View file

@ -4,14 +4,11 @@ import asyncio
import multiprocessing import multiprocessing
import signal import signal
import time import time
import traceback
from Crypto.Random import get_random_bytes
from aiohttp import web from aiohttp import web
from aiohttp.web import HTTPException, StaticResource from aiohttp.web import StaticResource
from aiohttp_swagger import setup_swagger from aiohttp_swagger import setup_swagger
from aputils.signer import Signer from aputils.signer import Signer
from base64 import b64encode
from bsql import Database from bsql import Database
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -26,8 +23,7 @@ from .config import Config
from .database import Connection, get_database from .database import Connection, get_database
from .database.schema import Instance from .database.schema import Instance
from .http_client import HttpClient from .http_client import HttpClient
from .misc import HttpError, Message, Response, check_open_port, get_resource from .misc import Message, Response, check_open_port, get_resource
from .misc import JSON_PATHS, TOKEN_PATHS
from .template import Template from .template import Template
from .views import VIEWS from .views import VIEWS
from .views.api import handle_api_path from .views.api import handle_api_path
@ -57,9 +53,9 @@ class Application(web.Application):
def __init__(self, cfgpath: Path | None, dev: bool = False): def __init__(self, cfgpath: Path | None, dev: bool = False):
web.Application.__init__(self, web.Application.__init__(self,
middlewares = [ middlewares = [
handle_response_headers, # type: ignore[list-item] handle_api_path, # type: ignore[list-item]
handle_frontend_path, # type: ignore[list-item] handle_frontend_path, # type: ignore[list-item]
handle_api_path # type: ignore[list-item] handle_response_headers # type: ignore[list-item]
] ]
) )
@ -286,76 +282,19 @@ class CacheCleanupThread(Thread):
self.running.clear() self.running.clear()
def format_error(request: web.Request, error: HttpError) -> Response:
app: Application = request.app # type: ignore[assignment]
if request.path.startswith(JSON_PATHS) or 'json' in request.headers.get('accept', ''):
return Response.new({'error': error.body}, error.status, ctype = 'json')
else:
body = app.template.render('page/error.haml', request, e = error)
return Response.new(body, error.status, ctype = 'html')
@web.middleware @web.middleware
async def handle_response_headers( async def handle_response_headers(
request: web.Request, request: web.Request,
handler: Callable[[web.Request], Awaitable[Response]]) -> Response: handler: Callable[[web.Request], Awaitable[Response]]) -> Response:
request['hash'] = b64encode(get_random_bytes(16)).decode('ascii')
request['token'] = None
request['user'] = None
app: Application = request.app # type: ignore[assignment]
if request.path == "/" or request.path.startswith(TOKEN_PATHS):
with app.database.session() as conn:
tokens = (
request.headers.get('Authorization', '').replace('Bearer', '').strip(),
request.cookies.get('user-token')
)
for token in tokens:
if not token:
continue
request['token'] = conn.get_app_by_token(token)
if request['token'] is not None:
request['user'] = conn.get_user(request['token'].user)
break
try:
resp = await handler(request) resp = await handler(request)
except HttpError as e:
resp = format_error(request, e)
except HTTPException as ae:
if ae.status == 404:
try:
text = (ae.text or "").split(":")[1].strip()
except IndexError:
text = ae.text or ""
resp = format_error(request, HttpError(ae.status, text))
else:
raise
except Exception as e:
resp = format_error(request, HttpError(500, f'{type(e).__name__}: {str(e)}'))
traceback.print_exc()
resp.headers['Server'] = 'ActivityRelay' resp.headers['Server'] = 'ActivityRelay'
# Still have to figure out how csp headers work # Still have to figure out how csp headers work
if resp.content_type == 'text/html': if resp.content_type == 'text/html' and not request.path.startswith("/api"):
resp.headers['Content-Security-Policy'] = get_csp(request) resp.headers['Content-Security-Policy'] = get_csp(request)
if not request.app['dev'] and request.path.endswith(('.css', '.js', '.woff2')): if not request.app['dev'] and request.path.endswith(('.css', '.js')):
# cache for 2 weeks # cache for 2 weeks
resp.headers['Cache-Control'] = 'public,max-age=1209600,immutable' resp.headers['Cache-Control'] = 'public,max-age=1209600,immutable'

View file

@ -51,8 +51,8 @@ WHERE username = :value or handle = :value;
-- name: get-user-by-token -- name: get-user-by-token
SELECT * FROM users SELECT * FROM users
WHERE username = ( WHERE username = (
SELECT user FROM app SELECT user FROM tokens
WHERE token = :token WHERE code = :code
); );
@ -67,28 +67,25 @@ DELETE FROM users
WHERE username = :value or handle = :value; WHERE username = :value or handle = :value;
-- name: get-app -- name: get-token
SELECT * FROM app SELECT * FROM tokens
WHERE client_id = :id and client_secret = :secret; WHERE code = :code;
-- name: get-app-with-token -- name: put-token
SELECT * FROM app INSERT INTO tokens (code, user, created)
WHERE client_id = :id and client_secret = :secret and token = :token; VALUES (:code, :user, :created)
RETURNING *;
-- name: get-app-by-token -- name: del-token
SELECT * FROM apps DELETE FROM tokens
WHERE token = :token; WHERE code = :code;
-- name: del-app
DELETE FROM apps
WHERE client_id = :id and client_secret = :secret;
-- name: del-app-with-token -- name: del-token-user
DELETE FROM apps DELETE FROM tokens
WHERE client_id = :id and client_secret = :secret and token = :token; WHERE user = :username;
-- name: get-software-ban -- name: get-software-ban

View file

@ -11,8 +11,12 @@ from .. import logger as logging
from ..misc import boolean from ..misc import boolean
if TYPE_CHECKING: if TYPE_CHECKING:
try:
from typing import Self from typing import Self
except ImportError:
from typing_extensions import Self
THEMES = { THEMES = {
'default': { 'default': {
@ -73,7 +77,7 @@ CONFIG_CONVERT: dict[str, tuple[Callable[[Any], str], Callable[[str], Any]]] = {
@dataclass() @dataclass()
class ConfigData: class ConfigData:
schema_version: int = 20240625 schema_version: int = 20240310
private_key: str = '' private_key: str = ''
approval_required: bool = False approval_required: bool = False
log_level: logging.LogLevel = logging.LogLevel.INFO log_level: logging.LogLevel = logging.LogLevel.INFO
@ -111,11 +115,11 @@ class ConfigData:
@classmethod @classmethod
def DEFAULT(cls: type[Self], key: str) -> str | int | bool: def DEFAULT(cls: type[Self], key: str) -> str | int | bool:
return cls.FIELD(key.replace('-', '_')).default # type: ignore[return-value] return cls.FIELD(key.replace('-', '_')).default # type: ignore
@classmethod @classmethod
def FIELD(cls: type[Self], key: str) -> Field[str | int | bool]: def FIELD(cls: type[Self], key: str) -> Field[Any]:
for field in fields(cls): for field in fields(cls):
if field.name == key.replace('-', '_'): if field.name == key.replace('-', '_'):
return field return field

View file

@ -1,14 +1,12 @@
from __future__ import annotations from __future__ import annotations
import secrets
from argon2 import PasswordHasher from argon2 import PasswordHasher
from blib import Date
from bsql import Connection as SqlConnection, Row, Update from bsql import Connection as SqlConnection, Row, Update
from collections.abc import Iterator from collections.abc import Iterator
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse from urllib.parse import urlparse
from uuid import uuid4
from . import schema from . import schema
from .config import ( from .config import (
@ -51,36 +49,6 @@ class Connection(SqlConnection):
yield instance yield instance
def fix_timestamps(self) -> None:
for app in self.select('apps').all(schema.App):
data = {'created': app.created.timestamp(), 'accessed': app.accessed.timestamp()}
self.update('apps', data, client_id = app.client_id)
for item in self.select('cache'):
data = {'updated': Date.parse(item['updated']).timestamp()}
self.update('cache', data, id = item['id'])
for dban in self.select('domain_bans').all(schema.DomainBan):
data = {'created': dban.created.timestamp()}
self.update('domain_bans', data, domain = dban.domain)
for instance in self.select('inboxes').all(schema.Instance):
data = {'created': instance.created.timestamp()}
self.update('inboxes', data, domain = instance.domain)
for sban in self.select('software_bans').all(schema.SoftwareBan):
data = {'created': sban.created.timestamp()}
self.update('software_bans', data, name = sban.name)
for user in self.select('users').all(schema.User):
data = {'created': user.created.timestamp()}
self.update('users', data, username = user.username)
for wlist in self.select('whitelist').all(schema.Whitelist):
data = {'created': wlist.created.timestamp()}
self.update('whitelist', data, domain = wlist.domain)
def get_config(self, key: str) -> Any: def get_config(self, key: str) -> Any:
key = key.replace('_', '-') key = key.replace('_', '-')
@ -225,8 +193,8 @@ class Connection(SqlConnection):
return cur.one(schema.User) return cur.one(schema.User)
def get_user_by_token(self, token: str) -> schema.User | None: def get_user_by_token(self, code: str) -> schema.User | None:
with self.run('get-user-by-token', {'token': token}) as cur: with self.run('get-user-by-token', {'code': code}) as cur:
return cur.one(schema.User) return cur.one(schema.User)
@ -281,114 +249,35 @@ class Connection(SqlConnection):
pass pass
def get_app(self, def get_token(self, code: str) -> schema.Token | None:
client_id: str, with self.run('get-token', {'code': code}) as cur:
client_secret: str, return cur.one(schema.Token)
token: str | None = None) -> schema.App | None:
params = {
'id': client_id, def get_tokens(self, username: str | None = None) -> Iterator[schema.Token]:
'secret': client_secret if username is not None:
return self.select('tokens').all(schema.Token)
return self.select('tokens', username = username).all(schema.Token)
def put_token(self, username: str) -> schema.Token:
data = {
'code': uuid4().hex,
'user': username,
'created': datetime.now(tz = timezone.utc)
} }
if token is not None: with self.run('put-token', data) as cur:
command = 'get-app-with-token' if (row := cur.one(schema.Token)) is None:
params['token'] = token raise RuntimeError(f"Failed to insert token for user: {username}")
else:
command = 'get-app'
with self.run(command, params) as cur:
return cur.one(schema.App)
def get_app_by_token(self, token: str) -> schema.App | None:
with self.run('get-app-by-token', {'token': token}) as cur:
return cur.one(schema.App)
def put_app(self, name: str, redirect_uri: str, website: str | None = None) -> schema.App:
params = {
'name': name,
'redirect_uri': redirect_uri,
'website': website,
'client_id': secrets.token_hex(20),
'client_secret': secrets.token_hex(20),
'created': Date.new_utc().timestamp(),
'accessed': Date.new_utc().timestamp()
}
with self.insert('apps', params) as cur:
if (row := cur.one(schema.App)) is None:
raise RuntimeError(f'Failed to insert app: {name}')
return row return row
def put_app_login(self, user: schema.User) -> schema.App: def del_token(self, code: str) -> None:
params = { with self.run('del-token', {'code': code}):
'name': 'Web', pass
'redirect_uri': 'urn:ietf:wg:oauth:2.0:oob',
'website': None,
'user': user.username,
'client_id': secrets.token_hex(20),
'client_secret': secrets.token_hex(20),
'auth_code': None,
'token': secrets.token_hex(20),
'created': Date.new_utc().timestamp(),
'accessed': Date.new_utc().timestamp()
}
with self.insert('apps', params) as cur:
if (row := cur.one(schema.App)) is None:
raise RuntimeError(f'Failed to create app for "{user.username}"')
return row
def update_app(self, app: schema.App, user: schema.User | None, set_auth: bool) -> schema.App:
data: dict[str, str | None] = {}
if user is not None:
data['user'] = user.username
if set_auth:
data['auth_code'] = secrets.token_hex(20)
else:
data['token'] = secrets.token_hex(20)
data['auth_code'] = None
params = {
'client_id': app.client_id,
'client_secret': app.client_secret
}
with self.update('app', data, **params) as cur: # type: ignore[arg-type]
if (row := cur.one(schema.App)) is None:
raise RuntimeError('Failed to update row')
return row
def del_app(self, client_id: str, client_secret: str, token: str | None = None) -> bool:
params = {
'id': client_id,
'secret': client_secret
}
if token is not None:
command = 'del-app-with-token'
params['token'] = token
else:
command = 'del-app'
with self.run(command, params) as cur:
if cur.row_count > 1:
raise RuntimeError('More than 1 row was deleted')
return cur.row_count == 0
def get_domain_ban(self, domain: str) -> schema.DomainBan | None: def get_domain_ban(self, domain: str) -> schema.DomainBan | None:

View file

@ -1,14 +1,14 @@
from __future__ import annotations from __future__ import annotations
from blib import Date import typing
from bsql import Column, Row, Tables from bsql import Column, Row, Tables
from collections.abc import Callable from collections.abc import Callable
from copy import deepcopy from datetime import datetime
from typing import TYPE_CHECKING, Any
from .config import ConfigData from .config import ConfigData
if TYPE_CHECKING: if typing.TYPE_CHECKING:
from .connection import Connection from .connection import Connection
@ -16,16 +16,6 @@ VERSIONS: dict[int, Callable[[Connection], None]] = {}
TABLES = Tables() TABLES = Tables()
def deserialize_timestamp(value: Any) -> Date:
try:
return Date.parse(value)
except ValueError:
pass
return Date.fromisoformat(value)
@TABLES.add_row @TABLES.add_row
class Config(Row): class Config(Row):
key: Column[str] = Column('key', 'text', primary_key = True, unique = True, nullable = False) key: Column[str] = Column('key', 'text', primary_key = True, unique = True, nullable = False)
@ -37,107 +27,62 @@ class Config(Row):
class Instance(Row): class Instance(Row):
table_name: str = 'inboxes' table_name: str = 'inboxes'
domain: Column[str] = Column( domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = False) 'domain', 'text', primary_key = True, unique = True, nullable = False)
actor: Column[str] = Column('actor', 'text', unique = True) actor: Column[str] = Column('actor', 'text', unique = True)
inbox: Column[str] = Column('inbox', 'text', unique = True, nullable = False) inbox: Column[str] = Column('inbox', 'text', unique = True, nullable = False)
followid: Column[str] = Column('followid', 'text') followid: Column[str] = Column('followid', 'text')
software: Column[str] = Column('software', 'text') software: Column[str] = Column('software', 'text')
accepted: Column[Date] = Column('accepted', 'boolean') accepted: Column[datetime] = Column('accepted', 'boolean')
created: Column[Date] = Column( created: Column[datetime] = Column('created', 'timestamp', nullable = False)
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
@TABLES.add_row @TABLES.add_row
class Whitelist(Row): class Whitelist(Row):
domain: Column[str] = Column( domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = True) 'domain', 'text', primary_key = True, unique = True, nullable = True)
created: Column[Date] = Column( created: Column[datetime] = Column('created', 'timestamp')
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
@TABLES.add_row @TABLES.add_row
class DomainBan(Row): class DomainBan(Row):
table_name: str = 'domain_bans' table_name: str = 'domain_bans'
domain: Column[str] = Column( domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = True) 'domain', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text') reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text') note: Column[str] = Column('note', 'text')
created: Column[Date] = Column( created: Column[datetime] = Column('created', 'timestamp')
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
@TABLES.add_row @TABLES.add_row
class SoftwareBan(Row): class SoftwareBan(Row):
table_name: str = 'software_bans' table_name: str = 'software_bans'
name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True) name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text') reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text') note: Column[str] = Column('note', 'text')
created: Column[Date] = Column( created: Column[datetime] = Column('created', 'timestamp')
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
@TABLES.add_row @TABLES.add_row
class User(Row): class User(Row):
table_name: str = 'users' table_name: str = 'users'
username: Column[str] = Column( username: Column[str] = Column(
'username', 'text', primary_key = True, unique = True, nullable = False) 'username', 'text', primary_key = True, unique = True, nullable = False)
hash: Column[str] = Column('hash', 'text', nullable = False) hash: Column[str] = Column('hash', 'text', nullable = False)
handle: Column[str] = Column('handle', 'text') handle: Column[str] = Column('handle', 'text')
created: Column[Date] = Column( created: Column[datetime] = Column('created', 'timestamp')
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
@TABLES.add_row @TABLES.add_row
class App(Row): class Token(Row):
table_name: str = 'apps' table_name: str = 'tokens'
code: Column[str] = Column('code', 'text', primary_key = True, unique = True, nullable = False)
client_id: Column[str] = Column( user: Column[str] = Column('user', 'text', nullable = False)
'client_id', 'text', primary_key = True, unique = True, nullable = False) created: Column[datetime] = Column('created', 'timestamp')
client_secret: Column[str] = Column('client_secret', 'text', nullable = False)
name: Column[str] = Column('name', 'text')
website: Column[str] = Column('website', 'text')
redirect_uri: Column[str] = Column('redirect_uri', 'text', nullable = False)
token: Column[str | None] = Column('token', 'text')
auth_code: Column[str | None] = Column('auth_code', 'text')
user: Column[str | None] = Column('user', 'text')
created: Column[Date] = Column(
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
accessed: Column[Date] = Column(
'accessed', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
def get_api_data(self, include_token: bool = False) -> dict[str, Any]:
data = deepcopy(self)
data.pop('user')
data.pop('auth_code')
if not include_token:
data.pop('token')
return data
def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]: def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]:
@ -158,11 +103,5 @@ def migrate_20240206(conn: Connection) -> None:
@migration @migration
def migrate_20240310(conn: Connection) -> None: def migrate_20240310(conn: Connection) -> None:
conn.execute('ALTER TABLE "inboxes" ADD COLUMN "accepted" BOOLEAN').close() conn.execute("ALTER TABLE inboxes ADD COLUMN accepted BOOLEAN")
conn.execute('UPDATE "inboxes" SET accepted = 1').close() conn.execute("UPDATE inboxes SET accepted = 1")
@migration
def migrate_20240625(conn: Connection) -> None:
conn.create_tables()
conn.execute('DROP TABLE tokens').close()

View file

@ -1,5 +1,5 @@
-macro menu_item(name, path) -macro menu_item(name, path)
-if request.path == path or (path != "/" and request.path.startswith(path)) -if view.request.path == path or (path != "/" and view.request.path.startswith(path))
%a.button(href="{{path}}" active="true") -> =name %a.button(href="{{path}}" active="true") -> =name
-else -else
@ -10,12 +10,12 @@
%head %head
%title << {{config.name}}: {{page}} %title << {{config.name}}: {{page}}
%meta(charset="UTF-8") %meta(charset="UTF-8")
%meta(name="ort" content="width=device-width, initial-scale=1") %meta(name="viewport" content="width=device-width, initial-scale=1")
%link(rel="stylesheet" type="text/css" href="/theme/{{config.theme}}.css?{{version}}" nonce="{{request['hash']}}" class="theme") %link(rel="stylesheet" type="text/css" href="/theme/{{config.theme}}.css?{{version}}" nonce="{{view.request['hash']}}" class="theme")
%link(rel="stylesheet" type="text/css" href="/static/style.css?{{version}}" nonce="{{request['hash']}}") %link(rel="stylesheet" type="text/css" href="/static/style.css?{{version}}" nonce="{{view.request['hash']}}")
%link(rel="stylesheet" type="text/css" href="/static/bootstrap-icons.css?{{version}}" nonce="{{request['hash']}}") %link(rel="stylesheet" type="text/css" href="/static/bootstrap-icons.css?{{version}}" nonce="{{view.request['hash']}}")
%link(rel="manifest" href="/manifest.json?{{version}}") %link(rel="manifest" href="/manifest.json?{{version}}")
%script(type="application/javascript" src="/static/functions.js?{{version}}" nonce="{{request['hash']}}" defer) %script(type="application/javascript" src="/static/functions.js?{{version}}" nonce="{{view.request['hash']}}" defer)
-block head -block head
%body %body
@ -26,7 +26,7 @@
{{menu_item("Home", "/")}} {{menu_item("Home", "/")}}
-if request["user"] -if view.request["user"]
{{menu_item("Instances", "/admin/instances")}} {{menu_item("Instances", "/admin/instances")}}
{{menu_item("Whitelist", "/admin/whitelist")}} {{menu_item("Whitelist", "/admin/whitelist")}}
{{menu_item("Domain Bans", "/admin/domain_bans")}} {{menu_item("Domain Bans", "/admin/domain_bans")}}
@ -61,11 +61,11 @@
#footer.section #footer.section
.col1 .col1
-if not request["user"] -if not view.request["user"]
%a(href="/login") << Login %a(href="/login") << Login
-else -else
=request["user"]["username"] =view.request["user"]["username"]
( (
%a(href="/logout") << Logout %a(href="/logout") << Logout
) )

View file

@ -1,31 +0,0 @@
-extends "base.haml"
-set page="App Authorization"
-block content
%fieldset.section
%legend << App Authorization
-if application.website
#title << Application "<a href="{{application.website}}" target="_new">{{application.name}}</a>" wants full API access
-else
#title << Application "{{application.name}}" wants full API access
#buttons
.spacer
%form(action="/oauth/authorize" method="POST")
%input(type="hidden" name="client_id" value="{{application.client_id}}")
%input(type="hidden" name="client_secret" value="{{application.client_secret}}")
%input(type="hidden" name="redirect_uri" value="{{application.redirect_uri}}")
%input(type="hidden" name="response" value="true")
%input.button(type="submit" value="Allow")
%form(action="/oauth/authorize" method="POST")
%input(type="hidden" name="client_id" value="{{application.client_id}}")
%input(type="hidden" name="client_secret" value="{{application.client_secret}}")
%input(type="hidden" name="redirect_uri" value="{{application.redirect_uri}}")
%input(type="hidden" name="response" value="false")
%input.button(type="submit" value="Deny")
.spacer

View file

@ -1,18 +0,0 @@
-extends "base.haml"
-set page="App Authorization"
-block content
%fieldset.section
%legend << App Authorization Code
-if application.website
%p
Copy the following code into
%a(href="{{application.website}}" target="_main") -> %code -> =application.name
-else
%p
Copy the following code info
%code -> =application.name
%pre#code -> =application.auth_code

View file

@ -1,7 +0,0 @@
-extends "base.haml"
-set page="Error"
-block content
.section.error
.title << HTTP Error {{e.status}}
.body -> =e.body

View file

@ -12,6 +12,4 @@
%label(for="password") << Password %label(for="password") << Password
%input(id="password" name="password" placeholder="Password" type="password") %input(id="password" name="password" placeholder="Password" type="password")
%input#redir(type="hidden" name="redir" value="{{redir}}")
%input.submit(type="button" value="Login") %input.submit(type="button" value="Login")

View file

@ -483,15 +483,13 @@ function page_instance() {
function page_login() { function page_login() {
const fields = { const fields = {
username: document.querySelector("#username"), username: document.querySelector("#username"),
password: document.querySelector("#password"), password: document.querySelector("#password")
redir: document.querySelector("#redir") }
};
async function login(event) { async function login(event) {
const values = { const values = {
username: fields.username.value.trim(), username: fields.username.value.trim(),
password: fields.password.value.trim(), password: fields.password.value.trim()
redir: fields.redir.value.trim()
} }
if (values.username === "" | values.password === "") { if (values.username === "" | values.password === "") {
@ -500,14 +498,14 @@ function page_login() {
} }
try { try {
await request("POST", "v1/login", values); await request("POST", "v1/token", values);
} catch (error) { } catch (error) {
toast(error); toast(error);
return; return;
} }
document.location = values.redir; document.location = "/";
} }
@ -850,6 +848,9 @@ if (location.pathname.startsWith("/admin/config")) {
} else if (location.pathname.startsWith("/admin/instances")) { } else if (location.pathname.startsWith("/admin/instances")) {
page_instance(); page_instance();
} else if (location.pathname.startsWith("/admin/login")) {
page_login();
} else if (location.pathname.startsWith("/admin/software_bans")) { } else if (location.pathname.startsWith("/admin/software_bans")) {
page_software_ban(); page_software_ban();
@ -858,7 +859,4 @@ if (location.pathname.startsWith("/admin/config")) {
} else if (location.pathname.startsWith("/admin/whitelist")) { } else if (location.pathname.startsWith("/admin/whitelist")) {
page_whitelist(); page_whitelist();
} else if (location.pathname.startsWith("/login")) {
page_login();
} }

View file

@ -338,44 +338,6 @@ textarea {
} }
/* error */
#content.page-error {
text-align: center;
}
#content.page-error .title {
font-size: 24px;
font-weight: bold;
}
/* auth */
#content.page-app_authorization {
text-align: center;
}
#content.page-app_authorization #code {
background: var(--background);
border: 1px solid var(--border);
font-size: 18px;
margin: 0 auto;
width: max-content;
padding: 5px;
}
#content.page-app_authorization #title {
font-size: 24px;
}
#content.page-app_authorization #buttons {
display: grid;
grid-template-columns: auto max-content max-content auto;
grid-gap: var(--spacing);
justify-items: center;
margin: var(--spacing) 0;
}
@keyframes show_toast { @keyframes show_toast {
0% { 0% {
transform: translateX(100%); transform: translateX(100%);

View file

@ -212,21 +212,6 @@ def cli_run(ctx: click.Context, dev: bool = False) -> None:
os._exit(0) os._exit(0)
@cli.command('db-maintenance')
@click.option('--fix-timestamps', '-t', is_flag = True,
help = 'Make sure timestamps in the database are float values')
@click.pass_context
def cli_db_maintenance(ctx: click.Context, fix_timestamps: bool) -> None:
'Perform maintenance tasks on the database'
if fix_timestamps:
with ctx.obj.database.session(True) as conn:
conn.fix_timestamps()
with ctx.obj.database.session(False) as conn:
with conn.execute("VACUUM"):
pass
@cli.command('convert') @cli.command('convert')
@click.option('--old-config', '-o', help = 'Path to the config file to convert from') @click.option('--old-config', '-o', help = 'Path to the config file to convert from')
@ -254,18 +239,18 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
ctx.obj.config.set('domain', config['host']) ctx.obj.config.set('domain', config['host'])
ctx.obj.config.save() ctx.obj.config.save()
# fix: mypy complains about the types returned by click.progressbar when updating click to 8.1.7
with get_database(ctx.obj.config) as db: with get_database(ctx.obj.config) as db:
with db.session(True) as conn: with db.session(True) as conn:
conn.put_config('private-key', database['private-key']) conn.put_config('private-key', database['private-key'])
conn.put_config('note', config['note']) conn.put_config('note', config['note'])
conn.put_config('whitelist-enabled', config['whitelist_enabled']) conn.put_config('whitelist-enabled', config['whitelist_enabled'])
with click.progressbar( with click.progressbar( # type: ignore
database['relay-list'].values(), database['relay-list'].values(),
label = 'Inboxes'.ljust(15), label = 'Inboxes'.ljust(15),
width = 0 width = 0
) as inboxes: ) as inboxes:
for inbox in inboxes: for inbox in inboxes:
if inbox['software'] in {'akkoma', 'pleroma'}: if inbox['software'] in {'akkoma', 'pleroma'}:
actor = f'https://{inbox["domain"]}/relay' actor = f'https://{inbox["domain"]}/relay'
@ -284,7 +269,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
software = inbox['software'] software = inbox['software']
) )
with click.progressbar( with click.progressbar( # type: ignore
config['blocked_software'], config['blocked_software'],
label = 'Banned software'.ljust(15), label = 'Banned software'.ljust(15),
width = 0 width = 0
@ -296,7 +281,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
reason = 'relay' if software in RELAY_SOFTWARE else None reason = 'relay' if software in RELAY_SOFTWARE else None
) )
with click.progressbar( with click.progressbar( # type: ignore
config['blocked_instances'], config['blocked_instances'],
label = 'Banned domains'.ljust(15), label = 'Banned domains'.ljust(15),
width = 0 width = 0
@ -305,7 +290,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
for domain in banned_software: for domain in banned_software:
conn.put_domain_ban(domain) conn.put_domain_ban(domain)
with click.progressbar( with click.progressbar( # type: ignore
config['whitelist'], config['whitelist'],
label = 'Whitelist'.ljust(15), label = 'Whitelist'.ljust(15),
width = 0 width = 0

View file

@ -62,27 +62,6 @@ SOFTWARE = (
'gotosocial' 'gotosocial'
) )
JSON_PATHS: tuple[str, ...] = (
'/api/v1',
'/actor',
'/inbox',
'/outbox',
'/following',
'/followers',
'/.well-known',
'/nodeinfo',
'/oauth/token',
'/oauth/revoke'
)
TOKEN_PATHS: tuple[str, ...] = (
'/logout',
'/admin',
'/api',
'/oauth/authorize',
'/oauth/revoke'
)
def boolean(value: Any) -> bool: def boolean(value: Any) -> bool:
if isinstance(value, str): if isinstance(value, str):
@ -134,17 +113,6 @@ def get_resource(path: str) -> Path:
return Path(str(pkgfiles('relay'))).joinpath(path) return Path(str(pkgfiles('relay'))).joinpath(path)
class HttpError(Exception):
def __init__(self,
status: int,
body: str) -> None:
self.body: str = body
self.status: int = status
Exception.__init__(self, f"HTTP Error {status}: {body}")
class JsonEncoder(json.JSONEncoder): class JsonEncoder(json.JSONEncoder):
def default(self, o: Any) -> str: def default(self, o: Any) -> str:
if isinstance(o, datetime): if isinstance(o, datetime):
@ -274,9 +242,9 @@ class Response(AiohttpResponse):
@classmethod @classmethod
def new_redir(cls: type[Self], path: str, status: int = 307) -> Self: def new_redir(cls: type[Self], path: str) -> Self:
body = f'Redirect to <a href="{path}">{path}</a>' body = f'Redirect to <a href="{path}">{path}</a>'
return cls.new(body, status, {'Location': path}, ctype = 'html') return cls.new(body, 302, {'Location': path})
@property @property

View file

@ -2,7 +2,6 @@ from __future__ import annotations
import textwrap import textwrap
from aiohttp.web import Request
from collections.abc import Callable from collections.abc import Callable
from hamlish_jinja import HamlishExtension from hamlish_jinja import HamlishExtension
from jinja2 import Environment, FileSystemLoader from jinja2 import Environment, FileSystemLoader
@ -14,6 +13,7 @@ from typing import TYPE_CHECKING, Any
from . import __version__ from . import __version__
from .misc import get_resource from .misc import get_resource
from .views.base import View
if TYPE_CHECKING: if TYPE_CHECKING:
from .application import Application from .application import Application
@ -43,12 +43,12 @@ class Template(Environment):
self.hamlish_mode = 'indented' self.hamlish_mode = 'indented'
def render(self, path: str, request: Request, **context: Any) -> str: def render(self, path: str, view: View | None = None, **context: Any) -> str:
with self.app.database.session(False) as conn: with self.app.database.session(False) as conn:
config = conn.get_config_all() config = conn.get_config_all()
new_context = { new_context = {
'request': request, 'view': view,
'domain': self.app.config.domain, 'domain': self.app.config.domain,
'version': __version__, 'version': __version__,
'config': config, 'config': config,

View file

@ -7,7 +7,7 @@ from .base import View, register_route
from .. import logger as logging from .. import logger as logging
from ..database import schema from ..database import schema
from ..misc import HttpError, Message, Response from ..misc import Message, Response
from ..processors import run_processor from ..processors import run_processor
@ -39,7 +39,8 @@ class ActorView(View):
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
await self.get_post_data() if response := await self.get_post_data():
return response
with self.database.session() as conn: with self.database.session() as conn:
self.instance = conn.get_inbox(self.actor.shared_inbox) # type: ignore[assignment] self.instance = conn.get_inbox(self.actor.shared_inbox) # type: ignore[assignment]
@ -64,13 +65,13 @@ class ActorView(View):
return Response.new(status = 202) return Response.new(status = 202)
async def get_post_data(self) -> None: async def get_post_data(self) -> Response | None:
try: try:
self.signature = aputils.Signature.parse(self.request.headers['signature']) self.signature = aputils.Signature.parse(self.request.headers['signature'])
except KeyError: except KeyError:
logging.verbose('Missing signature header') logging.verbose('Missing signature header')
raise HttpError(400, 'missing signature header') return Response.new_error(400, 'missing signature header', 'json')
try: try:
message: Message | None = await self.request.json(loads = Message.parse) message: Message | None = await self.request.json(loads = Message.parse)
@ -78,17 +79,17 @@ class ActorView(View):
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
logging.verbose('Failed to parse inbox message') logging.verbose('Failed to parse inbox message')
raise HttpError(400, 'failed to parse message') return Response.new_error(400, 'failed to parse message', 'json')
if message is None: if message is None:
logging.verbose('empty message') logging.verbose('empty message')
raise HttpError(400, 'missing message') return Response.new_error(400, 'missing message', 'json')
self.message = message self.message = message
if 'actor' not in self.message: if 'actor' not in self.message:
logging.verbose('actor not in message') logging.verbose('actor not in message')
raise HttpError(400, 'no actor in message') return Response.new_error(400, 'no actor in message', 'json')
try: try:
self.actor = await self.client.get(self.signature.keyid, True, Message) self.actor = await self.client.get(self.signature.keyid, True, Message)
@ -97,24 +98,26 @@ class ActorView(View):
# ld signatures aren't handled atm, so just ignore it # ld signatures aren't handled atm, so just ignore it
if self.message.type == 'Delete': if self.message.type == 'Delete':
logging.verbose('Instance sent a delete which cannot be handled') logging.verbose('Instance sent a delete which cannot be handled')
raise HttpError(202, '') return Response.new(status=202)
logging.verbose(f'Failed to fetch actor: {self.signature.keyid}') logging.verbose(f'Failed to fetch actor: {self.signature.keyid}')
raise HttpError(400, 'failed to fetch actor') return Response.new_error(400, 'failed to fetch actor', 'json')
try: try:
self.signer = self.actor.signer self.signer = self.actor.signer
except KeyError: except KeyError:
logging.verbose('Actor missing public key: %s', self.signature.keyid) logging.verbose('Actor missing public key: %s', self.signature.keyid)
raise HttpError(400, 'actor missing public key') return Response.new_error(400, 'actor missing public key', 'json')
try: try:
await self.signer.validate_request_async(self.request) await self.signer.validate_request_async(self.request)
except aputils.SignatureFailureError as e: except aputils.SignatureFailureError as e:
logging.verbose('signature validation failed for "%s": %s', self.actor.id, e) logging.verbose('signature validation failed for "%s": %s', self.actor.id, e)
raise HttpError(401, str(e)) return Response.new_error(401, str(e), 'json')
return None
@register_route('/outbox') @register_route('/outbox')

View file

@ -2,15 +2,15 @@ import traceback
from aiohttp.web import Request, middleware from aiohttp.web import Request, middleware
from argon2.exceptions import VerifyMismatchError from argon2.exceptions import VerifyMismatchError
from blib import convert_to_boolean
from collections.abc import Awaitable, Callable, Sequence from collections.abc import Awaitable, Callable, Sequence
from typing import Any
from urllib.parse import urlparse from urllib.parse import urlparse
from .base import View, register_route from .base import View, register_route
from .. import __version__ from .. import __version__
from ..database import ConfigData, schema from ..database import ConfigData
from ..misc import HttpError, Message, Response, boolean from ..misc import Message, Response, boolean, get_app
DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob' DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob'
@ -22,8 +22,6 @@ ALLOWED_HEADERS: set[str] = {
PUBLIC_API_PATHS: Sequence[tuple[str, str]] = ( PUBLIC_API_PATHS: Sequence[tuple[str, str]] = (
('GET', '/api/v1/relay'), ('GET', '/api/v1/relay'),
('POST', '/api/v1/app'),
('POST', '/api/v1/login'),
('POST', '/api/v1/token') ('POST', '/api/v1/token')
) )
@ -39,181 +37,64 @@ def check_api_path(method: str, path: str) -> bool:
async def handle_api_path( async def handle_api_path(
request: Request, request: Request,
handler: Callable[[Request], Awaitable[Response]]) -> Response: handler: Callable[[Request], Awaitable[Response]]) -> Response:
try:
if (token := request.cookies.get('user-token')):
request['token'] = token
if not request.path.startswith('/api'): else:
return await handler(request) request['token'] = request.headers['Authorization'].replace('Bearer', '').strip()
with get_app().database.session() as conn:
request['user'] = conn.get_user_by_token(request['token'])
except (KeyError, ValueError):
request['token'] = None
request['user'] = None
if request.method != "OPTIONS" and check_api_path(request.method, request.path): if request.method != "OPTIONS" and check_api_path(request.method, request.path):
if request['token'] is None: if not request['token']:
raise HttpError(401, 'Missing token') return Response.new_error(401, 'Missing token', 'json')
if request['user'] is None: if not request['user']:
raise HttpError(401, 'Invalid token') return Response.new_error(401, 'Invalid token', 'json')
response = await handler(request) response = await handler(request)
if request.path.startswith('/api'):
response.headers['Access-Control-Allow-Origin'] = '*' response.headers['Access-Control-Allow-Origin'] = '*'
response.headers['Access-Control-Allow-Headers'] = ', '.join(ALLOWED_HEADERS) response.headers['Access-Control-Allow-Headers'] = ', '.join(ALLOWED_HEADERS)
return response return response
@register_route('/oauth/authorize') @register_route('/api/v1/token')
class OauthAuthorize(View):
async def get(self, request: Request) -> Response:
data = await self.get_api_data(['response_type', 'client_id', 'redirect_uri'], [])
if data['response_type'] != 'code':
raise HttpError(400, 'Response type is not "code"')
with self.database.session(True) as conn:
with conn.select('app', client_id = data['client_id']) as cur:
if (app := cur.one(schema.App)) is None:
raise HttpError(404, 'Could not find app')
if app.token is not None or app.auth_code is not None:
context = {'application': app}
html = self.template.render(
'page/authorize_show.haml', self.request, **context
)
return Response.new(html, ctype = 'html')
if data['redirect_uri'] != app.redirect_uri:
raise HttpError(400, 'redirect_uri does not match application')
context = {'application': app}
html = self.template.render('page/authorize_new.haml', self.request, **context)
return Response.new(html, ctype = 'html')
async def post(self, request: Request) -> Response:
data = await self.get_api_data(
['client_id', 'client_secret', 'redirect_uri', 'response'], []
)
with self.database.session(True) as conn:
if (app := conn.get_app(data['client_id'], data['client_secret'])) is None:
return Response.new_error(404, 'Could not find app', 'json')
if convert_to_boolean(data['response']):
if app.auth_code is None:
app = conn.update_app(app, request['user'], True)
if app.redirect_uri == DEFAULT_REDIRECT:
context = {'application': app}
html = self.template.render(
'page/authorize_show.haml', self.request, **context
)
return Response.new(html, ctype = 'html')
return Response.new_redir(f'{app.redirect_uri}?code={app.auth_code}')
if not conn.del_app(app.client_id, app.client_secret):
raise HttpError(404, 'App not found')
return Response.new_redir('/')
@register_route('/oauth/token')
class OauthToken(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(
['grant_type', 'code', 'client_id', 'client_secret', 'redirect_uri'], []
)
if data['grant_type'] != 'authorization_code':
raise HttpError(400, 'Invalid grant type')
with self.database.session(True) as conn:
if (app := conn.get_app(data['client_id'], data['client_secret'])) is None:
raise HttpError(404, 'Application not found')
if app.auth_code != data['code']:
raise HttpError(400, 'Invalid authentication code')
if app.redirect_uri != data['redirect_uri']:
raise HttpError(400, 'Invalid redirect uri')
app = conn.update_app(app, request['user'], False)
return Response.new(app.get_api_data(True), ctype = 'json')
@register_route('/oauth/revoke')
class OauthRevoke(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['client_id', 'client_secret', 'token'], [])
with self.database.session(True) as conn:
if (app := conn.get_app(**data)) is None:
raise HttpError(404, 'Could not find token')
if app.user != request['token'].username:
raise HttpError(403, 'Invalid token')
if not conn.del_app(**data):
raise HttpError(400, 'Token not removed')
return Response.new({'msg': 'Token deleted'}, ctype = 'json')
@register_route('/api/v1/app')
class App(View):
async def get(self, request: Request) -> Response:
data = await self.get_api_data(['client_id', 'client_secret'], [])
with self.database.session(False) as conn:
if (app := conn.get_app(data['client_id'], data['client_secret'])) is None:
raise HttpError(404, 'Application cannot be found')
return Response.new(app.get_api_data(), ctype = 'json')
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['name', 'redirect_uri'], ['website'])
with self.database.session(True) as conn:
app = conn.put_app(
name = data['name'],
redirect_uri = data['redirect_uri'],
website = data.get('website')
)
return Response.new(app.get_api_data(), ctype = 'json')
async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['client_id', 'client_secret'], [])
with self.database.session(True) as conn:
if not conn.del_app(data['client_id'], data['client_secret'], request['token'].code):
raise HttpError(400, 'Token not removed')
return Response.new({'msg': 'Token deleted'}, ctype = 'json')
@register_route('/api/v1/login')
class Login(View): class Login(View):
async def get(self, request: Request) -> Response:
return Response.new({'message': 'Token valid'}, ctype = 'json')
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
data = await self.get_api_data(['username', 'password'], []) data = await self.get_api_data(['username', 'password'], [])
if isinstance(data, Response):
return data
with self.database.session(True) as conn: with self.database.session(True) as conn:
if not (user := conn.get_user(data['username'])): if not (user := conn.get_user(data['username'])):
raise HttpError(401, 'User not found') return Response.new_error(401, 'User not found', 'json')
try: try:
conn.hasher.verify(user['hash'], data['password']) conn.hasher.verify(user['hash'], data['password'])
except VerifyMismatchError: except VerifyMismatchError:
raise HttpError(401, 'Invalid password') return Response.new_error(401, 'Invalid password', 'json')
app = conn.put_app_login(user) token = conn.put_token(data['username'])
resp = Response.new({'token': app.token}, ctype = 'json') resp = Response.new({'token': token.code}, ctype = 'json')
resp.set_cookie( resp.set_cookie(
'user-token', 'user-token',
app.token, # type: ignore[arg-type] token.code,
max_age = 60 * 60 * 24 * 365, max_age = 60 * 60 * 24 * 365,
domain = self.config.domain, domain = self.config.domain,
path = '/', path = '/',
@ -225,6 +106,13 @@ class Login(View):
return resp return resp
async def delete(self, request: Request) -> Response:
with self.database.session() as conn:
conn.del_token(request['token'])
return Response.new({'message': 'Token revoked'}, ctype = 'json')
@register_route('/api/v1/relay') @register_route('/api/v1/relay')
class RelayInfo(View): class RelayInfo(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
@ -267,10 +155,14 @@ class Config(View):
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
data = await self.get_api_data(['key', 'value'], []) data = await self.get_api_data(['key', 'value'], [])
if isinstance(data, Response):
return data
data['key'] = data['key'].replace('-', '_') data['key'] = data['key'].replace('-', '_')
if data['key'] not in ConfigData.USER_KEYS(): if data['key'] not in ConfigData.USER_KEYS():
raise HttpError(400, 'Invalid key') return Response.new_error(400, 'Invalid key', 'json')
with self.database.session() as conn: with self.database.session() as conn:
conn.put_config(data['key'], data['value']) conn.put_config(data['key'], data['value'])
@ -281,8 +173,11 @@ class Config(View):
async def delete(self, request: Request) -> Response: async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['key'], []) data = await self.get_api_data(['key'], [])
if isinstance(data, Response):
return data
if data['key'] not in ConfigData.USER_KEYS(): if data['key'] not in ConfigData.USER_KEYS():
raise HttpError(400, 'Invalid key') return Response.new_error(400, 'Invalid key', 'json')
with self.database.session() as conn: with self.database.session() as conn:
conn.put_config(data['key'], ConfigData.DEFAULT(data['key'])) conn.put_config(data['key'], ConfigData.DEFAULT(data['key']))
@ -301,11 +196,15 @@ class Inbox(View):
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
data = await self.get_api_data(['actor'], ['inbox', 'software', 'followid']) data = await self.get_api_data(['actor'], ['inbox', 'software', 'followid'])
if isinstance(data, Response):
return data
data['domain'] = urlparse(data["actor"]).netloc data['domain'] = urlparse(data["actor"]).netloc
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_inbox(data['domain']) is not None: if conn.get_inbox(data['domain']) is not None:
raise HttpError(404, 'Instance already in database') return Response.new_error(404, 'Instance already in database', 'json')
data['domain'] = data['domain'].encode('idna').decode() data['domain'] = data['domain'].encode('idna').decode()
@ -315,7 +214,7 @@ class Inbox(View):
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
raise HttpError(500, 'Failed to fetch actor') return Response.new_error(500, 'Failed to fetch actor', 'json')
data['inbox'] = actor_data.shared_inbox data['inbox'] = actor_data.shared_inbox
@ -341,10 +240,14 @@ class Inbox(View):
async def patch(self, request: Request) -> Response: async def patch(self, request: Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
data = await self.get_api_data(['domain'], ['actor', 'software', 'followid']) data = await self.get_api_data(['domain'], ['actor', 'software', 'followid'])
if isinstance(data, Response):
return data
data['domain'] = data['domain'].encode('idna').decode() data['domain'] = data['domain'].encode('idna').decode()
if (instance := conn.get_inbox(data['domain'])) is None: if (instance := conn.get_inbox(data['domain'])) is None:
raise HttpError(404, 'Instance with domain not found') return Response.new_error(404, 'Instance with domain not found', 'json')
instance = conn.put_inbox( instance = conn.put_inbox(
instance.domain, instance.domain,
@ -359,10 +262,14 @@ class Inbox(View):
async def delete(self, request: Request) -> Response: async def delete(self, request: Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
data = await self.get_api_data(['domain'], []) data = await self.get_api_data(['domain'], [])
if isinstance(data, Response):
return data
data['domain'] = data['domain'].encode('idna').decode() data['domain'] = data['domain'].encode('idna').decode()
if not conn.get_inbox(data['domain']): if not conn.get_inbox(data['domain']):
raise HttpError(404, 'Instance with domain not found') return Response.new_error(404, 'Instance with domain not found', 'json')
conn.del_inbox(data['domain']) conn.del_inbox(data['domain'])
@ -379,21 +286,26 @@ class RequestView(View):
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
data = await self.get_api_data(['domain', 'accept'], []) data: dict[str, Any] | Response = await self.get_api_data(['domain', 'accept'], [])
if isinstance(data, Response):
return data
data['accept'] = boolean(data['accept'])
data['domain'] = data['domain'].encode('idna').decode() data['domain'] = data['domain'].encode('idna').decode()
try: try:
with self.database.session(True) as conn: with self.database.session(True) as conn:
instance = conn.put_request_response(data['domain'], boolean(data['accept'])) instance = conn.put_request_response(data['domain'], data['accept'])
except KeyError: except KeyError:
raise HttpError(404, 'Request not found') return Response.new_error(404, 'Request not found', 'json')
message = Message.new_response( message = Message.new_response(
host = self.config.domain, host = self.config.domain,
actor = instance.actor, actor = instance.actor,
followid = instance.followid, followid = instance.followid,
accept = boolean(data['accept']) accept = data['accept']
) )
self.app.push_message(instance.inbox, message, instance) self.app.push_message(instance.inbox, message, instance)
@ -421,11 +333,15 @@ class DomainBan(View):
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
data = await self.get_api_data(['domain'], ['note', 'reason']) data = await self.get_api_data(['domain'], ['note', 'reason'])
if isinstance(data, Response):
return data
data['domain'] = data['domain'].encode('idna').decode() data['domain'] = data['domain'].encode('idna').decode()
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_domain_ban(data['domain']) is not None: if conn.get_domain_ban(data['domain']) is not None:
raise HttpError(400, 'Domain already banned') return Response.new_error(400, 'Domain already banned', 'json')
ban = conn.put_domain_ban( ban = conn.put_domain_ban(
domain = data['domain'], domain = data['domain'],
@ -440,13 +356,16 @@ class DomainBan(View):
with self.database.session() as conn: with self.database.session() as conn:
data = await self.get_api_data(['domain'], ['note', 'reason']) data = await self.get_api_data(['domain'], ['note', 'reason'])
if isinstance(data, Response):
return data
if not any([data.get('note'), data.get('reason')]): if not any([data.get('note'), data.get('reason')]):
raise HttpError(400, 'Must include note and/or reason parameters') return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
data['domain'] = data['domain'].encode('idna').decode() data['domain'] = data['domain'].encode('idna').decode()
if conn.get_domain_ban(data['domain']) is None: if conn.get_domain_ban(data['domain']) is None:
raise HttpError(404, 'Domain not banned') return Response.new_error(404, 'Domain not banned', 'json')
ban = conn.update_domain_ban( ban = conn.update_domain_ban(
domain = data['domain'], domain = data['domain'],
@ -460,10 +379,14 @@ class DomainBan(View):
async def delete(self, request: Request) -> Response: async def delete(self, request: Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
data = await self.get_api_data(['domain'], []) data = await self.get_api_data(['domain'], [])
if isinstance(data, Response):
return data
data['domain'] = data['domain'].encode('idna').decode() data['domain'] = data['domain'].encode('idna').decode()
if conn.get_domain_ban(data['domain']) is None: if conn.get_domain_ban(data['domain']) is None:
raise HttpError(404, 'Domain not banned') return Response.new_error(404, 'Domain not banned', 'json')
conn.del_domain_ban(data['domain']) conn.del_domain_ban(data['domain'])
@ -482,9 +405,12 @@ class SoftwareBan(View):
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
data = await self.get_api_data(['name'], ['note', 'reason']) data = await self.get_api_data(['name'], ['note', 'reason'])
if isinstance(data, Response):
return data
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_software_ban(data['name']) is not None: if conn.get_software_ban(data['name']) is not None:
raise HttpError(400, 'Domain already banned') return Response.new_error(400, 'Domain already banned', 'json')
ban = conn.put_software_ban( ban = conn.put_software_ban(
name = data['name'], name = data['name'],
@ -498,12 +424,15 @@ class SoftwareBan(View):
async def patch(self, request: Request) -> Response: async def patch(self, request: Request) -> Response:
data = await self.get_api_data(['name'], ['note', 'reason']) data = await self.get_api_data(['name'], ['note', 'reason'])
if isinstance(data, Response):
return data
if not any([data.get('note'), data.get('reason')]): if not any([data.get('note'), data.get('reason')]):
raise HttpError(400, 'Must include note and/or reason parameters') return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_software_ban(data['name']) is None: if conn.get_software_ban(data['name']) is None:
raise HttpError(404, 'Software not banned') return Response.new_error(404, 'Software not banned', 'json')
ban = conn.update_software_ban( ban = conn.update_software_ban(
name = data['name'], name = data['name'],
@ -517,9 +446,12 @@ class SoftwareBan(View):
async def delete(self, request: Request) -> Response: async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['name'], []) data = await self.get_api_data(['name'], [])
if isinstance(data, Response):
return data
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_software_ban(data['name']) is None: if conn.get_software_ban(data['name']) is None:
raise HttpError(404, 'Software not banned') return Response.new_error(404, 'Software not banned', 'json')
conn.del_software_ban(data['name']) conn.del_software_ban(data['name'])
@ -542,9 +474,12 @@ class User(View):
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
data = await self.get_api_data(['username', 'password'], ['handle']) data = await self.get_api_data(['username', 'password'], ['handle'])
if isinstance(data, Response):
return data
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_user(data['username']) is not None: if conn.get_user(data['username']) is not None:
raise HttpError(404, 'User already exists') return Response.new_error(404, 'User already exists', 'json')
user = conn.put_user( user = conn.put_user(
username = data['username'], username = data['username'],
@ -559,6 +494,9 @@ class User(View):
async def patch(self, request: Request) -> Response: async def patch(self, request: Request) -> Response:
data = await self.get_api_data(['username'], ['password', 'handle']) data = await self.get_api_data(['username'], ['password', 'handle'])
if isinstance(data, Response):
return data
with self.database.session(True) as conn: with self.database.session(True) as conn:
user = conn.put_user( user = conn.put_user(
username = data['username'], username = data['username'],
@ -573,9 +511,12 @@ class User(View):
async def delete(self, request: Request) -> Response: async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['username'], []) data = await self.get_api_data(['username'], [])
if isinstance(data, Response):
return data
with self.database.session(True) as conn: with self.database.session(True) as conn:
if conn.get_user(data['username']) is None: if conn.get_user(data['username']) is None:
raise HttpError(404, 'User does not exist') return Response.new_error(404, 'User does not exist', 'json')
conn.del_user(data['username']) conn.del_user(data['username'])
@ -594,11 +535,14 @@ class Whitelist(View):
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
data = await self.get_api_data(['domain'], []) data = await self.get_api_data(['domain'], [])
if isinstance(data, Response):
return data
domain = data['domain'].encode('idna').decode() domain = data['domain'].encode('idna').decode()
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_domain_whitelist(domain) is not None: if conn.get_domain_whitelist(domain) is not None:
raise HttpError(400, 'Domain already added to whitelist') return Response.new_error(400, 'Domain already added to whitelist', 'json')
item = conn.put_domain_whitelist(domain) item = conn.put_domain_whitelist(domain)
@ -608,11 +552,14 @@ class Whitelist(View):
async def delete(self, request: Request) -> Response: async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['domain'], []) data = await self.get_api_data(['domain'], [])
if isinstance(data, Response):
return data
domain = data['domain'].encode('idna').decode() domain = data['domain'].encode('idna').decode()
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_domain_whitelist(domain) is None: if conn.get_domain_whitelist(domain) is None:
raise HttpError(404, 'Domain not in whitelist') return Response.new_error(404, 'Domain not in whitelist', 'json')
conn.del_domain_whitelist(domain) conn.del_domain_whitelist(domain)

View file

@ -1,8 +1,10 @@
from __future__ import annotations from __future__ import annotations
from Crypto.Random import get_random_bytes
from aiohttp.abc import AbstractView from aiohttp.abc import AbstractView
from aiohttp.hdrs import METH_ALL as METHODS from aiohttp.hdrs import METH_ALL as METHODS
from aiohttp.web import Request from aiohttp.web import HTTPMethodNotAllowed, Request
from base64 import b64encode
from bsql import Database from bsql import Database
from collections.abc import Awaitable, Callable, Generator, Sequence, Mapping from collections.abc import Awaitable, Callable, Generator, Sequence, Mapping
from functools import cached_property from functools import cached_property
@ -13,7 +15,7 @@ from ..cache import Cache
from ..config import Config from ..config import Config
from ..database import Connection from ..database import Connection
from ..http_client import HttpClient from ..http_client import HttpClient
from ..misc import HttpError, Response, get_app from ..misc import Response, get_app
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Self from typing import Self
@ -41,10 +43,10 @@ def register_route(*paths: str) -> Callable[[type[View]], type[View]]:
class View(AbstractView): class View(AbstractView):
def __await__(self) -> Generator[Any, None, Response]: def __await__(self) -> Generator[Any, None, Response]:
if self.request.method not in METHODS: if self.request.method not in METHODS:
raise HttpError(405, f'"{self.request.method}" method not allowed') raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
if not (handler := self.handlers.get(self.request.method)): if not (handler := self.handlers.get(self.request.method)):
raise HttpError(405, f'"{self.request.method}" method not allowed') raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
return self._run_handler(handler).__await__() return self._run_handler(handler).__await__()
@ -56,6 +58,7 @@ class View(AbstractView):
async def _run_handler(self, handler: HandlerCallback, **kwargs: Any) -> Response: async def _run_handler(self, handler: HandlerCallback, **kwargs: Any) -> Response:
self.request['hash'] = b64encode(get_random_bytes(16)).decode('ascii')
return await handler(self.request, **self.request.match_info, **kwargs) return await handler(self.request, **self.request.match_info, **kwargs)
@ -114,18 +117,17 @@ class View(AbstractView):
async def get_api_data(self, async def get_api_data(self,
required: list[str], required: list[str],
optional: list[str]) -> dict[str, str]: optional: list[str]) -> dict[str, str] | Response:
if self.request.content_type in {'application/x-www-form-urlencoded', 'multipart/form-data'}: if self.request.content_type in {'x-www-form-urlencoded', 'multipart/form-data'}:
post_data = convert_data(await self.request.post()) post_data = convert_data(await self.request.post())
# post_data = {key: value for key, value in parse_qsl(await self.request.text())}
elif self.request.content_type == 'application/json': elif self.request.content_type == 'application/json':
try: try:
post_data = convert_data(await self.request.json()) post_data = convert_data(await self.request.json())
except JSONDecodeError: except JSONDecodeError:
raise HttpError(400, 'Invalid JSON data') return Response.new_error(400, 'Invalid JSON data', 'json')
else: else:
post_data = convert_data(self.request.query) post_data = convert_data(self.request.query)
@ -137,9 +139,9 @@ class View(AbstractView):
data[key] = post_data[key] data[key] = post_data[key]
except KeyError as e: except KeyError as e:
raise HttpError(400, f'Missing {str(e)} pararmeter') return Response.new_error(400, f'Missing {str(e)} pararmeter', 'json')
for key in optional: for key in optional:
data[key] = post_data.get(key) # type: ignore[assignment] data[key] = post_data.get(key, '')
return data return data

View file

@ -1,13 +1,18 @@
from aiohttp import web from aiohttp import web
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import Any from typing import Any
from urllib.parse import unquote
from .base import View, register_route from .base import View, register_route
from ..database import THEMES from ..database import THEMES
from ..logger import LogLevel from ..logger import LogLevel
from ..misc import TOKEN_PATHS, Response from ..misc import Response, get_app
UNAUTH_ROUTES = {
'/',
'/login'
}
@web.middleware @web.middleware
@ -15,24 +20,27 @@ async def handle_frontend_path(
request: web.Request, request: web.Request,
handler: Callable[[web.Request], Awaitable[Response]]) -> Response: handler: Callable[[web.Request], Awaitable[Response]]) -> Response:
if request['user'] is not None and request.path == '/login': app = get_app()
return Response.new_redir('/')
if request.path.startswith(TOKEN_PATHS[:2]) and request['user'] is None: if request.path in UNAUTH_ROUTES or request.path.startswith('/admin'):
if request.path == '/logout': request['token'] = request.cookies.get('user-token')
return Response.new_redir('/') request['user'] = None
response = Response.new_redir(f'/login?redir={request.path}') if request['token']:
with app.database.session(False) as conn:
request['user'] = conn.get_user_by_token(request['token'])
if request['token'] is not None: if request['user'] and request.path == '/login':
return Response.new('', 302, {'Location': '/'})
if not request['user'] and request.path.startswith('/admin'):
response = Response.new('', 302, {'Location': f'/login?redir={request.path}'})
response.del_cookie('user-token') response.del_cookie('user-token')
return response return response
response = await handler(request) response = await handler(request)
if not request.path.startswith('/api'): if not request.path.startswith('/api') and not request['user'] and request['token']:
if request['user'] is None and request['token'] is not None:
response.del_cookie('user-token') response.del_cookie('user-token')
return response return response
@ -46,15 +54,14 @@ class HomeView(View):
'instances': tuple(conn.get_inboxes()) 'instances': tuple(conn.get_inboxes())
} }
data = self.template.render('page/home.haml', self.request, **context) data = self.template.render('page/home.haml', self, **context)
return Response.new(data, ctype='html') return Response.new(data, ctype='html')
@register_route('/login') @register_route('/login')
class Login(View): class Login(View):
async def get(self, request: web.Request) -> Response: async def get(self, request: web.Request) -> Response:
redir = unquote(request.query.get('redir', '/')) data = self.template.render('page/login.haml', self)
data = self.template.render('page/login.haml', self.request, redir = redir)
return Response.new(data, ctype = 'html') return Response.new(data, ctype = 'html')
@ -62,7 +69,7 @@ class Login(View):
class Logout(View): class Logout(View):
async def get(self, request: web.Request) -> Response: async def get(self, request: web.Request) -> Response:
with self.database.session(True) as conn: with self.database.session(True) as conn:
conn.del_app(request['token'].client_id, request['token'].client_secret) conn.del_token(request['token'])
resp = Response.new_redir('/') resp = Response.new_redir('/')
resp.del_cookie('user-token', domain = self.config.domain, path = '/') resp.del_cookie('user-token', domain = self.config.domain, path = '/')
@ -72,7 +79,7 @@ class Logout(View):
@register_route('/admin') @register_route('/admin')
class Admin(View): class Admin(View):
async def get(self, request: web.Request) -> Response: async def get(self, request: web.Request) -> Response:
return Response.new_redir(f'/login?redir={request.path}', 301) return Response.new('', 302, {'Location': '/admin/instances'})
@register_route('/admin/instances') @register_route('/admin/instances')
@ -94,7 +101,7 @@ class AdminInstances(View):
if message: if message:
context['message'] = message context['message'] = message
data = self.template.render('page/admin-instances.haml', self.request, **context) data = self.template.render('page/admin-instances.haml', self, **context)
return Response.new(data, ctype = 'html') return Response.new(data, ctype = 'html')
@ -116,7 +123,7 @@ class AdminWhitelist(View):
if message: if message:
context['message'] = message context['message'] = message
data = self.template.render('page/admin-whitelist.haml', self.request, **context) data = self.template.render('page/admin-whitelist.haml', self, **context)
return Response.new(data, ctype = 'html') return Response.new(data, ctype = 'html')
@ -138,7 +145,7 @@ class AdminDomainBans(View):
if message: if message:
context['message'] = message context['message'] = message
data = self.template.render('page/admin-domain_bans.haml', self.request, **context) data = self.template.render('page/admin-domain_bans.haml', self, **context)
return Response.new(data, ctype = 'html') return Response.new(data, ctype = 'html')
@ -160,7 +167,7 @@ class AdminSoftwareBans(View):
if message: if message:
context['message'] = message context['message'] = message
data = self.template.render('page/admin-software_bans.haml', self.request, **context) data = self.template.render('page/admin-software_bans.haml', self, **context)
return Response.new(data, ctype = 'html') return Response.new(data, ctype = 'html')
@ -182,7 +189,7 @@ class AdminUsers(View):
if message: if message:
context['message'] = message context['message'] = message
data = self.template.render('page/admin-users.haml', self.request, **context) data = self.template.render('page/admin-users.haml', self, **context)
return Response.new(data, ctype = 'html') return Response.new(data, ctype = 'html')
@ -206,7 +213,7 @@ class AdminConfig(View):
} }
} }
data = self.template.render('page/admin-config.haml', self.request, **context) data = self.template.render('page/admin-config.haml', self, **context)
return Response.new(data, ctype = 'html') return Response.new(data, ctype = 'html')
@ -244,5 +251,5 @@ class ThemeCss(View):
except KeyError: except KeyError:
return Response.new('Invalid theme', 404) return Response.new('Invalid theme', 404)
data = self.template.render('variables.css', self.request, **context) data = self.template.render('variables.css', self, **context)
return Response.new(data, ctype = 'css') return Response.new(data, ctype = 'css')