Compare commits

..

No commits in common. "773922e2630ad355340f5497ac5a86f899817985" and "b22b5bbefaa1b6cf13deaeb65396b135dc3fb192" have entirely different histories.

19 changed files with 283 additions and 701 deletions

View file

@ -4,14 +4,11 @@ import asyncio
import multiprocessing
import signal
import time
import traceback
from Crypto.Random import get_random_bytes
from aiohttp import web
from aiohttp.web import HTTPException, StaticResource
from aiohttp.web import StaticResource
from aiohttp_swagger import setup_swagger
from aputils.signer import Signer
from base64 import b64encode
from bsql import Database
from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta
@ -26,8 +23,7 @@ from .config import Config
from .database import Connection, get_database
from .database.schema import Instance
from .http_client import HttpClient
from .misc import HttpError, Message, Response, check_open_port, get_resource
from .misc import JSON_PATHS, TOKEN_PATHS
from .misc import Message, Response, check_open_port, get_resource
from .template import Template
from .views import VIEWS
from .views.api import handle_api_path
@ -57,9 +53,9 @@ class Application(web.Application):
def __init__(self, cfgpath: Path | None, dev: bool = False):
web.Application.__init__(self,
middlewares = [
handle_response_headers, # type: ignore[list-item]
handle_api_path, # type: ignore[list-item]
handle_frontend_path, # type: ignore[list-item]
handle_api_path # type: ignore[list-item]
handle_response_headers # type: ignore[list-item]
]
)
@ -286,76 +282,19 @@ class CacheCleanupThread(Thread):
self.running.clear()
def format_error(request: web.Request, error: HttpError) -> Response:
app: Application = request.app # type: ignore[assignment]
if request.path.startswith(JSON_PATHS) or 'json' in request.headers.get('accept', ''):
return Response.new({'error': error.body}, error.status, ctype = 'json')
else:
body = app.template.render('page/error.haml', request, e = error)
return Response.new(body, error.status, ctype = 'html')
@web.middleware
async def handle_response_headers(
request: web.Request,
handler: Callable[[web.Request], Awaitable[Response]]) -> Response:
request['hash'] = b64encode(get_random_bytes(16)).decode('ascii')
request['token'] = None
request['user'] = None
app: Application = request.app # type: ignore[assignment]
if request.path == "/" or request.path.startswith(TOKEN_PATHS):
with app.database.session() as conn:
tokens = (
request.headers.get('Authorization', '').replace('Bearer', '').strip(),
request.cookies.get('user-token')
)
for token in tokens:
if not token:
continue
request['token'] = conn.get_app_by_token(token)
if request['token'] is not None:
request['user'] = conn.get_user(request['token'].user)
break
try:
resp = await handler(request)
except HttpError as e:
resp = format_error(request, e)
except HTTPException as ae:
if ae.status == 404:
try:
text = (ae.text or "").split(":")[1].strip()
except IndexError:
text = ae.text or ""
resp = format_error(request, HttpError(ae.status, text))
else:
raise
except Exception as e:
resp = format_error(request, HttpError(500, f'{type(e).__name__}: {str(e)}'))
traceback.print_exc()
resp.headers['Server'] = 'ActivityRelay'
# Still have to figure out how csp headers work
if resp.content_type == 'text/html':
if resp.content_type == 'text/html' and not request.path.startswith("/api"):
resp.headers['Content-Security-Policy'] = get_csp(request)
if not request.app['dev'] and request.path.endswith(('.css', '.js', '.woff2')):
if not request.app['dev'] and request.path.endswith(('.css', '.js')):
# cache for 2 weeks
resp.headers['Cache-Control'] = 'public,max-age=1209600,immutable'

View file

@ -51,8 +51,8 @@ WHERE username = :value or handle = :value;
-- name: get-user-by-token
SELECT * FROM users
WHERE username = (
SELECT user FROM app
WHERE token = :token
SELECT user FROM tokens
WHERE code = :code
);
@ -67,28 +67,25 @@ DELETE FROM users
WHERE username = :value or handle = :value;
-- name: get-app
SELECT * FROM app
WHERE client_id = :id and client_secret = :secret;
-- name: get-token
SELECT * FROM tokens
WHERE code = :code;
-- name: get-app-with-token
SELECT * FROM app
WHERE client_id = :id and client_secret = :secret and token = :token;
-- name: put-token
INSERT INTO tokens (code, user, created)
VALUES (:code, :user, :created)
RETURNING *;
-- name: get-app-by-token
SELECT * FROM apps
WHERE token = :token;
-- name: del-app
DELETE FROM apps
WHERE client_id = :id and client_secret = :secret;
-- name: del-token
DELETE FROM tokens
WHERE code = :code;
-- name: del-app-with-token
DELETE FROM apps
WHERE client_id = :id and client_secret = :secret and token = :token;
-- name: del-token-user
DELETE FROM tokens
WHERE user = :username;
-- name: get-software-ban

View file

