rework api endpoints

This commit is contained in:
Izalia Mae 2024-10-13 17:30:12 -04:00
parent b00daa5a78
commit a25df0ccc4
9 changed files with 783 additions and 632 deletions

142
relay/api_objects.py Normal file
View file

@ -0,0 +1,142 @@
from __future__ import annotations
from blib import Date, JsonBase
from bsql import Row
from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING, Any
from . import logger as logging
from .database import ConfigData
from .misc import utf_to_idna
if TYPE_CHECKING:
try:
from typing import Self
except ImportError:
from typing_extensions import Self
class ApiObject:
def __str__(self) -> str:
return self.to_json()
@classmethod
def from_row(cls: type[Self], row: Row, *exclude: str) -> Self:
return cls(**{k: v for k, v in row.items() if k not in exclude})
def to_dict(self, *exclude: str) -> dict[str, Any]:
return {k: v for k, v in asdict(self).items() if k not in exclude} # type: ignore[call-overload]
def to_json(self, *exclude: str, indent: int | str | None = None) -> str:
data = self.to_dict(*exclude)
return JsonBase(data).to_json(indent = indent)
@dataclass(slots = True)
class Message(ApiObject):
msg: str
@dataclass(slots = True)
class Application(ApiObject):
client_id: str
client_secret: str
name: str
website: str | None
redirect_uri: str
token: str | None
created: Date
updated: Date
@dataclass(slots = True)
class Config(ApiObject):
approval_required: bool
log_level: logging.LogLevel
name: str
note: str
theme: str
whitelist_enabled: bool
@classmethod
def from_config(cls: type[Self], cfg: ConfigData) -> Self:
return cls(
cfg.approval_required,
cfg.log_level,
cfg.name,
cfg.note,
cfg.theme,
cfg.whitelist_enabled
)
@dataclass(slots = True)
class ConfigItem(ApiObject):
key: str
value: Any
type: str
@dataclass(slots = True)
class DomainBan(ApiObject):
domain: str
reason: str | None
note: str | None
created: Date
@dataclass(slots = True)
class Instance(ApiObject):
domain: str
actor: str
inbox: str
followid: str
software: str
accepted: Date
created: Date
def __post_init__(self) -> None:
self.domain = utf_to_idna(self.domain)
self.actor = utf_to_idna(self.actor)
self.inbox = utf_to_idna(self.inbox)
self.followid = utf_to_idna(self.followid)
@dataclass(slots = True)
class Relay(ApiObject):
domain: str
name: str
description: str
version: str
whitelist_enabled: bool
email: str | None
admin: str | None
icon: str | None
instances: list[str]
@dataclass(slots = True)
class SoftwareBan(ApiObject):
name: str
reason: str | None
note: str | None
created: Date
@dataclass(slots = True)
class User(ApiObject):
username: str
handle: str | None
created: Date
@dataclass(slots = True)
class Whitelist(ApiObject):
domain: str
created: Date

View file

@ -29,8 +29,7 @@ from .database.schema import Instance
from .http_client import HttpClient from .http_client import HttpClient
from .misc import JSON_PATHS, TOKEN_PATHS, Message, Response from .misc import JSON_PATHS, TOKEN_PATHS, Message, Response
from .template import Template from .template import Template
from .views import ROUTES, VIEWS from .views import ROUTES
from .views.api import handle_api_path
from .views.frontend import handle_frontend_path from .views.frontend import handle_frontend_path
from .workers import PushWorkers from .workers import PushWorkers
@ -59,8 +58,7 @@ class Application(web.Application):
web.Application.__init__(self, web.Application.__init__(self,
middlewares = [ middlewares = [
handle_response_headers, # type: ignore[list-item] handle_response_headers, # type: ignore[list-item]
handle_frontend_path, # type: ignore[list-item] handle_frontend_path # type: ignore[list-item]
handle_api_path # type: ignore[list-item]
] ]
) )
@ -84,9 +82,6 @@ class Application(web.Application):
self.cache.setup() self.cache.setup()
self.on_cleanup.append(handle_cleanup) # type: ignore self.on_cleanup.append(handle_cleanup) # type: ignore
for path, view in VIEWS:
self.router.add_view(path, view)
for method, path, handler in ROUTES: for method, path, handler in ROUTES:
self.router.add_route(method, path, handler) self.router.add_route(method, path, handler)

