various changes

* Add oauth login support
* Add `HttpError` class
* Add custom error handling
* Use `blib.Date` class for (de)serializing db timestamp values
* Add `db-maintenance` command
* Rework middleware route checking
* Fix fetching post data in api endpoints
This commit is contained in:
Izalia Mae 2024-07-04 20:36:04 -04:00
parent b22b5bbefa
commit f98ca54ab7
19 changed files with 748 additions and 231 deletions

View file

@ -4,11 +4,14 @@ import asyncio
import multiprocessing
import signal
import time
import traceback
from Crypto.Random import get_random_bytes
from aiohttp import web
from aiohttp.web import StaticResource
from aiohttp.web import HTTPException, StaticResource
from aiohttp_swagger import setup_swagger
from aputils.signer import Signer
from base64 import b64encode
from bsql import Database
from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta
@ -23,7 +26,8 @@ from .config import Config
from .database import Connection, get_database
from .database.schema import Instance
from .http_client import HttpClient
from .misc import Message, Response, check_open_port, get_resource
from .misc import HttpError, Message, Response, check_open_port, get_resource
from .misc import JSON_PATHS, TOKEN_PATHS
from .template import Template
from .views import VIEWS
from .views.api import handle_api_path
@ -53,9 +57,9 @@ class Application(web.Application):
def __init__(self, cfgpath: Path | None, dev: bool = False):
web.Application.__init__(self,
middlewares = [
handle_api_path, # type: ignore[list-item]
handle_response_headers, # type: ignore[list-item]
handle_frontend_path, # type: ignore[list-item]
handle_response_headers # type: ignore[list-item]
handle_api_path # type: ignore[list-item]
]
)
@ -282,19 +286,70 @@ class CacheCleanupThread(Thread):
self.running.clear()
def format_error(request: web.Request, error: HttpError) -> Response:
app: Application = request.app # type: ignore[assignment]
if request.path.startswith(JSON_PATHS) or 'json' in request.headers.get('accept', ''):
return Response.new({'error': error.body}, error.status, ctype = 'json')
else:
body = app.template.render('page/error.haml', request, e = error)
return Response.new(body, error.status, ctype = 'html')
@web.middleware
async def handle_response_headers(
request: web.Request,
handler: Callable[[web.Request], Awaitable[Response]]) -> Response:
resp = await handler(request)
request['hash'] = b64encode(get_random_bytes(16)).decode('ascii')
request['token'] = None
request['user'] = None
app: Application = request.app # type: ignore[assignment]
if request.path == "/" or request.path.startswith(TOKEN_PATHS):
with app.database.session() as conn:
if (token := request.headers.get('Authorization')) is not None:
token = token.replace('Bearer', '').strip()
request['token'] = conn.get_app_by_token(token)
request['user'] = conn.get_user_by_app_token(token)
elif (token := request.cookies.get('user-token')) is not None:
request['token'] = conn.get_token(token)
request['user'] = conn.get_user_by_token(token)
try:
resp = await handler(request)
except HttpError as e:
resp = format_error(request, e)
except HTTPException as ae:
if ae.status == 404:
try:
text = (ae.text or "").split(":")[1].strip()
except IndexError:
text = ae.text or ""
resp = format_error(request, HttpError(ae.status, text))
else:
raise
except Exception as e:
resp = format_error(request, HttpError(500, f'{type(e).__name__}: {str(e)}'))
traceback.print_exc()
resp.headers['Server'] = 'ActivityRelay'
# Still have to figure out how csp headers work
if resp.content_type == 'text/html' and not request.path.startswith("/api"):
if resp.content_type == 'text/html':
resp.headers['Content-Security-Policy'] = get_csp(request)
if not request.app['dev'] and request.path.endswith(('.css', '.js')):
if not request.app['dev'] and request.path.endswith(('.css', '.js', '.woff2')):
# cache for 2 weeks
resp.headers['Cache-Control'] = 'public,max-age=1209600,immutable'

View file

@ -56,6 +56,14 @@ WHERE username = (
);
-- name: get-user-by-app-token
SELECT * FROM users
WHERE username = (
SELECT user FROM app
WHERE code = :code
);
-- name: put-user
INSERT INTO users (username, hash, handle, created)
VALUES (:username, :hash, :handle, :created)
@ -67,6 +75,30 @@ DELETE FROM users
WHERE username = :value or handle = :value;
-- name: get-app
SELECT * FROM app
WHERE client_id = :id and client_secret = :secret;
-- name: get-app-token
SELECT * FROM app
WHERE client_id = :id and client_secret = :secret and token = :token;
-- name: get-app-by-token
SELECT * FROM app
WHERE token = :token;
-- name: del-app
DELETE FROM users
WHERE client_id = :id and client_secret = :secret;
-- name: del-app-token
DELETE FROM users
WHERE client_id = :id and client_secret = :secret and token = :token;
-- name: get-token
SELECT * FROM tokens
WHERE code = :code;

View file