@ -11,8 +11,12 @@ from .. import logger as logging
from ..misc import boolean
if TYPE_CHECKING:
try:
from typing import Self
except ImportError:
from typing_extensions import Self
THEMES = {
'default': {
@ -73,7 +77,7 @@ CONFIG_CONVERT: dict[str, tuple[Callable[[Any], str], Callable[[str], Any]]] = {
@dataclass()
class ConfigData:
schema_version: int = 20240625
schema_version: int = 20240310
private_key: str = ''
approval_required: bool = False
log_level: logging.LogLevel = logging.LogLevel.INFO
@ -111,11 +115,11 @@ class ConfigData:
@classmethod
def DEFAULT(cls: type[Self], key: str) -> str | int | bool:
return cls.FIELD(key.replace('-', '_')).default # type: ignore[return-value]
return cls.FIELD(key.replace('-', '_')).default # type: ignore
@classmethod
def FIELD(cls: type[Self], key: str) -> Field[str | int | bool]:
def FIELD(cls: type[Self], key: str) -> Field[Any]:
for field in fields(cls):
if field.name == key.replace('-', '_'):
return field

View file

@ -1,14 +1,12 @@
from __future__ import annotations
import secrets
from argon2 import PasswordHasher
from blib import Date
from bsql import Connection as SqlConnection, Row, Update
from collections.abc import Iterator
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse
from uuid import uuid4
from . import schema
from .config import (
@ -51,36 +49,6 @@ class Connection(SqlConnection):
yield instance
def fix_timestamps(self) -> None:
for app in self.select('apps').all(schema.App):
data = {'created': app.created.timestamp(), 'accessed': app.accessed.timestamp()}
self.update('apps', data, client_id = app.client_id)
for item in self.select('cache'):
data = {'updated': Date.parse(item['updated']).timestamp()}
self.update('cache', data, id = item['id'])
for dban in self.select('domain_bans').all(schema.DomainBan):
data = {'created': dban.created.timestamp()}
self.update('domain_bans', data, domain = dban.domain)
for instance in self.select('inboxes').all(schema.Instance):
data = {'created': instance.created.timestamp()}
self.update('inboxes', data, domain = instance.domain)
for sban in self.select('software_bans').all(schema.SoftwareBan):
data = {'created': sban.created.timestamp()}
self.update('software_bans', data, name = sban.name)
for user in self.select('users').all(schema.User):
data = {'created': user.created.timestamp()}
self.update('users', data, username = user.username)
for wlist in self.select('whitelist').all(schema.Whitelist):
data = {'created': wlist.created.timestamp()}
self.update('whitelist', data, domain = wlist.domain)
def get_config(self, key: str) -> Any:
key = key.replace('_', '-')
@ -225,8 +193,8 @@ class Connection(SqlConnection):
return cur.one(schema.User)
def get_user_by_token(self, token: str) -> schema.User | None:
with self.run('get-user-by-token', {'token': token}) as cur:
def get_user_by_token(self, code: str) -> schema.User | None:
with self.run('get-user-by-token', {'code': code}) as cur:
return cur.one(schema.User)
@ -281,114 +249,35 @@ class Connection(SqlConnection):
pass
def get_app(self,
client_id: str,
client_secret: str,
token: str | None = None) -> schema.App | None:
def get_token(self, code: str) -> schema.Token | None:
with self.run('get-token', {'code': code}) as cur:
return cur.one(schema.Token)
params = {
'id': client_id,
'secret': client_secret
def get_tokens(self, username: str | None = None) -> Iterator[schema.Token]:
if username is not None:
return self.select('tokens').all(schema.Token)
return self.select('tokens', username = username).all(schema.Token)
def put_token(self, username: str) -> schema.Token:
data = {
'code': uuid4().hex,
'user': username,
'created': datetime.now(tz = timezone.utc)
}
if token is not None:
command = 'get-app-with-token'
params['token'] = token
else:
command = 'get-app'
with self.run(command, params) as cur:
return cur.one(schema.App)
def get_app_by_token(self, token: str) -> schema.App | None:
with self.run('get-app-by-token', {'token': token}) as cur:
return cur.one(schema.App)
def put_app(self, name: str, redirect_uri: str, website: str | None = None) -> schema.App:
params = {
'name': name,
'redirect_uri': redirect_uri,
'website': website,
'client_id': secrets.token_hex(20),
'client_secret': secrets.token_hex(20),
'created': Date.new_utc().timestamp(),
'accessed': Date.new_utc().timestamp()
}
with self.insert('apps', params) as cur:
if (row := cur.one(schema.App)) is None:
raise RuntimeError(f'Failed to insert app: {name}')
with self.run('put-token', data) as cur:
if (row := cur.one(schema.Token)) is None:
raise RuntimeError(f"Failed to insert token for user: {username}")
return row
def put_app_login(self, user: schema.User) -> schema.App:
params = {
'name': 'Web',
'redirect_uri': 'urn:ietf:wg:oauth:2.0:oob',
'website': None,
'user': user.username,
'client_id': secrets.token_hex(20),
'client_secret': secrets.token_hex(20),
'auth_code': None,
'token': secrets.token_hex(20),
'created': Date.new_utc().timestamp(),
'accessed': Date.new_utc().timestamp()
}
with self.insert('apps', params) as cur:
if (row := cur.one(schema.App)) is None:
raise RuntimeError(f'Failed to create app for "{user.username}"')
return row
def update_app(self, app: schema.App, user: schema.User | None, set_auth: bool) -> schema.App:
data: dict[str, str | None] = {}
if user is not None:
data['user'] = user.username
if set_auth:
data['auth_code'] = secrets.token_hex(20)
else:
data['token'] = secrets.token_hex(20)
data['auth_code'] = None
params = {
'client_id': app.client_id,
'client_secret': app.client_secret
}
with self.update('app', data, **params) as cur: # type: ignore[arg-type]
if (row := cur.one(schema.App)) is None:
raise RuntimeError('Failed to update row')
return row
def del_app(self, client_id: str, client_secret: str, token: str | None = None) -> bool:
params = {
'id': client_id,
'secret': client_secret
}
if token is not None:
command = 'del-app-with-token'
params['token'] = token
else:
command = 'del-app'
with self.run(command, params) as cur:
if cur.row_count > 1:
raise RuntimeError('More than 1 row was deleted')
return cur.row_count == 0
def del_token(self, code: str) -> None:
with self.run('del-token', {'code': code}):
pass
def get_domain_ban(self, domain: str) -> schema.DomainBan | None:

View file

@ -1,14 +1,14 @@
from __future__ import annotations
from blib import Date
import typing
from bsql import Column, Row, Tables
from collections.abc import Callable
from copy import deepcopy
from typing import TYPE_CHECKING, Any
from datetime import datetime
from .config import ConfigData
if TYPE_CHECKING:
if typing.TYPE_CHECKING:
from .connection import Connection
@ -16,16 +16,6 @@ VERSIONS: dict[int, Callable[[Connection], None]] = {}
TABLES = Tables()
def deserialize_timestamp(value: Any) -> Date:
try:
return Date.parse(value)
except ValueError:
pass
return Date.fromisoformat(value)
@TABLES.add_row
class Config(Row):
key: Column[str] = Column('key', 'text', primary_key = True, unique = True, nullable = False)
@ -37,107 +27,62 @@ class Config(Row):
class Instance(Row):
table_name: str = 'inboxes'
domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = False)
actor: Column[str] = Column('actor', 'text', unique = True)
inbox: Column[str] = Column('inbox', 'text', unique = True, nullable = False)
followid: Column[str] = Column('followid', 'text')
software: Column[str] = Column('software', 'text')
accepted: Column[Date] = Column('accepted', 'boolean')
created: Column[Date] = Column(
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
accepted: Column[datetime] = Column('accepted', 'boolean')
created: Column[datetime] = Column('created', 'timestamp', nullable = False)
@TABLES.add_row
class Whitelist(Row):
domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = True)
created: Column[Date] = Column(
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
created: Column[datetime] = Column('created', 'timestamp')
@TABLES.add_row
class DomainBan(Row):
table_name: str = 'domain_bans'
domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text')
created: Column[Date] = Column(
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
created: Column[datetime] = Column('created', 'timestamp')
@TABLES.add_row
class SoftwareBan(Row):
table_name: str = 'software_bans'
name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text')
created: Column[Date] = Column(
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
created: Column[datetime] = Column('created', 'timestamp')
@TABLES.add_row
class User(Row):
table_name: str = 'users'
username: Column[str] = Column(
'username', 'text', primary_key = True, unique = True, nullable = False)
hash: Column[str] = Column('hash', 'text', nullable = False)
handle: Column[str] = Column('handle', 'text')
created: Column[Date] = Column(
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
created: Column[datetime] = Column('created', 'timestamp')
@TABLES.add_row
class App(Row):
table_name: str = 'apps'
class Token(Row):
table_name: str = 'tokens'
client_id: Column[str] = Column(
'client_id', 'text', primary_key = True, unique = True, nullable = False)
client_secret: Column[str] = Column('client_secret', 'text', nullable = False)
name: Column[str] = Column('name', 'text')
website: Column[str] = Column('website', 'text')
redirect_uri: Column[str] = Column('redirect_uri', 'text', nullable = False)
token: Column[str | None] = Column('token', 'text')
auth_code: Column[str | None] = Column('auth_code', 'text')
user: Column[str | None] = Column('user', 'text')
created: Column[Date] = Column(
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
accessed: Column[Date] = Column(
'accessed', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
def get_api_data(self, include_token: bool = False) -> dict[str, Any]:
data = deepcopy(self)
data.pop('user')
data.pop('auth_code')
if not include_token:
data.pop('token')
return data
code: Column[str] = Column('code', 'text', primary_key = True, unique = True, nullable = False)
user: Column[str] = Column('user', 'text', nullable = False)
created: Column[datetime] = Column('created', 'timestamp')
def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]:
@ -158,11 +103,5 @@ def migrate_20240206(conn: Connection) -> None:
@migration
def migrate_20240310(conn: Connection) -> None:
conn.execute('ALTER TABLE "inboxes" ADD COLUMN "accepted" BOOLEAN').close()
conn.execute('UPDATE "inboxes" SET accepted = 1').close()
@migration
def migrate_20240625(conn: Connection) -> None:
conn.create_tables()
conn.execute('DROP TABLE tokens').close()
conn.execute("ALTER TABLE inboxes ADD COLUMN accepted BOOLEAN")
conn.execute("UPDATE inboxes SET accepted = 1")

View file

@ -1,5 +1,5 @@
-macro menu_item(name, path)
-if request.path == path or (path != "/" and request.path.startswith(path))
-if view.request.path == path or (path != "/" and view.request.path.startswith(path))
%a.button(href="{{path}}" active="true") -> =name
-else
@ -10,12 +10,12 @@
%head
%title << {{config.name}}: {{page}}
%meta(charset="UTF-8")
%meta(name="ort" content="width=device-width, initial-scale=1")
%link(rel="stylesheet" type="text/css" href="/theme/{{config.theme}}.css?{{version}}" nonce="{{request['hash']}}" class="theme")
%link(rel="stylesheet" type="text/css" href="/static/style.css?{{version}}" nonce="{{request['hash']}}")
%link(rel="stylesheet" type="text/css" href="/static/bootstrap-icons.css?{{version}}" nonce="{{request['hash']}}")
%meta(name="viewport" content="width=device-width, initial-scale=1")
%link(rel="stylesheet" type="text/css" href="/theme/{{config.theme}}.css?{{version}}" nonce="{{view.request['hash']}}" class="theme")
%link(rel="stylesheet" type="text/css" href="/static/style.css?{{version}}" nonce="{{view.request['hash']}}")
%link(rel="stylesheet" type="text/css" href="/static/bootstrap-icons.css?{{version}}" nonce="{{view.request['hash']}}")
%link(rel="manifest" href="/manifest.json?{{version}}")
%script(type="application/javascript" src="/static/functions.js?{{version}}" nonce="{{request['hash']}}" defer)
%script(type="application/javascript" src="/static/functions.js?{{version}}" nonce="{{view.request['hash']}}" defer)
-block head
%body
@ -26,7 +26,7 @@
{{menu_item("Home", "/")}}
-if request["user"]
-if view.request["user"]
{{menu_item("Instances", "/admin/instances")}}
{{menu_item("Whitelist", "/admin/whitelist")}}
{{menu_item("Domain Bans", "/admin/domain_bans")}}
@ -61,11 +61,11 @@
#footer.section
.col1
-if not request["user"]
-if not view.request["user"]
%a(href="/login") << Login
-else
=request["user"]["username"]
=view.request["user"]["username"]
(
%a(href="/logout") << Logout
)

View file

@ -1,31 +0,0 @@
-extends "base.haml"
-set page="App Authorization"
-block content
%fieldset.section
%legend << App Authorization
-if application.website
#title << Application "<a href="{{application.website}}" target="_new">{{application.name}}</a>" wants full API access
-else
#title << Application "{{application.name}}" wants full API access
#buttons
.spacer
%form(action="/oauth/authorize" method="POST")
%input(type="hidden" name="client_id" value="{{application.client_id}}")
%input(type="hidden" name="client_secret" value="{{application.client_secret}}")
%input(type="hidden" name="redirect_uri" value="{{application.redirect_uri}}")
%input(type="hidden" name="response" value="true")
%input.button(type="submit" value="Allow")
%form(action="/oauth/authorize" method="POST")
%input(type="hidden" name="client_id" value="{{application.client_id}}")
%input(type="hidden" name="client_secret" value="{{application.client_secret}}")
%input(type="hidden" name="redirect_uri" value="{{application.redirect_uri}}")
%input(type="hidden" name="response" value="false")
%input.button(type="submit" value="Deny")
.spacer

View file

@ -1,18 +0,0 @@
-extends "base.haml"
-set page="App Authorization"
-block content
%fieldset.section
%legend << App Authorization Code
-if application.website
%p
Copy the following code into
%a(href="{{application.website}}" target="_main") -> %code -> =application.name
-else
%p
Copy the following code info
%code -> =application.name
%pre#code -> =application.auth_code

View file

@ -1,7 +0,0 @@
-extends "base.haml"
-set page="Error"
-block content
.section.error
.title << HTTP Error {{e.status}}
.body -> =e.body

View file

@ -12,6 +12,4 @@
%label(for="password") << Password
%input(id="password" name="password" placeholder="Password" type="password")
%input#redir(type="hidden" name="redir" value="{{redir}}")
%input.submit(type="button" value="Login")

View file

@ -483,15 +483,13 @@ function page_instance() {
function page_login() {
const fields = {
username: document.querySelector("#username"),
password: document.querySelector("#password"),
redir: document.querySelector("#redir")
};
password: document.querySelector("#password")
}
async function login(event) {
const values = {
username: fields.username.value.trim(),
password: fields.password.value.trim(),
redir: fields.redir.value.trim()
password: fields.password.value.trim()
}
if (values.username === "" | values.password === "") {
@ -500,14 +498,14 @@ function page_login() {
}
try {
await request("POST", "v1/login", values);
await request("POST", "v1/token", values);
} catch (error) {
toast(error);
return;
}
document.location = values.redir;
document.location = "/";
}
@ -850,6 +848,9 @@ if (location.pathname.startsWith("/admin/config")) {
} else if (location.pathname.startsWith("/admin/instances")) {
page_instance();
} else if (location.pathname.startsWith("/admin/login")) {
page_login();
} else if (location.pathname.startsWith("/admin/software_bans")) {
page_software_ban();
@ -858,7 +859,4 @@ if (location.pathname.startsWith("/admin/config")) {
} else if (location.pathname.startsWith("/admin/whitelist")) {
page_whitelist();
} else if (location.pathname.startsWith("/login")) {
page_login();
}

View file

@ -338,44 +338,6 @@ textarea {
}
/* error */
#content.page-error {
text-align: center;
}
#content.page-error .title {
font-size: 24px;
font-weight: bold;
}
/* auth */
#content.page-app_authorization {
text-align: center;
}
#content.page-app_authorization #code {
background: var(--background);
border: 1px solid var(--border);
font-size: 18px;
margin: 0 auto;
width: max-content;
padding: 5px;
}
#content.page-app_authorization #title {
font-size: 24px;
}
#content.page-app_authorization #buttons {
display: grid;
grid-template-columns: auto max-content max-content auto;
grid-gap: var(--spacing);
justify-items: center;
margin: var(--spacing) 0;
}
@keyframes show_toast {
0% {
transform: translateX(100%);

View file

@ -212,21 +212,6 @@ def cli_run(ctx: click.Context, dev: bool = False) -> None:
os._exit(0)
@cli.command('db-maintenance')
@click.option('--fix-timestamps', '-t', is_flag = True,
help = 'Make sure timestamps in the database are float values')
@click.pass_context
def cli_db_maintenance(ctx: click.Context, fix_timestamps: bool) -> None:
'Perform maintenance tasks on the database'
if fix_timestamps:
with ctx.obj.database.session(True) as conn:
conn.fix_timestamps()
with ctx.obj.database.session(False) as conn:
with conn.execute("VACUUM"):
pass
@cli.command('convert')
@click.option('--old-config', '-o', help = 'Path to the config file to convert from')
@ -254,18 +239,18 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
ctx.obj.config.set('domain', config['host'])
ctx.obj.config.save()
# fix: mypy complains about the types returned by click.progressbar when updating click to 8.1.7
with get_database(ctx.obj.config) as db:
with db.session(True) as conn:
conn.put_config('private-key', database['private-key'])
conn.put_config('note', config['note'])
conn.put_config('whitelist-enabled', config['whitelist_enabled'])
with click.progressbar(
with click.progressbar( # type: ignore
database['relay-list'].values(),
label = 'Inboxes'.ljust(15),
width = 0
) as inboxes:
for inbox in inboxes:
if inbox['software'] in {'akkoma', 'pleroma'}:
actor = f'https://{inbox["domain"]}/relay'
@ -284,7 +269,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
software = inbox['software']
)
with click.progressbar(
with click.progressbar( # type: ignore
config['blocked_software'],
label = 'Banned software'.ljust(15),
width = 0
@ -296,7 +281,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
reason = 'relay' if software in RELAY_SOFTWARE else None
)
with click.progressbar(
with click.progressbar( # type: ignore
config['blocked_instances'],
label = 'Banned domains'.ljust(15),
width = 0
@ -305,7 +290,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
for domain in banned_software:
conn.put_domain_ban(domain)
with click.progressbar(
with click.progressbar( # type: ignore
config['whitelist'],
label = 'Whitelist'.ljust(15),
width = 0

View file

@ -62,27 +62,6 @@ SOFTWARE = (
'gotosocial'
)
JSON_PATHS: tuple[str, ...] = (
'/api/v1',
'/actor',
'/inbox',
'/outbox',
'/following',
'/followers',
'/.well-known',
'/nodeinfo',
'/oauth/token',
'/oauth/revoke'
)
TOKEN_PATHS: tuple[str, ...] = (
'/logout',
'/admin',
'/api',
'/oauth/authorize',
'/oauth/revoke'
)
def boolean(value: Any) -> bool:
if isinstance(value, str):
@ -134,17 +113,6 @@ def get_resource(path: str) -> Path:
return Path(str(pkgfiles('relay'))).joinpath(path)
class HttpError(Exception):
def __init__(self,
status: int,
body: str) -> None:
self.body: str = body
self.status: int = status
Exception.__init__(self, f"HTTP Error {status}: {body}")
class JsonEncoder(json.JSONEncoder):
def default(self, o: Any) -> str:
if isinstance(o, datetime):
@ -274,9 +242,9 @@ class Response(AiohttpResponse):
@classmethod
def new_redir(cls: type[Self], path: str, status: int = 307) -> Self:
def new_redir(cls: type[Self], path: str) -> Self:
body = f'Redirect to <a href="{path}">{path}</a>'
return cls.new(body, status, {'Location': path}, ctype = 'html')
return cls.new(body, 302, {'Location': path})
@property

View file

@ -2,7 +2,6 @@ from __future__ import annotations
import textwrap
from aiohttp.web import Request
from collections.abc import Callable
from hamlish_jinja import HamlishExtension
from jinja2 import Environment, FileSystemLoader
@ -14,6 +13,7 @@ from typing import TYPE_CHECKING, Any
from . import __version__
from .misc import get_resource
from .views.base import View
if TYPE_CHECKING:
from .application import Application
@ -43,12 +43,12 @@ class Template(Environment):
self.hamlish_mode = 'indented'
def render(self, path: str, request: Request, **context: Any) -> str:
def render(self, path: str, view: View | None = None, **context: Any) -> str:
with self.app.database.session(False) as conn:
config = conn.get_config_all()
new_context = {
'request': request,
'view': view,
'domain': self.app.config.domain,
'version': __version__,
'config': config,

View file

@ -7,7 +7,7 @@ from .base import View, register_route
from .. import logger as logging
from ..database import schema
from ..misc import HttpError, Message, Response
from ..misc import Message, Response
from ..processors import run_processor
@ -39,7 +39,8 @@ class ActorView(View):
async def post(self, request: Request) -> Response:
await self.get_post_data()
if response := await self.get_post_data():
return response
with self.database.session() as conn:
self.instance = conn.get_inbox(self.actor.shared_inbox) # type: ignore[assignment]
@ -64,13 +65,13 @@ class ActorView(View):
return Response.new(status = 202)
async def get_post_data(self) -> None:
async def get_post_data(self) -> Response | None:
try:
self.signature = aputils.Signature.parse(self.request.headers['signature'])
except KeyError:
logging.verbose('Missing signature header')
raise HttpError(400, 'missing signature header')
return Response.new_error(400, 'missing signature header', 'json')
try:
message: Message | None = await self.request.json(loads = Message.parse)
@ -78,17 +79,17 @@ class ActorView(View):
except Exception:
traceback.print_exc()
logging.verbose('Failed to parse inbox message')
raise HttpError(400, 'failed to parse message')
return Response.new_error(400, 'failed to parse message', 'json')
if message is None:
logging.verbose('empty message')
raise HttpError(400, 'missing message')
return Response.new_error(400, 'missing message', 'json')
self.message = message
if 'actor' not in self.message:
logging.verbose('actor not in message')
raise HttpError(400, 'no actor in message')
return Response.new_error(400, 'no actor in message', 'json')
try:
self.actor = await self.client.get(self.signature.keyid, True, Message)
@ -97,24 +98,26 @@ class ActorView(View):
# ld signatures aren't handled atm, so just ignore it
if self.message.type == 'Delete':
logging.verbose('Instance sent a delete which cannot be handled')
raise HttpError(202, '')
return Response.new(status=202)
logging.verbose(f'Failed to fetch actor: {self.signature.keyid}')
raise HttpError(400, 'failed to fetch actor')
return Response.new_error(400, 'failed to fetch actor', 'json')
try:
self.signer = self.actor.signer
except KeyError:
logging.verbose('Actor missing public key: %s', self.signature.keyid)
raise HttpError(400, 'actor missing public key')
return Response.new_error(400, 'actor missing public key', 'json')
try:
await self.signer.validate_request_async(self.request)
except aputils.SignatureFailureError as e:
logging.verbose('signature validation failed for "%s": %s', self.actor.id, e)
raise HttpError(401, str(e))
return Response.new_error(401, str(e), 'json')
return None
@register_route('/outbox')

View file

@ -2,15 +2,15 @@ import traceback
from aiohttp.web import Request, middleware
from argon2.exceptions import VerifyMismatchError
from blib import convert_to_boolean
from collections.abc import Awaitable, Callable, Sequence
from typing import Any
from urllib.parse import urlparse
from .base import View, register_route
from .. import __version__
from ..database import ConfigData, schema
from ..misc import HttpError, Message, Response, boolean
from ..database import ConfigData
from ..misc import Message, Response, boolean, get_app
DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob'
@ -22,8 +22,6 @@ ALLOWED_HEADERS: set[str] = {
PUBLIC_API_PATHS: Sequence[tuple[str, str]] = (
('GET', '/api/v1/relay'),
('POST', '/api/v1/app'),
('POST', '/api/v1/login'),
('POST', '/api/v1/token')
)
@ -39,181 +37,64 @@ def check_api_path(method: str, path: str) -> bool:
async def handle_api_path(
request: Request,
handler: Callable[[Request], Awaitable[Response]]) -> Response:
try:
if (token := request.cookies.get('user-token')):
request['token'] = token
if not request.path.startswith('/api'):
return await handler(request)
else:
request['token'] = request.headers['Authorization'].replace('Bearer', '').strip()
with get_app().database.session() as conn:
request['user'] = conn.get_user_by_token(request['token'])
except (KeyError, ValueError):
request['token'] = None
request['user'] = None
if request.method != "OPTIONS" and check_api_path(request.method, request.path):
if request['token'] is None:
raise HttpError(401, 'Missing token')
if not request['token']:
return Response.new_error(401, 'Missing token', 'json')
if request['user'] is None:
raise HttpError(401, 'Invalid token')
if not request['user']:
return Response.new_error(401, 'Invalid token', 'json')
response = await handler(request)
if request.path.startswith('/api'):
response.headers['Access-Control-Allow-Origin'] = '*'
response.headers['Access-Control-Allow-Headers'] = ', '.join(ALLOWED_HEADERS)
return response
@register_route('/oauth/authorize')
class OauthAuthorize(View):
async def get(self, request: Request) -> Response:
data = await self.get_api_data(['response_type', 'client_id', 'redirect_uri'], [])
if data['response_type'] != 'code':
raise HttpError(400, 'Response type is not "code"')
with self.database.session(True) as conn:
with conn.select('app', client_id = data['client_id']) as cur:
if (app := cur.one(schema.App)) is None:
raise HttpError(404, 'Could not find app')
if app.token is not None or app.auth_code is not None:
context = {'application': app}
html = self.template.render(
'page/authorize_show.haml', self.request, **context
)
return Response.new(html, ctype = 'html')
if data['redirect_uri'] != app.redirect_uri:
raise HttpError(400, 'redirect_uri does not match application')
context = {'application': app}
html = self.template.render('page/authorize_new.haml', self.request, **context)
return Response.new(html, ctype = 'html')
async def post(self, request: Request) -> Response:
data = await self.get_api_data(
['client_id', 'client_secret', 'redirect_uri', 'response'], []
)
with self.database.session(True) as conn:
if (app := conn.get_app(data['client_id'], data['client_secret'])) is None:
return Response.new_error(404, 'Could not find app', 'json')
if convert_to_boolean(data['response']):
if app.auth_code is None:
app = conn.update_app(app, request['user'], True)
if app.redirect_uri == DEFAULT_REDIRECT:
context = {'application': app}
html = self.template.render(
'page/authorize_show.haml', self.request, **context
)
return Response.new(html, ctype = 'html')
return Response.new_redir(f'{app.redirect_uri}?code={app.auth_code}')
if not conn.del_app(app.client_id, app.client_secret):
raise HttpError(404, 'App not found')
return Response.new_redir('/')
@register_route('/oauth/token')
class OauthToken(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(
['grant_type', 'code', 'client_id', 'client_secret', 'redirect_uri'], []
)
if data['grant_type'] != 'authorization_code':
raise HttpError(400, 'Invalid grant type')
with self.database.session(True) as conn:
if (app := conn.get_app(data['client_id'], data['client_secret'])) is None:
raise HttpError(404, 'Application not found')
if app.auth_code != data['code']:
raise HttpError(400, 'Invalid authentication code')
if app.redirect_uri != data['redirect_uri']:
raise HttpError(400, 'Invalid redirect uri')
app = conn.update_app(app, request['user'], False)
return Response.new(app.get_api_data(True), ctype = 'json')
@register_route('/oauth/revoke')
class OauthRevoke(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['client_id', 'client_secret', 'token'], [])
with self.database.session(True) as conn:
if (app := conn.get_app(**data)) is None:
raise HttpError(404, 'Could not find token')
if app.user != request['token'].username:
raise HttpError(403, 'Invalid token')
if not conn.del_app(**data):
raise HttpError(400, 'Token not removed')
return Response.new({'msg': 'Token deleted'}, ctype = 'json')
@register_route('/api/v1/app')
class App(View):
async def get(self, request: Request) -> Response:
data = await self.get_api_data(['client_id', 'client_secret'], [])
with self.database.session(False) as conn:
if (app := conn.get_app(data['client_id'], data['client_secret'])) is None:
raise HttpError(404, 'Application cannot be found')
return Response.new(app.get_api_data(), ctype = 'json')
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['name', 'redirect_uri'], ['website'])
with self.database.session(True) as conn:
app = conn.put_app(
name = data['name'],
redirect_uri = data['redirect_uri'],
website = data.get('website')
)
return Response.new(app.get_api_data(), ctype = 'json')
async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['client_id', 'client_secret'], [])
with self.database.session(True) as conn:
if not conn.del_app(data['client_id'], data['client_secret'], request['token'].code):
raise HttpError(400, 'Token not removed')
return Response.new({'msg': 'Token deleted'}, ctype = 'json')
@register_route('/api/v1/login')
@register_route('/api/v1/token')
class Login(View):
async def get(self, request: Request) -> Response:
return Response.new({'message': 'Token valid'}, ctype = 'json')
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['username', 'password'], [])
if isinstance(data, Response):
return data
with self.database.session(True) as conn:
if not (user := conn.get_user(data['username'])):
raise HttpError(401, 'User not found')
return Response.new_error(401, 'User not found', 'json')
try:
conn.hasher.verify(user['hash'], data['password'])
except VerifyMismatchError:
raise HttpError(401, 'Invalid password')
return Response.new_error(401, 'Invalid password', 'json')
app = conn.put_app_login(user)
token = conn.put_token(data['username'])
resp = Response.new({'token': app.token}, ctype = 'json')
resp = Response.new({'token': token.code}, ctype = 'json')
resp.set_cookie(
'user-token',
app.token, # type: ignore[arg-type]
token.code,
max_age = 60 * 60 * 24 * 365,
domain = self.config.domain,
path = '/',
@ -225,6 +106,13 @@ class Login(View):
return resp
async def delete(self, request: Request) -> Response:
with self.database.session() as conn:
conn.del_token(request['token'])
return Response.new({'message': 'Token revoked'}, ctype = 'json')
@register_route('/api/v1/relay')
class RelayInfo(View):
async def get(self, request: Request) -> Response:
@ -267,10 +155,14 @@ class Config(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['key', 'value'], [])
if isinstance(data, Response):
return data
data['key'] = data['key'].replace('-', '_')
if data['key'] not in ConfigData.USER_KEYS():
raise HttpError(400, 'Invalid key')
return Response.new_error(400, 'Invalid key', 'json')
with self.database.session() as conn:
conn.put_config(data['key'], data['value'])
@ -281,8 +173,11 @@ class Config(View):
async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['key'], [])
if isinstance(data, Response):
return data
if data['key'] not in ConfigData.USER_KEYS():
raise HttpError(400, 'Invalid key')
return Response.new_error(400, 'Invalid key', 'json')
with self.database.session() as conn:
conn.put_config(data['key'], ConfigData.DEFAULT(data['key']))
@ -301,11 +196,15 @@ class Inbox(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['actor'], ['inbox', 'software', 'followid'])
if isinstance(data, Response):
return data
data['domain'] = urlparse(data["actor"]).netloc
with self.database.session() as conn:
if conn.get_inbox(data['domain']) is not None:
raise HttpError(404, 'Instance already in database')
return Response.new_error(404, 'Instance already in database', 'json')
data['domain'] = data['domain'].encode('idna').decode()
@ -315,7 +214,7 @@ class Inbox(View):
except Exception:
traceback.print_exc()
raise HttpError(500, 'Failed to fetch actor')
return Response.new_error(500, 'Failed to fetch actor', 'json')
data['inbox'] = actor_data.shared_inbox
@ -341,10 +240,14 @@ class Inbox(View):
async def patch(self, request: Request) -> Response:
with self.database.session() as conn:
data = await self.get_api_data(['domain'], ['actor', 'software', 'followid'])
if isinstance(data, Response):
return data
data['domain'] = data['domain'].encode('idna').decode()
if (instance := conn.get_inbox(data['domain'])) is None:
raise HttpError(404, 'Instance with domain not found')
return Response.new_error(404, 'Instance with domain not found', 'json')
instance = conn.put_inbox(
instance.domain,
@ -359,10 +262,14 @@ class Inbox(View):
async def delete(self, request: Request) -> Response:
with self.database.session() as conn:
data = await self.get_api_data(['domain'], [])
if isinstance(data, Response):
return data
data['domain'] = data['domain'].encode('idna').decode()
if not conn.get_inbox(data['domain']):
raise HttpError(404, 'Instance with domain not found')
return Response.new_error(404, 'Instance with domain not found', 'json')
conn.del_inbox(data['domain'])
@ -379,21 +286,26 @@ class RequestView(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['domain', 'accept'], [])
data: dict[str, Any] | Response = await self.get_api_data(['domain', 'accept'], [])
if isinstance(data, Response):
return data
data['accept'] = boolean(data['accept'])
data['domain'] = data['domain'].encode('idna').decode()
try:
with self.database.session(True) as conn:
instance = conn.put_request_response(data['domain'], boolean(data['accept']))
instance = conn.put_request_response(data['domain'], data['accept'])
except KeyError:
raise HttpError(404, 'Request not found')
return Response.new_error(404, 'Request not found', 'json')
message = Message.new_response(
host = self.config.domain,
actor = instance.actor,
followid = instance.followid,
accept = boolean(data['accept'])
accept = data['accept']
)
self.app.push_message(instance.inbox, message, instance)
@ -421,11 +333,15 @@ class DomainBan(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['domain'], ['note', 'reason'])
if isinstance(data, Response):
return data
data['domain'] = data['domain'].encode('idna').decode()
with self.database.session() as conn:
if conn.get_domain_ban(data['domain']) is not None:
raise HttpError(400, 'Domain already banned')
return Response.new_error(400, 'Domain already banned', 'json')
ban = conn.put_domain_ban(
domain = data['domain'],
@ -440,13 +356,16 @@ class DomainBan(View):
with self.database.session() as conn:
data = await self.get_api_data(['domain'], ['note', 'reason'])
if isinstance(data, Response):
return data
if not any([data.get('note'), data.get('reason')]):
raise HttpError(400, 'Must include note and/or reason parameters')
return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
data['domain'] = data['domain'].encode('idna').decode()
if conn.get_domain_ban(data['domain']) is None:
raise HttpError(404, 'Domain not banned')
return Response.new_error(404, 'Domain not banned', 'json')
ban = conn.update_domain_ban(
domain = data['domain'],
@ -460,10 +379,14 @@ class DomainBan(View):
async def delete(self, request: Request) -> Response:
with self.database.session() as conn:
data = await self.get_api_data(['domain'], [])
if isinstance(data, Response):
return data
data['domain'] = data['domain'].encode('idna').decode()
if conn.get_domain_ban(data['domain']) is None:
raise HttpError(404, 'Domain not banned')
return Response.new_error(404, 'Domain not banned', 'json')
conn.del_domain_ban(data['domain'])
@ -482,9 +405,12 @@ class SoftwareBan(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['name'], ['note', 'reason'])
if isinstance(data, Response):
return data
with self.database.session() as conn:
if conn.get_software_ban(data['name']) is not None:
raise HttpError(400, 'Domain already banned')
return Response.new_error(400, 'Domain already banned', 'json')
ban = conn.put_software_ban(
name = data['name'],
@ -498,12 +424,15 @@ class SoftwareBan(View):
async def patch(self, request: Request) -> Response:
data = await self.get_api_data(['name'], ['note', 'reason'])
if isinstance(data, Response):
return data
if not any([data.get('note'), data.get('reason')]):
raise HttpError(400, 'Must include note and/or reason parameters')
return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
with self.database.session() as conn:
if conn.get_software_ban(data['name']) is None:
raise HttpError(404, 'Software not banned')
return Response.new_error(404, 'Software not banned', 'json')
ban = conn.update_software_ban(
name = data['name'],
@ -517,9 +446,12 @@ class SoftwareBan(View):
async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['name'], [])
if isinstance(data, Response):
return data
with self.database.session() as conn:
if conn.get_software_ban(data['name']) is None:
raise HttpError(404, 'Software not banned')
return Response.new_error(404, 'Software not banned', 'json')
conn.del_software_ban(data['name'])
@ -542,9 +474,12 @@ class User(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['username', 'password'], ['handle'])
if isinstance(data, Response):
return data
with self.database.session() as conn:
if conn.get_user(data['username']) is not None:
raise HttpError(404, 'User already exists')
return Response.new_error(404, 'User already exists', 'json')
user = conn.put_user(
username = data['username'],
@ -559,6 +494,9 @@ class User(View):
async def patch(self, request: Request) -> Response:
data = await self.get_api_data(['username'], ['password', 'handle'])
if isinstance(data, Response):
return data
with self.database.session(True) as conn:
user = conn.put_user(
username = data['username'],
@ -573,9 +511,12 @@ class User(View):
async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['username'], [])
if isinstance(data, Response):
return data
with self.database.session(True) as conn:
if conn.get_user(data['username']) is None:
raise HttpError(404, 'User does not exist')
return Response.new_error(404, 'User does not exist', 'json')
conn.del_user(data['username'])
@ -594,11 +535,14 @@ class Whitelist(View):
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['domain'], [])
if isinstance(data, Response):
return data
domain = data['domain'].encode('idna').decode()
with self.database.session() as conn:
if conn.get_domain_whitelist(domain) is not None:
raise HttpError(400, 'Domain already added to whitelist')
return Response.new_error(400, 'Domain already added to whitelist', 'json')
item = conn.put_domain_whitelist(domain)
@ -608,11 +552,14 @@ class Whitelist(View):
async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['domain'], [])
if isinstance(data, Response):
return data
domain = data['domain'].encode('idna').decode()
with self.database.session() as conn:
if conn.get_domain_whitelist(domain) is None:
raise HttpError(404, 'Domain not in whitelist')
return Response.new_error(404, 'Domain not in whitelist', 'json')
conn.del_domain_whitelist(domain)

View file

@ -1,8 +1,10 @@
from __future__ import annotations
from Crypto.Random import get_random_bytes
from aiohttp.abc import AbstractView
from aiohttp.hdrs import METH_ALL as METHODS
from aiohttp.web import Request
from aiohttp.web import HTTPMethodNotAllowed, Request
from base64 import b64encode
from bsql import Database
from collections.abc import Awaitable, Callable, Generator, Sequence, Mapping
from functools import cached_property
@ -13,7 +15,7 @@ from ..cache import Cache
from ..config import Config
from ..database import Connection
from ..http_client import HttpClient
from ..misc import HttpError, Response, get_app
from ..misc import Response, get_app
if TYPE_CHECKING:
from typing import Self
@ -41,10 +43,10 @@ def register_route(*paths: str) -> Callable[[type[View]], type[View]]:
class View(AbstractView):
def __await__(self) -> Generator[Any, None, Response]:
if self.request.method not in METHODS:
raise HttpError(405, f'"{self.request.method}" method not allowed')
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
if not (handler := self.handlers.get(self.request.method)):
raise HttpError(405, f'"{self.request.method}" method not allowed')
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
return self._run_handler(handler).__await__()
@ -56,6 +58,7 @@ class View(AbstractView):
async def _run_handler(self, handler: HandlerCallback, **kwargs: Any) -> Response:
self.request['hash'] = b64encode(get_random_bytes(16)).decode('ascii')
return await handler(self.request, **self.request.match_info, **kwargs)
@ -114,18 +117,17 @@ class View(AbstractView):
async def get_api_data(self,
required: list[str],
optional: list[str]) -> dict[str, str]:
optional: list[str]) -> dict[str, str] | Response:
if self.request.content_type in {'application/x-www-form-urlencoded', 'multipart/form-data'}:
if self.request.content_type in {'x-www-form-urlencoded', 'multipart/form-data'}:
post_data = convert_data(await self.request.post())
# post_data = {key: value for key, value in parse_qsl(await self.request.text())}
elif self.request.content_type == 'application/json':
try:
post_data = convert_data(await self.request.json())
except JSONDecodeError:
raise HttpError(400, 'Invalid JSON data')
return Response.new_error(400, 'Invalid JSON data', 'json')
else:
post_data = convert_data(self.request.query)
@ -137,9 +139,9 @@ class View(AbstractView):
data[key] = post_data[key]
except KeyError as e:
raise HttpError(400, f'Missing {str(e)} pararmeter')
return Response.new_error(400, f'Missing {str(e)} pararmeter', 'json')
for key in optional:
data[key] = post_data.get(key) # type: ignore[assignment]
data[key] = post_data.get(key, '')
return data

View file

@ -1,13 +1,18 @@
from aiohttp import web
from collections.abc import Awaitable, Callable
from typing import Any
from urllib.parse import unquote
from .base import View, register_route
from ..database import THEMES
from ..logger import LogLevel
from ..misc import TOKEN_PATHS, Response
from ..misc import Response, get_app
UNAUTH_ROUTES = {
'/',
'/login'
}
@web.middleware
@ -15,24 +20,27 @@ async def handle_frontend_path(
request: web.Request,
handler: Callable[[web.Request], Awaitable[Response]]) -> Response:
if request['user'] is not None and request.path == '/login':
return Response.new_redir('/')
app = get_app()
if request.path.startswith(TOKEN_PATHS[:2]) and request['user'] is None:
if request.path == '/logout':
return Response.new_redir('/')
if request.path in UNAUTH_ROUTES or request.path.startswith('/admin'):
request['token'] = request.cookies.get('user-token')
request['user'] = None
response = Response.new_redir(f'/login?redir={request.path}')
if request['token']:
with app.database.session(False) as conn:
request['user'] = conn.get_user_by_token(request['token'])
if request['token'] is not None:
if request['user'] and request.path == '/login':
return Response.new('', 302, {'Location': '/'})
if not request['user'] and request.path.startswith('/admin'):
response = Response.new('', 302, {'Location': f'/login?redir={request.path}'})
response.del_cookie('user-token')
return response
response = await handler(request)
if not request.path.startswith('/api'):
if request['user'] is None and request['token'] is not None:
if not request.path.startswith('/api') and not request['user'] and request['token']:
response.del_cookie('user-token')
return response
@ -46,15 +54,14 @@ class HomeView(View):
'instances': tuple(conn.get_inboxes())
}
data = self.template.render('page/home.haml', self.request, **context)
data = self.template.render('page/home.haml', self, **context)
return Response.new(data, ctype='html')
@register_route('/login')
class Login(View):
async def get(self, request: web.Request) -> Response:
redir = unquote(request.query.get('redir', '/'))
data = self.template.render('page/login.haml', self.request, redir = redir)
data = self.template.render('page/login.haml', self)
return Response.new(data, ctype = 'html')
@ -62,7 +69,7 @@ class Login(View):
class Logout(View):
async def get(self, request: web.Request) -> Response:
with self.database.session(True) as conn:
conn.del_app(request['token'].client_id, request['token'].client_secret)
conn.del_token(request['token'])
resp = Response.new_redir('/')
resp.del_cookie('user-token', domain = self.config.domain, path = '/')
@ -72,7 +79,7 @@ class Logout(View):
@register_route('/admin')
class Admin(View):
async def get(self, request: web.Request) -> Response:
return Response.new_redir(f'/login?redir={request.path}', 301)
return Response.new('', 302, {'Location': '/admin/instances'})
@register_route('/admin/instances')
@ -94,7 +101,7 @@ class AdminInstances(View):
if message:
context['message'] = message
data = self.template.render('page/admin-instances.haml', self.request, **context)
data = self.template.render('page/admin-instances.haml', self, **context)
return Response.new(data, ctype = 'html')
@ -116,7 +123,7 @@ class AdminWhitelist(View):
if message:
context['message'] = message
data = self.template.render('page/admin-whitelist.haml', self.request, **context)
data = self.template.render('page/admin-whitelist.haml', self, **context)
return Response.new(data, ctype = 'html')
@ -138,7 +145,7 @@ class AdminDomainBans(View):
if message:
context['message'] = message
data = self.template.render('page/admin-domain_bans.haml', self.request, **context)
data = self.template.render('page/admin-domain_bans.haml', self, **context)
return Response.new(data, ctype = 'html')
@ -160,7 +167,7 @@ class AdminSoftwareBans(View):
if message:
context['message'] = message
data = self.template.render('page/admin-software_bans.haml', self.request, **context)
data = self.template.render('page/admin-software_bans.haml', self, **context)
return Response.new(data, ctype = 'html')
@ -182,7 +189,7 @@ class AdminUsers(View):
if message:
context['message'] = message
data = self.template.render('page/admin-users.haml', self.request, **context)
data = self.template.render('page/admin-users.haml', self, **context)
return Response.new(data, ctype = 'html')
@ -206,7 +213,7 @@ class AdminConfig(View):
}
}
data = self.template.render('page/admin-config.haml', self.request, **context)
data = self.template.render('page/admin-config.haml', self, **context)
return Response.new(data, ctype = 'html')
@ -244,5 +251,5 @@ class ThemeCss(View):
except KeyError:
return Response.new('Invalid theme', 404)
data = self.template.render('variables.css', self.request, **context)
data = self.template.render('variables.css', self, **context)
return Response.new(data, ctype = 'css')