View file

@ -116,8 +116,10 @@ class ConfigData:
@classmethod @classmethod
def FIELD(cls: type[Self], key: str) -> Field[str | int | bool]: def FIELD(cls: type[Self], key: str) -> Field[str | int | bool]:
parsed_key = key.replace('-', '_')
for field in fields(cls): for field in fields(cls):
if field.name == key.replace('-', '_'): if field.name == parsed_key:
return field return field
raise KeyError(key) raise KeyError(key)

View file

@ -111,20 +111,23 @@ class Connection(SqlConnection):
def put_config(self, key: str, value: Any) -> Any: def put_config(self, key: str, value: Any) -> Any:
field = ConfigData.FIELD(key) field = ConfigData.FIELD(key)
key = field.name.replace('_', '-')
if key == 'private-key': match field.name:
case "private_key":
self.app.signer = value self.app.signer = value
elif key == 'log-level': case "log_level":
value = logging.LogLevel.parse(value) value = logging.LogLevel.parse(value)
logging.set_level(value) logging.set_level(value)
self.app['workers'].set_log_level(value) self.app['workers'].set_log_level(value)
elif key in {'approval-required', 'whitelist-enabled'}: case "approval_required":
value = convert_to_boolean(value) value = convert_to_boolean(value)
elif key == 'theme': case "whitelist_enabled":
value = convert_to_boolean(value)
case "theme":
if value not in THEMES: if value not in THEMES:
raise ValueError(f'"{value}" is not a valid theme') raise ValueError(f'"{value}" is not a valid theme')

View file