@ -11,11 +11,7 @@ from .. import logger as logging
from ..misc import boolean
if TYPE_CHECKING:
try:
from typing import Self
except ImportError:
from typing_extensions import Self
from typing import Self
THEMES = {
@ -77,7 +73,7 @@ CONFIG_CONVERT: dict[str, tuple[Callable[[Any], str], Callable[[str], Any]]] = {
@dataclass()
class ConfigData:
schema_version: int = 20240310
schema_version: int = 20240625
private_key: str = ''
approval_required: bool = False
log_level: logging.LogLevel = logging.LogLevel.INFO
@ -115,11 +111,11 @@ class ConfigData:
@classmethod
def DEFAULT(cls: type[Self], key: str) -> str | int | bool:
return cls.FIELD(key.replace('-', '_')).default # type: ignore
return cls.FIELD(key.replace('-', '_')).default # type: ignore[return-value]
@classmethod
def FIELD(cls: type[Self], key: str) -> Field[Any]:
def FIELD(cls: type[Self], key: str) -> Field[str | int | bool]:
for field in fields(cls):
if field.name == key.replace('-', '_'):
return field

View file

@ -1,6 +1,9 @@
from __future__ import annotations
import secrets
from argon2 import PasswordHasher
from blib import Date
from bsql import Connection as SqlConnection, Row, Update
from collections.abc import Iterator
from datetime import datetime, timezone
@ -49,6 +52,40 @@ class Connection(SqlConnection):
yield instance
def fix_timestamps(self) -> None:
for app in self.select('apps').all(schema.App):
data = {'created': app.created.timestamp(), 'accessed': app.accessed.timestamp()}
self.update('apps', data, client_id = app.client_id)
for item in self.select('cache'):
data = {'updated': Date.parse(item['updated']).timestamp()}
self.update('cache', data, id = item['id'])
for dban in self.select('domain_bans').all(schema.DomainBan):
data = {'created': dban.created.timestamp()}
self.update('domain_bans', data, domain = dban.domain)
for instance in self.select('inboxes').all(schema.Instance):
data = {'created': instance.created.timestamp()}
self.update('inboxes', data, domain = instance.domain)
for sban in self.select('software_bans').all(schema.SoftwareBan):
data = {'created': sban.created.timestamp()}
self.update('software_bans', data, name = sban.name)
for token in self.select('tokens').all(schema.Token):
data = {'created': token.created.timestamp(), 'accessed': token.accessed.timestamp()}
self.update('tokens', data, code = token.code)
for user in self.select('users').all(schema.User):
data = {'created': user.created.timestamp()}
self.update('users', data, username = user.username)
for wlist in self.select('whitelist').all(schema.Whitelist):
data = {'created': wlist.created.timestamp()}
self.update('whitelist', data, domain = wlist.domain)
def get_config(self, key: str) -> Any:
key = key.replace('_', '-')
@ -198,6 +235,11 @@ class Connection(SqlConnection):
return cur.one(schema.User)
def get_user_by_app_token(self, code: str) -> schema.User | None:
with self.run('get-user-by-app-token', {'code': code}) as cur:
return cur.one(schema.User)
def get_users(self) -> Iterator[schema.User]:
return self.execute("SELECT * FROM users").all(schema.User)
@ -249,13 +291,102 @@ class Connection(SqlConnection):
pass
def get_app(self,
client_id: str,
client_secret: str,
token: str | None = None) -> schema.App | None:
params = {
'id': client_id,
'secret': client_secret
}
if token is not None:
command = 'get-app-with-token'
params['token'] = token
else:
command = 'get-app'
with self.run(command, params) as cur:
return cur.one(schema.App)
def get_app_by_token(self, token: str) -> schema.App | None:
with self.run('get-app-by-token', {'token': token}) as cur:
return cur.one(schema.App)
def put_app(self, name: str, redirect_uri: str, website: str | None = None) -> schema.App:
params = {
'name': name,
'redirect_uri': redirect_uri,
'website': website,
'client_id': secrets.token_hex(20),
'client_secret': secrets.token_hex(20),
'created': Date.new_utc().timestamp(),
'accessed': Date.new_utc().timestamp()
}
with self.insert('app', params) as cur:
if (row := cur.one(schema.App)) is None:
raise RuntimeError(f'Failed to insert app: {name}')
return row
def update_app(self, app: schema.App, user: schema.User | None, set_auth: bool) -> schema.App:
data: dict[str, str | None] = {}
if user is not None:
data['user'] = user.username
if set_auth:
data['auth_code'] = secrets.token_hex(20)
else:
data['token'] = secrets.token_hex(20)
data['auth_code'] = None
params = {
'client_id': app.client_id,
'client_secret': app.client_secret
}
with self.update('app', data, **params) as cur: # type: ignore[arg-type]
if (row := cur.one(schema.App)) is None:
raise RuntimeError('Failed to update row')
return row
def del_app(self, client_id: str, client_secret: str, token: str | None = None) -> bool:
params = {
'id': client_id,
'secret': client_secret
}
if token is not None:
command = 'del-app-token'
params['token'] = token
else:
command = 'del-app'
with self.run(command, params) as cur:
if cur.row_count > 1:
raise RuntimeError('More than 1 row was deleted')
return cur.row_count == 0
def get_token(self, code: str) -> schema.Token | None:
with self.run('get-token', {'code': code}) as cur:
return cur.one(schema.Token)
def get_tokens(self, username: str | None = None) -> Iterator[schema.Token]:
if username is not None:
if username is None:
return self.select('tokens').all(schema.Token)
return self.select('tokens', username = username).all(schema.Token)

View file

@ -1,14 +1,14 @@
from __future__ import annotations
import typing
from blib import Date
from bsql import Column, Row, Tables
from collections.abc import Callable
from datetime import datetime
from copy import deepcopy
from typing import TYPE_CHECKING, Any
from .config import ConfigData
if typing.TYPE_CHECKING:
if TYPE_CHECKING:
from .connection import Connection
@ -16,6 +16,16 @@ VERSIONS: dict[int, Callable[[Connection], None]] = {}
TABLES = Tables()
def deserialize_timestamp(value: Any) -> Date:
try:
return Date.parse(value)
except ValueError:
pass
return Date.fromisoformat(value)
@TABLES.add_row
class Config(Row):
key: Column[str] = Column('key', 'text', primary_key = True, unique = True, nullable = False)
@ -27,62 +37,125 @@ class Config(Row):
class Instance(Row):
table_name: str = 'inboxes'
domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = False)
actor: Column[str] = Column('actor', 'text', unique = True)
inbox: Column[str] = Column('inbox', 'text', unique = True, nullable = False)
followid: Column[str] = Column('followid', 'text')
software: Column[str] = Column('software', 'text')
accepted: Column[datetime] = Column('accepted', 'boolean')
created: Column[datetime] = Column('created', 'timestamp', nullable = False)
accepted: Column[Date] = Column('accepted', 'boolean')
created: Column[Date] = Column(
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
@TABLES.add_row
class Whitelist(Row):
domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = True)
created: Column[datetime] = Column('created', 'timestamp')
created: Column[Date] = Column(
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
@TABLES.add_row
class DomainBan(Row):
table_name: str = 'domain_bans'
domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text')
created: Column[datetime] = Column('created', 'timestamp')
created: Column[Date] = Column(
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
@TABLES.add_row
class SoftwareBan(Row):
table_name: str = 'software_bans'
name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text')
created: Column[datetime] = Column('created', 'timestamp')
created: Column[Date] = Column(
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
@TABLES.add_row
class User(Row):
table_name: str = 'users'
username: Column[str] = Column(
'username', 'text', primary_key = True, unique = True, nullable = False)
hash: Column[str] = Column('hash', 'text', nullable = False)
handle: Column[str] = Column('handle', 'text')
created: Column[datetime] = Column('created', 'timestamp')
created: Column[Date] = Column(
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
@TABLES.add_row
class Token(Row):
table_name: str = 'tokens'
code: Column[str] = Column('code', 'text', primary_key = True, unique = True, nullable = False)
user: Column[str] = Column('user', 'text', nullable = False)
created: Column[datetime] = Column('created', 'timestamp')
created: Column[Date] = Column(
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
accessed: Column[Date] = Column(
'accessed', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
@TABLES.add_row
class App(Row):
table_name: str = 'apps'
client_id: Column[str] = Column(
'client_id', 'text', primary_key = True, unique = True, nullable = False)
client_secret: Column[str] = Column('client_secret', 'text', nullable = False)
name: Column[str] = Column('name', 'text')
website: Column[str] = Column('website', 'text')
redirect_uri: Column[str] = Column('redirect_uri', 'text', nullable = False)
token: Column[str | None] = Column('token', 'text')
auth_code: Column[str | None] = Column('auth_code', 'text')
user: Column[str | None] = Column('user', 'text')
created: Column[Date] = Column(
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
accessed: Column[Date] = Column(
'accessed', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
def get_api_data(self, include_token: bool = False) -> dict[str, Any]:
data = deepcopy(self)
data.pop('auth_code')
data.pop('created')
data.pop('accessed')
if not include_token:
data.pop('token')
return data
def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]:
@ -103,5 +176,15 @@ def migrate_20240206(conn: Connection) -> None:
@migration
def migrate_20240310(conn: Connection) -> None:
conn.execute("ALTER TABLE inboxes ADD COLUMN accepted BOOLEAN")
conn.execute("UPDATE inboxes SET accepted = 1")
conn.execute('ALTER TABLE "inboxes" ADD COLUMN "accepted" BOOLEAN')
conn.execute('UPDATE "inboxes" SET accepted = 1')
@migration
def migrate_20240625(conn: Connection) -> None:
conn.execute('ALTER TABLE "tokens" ADD "accessed" timestamp')
for token in conn.get_tokens():
conn.update('tokens', {'accessed': token.created}, code = token.code).one()
conn.create_tables()

View file

@ -1,5 +1,5 @@
-macro menu_item(name, path)
-if view.request.path == path or (path != "/" and view.request.path.startswith(path))
-if request.path == path or (path != "/" and request.path.startswith(path))
%a.button(href="{{path}}" active="true") -> =name
-else
@ -10,12 +10,12 @@
%head
%title << {{config.name}}: {{page}}
%meta(charset="UTF-8")
%meta(name="viewport" content="width=device-width, initial-scale=1")
%link(rel="stylesheet" type="text/css" href="/theme/{{config.theme}}.css?{{version}}" nonce="{{view.request['hash']}}" class="theme")
%link(rel="stylesheet" type="text/css" href="/static/style.css?{{version}}" nonce="{{view.request['hash']}}")
%link(rel="stylesheet" type="text/css" href="/static/bootstrap-icons.css?{{version}}" nonce="{{view.request['hash']}}")
%meta(name="ort" content="width=device-width, initial-scale=1")
%link(rel="stylesheet" type="text/css" href="/theme/{{config.theme}}.css?{{version}}" nonce="{{request['hash']}}" class="theme")
%link(rel="stylesheet" type="text/css" href="/static/style.css?{{version}}" nonce="{{request['hash']}}")
%link(rel="stylesheet" type="text/css" href="/static/bootstrap-icons.css?{{version}}" nonce="{{request['hash']}}")
%link(rel="manifest" href="/manifest.json?{{version}}")
%script(type="application/javascript" src="/static/functions.js?{{version}}" nonce="{{view.request['hash']}}" defer)
%script(type="application/javascript" src="/static/functions.js?{{version}}" nonce="{{request['hash']}}" defer)
-block head
%body
@ -26,7 +26,7 @@
{{menu_item("Home", "/")}}
-if view.request["user"]
-if request["user"]
{{menu_item("Instances", "/admin/instances")}}
{{menu_item("Whitelist", "/admin/whitelist")}}
{{menu_item("Domain Bans", "/admin/domain_bans")}}
@ -61,11 +61,11 @@
#footer.section
.col1
-if not view.request["user"]
-if not request["user"]
%a(href="/login") << Login
-else
=view.request["user"]["username"]
=request["user"]["username"]
(
%a(href="/logout") << Logout
)

View file

@ -0,0 +1,31 @@
-extends "base.haml"
-set page="App Authorization"
-block content
%fieldset.section
%legend << App Authorization
-if application.website
#title << Application "<a href="{{application.website}}" target="_new">{{application.name}}</a>" wants full API access
-else
#title << Application "{{application.name}}" wants full API access
#buttons
.spacer
%form(action="/oauth/authorize" method="POST")
%input(type="hidden" name="client_id" value="{{application.client_id}}")
%input(type="hidden" name="client_secret" value="{{application.client_secret}}")
%input(type="hidden" name="redirect_uri" value="{{application.redirect_uri}}")
%input(type="hidden" name="response" value="true")
%input.button(type="submit" value="Allow")
%form(action="/oauth/authorize" method="POST")
%input(type="hidden" name="client_id" value="{{application.client_id}}")
%input(type="hidden" name="client_secret" value="{{application.client_secret}}")
%input(type="hidden" name="redirect_uri" value="{{application.redirect_uri}}")
%input(type="hidden" name="response" value="false")
%input.button(type="submit" value="Deny")
.spacer

View file

@ -0,0 +1,18 @@
-extends "base.haml"
-set page="App Authorization"
-block content
%fieldset.section
%legend << App Authorization Code
-if application.website
%p
Copy the following code into
%a(href="{{application.website}}" target="_main") -> %code -> =application.name
-else
%p
Copy the following code info
%code -> =application.name
%pre#code -> =application.auth_code

View file

@ -0,0 +1,7 @@
-extends "base.haml"
-set page="Error"
-block content
.section.error
.title << HTTP Error {{e.status}}
.body -> =e.body

View file

@ -12,4 +12,6 @@
%label(for="password") << Password
%input(id="password" name="password" placeholder="Password" type="password")
%input#redir(type="hidden" name="redir" value="{{redir}}")
%input.submit(type="button" value="Login")

View file

@ -483,13 +483,15 @@ function page_instance() {
function page_login() {
const fields = {
username: document.querySelector("#username"),
password: document.querySelector("#password")
}
password: document.querySelector("#password"),
redir: document.querySelector("#redir")
};
async function login(event) {
const values = {
username: fields.username.value.trim(),
password: fields.password.value.trim()
password: fields.password.value.trim(),
redir: fields.redir.value.trim()
}
if (values.username === "" | values.password === "") {
@ -498,14 +500,14 @@ function page_login() {
}
try {
await request("POST", "v1/token", values);
await request("POST", "v1/login", values);
} catch (error) {
toast(error);
return;
}
document.location = "/";
document.location = values.redir;
}
@ -848,9 +850,6 @@ if (location.pathname.startsWith("/admin/config")) {
} else if (location.pathname.startsWith("/admin/instances")) {
page_instance();
} else if (location.pathname.startsWith("/admin/login")) {
page_login();
} else if (location.pathname.startsWith("/admin/software_bans")) {
page_software_ban();
@ -859,4 +858,7 @@ if (location.pathname.startsWith("/admin/config")) {
} else if (location.pathname.startsWith("/admin/whitelist")) {
page_whitelist();
} else if (location.pathname.startsWith("/login")) {
page_login();
}

View file

@ -338,6 +338,44 @@ textarea {
}
/* error */
#content.page-error {
text-align: center;
}
#content.page-error .title {
font-size: 24px;
font-weight: bold;
}
/* auth */
#content.page-app_authorization {
text-align: center;
}
#content.page-app_authorization #code {
background: var(--background);
border: 1px solid var(--border);
font-size: 18px;
margin: 0 auto;
width: max-content;
padding: 5px;
}
#content.page-app_authorization #title {
font-size: 24px;
}
#content.page-app_authorization #buttons {
display: grid;
grid-template-columns: auto max-content max-content auto;
grid-gap: var(--spacing);
justify-items: center;
margin: var(--spacing) 0;
}
@keyframes show_toast {
0% {
transform: translateX(100%);

View file

@ -212,6 +212,21 @@ def cli_run(ctx: click.Context, dev: bool = False) -> None:
os._exit(0)
@cli.command('db-maintenance')
@click.option('--fix-timestamps', '-t', is_flag = True,
help = 'Make sure timestamps in the database are float values')
@click.pass_context
def cli_db_maintenance(ctx: click.Context, fix_timestamps: bool) -> None:
'Perform maintenance tasks on the database'
if fix_timestamps:
with ctx.obj.database.session(True) as conn:
conn.fix_timestamps()
with ctx.obj.database.session(False) as conn:
with conn.execute("VACUUM"):
pass
@cli.command('convert')
@click.option('--old-config', '-o', help = 'Path to the config file to convert from')
@ -239,18 +254,18 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
ctx.obj.config.set('domain', config['host'])
ctx.obj.config.save()
# fix: mypy complains about the types returned by click.progressbar when updating click to 8.1.7
with get_database(ctx.obj.config) as db:
with db.session(True) as conn:
conn.put_config('private-key', database['private-key'])
conn.put_config('note', config['note'])
conn.put_config('whitelist-enabled', config['whitelist_enabled'])
with click.progressbar( # type: ignore
with click.progressbar(
database['relay-list'].values(),
label = 'Inboxes'.ljust(15),
width = 0
) as inboxes:
for inbox in inboxes:
if inbox['software'] in {'akkoma', 'pleroma'}:
actor = f'https://{inbox["domain"]}/relay'
@ -269,7 +284,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
software = inbox['software']
)
with click.progressbar( # type: ignore
with click.progressbar(
config['blocked_software'],
label = 'Banned software'.ljust(15),
width = 0
@ -281,7 +296,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
reason = 'relay' if software in RELAY_SOFTWARE else None
)
with click.progressbar( # type: ignore
with click.progressbar(
config['blocked_instances'],
label = 'Banned domains'.ljust(15),
width = 0
@ -290,7 +305,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
for domain in banned_software:
conn.put_domain_ban(domain)
with click.progressbar( # type: ignore
with click.progressbar(
config['whitelist'],
label = 'Whitelist'.ljust(15),
width = 0

View file

@ -62,6 +62,28 @@ SOFTWARE = (
'gotosocial'
)
JSON_PATHS: tuple[str, ...] = (
'/api/v1',
'/actor',
'/inbox',
'/outbox',
'/following',
'/followers',
'/.well-known',
'/nodeinfo',
'/oauth/token',
'/oauth/revoke'
)
TOKEN_PATHS: tuple[str, ...] = (
'/api',
'/login',
'/logout',
'/oauth/authorize',
'/oauth/revoke',
'/admin'
)
def boolean(value: Any) -> bool:
if isinstance(value, str):
@ -113,6 +135,17 @@ def get_resource(path: str) -> Path:
return Path(str(pkgfiles('relay'))).joinpath(path)
class HttpError(Exception):
def __init__(self,
status: int,
body: str) -> None:
self.body: str = body
self.status: int = status
Exception.__init__(self, f"HTTP Error {status}: {body}")
class JsonEncoder(json.JSONEncoder):
def default(self, o: Any) -> str:
if isinstance(o, datetime):
@ -242,9 +275,9 @@ class Response(AiohttpResponse):
@classmethod
def new_redir(cls: type[Self], path: str) -> Self:
def new_redir(cls: type[Self], path: str, status: int = 307) -> Self:
body = f'Redirect to <a href="{path}">{path}</a>'
return cls.new(body, 302, {'Location': path})
return cls.new(body, status, {'Location': path}, ctype = 'html')
@property

View file

@ -2,6 +2,7 @@ from __future__ import annotations
import textwrap
from aiohttp.web import Request
from collections.abc import Callable
from hamlish_jinja import HamlishExtension
from jinja2 import Environment, FileSystemLoader
@ -13,7 +14,6 @@ from typing import TYPE_CHECKING, Any
from . import __version__
from .misc import get_resource
from .views.base import View
if TYPE_CHECKING:
from .application import Application
@ -43,12 +43,12 @@ class Template(Environment):
self.hamlish_mode = 'indented'
def render(self, path: str, view: View | None = None, **context: Any) -> str:
def render(self, path: str, request: Request, **context: Any) -> str:
with self.app.database.session(False) as conn:
config = conn.get_config_all()
new_context = {
'view': view,
'request': request,
'domain': self.app.config.domain,
'version': __version__,
'config': config,

View file

@ -7,7 +7,7 @@ from .base import View, register_route
from .. import logger as logging
from ..database import schema
from ..misc import Message, Response
from ..misc import HttpError, Message, Response
from ..processors import run_processor
@ -39,8 +39,7 @@ class ActorView(View):
async def post(self, request: Request) -> Response:
if response := await self.get_post_data():
return response
await self.get_post_data()
with self.database.session() as conn:
self.instance = conn.get_inbox(self.actor.shared_inbox) # type: ignore[assignment]
@ -65,13 +64,13 @@ class ActorView(View):
return Response.new(status = 202)
async def get_post_data(self) -> Response | None:
async def get_post_data(self) -> None:
try:
self.signature = aputils.Signature.parse(self.request.headers['signature'])
except KeyError:
logging.verbose('Missing signature header')
return Response.new_error(400, 'missing signature header', 'json')
raise HttpError(400, 'missing signature header')
try:
message: Message | None = await self.request.json(loads = Message.parse)
@ -79,17 +78,17 @@ class ActorView(View):
except Exception:
traceback.print_exc()
logging.verbose('Failed to parse inbox message')
return Response.new_error(400, 'failed to parse message', 'json')
raise HttpError(400, 'failed to parse message')
if message is None:
logging.verbose('empty message')
return Response.new_error(400, 'missing message', 'json')
raise HttpError(400, 'missing message')
self.message = message
if 'actor' not in self.message:
logging.verbose('actor not in message')
return Response.new_error(400, 'no actor in message', 'json')
raise HttpError(400, 'no actor in message')
try:
self.actor = await self.client.get(self.signature.keyid, True, Message)
@ -98,26 +97,24 @@ class ActorView(View):
# ld signatures aren't handled atm, so just ignore it
if self.message.type == 'Delete':
logging.verbose('Instance sent a delete which cannot be handled')
return Response.new(status=202)
raise HttpError(202, '')
logging.verbose(f'Failed to fetch actor: {self.signature.keyid}')
return Response.new_error(400, 'failed to fetch actor', 'json')
raise HttpError(400, 'failed to fetch actor')
try:
self.signer = self.actor.signer
except KeyError:
logging.verbose('Actor missing public key: %s', self.signature.keyid)
return Response.new_error(400, 'actor missing public key', 'json')
raise HttpError(400, 'actor missing public key')
try:
await self.signer.validate_request_async(self.request)
except aputils.SignatureFailureError as e:
logging.verbose('signature validation failed for "%s": %s', self.actor.id, e)
return Response.new_error(401, str(e), 'json')
return None
raise HttpError(401, str(e))
@register_route('/outbox')

View file

@ -1,16 +1,17 @@
import secrets
import traceback
from aiohttp.web import Request, middleware
from argon2.exceptions import VerifyMismatchError
from blib import convert_to_boolean
from collections.abc import Awaitable, Callable, Sequence
from typing import Any
from urllib.parse import urlparse
from .base import View, register_route
from .. import __version__
from ..database import ConfigData
from ..misc import Message, Response, boolean, get_app
from ..database import ConfigData, schema
from ..misc import HttpError, Message, Response, boolean
DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob'
@ -22,6 +23,8 @@ ALLOWED_HEADERS: set[str] = {
PUBLIC_API_PATHS: Sequence[tuple[str, str]] = (
('GET', '/api/v1/relay'),
('POST', '/api/v1/app'),
('POST', '/api/v1/login'),
('POST', '/api/v1/token')
)
@ -37,57 +40,174 @@ def check_api_path(method: str, path: str) -> bool:
async def handle_api_path(
request: Request,
handler: Callable[[Request], Awaitable[Response]]) -> Response:
try:
if (token := request.cookies.get('user-token')):
request['token'] = token
else:
request['token'] = request.headers['Authorization'].replace('Bearer', '').strip()
with get_app().database.session() as conn:
request['user'] = conn.get_user_by_token(request['token'])
except (KeyError, ValueError):
request['token'] = None
request['user'] = None
if not request.path.startswith('/api'):
return await handler(request)
if request.method != "OPTIONS" and check_api_path(request.method, request.path):
if not request['token']:
return Response.new_error(401, 'Missing token', 'json')
if request['token'] is None:
raise HttpError(401, 'Missing token')
if not request['user']:
return Response.new_error(401, 'Invalid token', 'json')
if request['user'] is None:
raise HttpError(401, 'Invalid token')
response = await handler(request)
if request.path.startswith('/api'):
response.headers['Access-Control-Allow-Origin'] = '*'
response.headers['Access-Control-Allow-Headers'] = ', '.join(ALLOWED_HEADERS)
response.headers['Access-Control-Allow-Origin'] = '*'
response.headers['Access-Control-Allow-Headers'] = ', '.join(ALLOWED_HEADERS)
return response
@register_route('/api/v1/token')
class Login(View):
@register_route('/oauth/authorize')
class OauthAuthorize(View):
async def get(self, request: Request) -> Response:
return Response.new({'message': 'Token valid'}, ctype = 'json')
data = await self.get_api_data(['response_type', 'client_id', 'redirect_uri'], [])
if data['response_type'] != 'code':
raise HttpError(400, 'Response type is not "code"')
with self.database.session(True) as conn:
with conn.select('app', client_id = data['client_id']) as cur:
if (app := cur.one(schema.App)) is None:
raise HttpError(404, 'Could not find app')
if app.token is not None or app.auth_code is not None:
context = {'application': app}
html = self.template.render(
'page/authorize_show.haml', self.request, **context
)
return Response.new(html, ctype = 'html')
if data['redirect_uri'] != app.redirect_uri:
raise HttpError(400, 'redirect_uri does not match application')
context = {'application': app}
html = self.template.render('page/authorize_new.haml', self.request, **context)
return Response.new(html, ctype = 'html')
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['username', 'password'], [])
data = await self.get_api_data(
['client_id', 'client_secret', 'redirect_uri', 'response'], []
)
if isinstance(data, Response):
return data
with self.database.session(True) as conn:
if (app := conn.get_app(data['client_id'], data['client_secret'])) is None:
return Response.new_error(404, 'Could not find app', 'json')
if convert_to_boolean(data['response']):
if app.auth_code is None:
app = conn.update_app(app, request['user'], True)
if app.redirect_uri == DEFAULT_REDIRECT:
context = {'application': app}
html = self.template.render(
'page/authorize_show.haml', self.request, **context
)
return Response.new(html, ctype = 'html')
return Response.new_redir(f'{app.redirect_uri}?code={app.auth_code}')
if not conn.del_app(app.client_id, app.client_secret):
raise HttpError(404, 'App not found')
return Response.new_redir('/')
@register_route('/oauth/token')
class OauthToken(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(
['grant_type', 'code', 'client_id', 'client_secret', 'redirect_uri'], []
)
if data['grant_type'] != 'authorization_code':
raise HttpError(400, 'Invalid grant type')
with self.database.session(True) as conn:
if (app := conn.get_app(data['client_id'], data['client_secret'])) is None:
raise HttpError(404, 'Application not found')
if app.auth_code != data['code']:
raise HttpError(400, 'Invalid authentication code')
if app.redirect_uri != data['redirect_uri']:
raise HttpError(400, 'Invalid redirect uri')
app = conn.update_app(app, request['user'], False)
return Response.new(app.get_api_data(True), ctype = 'json')
@register_route('/oauth/revoke')
class OauthRevoke(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['client_id', 'client_secret', 'token'], [])
with self.database.session(True) as conn:
if (app := conn.get_app(**data)) is None:
raise HttpError(404, 'Could not find token')
if app.user != request['token'].username:
raise HttpError(403, 'Invalid token')
if not conn.del_app(**data):
raise HttpError(400, 'Token not removed')
return Response.new({'msg': 'Token deleted'}, ctype = 'json')
@register_route('/api/v1/app')
class App(View):
async def get(self, request: Request) -> Response:
data = await self.get_api_data(['client_id', 'client_secret'], [])
with self.database.session(False) as conn:
if (app := conn.get_app(data['client_id'], data['client_secret'])) is None:
raise HttpError(404, 'Application cannot be found')
return Response.new(app.get_api_data(), ctype = 'json')
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['name', 'redirect_uri'], ['website'])
with self.database.session(True) as conn:
app = conn.put_app(
name = data['name'],
redirect_uri = data['redirect_uri'],
website = data.get('website')
)
return Response.new(app.get_api_data(), ctype = 'json')
async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['client_id', 'client_secret'], [])
with self.database.session(True) as conn:
if not conn.del_app(data['client_id'], data['client_secret'], request['token'].code):
raise HttpError(400, 'Token not removed')
return Response.new({'msg': 'Token deleted'}, ctype = 'json')
@register_route('/api/v1/login')
class Login(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['username', 'password'], [])
with self.database.session(True) as conn:
if not (user := conn.get_user(data['username'])):
return Response.new_error(401, 'User not found', 'json')
raise HttpError(401, 'User not found')
try:
conn.hasher.verify(user['hash'], data['password'])
except VerifyMismatchError:
return Response.new_error(401, 'Invalid password', 'json')
raise HttpError(401, 'Invalid password')
token = conn.put_token(data['username'])
@ -106,11 +226,36 @@ class Login(View):
return resp
async def delete(self, request: Request) -> Response:
with self.database.session() as conn:
conn.del_token(request['token'])
return Response.new({'message': 'Token revoked'}, ctype = 'json')
async def post2(self, request: Request) -> Response:
data = await self.get_api_data(['username', 'password'], [])
with self.database.session(True) as conn:
if not (user := conn.get_user(data['username'])):
raise HttpError(401, 'User not found')
try:
conn.hasher.verify(user['hash'], data['password'])
except VerifyMismatchError:
raise HttpError(401, 'Invalid password')
app = conn.put_app(
data['app_name'],
DEFAULT_REDIRECT,
data.get('website')
)
params = {
'code': secrets.token_hex(20),
'user': user.username
}
with conn.update('app', params, client_id = app.client_id) as cur:
if (row := cur.one(schema.App)) is None:
raise HttpError(500, 'Failed to create app')
return Response.new(row.get_api_data(True), ctype = 'json')
@register_route('/api/v1/relay')
@ -155,14 +300,10 @@ class Config(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['key', 'value'], [])
if isinstance(data, Response):
return data
data['key'] = data['key'].replace('-', '_')
if data['key'] not in ConfigData.USER_KEYS():
return Response.new_error(400, 'Invalid key', 'json')
raise HttpError(400, 'Invalid key')
with self.database.session() as conn:
conn.put_config(data['key'], data['value'])
@ -173,11 +314,8 @@ class Config(View):
async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['key'], [])
if isinstance(data, Response):
return data
if data['key'] not in ConfigData.USER_KEYS():
return Response.new_error(400, 'Invalid key', 'json')
raise HttpError(400, 'Invalid key')
with self.database.session() as conn:
conn.put_config(data['key'], ConfigData.DEFAULT(data['key']))
@ -196,15 +334,11 @@ class Inbox(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['actor'], ['inbox', 'software', 'followid'])
if isinstance(data, Response):
return data
data['domain'] = urlparse(data["actor"]).netloc
with self.database.session() as conn:
if conn.get_inbox(data['domain']) is not None:
return Response.new_error(404, 'Instance already in database', 'json')
raise HttpError(404, 'Instance already in database')
data['domain'] = data['domain'].encode('idna').decode()
@ -214,7 +348,7 @@ class Inbox(View):
except Exception:
traceback.print_exc()
return Response.new_error(500, 'Failed to fetch actor', 'json')
raise HttpError(500, 'Failed to fetch actor')
data['inbox'] = actor_data.shared_inbox
@ -240,14 +374,10 @@ class Inbox(View):
async def patch(self, request: Request) -> Response:
with self.database.session() as conn:
data = await self.get_api_data(['domain'], ['actor', 'software', 'followid'])
if isinstance(data, Response):
return data
data['domain'] = data['domain'].encode('idna').decode()
if (instance := conn.get_inbox(data['domain'])) is None:
return Response.new_error(404, 'Instance with domain not found', 'json')
raise HttpError(404, 'Instance with domain not found')
instance = conn.put_inbox(
instance.domain,
@ -262,14 +392,10 @@ class Inbox(View):
async def delete(self, request: Request) -> Response:
with self.database.session() as conn:
data = await self.get_api_data(['domain'], [])
if isinstance(data, Response):
return data
data['domain'] = data['domain'].encode('idna').decode()
if not conn.get_inbox(data['domain']):
return Response.new_error(404, 'Instance with domain not found', 'json')
raise HttpError(404, 'Instance with domain not found')
conn.del_inbox(data['domain'])
@ -286,26 +412,21 @@ class RequestView(View):
async def post(self, request: Request) -> Response:
data: dict[str, Any] | Response = await self.get_api_data(['domain', 'accept'], [])
if isinstance(data, Response):
return data
data['accept'] = boolean(data['accept'])
data = await self.get_api_data(['domain', 'accept'], [])
data['domain'] = data['domain'].encode('idna').decode()
try:
with self.database.session(True) as conn:
instance = conn.put_request_response(data['domain'], data['accept'])
instance = conn.put_request_response(data['domain'], boolean(data['accept']))
except KeyError:
return Response.new_error(404, 'Request not found', 'json')
raise HttpError(404, 'Request not found')
message = Message.new_response(
host = self.config.domain,
actor = instance.actor,
followid = instance.followid,
accept = data['accept']
accept = boolean(data['accept'])
)
self.app.push_message(instance.inbox, message, instance)
@ -333,15 +454,11 @@ class DomainBan(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['domain'], ['note', 'reason'])
if isinstance(data, Response):
return data
data['domain'] = data['domain'].encode('idna').decode()
with self.database.session() as conn:
if conn.get_domain_ban(data['domain']) is not None:
return Response.new_error(400, 'Domain already banned', 'json')
raise HttpError(400, 'Domain already banned')
ban = conn.put_domain_ban(
domain = data['domain'],
@ -356,16 +473,13 @@ class DomainBan(View):
with self.database.session() as conn:
data = await self.get_api_data(['domain'], ['note', 'reason'])
if isinstance(data, Response):
return data
if not any([data.get('note'), data.get('reason')]):
return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
raise HttpError(400, 'Must include note and/or reason parameters')
data['domain'] = data['domain'].encode('idna').decode()
if conn.get_domain_ban(data['domain']) is None:
return Response.new_error(404, 'Domain not banned', 'json')
raise HttpError(404, 'Domain not banned')
ban = conn.update_domain_ban(
domain = data['domain'],
@ -379,14 +493,10 @@ class DomainBan(View):
async def delete(self, request: Request) -> Response:
with self.database.session() as conn:
data = await self.get_api_data(['domain'], [])
if isinstance(data, Response):
return data
data['domain'] = data['domain'].encode('idna').decode()
if conn.get_domain_ban(data['domain']) is None:
return Response.new_error(404, 'Domain not banned', 'json')
raise HttpError(404, 'Domain not banned')
conn.del_domain_ban(data['domain'])
@ -405,12 +515,9 @@ class SoftwareBan(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['name'], ['note', 'reason'])
if isinstance(data, Response):
return data
with self.database.session() as conn:
if conn.get_software_ban(data['name']) is not None:
return Response.new_error(400, 'Domain already banned', 'json')
raise HttpError(400, 'Domain already banned')
ban = conn.put_software_ban(
name = data['name'],
@ -424,15 +531,12 @@ class SoftwareBan(View):
async def patch(self, request: Request) -> Response:
data = await self.get_api_data(['name'], ['note', 'reason'])
if isinstance(data, Response):
return data
if not any([data.get('note'), data.get('reason')]):
return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
raise HttpError(400, 'Must include note and/or reason parameters')
with self.database.session() as conn:
if conn.get_software_ban(data['name']) is None:
return Response.new_error(404, 'Software not banned', 'json')
raise HttpError(404, 'Software not banned')
ban = conn.update_software_ban(
name = data['name'],
@ -446,12 +550,9 @@ class SoftwareBan(View):
async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['name'], [])
if isinstance(data, Response):
return data
with self.database.session() as conn:
if conn.get_software_ban(data['name']) is None:
return Response.new_error(404, 'Software not banned', 'json')
raise HttpError(404, 'Software not banned')
conn.del_software_ban(data['name'])
@ -474,12 +575,9 @@ class User(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['username', 'password'], ['handle'])
if isinstance(data, Response):
return data
with self.database.session() as conn:
if conn.get_user(data['username']) is not None:
return Response.new_error(404, 'User already exists', 'json')
raise HttpError(404, 'User already exists')
user = conn.put_user(
username = data['username'],
@ -494,9 +592,6 @@ class User(View):
async def patch(self, request: Request) -> Response:
data = await self.get_api_data(['username'], ['password', 'handle'])
if isinstance(data, Response):
return data
with self.database.session(True) as conn:
user = conn.put_user(
username = data['username'],
@ -511,12 +606,9 @@ class User(View):
async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['username'], [])
if isinstance(data, Response):
return data
with self.database.session(True) as conn:
if conn.get_user(data['username']) is None:
return Response.new_error(404, 'User does not exist', 'json')
raise HttpError(404, 'User does not exist')
conn.del_user(data['username'])
@ -535,14 +627,11 @@ class Whitelist(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['domain'], [])
if isinstance(data, Response):
return data
domain = data['domain'].encode('idna').decode()
with self.database.session() as conn:
if conn.get_domain_whitelist(domain) is not None:
return Response.new_error(400, 'Domain already added to whitelist', 'json')
raise HttpError(400, 'Domain already added to whitelist')
item = conn.put_domain_whitelist(domain)
@ -552,14 +641,11 @@ class Whitelist(View):
async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['domain'], [])
if isinstance(data, Response):
return data
domain = data['domain'].encode('idna').decode()
with self.database.session() as conn:
if conn.get_domain_whitelist(domain) is None:
return Response.new_error(404, 'Domain not in whitelist', 'json')
raise HttpError(404, 'Domain not in whitelist')
conn.del_domain_whitelist(domain)

View file

@ -1,10 +1,8 @@
from __future__ import annotations
from Crypto.Random import get_random_bytes
from aiohttp.abc import AbstractView
from aiohttp.hdrs import METH_ALL as METHODS
from aiohttp.web import HTTPMethodNotAllowed, Request
from base64 import b64encode
from aiohttp.web import Request
from bsql import Database
from collections.abc import Awaitable, Callable, Generator, Sequence, Mapping
from functools import cached_property
@ -15,7 +13,7 @@ from ..cache import Cache
from ..config import Config
from ..database import Connection
from ..http_client import HttpClient
from ..misc import Response, get_app
from ..misc import HttpError, Response, get_app
if TYPE_CHECKING:
from typing import Self
@ -43,10 +41,10 @@ def register_route(*paths: str) -> Callable[[type[View]], type[View]]:
class View(AbstractView):
def __await__(self) -> Generator[Any, None, Response]:
if self.request.method not in METHODS:
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
raise HttpError(405, f'"{self.request.method}" method not allowed')
if not (handler := self.handlers.get(self.request.method)):
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
raise HttpError(405, f'"{self.request.method}" method not allowed')
return self._run_handler(handler).__await__()
@ -58,7 +56,6 @@ class View(AbstractView):
async def _run_handler(self, handler: HandlerCallback, **kwargs: Any) -> Response:
self.request['hash'] = b64encode(get_random_bytes(16)).decode('ascii')
return await handler(self.request, **self.request.match_info, **kwargs)
@ -117,17 +114,18 @@ class View(AbstractView):
async def get_api_data(self,
required: list[str],
optional: list[str]) -> dict[str, str] | Response:
optional: list[str]) -> dict[str, str]:
if self.request.content_type in {'x-www-form-urlencoded', 'multipart/form-data'}:
if self.request.content_type in {'application/x-www-form-urlencoded', 'multipart/form-data'}:
post_data = convert_data(await self.request.post())
# post_data = {key: value for key, value in parse_qsl(await self.request.text())}
elif self.request.content_type == 'application/json':
try:
post_data = convert_data(await self.request.json())
except JSONDecodeError:
return Response.new_error(400, 'Invalid JSON data', 'json')
raise HttpError(400, 'Invalid JSON data')
else:
post_data = convert_data(self.request.query)
@ -139,9 +137,9 @@ class View(AbstractView):
data[key] = post_data[key]
except KeyError as e:
return Response.new_error(400, f'Missing {str(e)} pararmeter', 'json')
raise HttpError(400, f'Missing {str(e)} pararmeter')
for key in optional:
data[key] = post_data.get(key, '')
data[key] = post_data.get(key) # type: ignore[assignment]
return data

View file

@ -1,18 +1,13 @@
from aiohttp import web
from collections.abc import Awaitable, Callable
from typing import Any
from urllib.parse import unquote
from .base import View, register_route
from ..database import THEMES
from ..logger import LogLevel
from ..misc import Response, get_app
UNAUTH_ROUTES = {
'/',
'/login'
}
from ..misc import TOKEN_PATHS, Response
@web.middleware
@ -20,28 +15,25 @@ async def handle_frontend_path(
request: web.Request,
handler: Callable[[web.Request], Awaitable[Response]]) -> Response:
app = get_app()
if request['user'] is not None and request.path == '/login':
return Response.new_redir('/')
if request.path in UNAUTH_ROUTES or request.path.startswith('/admin'):
request['token'] = request.cookies.get('user-token')
request['user'] = None
if request.path.startswith(TOKEN_PATHS) and request['user'] is None:
if request.path == '/logout':
return Response.new_redir('/')
if request['token']:
with app.database.session(False) as conn:
request['user'] = conn.get_user_by_token(request['token'])
response = Response.new_redir(f'/login?redir={request.path}')
if request['user'] and request.path == '/login':
return Response.new('', 302, {'Location': '/'})
if not request['user'] and request.path.startswith('/admin'):
response = Response.new('', 302, {'Location': f'/login?redir={request.path}'})
if request['token'] is not None:
response.del_cookie('user-token')
return response
return response
response = await handler(request)
if not request.path.startswith('/api') and not request['user'] and request['token']:
response.del_cookie('user-token')
if not request.path.startswith('/api'):
if request['user'] is None and request['token'] is not None:
response.del_cookie('user-token')
return response
@ -54,14 +46,15 @@ class HomeView(View):
'instances': tuple(conn.get_inboxes())
}
data = self.template.render('page/home.haml', self, **context)
data = self.template.render('page/home.haml', self.request, **context)
return Response.new(data, ctype='html')
@register_route('/login')
class Login(View):
async def get(self, request: web.Request) -> Response:
data = self.template.render('page/login.haml', self)
redir = unquote(request.query.get('redir', '/'))
data = self.template.render('page/login.haml', self.request, redir = redir)
return Response.new(data, ctype = 'html')
@ -69,7 +62,7 @@ class Login(View):
class Logout(View):
async def get(self, request: web.Request) -> Response:
with self.database.session(True) as conn:
conn.del_token(request['token'])
conn.del_token(request['token'].code)
resp = Response.new_redir('/')
resp.del_cookie('user-token', domain = self.config.domain, path = '/')
@ -79,7 +72,7 @@ class Logout(View):
@register_route('/admin')
class Admin(View):
async def get(self, request: web.Request) -> Response:
return Response.new('', 302, {'Location': '/admin/instances'})
return Response.new_redir(f'/login?redir={request.path}', 301)
@register_route('/admin/instances')
@ -101,7 +94,7 @@ class AdminInstances(View):
if message:
context['message'] = message
data = self.template.render('page/admin-instances.haml', self, **context)
data = self.template.render('page/admin-instances.haml', self.request, **context)
return Response.new(data, ctype = 'html')
@ -123,7 +116,7 @@ class AdminWhitelist(View):
if message:
context['message'] = message
data = self.template.render('page/admin-whitelist.haml', self, **context)
data = self.template.render('page/admin-whitelist.haml', self.request, **context)
return Response.new(data, ctype = 'html')
@ -145,7 +138,7 @@ class AdminDomainBans(View):
if message:
context['message'] = message
data = self.template.render('page/admin-domain_bans.haml', self, **context)
data = self.template.render('page/admin-domain_bans.haml', self.request, **context)
return Response.new(data, ctype = 'html')
@ -167,7 +160,7 @@ class AdminSoftwareBans(View):
if message:
context['message'] = message
data = self.template.render('page/admin-software_bans.haml', self, **context)
data = self.template.render('page/admin-software_bans.haml', self.request, **context)
return Response.new(data, ctype = 'html')
@ -189,7 +182,7 @@ class AdminUsers(View):
if message:
context['message'] = message
data = self.template.render('page/admin-users.haml', self, **context)
data = self.template.render('page/admin-users.haml', self.request, **context)
return Response.new(data, ctype = 'html')
@ -213,7 +206,7 @@ class AdminConfig(View):
}
}
data = self.template.render('page/admin-config.haml', self, **context)
data = self.template.render('page/admin-config.haml', self.request, **context)
return Response.new(data, ctype = 'html')
@ -251,5 +244,5 @@ class ThemeCss(View):
except KeyError:
return Response.new('Invalid theme', 404)
data = self.template.render('variables.css', self, **context)
data = self.template.render('variables.css', self.request, **context)
return Response.new(data, ctype = 'css')