rework api endpoints

This commit is contained in:
Izalia Mae 2024-10-13 17:30:12 -04:00
parent b00daa5a78
commit a25df0ccc4
9 changed files with 783 additions and 632 deletions

142
relay/api_objects.py Normal file
View file

@ -0,0 +1,142 @@
from __future__ import annotations
from blib import Date, JsonBase
from bsql import Row
from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING, Any
from . import logger as logging
from .database import ConfigData
from .misc import utf_to_idna
if TYPE_CHECKING:
try:
from typing import Self
except ImportError:
from typing_extensions import Self
class ApiObject:
def __str__(self) -> str:
return self.to_json()
@classmethod
def from_row(cls: type[Self], row: Row, *exclude: str) -> Self:
return cls(**{k: v for k, v in row.items() if k not in exclude})
def to_dict(self, *exclude: str) -> dict[str, Any]:
return {k: v for k, v in asdict(self).items() if k not in exclude} # type: ignore[call-overload]
def to_json(self, *exclude: str, indent: int | str | None = None) -> str:
data = self.to_dict(*exclude)
return JsonBase(data).to_json(indent = indent)
@dataclass(slots = True)
class Message(ApiObject):
msg: str
@dataclass(slots = True)
class Application(ApiObject):
client_id: str
client_secret: str
name: str
website: str | None
redirect_uri: str
token: str | None
created: Date
updated: Date
@dataclass(slots = True)
class Config(ApiObject):
approval_required: bool
log_level: logging.LogLevel
name: str
note: str
theme: str
whitelist_enabled: bool
@classmethod
def from_config(cls: type[Self], cfg: ConfigData) -> Self:
return cls(
cfg.approval_required,
cfg.log_level,
cfg.name,
cfg.note,
cfg.theme,
cfg.whitelist_enabled
)
@dataclass(slots = True)
class ConfigItem(ApiObject):
key: str
value: Any
type: str
@dataclass(slots = True)
class DomainBan(ApiObject):
domain: str
reason: str | None
note: str | None
created: Date
@dataclass(slots = True)
class Instance(ApiObject):
domain: str
actor: str
inbox: str
followid: str
software: str
accepted: Date
created: Date
def __post_init__(self) -> None:
self.domain = utf_to_idna(self.domain)
self.actor = utf_to_idna(self.actor)
self.inbox = utf_to_idna(self.inbox)
self.followid = utf_to_idna(self.followid)
@dataclass(slots = True)
class Relay(ApiObject):
domain: str
name: str
description: str
version: str
whitelist_enabled: bool
email: str | None
admin: str | None
icon: str | None
instances: list[str]
@dataclass(slots = True)
class SoftwareBan(ApiObject):
name: str
reason: str | None
note: str | None
created: Date
@dataclass(slots = True)
class User(ApiObject):
username: str
handle: str | None
created: Date
@dataclass(slots = True)
class Whitelist(ApiObject):
domain: str
created: Date

View file

@ -29,8 +29,7 @@ from .database.schema import Instance
from .http_client import HttpClient
from .misc import JSON_PATHS, TOKEN_PATHS, Message, Response
from .template import Template
from .views import ROUTES, VIEWS
from .views.api import handle_api_path
from .views import ROUTES
from .views.frontend import handle_frontend_path
from .workers import PushWorkers
@ -59,8 +58,7 @@ class Application(web.Application):
web.Application.__init__(self,
middlewares = [
handle_response_headers, # type: ignore[list-item]
handle_frontend_path, # type: ignore[list-item]
handle_api_path # type: ignore[list-item]
handle_frontend_path # type: ignore[list-item]
]
)
@ -84,9 +82,6 @@ class Application(web.Application):
self.cache.setup()
self.on_cleanup.append(handle_cleanup) # type: ignore
for path, view in VIEWS:
self.router.add_view(path, view)
for method, path, handler in ROUTES:
self.router.add_route(method, path, handler)

View file

@ -116,8 +116,10 @@ class ConfigData:
@classmethod
def FIELD(cls: type[Self], key: str) -> Field[str | int | bool]:
parsed_key = key.replace('-', '_')
for field in fields(cls):
if field.name == key.replace('-', '_'):
if field.name == parsed_key:
return field
raise KeyError(key)

View file