@ -1,3 +1,7 @@
let a = `; ${document.cookie}`.match(";\\s*user-token=([^;]+)");
const token = a ? a[1] : null;
// toast notifications // toast notifications
const notifications = document.querySelector("#notifications") const notifications = document.querySelector("#notifications")
@ -60,6 +64,7 @@ for (const elem of document.querySelectorAll("#menu-open div")) {
// misc // misc
function get_date_string(date) { function get_date_string(date) {
var year = date.getUTCFullYear().toString(); var year = date.getUTCFullYear().toString();
var month = (date.getUTCMonth() + 1).toString().padStart(2, "0"); var month = (date.getUTCMonth() + 1).toString().padStart(2, "0");
@ -94,9 +99,13 @@ async function request(method, path, body = null) {
"Accept": "application/json" "Accept": "application/json"
} }
if (token !== null) {
headers["Authorization"] = `Bearer ${token}`;
}
if (body !== null) { if (body !== null) {
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json";
body = JSON.stringify(body) body = JSON.stringify(body);
} }
const response = await fetch("/api/" + path, { const response = await fetch("/api/" + path, {

View file

@ -8,7 +8,7 @@ import platform
from aiohttp.web import Request, Response as AiohttpResponse from aiohttp.web import Request, Response as AiohttpResponse
from collections.abc import Sequence from collections.abc import Sequence
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING, Any, TypedDict, TypeVar from typing import TYPE_CHECKING, Any, TypedDict, TypeVar, overload
from uuid import uuid4 from uuid import uuid4
if TYPE_CHECKING: if TYPE_CHECKING:
@ -85,6 +85,40 @@ def get_app() -> Application:
return Application.DEFAULT return Application.DEFAULT
@overload
def idna_to_utf(string: str) -> str:
...
@overload
def idna_to_utf(string: None) -> None:
...
def idna_to_utf(string: str | None) -> str | None:
if string is None:
return None
return string.encode("idna").decode("utf-8")
@overload
def utf_to_idna(string: str) -> str:
...
@overload
def utf_to_idna(string: None) -> None:
...
def utf_to_idna(string: str | None) -> str | None:
if string is None:
return None
return string.encode("utf-8").decode("idna")
class JsonEncoder(json.JSONEncoder): class JsonEncoder(json.JSONEncoder):
def default(self, o: Any) -> str: def default(self, o: Any) -> str:
if isinstance(o, datetime): if isinstance(o, datetime):

View file

@ -1,4 +1,4 @@
from __future__ import annotations from __future__ import annotations
from . import activitypub, api, frontend, misc from . import activitypub, api, frontend, misc
from .base import ROUTES, VIEWS, View from .base import ROUTES

File diff suppressed because it is too large Load diff

View file

@ -1,47 +1,42 @@
from __future__ import annotations from __future__ import annotations
from aiohttp.abc import AbstractView from aiohttp.web import Request, StreamResponse
from aiohttp.hdrs import METH_ALL as METHODS
from aiohttp.web import Request
from blib import HttpError, HttpMethod from blib import HttpError, HttpMethod
from bsql import Database from collections.abc import Awaitable, Callable, Mapping
from collections.abc import Awaitable, Callable, Generator, Sequence, Mapping
from functools import cached_property
from json.decoder import JSONDecodeError from json.decoder import JSONDecodeError
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, overload
from ..cache import Cache from ..api_objects import ApiObject
from ..config import Config
from ..database import Connection
from ..http_client import HttpClient
from ..misc import Response, get_app from ..misc import Response, get_app
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Self
from ..application import Application from ..application import Application
from ..template import Template
try:
from typing import Self
except ImportError:
from typing_extensions import Self
ApiRouteHandler = Callable[..., Awaitable[ApiObject | list[Any] | StreamResponse]]
RouteHandler = Callable[[Application, Request], Awaitable[Response]] RouteHandler = Callable[[Application, Request], Awaitable[Response]]
HandlerCallback = Callable[[Request], Awaitable[Response]] HandlerCallback = Callable[[Request], Awaitable[Response]]
VIEWS: list[tuple[str, type[View]]] = []
ROUTES: list[tuple[str, str, HandlerCallback]] = [] ROUTES: list[tuple[str, str, HandlerCallback]] = []
DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob'
ALLOWED_HEADERS: set[str] = {
'accept',
'authorization',
'content-type'
}
def convert_data(data: Mapping[str, Any]) -> dict[str, str]: def convert_data(data: Mapping[str, Any]) -> dict[str, str]:
return {key: str(value) for key, value in data.items()} return {key: str(value) for key, value in data.items()}
def register_view(*paths: str) -> Callable[[type[View]], type[View]]:
def wrapper(view: type[View]) -> type[View]:
for path in paths:
VIEWS.append((path, view))
return view
return wrapper
def register_route( def register_route(
method: HttpMethod | str, *paths: str) -> Callable[[RouteHandler], HandlerCallback]: method: HttpMethod | str, *paths: str) -> Callable[[RouteHandler], HandlerCallback]:
@ -56,108 +51,107 @@ def register_route(
return wrapper return wrapper
class View(AbstractView): class Route:
def __await__(self) -> Generator[Any, None, Response]: handler: ApiRouteHandler
if self.request.method not in METHODS:
raise HttpError(405, f'"{self.request.method}" method not allowed')
if not (handler := self.handlers.get(self.request.method)): def __init__(self,
raise HttpError(405, f'"{self.request.method}" method not allowed') method: HttpMethod,
path: str,
category: str,
require_token: bool) -> None:
return self._run_handler(handler).__await__() self.method: HttpMethod = HttpMethod.parse(method)
self.path: str = path
self.category: str = category
self.require_token: bool = require_token
ROUTES.append((self.method, self.path, self)) # type: ignore[arg-type]
@classmethod @overload
async def run(cls: type[Self], method: str, request: Request, **kwargs: Any) -> Response: def __call__(self, obj: Request) -> Awaitable[StreamResponse]:
view = cls(request) ...
return await view.handlers[method](request, **kwargs)
async def _run_handler(self, handler: HandlerCallback, **kwargs: Any) -> Response: @overload
return await handler(self.request, **self.request.match_info, **kwargs) def __call__(self, obj: ApiRouteHandler) -> Self:
...
async def options(self, request: Request) -> Response: def __call__(self, obj: Request | ApiRouteHandler) -> Self | Awaitable[StreamResponse]:
return Response.new() if isinstance(obj, Request):
return self.handle_request(obj)
self.handler = obj
return self
@cached_property async def handle_request(self, request: Request) -> StreamResponse:
def allowed_methods(self) -> Sequence[str]: request["application"] = None
return tuple(self.handlers.keys())
if request.method != "OPTIONS" and self.require_token:
if (auth := request.headers.getone("Authorization", None)) is None:
raise HttpError(401, 'Missing token')
@cached_property
def handlers(self) -> dict[str, HandlerCallback]:
data = {}
for method in METHODS:
try: try:
data[method] = getattr(self, method.lower()) authtype, code = auth.split(" ", 1)
except AttributeError: except IndexError:
continue raise HttpError(401, "Invalid authorization heder format")
return data if authtype != "Bearer":
raise HttpError(401, f"Invalid authorization type: {authtype}")
if not code:
raise HttpError(401, "Missing token")
@property with get_app().database.session(False) as s:
def app(self) -> Application: if (application := s.get_app_by_token(code)) is None:
return get_app() raise HttpError(401, "Invalid token")
if application.auth_code is not None:
raise HttpError(401, "Invalid token")
@property request["application"] = application
def cache(self) -> Cache:
return self.app.cache
if request.content_type in {'application/x-www-form-urlencoded', 'multipart/form-data'}:
post_data = {key: value for key, value in (await request.post()).items()}
@property elif request.content_type == 'application/json':
def client(self) -> HttpClient:
return self.app.client
@property
def config(self) -> Config:
return self.app.config
@property
def database(self) -> Database[Connection]:
return self.app.database
@property
def template(self) -> Template:
return self.app['template'] # type: ignore[no-any-return]
async def get_api_data(self,
required: list[str],
optional: list[str]) -> dict[str, str]:
if self.request.content_type in {'application/x-www-form-urlencoded', 'multipart/form-data'}:
post_data = convert_data(await self.request.post())
# post_data = {key: value for key, value in parse_qsl(await self.request.text())}
elif self.request.content_type == 'application/json':
try: try:
post_data = convert_data(await self.request.json()) post_data = await request.json()
except JSONDecodeError: except JSONDecodeError:
raise HttpError(400, 'Invalid JSON data') raise HttpError(400, 'Invalid JSON data')
else: else:
post_data = convert_data(self.request.query) post_data = {key: str(value) for key, value in request.query.items()}
data = {}
try: try:
for key in required: response = await self.handler(get_app(), request, **post_data)
data[key] = post_data[key]
except KeyError as e: except HttpError as error:
raise HttpError(400, f'Missing {str(e)} pararmeter') from None return Response.new({'error': error.message}, error.status, ctype = "json")
for key in optional: headers = {
data[key] = post_data.get(key) # type: ignore[assignment] "Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": ", ".join(ALLOWED_HEADERS)
}
return data if isinstance(response, StreamResponse):
response.headers.update(headers)
return response
if isinstance(response, ApiObject):
return Response.new(response.to_json(), headers = headers, ctype = "json")
if isinstance(response, list):
data = []
for item in response:
if isinstance(item, ApiObject):
data.append(item.to_dict())
response = data
return Response.new(response, headers = headers, ctype = "json")