@ -111,20 +111,23 @@ class Connection(SqlConnection):
def put_config(self, key: str, value: Any) -> Any:
field = ConfigData.FIELD(key)
key = field.name.replace('_', '-')
if key == 'private-key':
match field.name:
case "private_key":
self.app.signer = value
elif key == 'log-level':
case "log_level":
value = logging.LogLevel.parse(value)
logging.set_level(value)
self.app['workers'].set_log_level(value)
elif key in {'approval-required', 'whitelist-enabled'}:
case "approval_required":
value = convert_to_boolean(value)
elif key == 'theme':
case "whitelist_enabled":
value = convert_to_boolean(value)
case "theme":
if value not in THEMES:
raise ValueError(f'"{value}" is not a valid theme')

View file

@ -1,3 +1,7 @@
let a = `; ${document.cookie}`.match(";\\s*user-token=([^;]+)");
const token = a ? a[1] : null;
// toast notifications
const notifications = document.querySelector("#notifications")
@ -60,6 +64,7 @@ for (const elem of document.querySelectorAll("#menu-open div")) {
// misc
function get_date_string(date) {
var year = date.getUTCFullYear().toString();
var month = (date.getUTCMonth() + 1).toString().padStart(2, "0");
@ -94,9 +99,13 @@ async function request(method, path, body = null) {
"Accept": "application/json"
}
if (token !== null) {
headers["Authorization"] = `Bearer ${token}`;
}
if (body !== null) {
headers["Content-Type"] = "application/json"
body = JSON.stringify(body)
headers["Content-Type"] = "application/json";
body = JSON.stringify(body);
}
const response = await fetch("/api/" + path, {

View file

@ -8,7 +8,7 @@ import platform
from aiohttp.web import Request, Response as AiohttpResponse
from collections.abc import Sequence
from datetime import datetime
from typing import TYPE_CHECKING, Any, TypedDict, TypeVar
from typing import TYPE_CHECKING, Any, TypedDict, TypeVar, overload
from uuid import uuid4
if TYPE_CHECKING:
@ -85,6 +85,40 @@ def get_app() -> Application:
return Application.DEFAULT
@overload
def idna_to_utf(string: str) -> str:
...
@overload
def idna_to_utf(string: None) -> None:
...
def idna_to_utf(string: str | None) -> str | None:
if string is None:
return None
return string.encode("idna").decode("utf-8")
@overload
def utf_to_idna(string: str) -> str:
...
@overload
def utf_to_idna(string: None) -> None:
...
def utf_to_idna(string: str | None) -> str | None:
if string is None:
return None
return string.encode("utf-8").decode("idna")
class JsonEncoder(json.JSONEncoder):
def default(self, o: Any) -> str:
if isinstance(o, datetime):

View file

@ -1,4 +1,4 @@
from __future__ import annotations
from . import activitypub, api, frontend, misc
from .base import ROUTES, VIEWS, View
from .base import ROUTES

File diff suppressed because it is too large Load diff

View file

@ -1,47 +1,42 @@
from __future__ import annotations
from aiohttp.abc import AbstractView
from aiohttp.hdrs import METH_ALL as METHODS
from aiohttp.web import Request
from aiohttp.web import Request, StreamResponse
from blib import HttpError, HttpMethod
from bsql import Database
from collections.abc import Awaitable, Callable, Generator, Sequence, Mapping
from functools import cached_property
from collections.abc import Awaitable, Callable, Mapping
from json.decoder import JSONDecodeError
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, overload
from ..cache import Cache
from ..config import Config
from ..database import Connection
from ..http_client import HttpClient
from ..api_objects import ApiObject
from ..misc import Response, get_app
if TYPE_CHECKING:
from typing import Self
from ..application import Application
from ..template import Template
try:
from typing import Self
except ImportError:
from typing_extensions import Self
ApiRouteHandler = Callable[..., Awaitable[ApiObject | list[Any] | StreamResponse]]
RouteHandler = Callable[[Application, Request], Awaitable[Response]]
HandlerCallback = Callable[[Request], Awaitable[Response]]
VIEWS: list[tuple[str, type[View]]] = []
ROUTES: list[tuple[str, str, HandlerCallback]] = []
DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob'
ALLOWED_HEADERS: set[str] = {
'accept',
'authorization',
'content-type'
}
def convert_data(data: Mapping[str, Any]) -> dict[str, str]:
return {key: str(value) for key, value in data.items()}
def register_view(*paths: str) -> Callable[[type[View]], type[View]]:
def wrapper(view: type[View]) -> type[View]:
for path in paths:
VIEWS.append((path, view))
return view
return wrapper
def register_route(
method: HttpMethod | str, *paths: str) -> Callable[[RouteHandler], HandlerCallback]:
@ -56,108 +51,107 @@ def register_route(
return wrapper
class View(AbstractView):
def __await__(self) -> Generator[Any, None, Response]:
if self.request.method not in METHODS:
raise HttpError(405, f'"{self.request.method}" method not allowed')
class Route:
handler: ApiRouteHandler
if not (handler := self.handlers.get(self.request.method)):
raise HttpError(405, f'"{self.request.method}" method not allowed')
def __init__(self,
method: HttpMethod,
path: str,
category: str,
require_token: bool) -> None:
return self._run_handler(handler).__await__()
self.method: HttpMethod = HttpMethod.parse(method)
self.path: str = path
self.category: str = category
self.require_token: bool = require_token
ROUTES.append((self.method, self.path, self)) # type: ignore[arg-type]
@classmethod
async def run(cls: type[Self], method: str, request: Request, **kwargs: Any) -> Response:
view = cls(request)
return await view.handlers[method](request, **kwargs)
@overload
def __call__(self, obj: Request) -> Awaitable[StreamResponse]:
...
async def _run_handler(self, handler: HandlerCallback, **kwargs: Any) -> Response:
return await handler(self.request, **self.request.match_info, **kwargs)
@overload
def __call__(self, obj: ApiRouteHandler) -> Self:
...
async def options(self, request: Request) -> Response:
return Response.new()
def __call__(self, obj: Request | ApiRouteHandler) -> Self | Awaitable[StreamResponse]:
if isinstance(obj, Request):
return self.handle_request(obj)
self.handler = obj
return self
@cached_property
def allowed_methods(self) -> Sequence[str]:
return tuple(self.handlers.keys())
async def handle_request(self, request: Request) -> StreamResponse:
request["application"] = None
if request.method != "OPTIONS" and self.require_token:
if (auth := request.headers.getone("Authorization", None)) is None:
raise HttpError(401, 'Missing token')
@cached_property
def handlers(self) -> dict[str, HandlerCallback]:
data = {}
for method in METHODS:
try:
data[method] = getattr(self, method.lower())
authtype, code = auth.split(" ", 1)
except AttributeError:
continue
except IndexError:
raise HttpError(401, "Invalid authorization heder format")
return data
if authtype != "Bearer":
raise HttpError(401, f"Invalid authorization type: {authtype}")
if not code:
raise HttpError(401, "Missing token")
@property
def app(self) -> Application:
return get_app()
with get_app().database.session(False) as s:
if (application := s.get_app_by_token(code)) is None:
raise HttpError(401, "Invalid token")
if application.auth_code is not None:
raise HttpError(401, "Invalid token")
@property
def cache(self) -> Cache:
return self.app.cache
request["application"] = application
if request.content_type in {'application/x-www-form-urlencoded', 'multipart/form-data'}:
post_data = {key: value for key, value in (await request.post()).items()}
@property
def client(self) -> HttpClient:
return self.app.client
@property
def config(self) -> Config:
return self.app.config
@property
def database(self) -> Database[Connection]:
return self.app.database
@property
def template(self) -> Template:
return self.app['template'] # type: ignore[no-any-return]
async def get_api_data(self,
required: list[str],
optional: list[str]) -> dict[str, str]:
if self.request.content_type in {'application/x-www-form-urlencoded', 'multipart/form-data'}:
post_data = convert_data(await self.request.post())
# post_data = {key: value for key, value in parse_qsl(await self.request.text())}
elif self.request.content_type == 'application/json':
elif request.content_type == 'application/json':
try:
post_data = convert_data(await self.request.json())
post_data = await request.json()
except JSONDecodeError:
raise HttpError(400, 'Invalid JSON data')
else:
post_data = convert_data(self.request.query)
data = {}
post_data = {key: str(value) for key, value in request.query.items()}
try:
for key in required:
data[key] = post_data[key]
response = await self.handler(get_app(), request, **post_data)
except KeyError as e:
raise HttpError(400, f'Missing {str(e)} pararmeter') from None
except HttpError as error:
return Response.new({'error': error.message}, error.status, ctype = "json")
for key in optional:
data[key] = post_data.get(key) # type: ignore[assignment]
headers = {
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": ", ".join(ALLOWED_HEADERS)
}
return data
if isinstance(response, StreamResponse):
response.headers.update(headers)
return response
if isinstance(response, ApiObject):
return Response.new(response.to_json(), headers = headers, ctype = "json")
if isinstance(response, list):
data = []
for item in response:
if isinstance(item, ApiObject):
data.append(item.to_dict())
response = data
return Response.new(response, headers = headers, ctype = "json")