mirror of
https://git.pleroma.social/pleroma/relay.git
synced 2024-11-14 11:37:59 +00:00
rework api endpoints
This commit is contained in:
parent
b00daa5a78
commit
a25df0ccc4
142
relay/api_objects.py
Normal file
142
relay/api_objects.py
Normal file
|
@ -0,0 +1,142 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from blib import Date, JsonBase
|
||||
from bsql import Row
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from . import logger as logging
|
||||
from .database import ConfigData
|
||||
from .misc import utf_to_idna
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
from typing import Self
|
||||
|
||||
except ImportError:
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class ApiObject:
|
||||
def __str__(self) -> str:
|
||||
return self.to_json()
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_row(cls: type[Self], row: Row, *exclude: str) -> Self:
|
||||
return cls(**{k: v for k, v in row.items() if k not in exclude})
|
||||
|
||||
|
||||
def to_dict(self, *exclude: str) -> dict[str, Any]:
|
||||
return {k: v for k, v in asdict(self).items() if k not in exclude} # type: ignore[call-overload]
|
||||
|
||||
|
||||
def to_json(self, *exclude: str, indent: int | str | None = None) -> str:
|
||||
data = self.to_dict(*exclude)
|
||||
return JsonBase(data).to_json(indent = indent)
|
||||
|
||||
|
||||
@dataclass(slots = True)
|
||||
class Message(ApiObject):
|
||||
msg: str
|
||||
|
||||
|
||||
@dataclass(slots = True)
|
||||
class Application(ApiObject):
|
||||
client_id: str
|
||||
client_secret: str
|
||||
name: str
|
||||
website: str | None
|
||||
redirect_uri: str
|
||||
token: str | None
|
||||
created: Date
|
||||
updated: Date
|
||||
|
||||
|
||||
@dataclass(slots = True)
|
||||
class Config(ApiObject):
|
||||
approval_required: bool
|
||||
log_level: logging.LogLevel
|
||||
name: str
|
||||
note: str
|
||||
theme: str
|
||||
whitelist_enabled: bool
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_config(cls: type[Self], cfg: ConfigData) -> Self:
|
||||
return cls(
|
||||
cfg.approval_required,
|
||||
cfg.log_level,
|
||||
cfg.name,
|
||||
cfg.note,
|
||||
cfg.theme,
|
||||
cfg.whitelist_enabled
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots = True)
|
||||
class ConfigItem(ApiObject):
|
||||
key: str
|
||||
value: Any
|
||||
type: str
|
||||
|
||||
|
||||
@dataclass(slots = True)
|
||||
class DomainBan(ApiObject):
|
||||
domain: str
|
||||
reason: str | None
|
||||
note: str | None
|
||||
created: Date
|
||||
|
||||
|
||||
@dataclass(slots = True)
|
||||
class Instance(ApiObject):
|
||||
domain: str
|
||||
actor: str
|
||||
inbox: str
|
||||
followid: str
|
||||
software: str
|
||||
accepted: Date
|
||||
created: Date
|
||||
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.domain = utf_to_idna(self.domain)
|
||||
self.actor = utf_to_idna(self.actor)
|
||||
self.inbox = utf_to_idna(self.inbox)
|
||||
self.followid = utf_to_idna(self.followid)
|
||||
|
||||
|
||||
@dataclass(slots = True)
|
||||
class Relay(ApiObject):
|
||||
domain: str
|
||||
name: str
|
||||
description: str
|
||||
version: str
|
||||
whitelist_enabled: bool
|
||||
email: str | None
|
||||
admin: str | None
|
||||
icon: str | None
|
||||
instances: list[str]
|
||||
|
||||
|
||||
@dataclass(slots = True)
|
||||
class SoftwareBan(ApiObject):
|
||||
name: str
|
||||
reason: str | None
|
||||
note: str | None
|
||||
created: Date
|
||||
|
||||
|
||||
@dataclass(slots = True)
|
||||
class User(ApiObject):
|
||||
username: str
|
||||
handle: str | None
|
||||
created: Date
|
||||
|
||||
|
||||
@dataclass(slots = True)
|
||||
class Whitelist(ApiObject):
|
||||
domain: str
|
||||
created: Date
|
|
@ -29,8 +29,7 @@ from .database.schema import Instance
|
|||
from .http_client import HttpClient
|
||||
from .misc import JSON_PATHS, TOKEN_PATHS, Message, Response
|
||||
from .template import Template
|
||||
from .views import ROUTES, VIEWS
|
||||
from .views.api import handle_api_path
|
||||
from .views import ROUTES
|
||||
from .views.frontend import handle_frontend_path
|
||||
from .workers import PushWorkers
|
||||
|
||||
|
@ -59,8 +58,7 @@ class Application(web.Application):
|
|||
web.Application.__init__(self,
|
||||
middlewares = [
|
||||
handle_response_headers, # type: ignore[list-item]
|
||||
handle_frontend_path, # type: ignore[list-item]
|
||||
handle_api_path # type: ignore[list-item]
|
||||
handle_frontend_path # type: ignore[list-item]
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -84,9 +82,6 @@ class Application(web.Application):
|
|||
self.cache.setup()
|
||||
self.on_cleanup.append(handle_cleanup) # type: ignore
|
||||
|
||||
for path, view in VIEWS:
|
||||
self.router.add_view(path, view)
|
||||
|
||||
for method, path, handler in ROUTES:
|
||||
self.router.add_route(method, path, handler)
|
||||
|
||||
|
|
|
@ -116,8 +116,10 @@ class ConfigData:
|
|||
|
||||
@classmethod
|
||||
def FIELD(cls: type[Self], key: str) -> Field[str | int | bool]:
|
||||
parsed_key = key.replace('-', '_')
|
||||
|
||||
for field in fields(cls):
|
||||
if field.name == key.replace('-', '_'):
|
||||
if field.name == parsed_key:
|
||||
return field
|
||||
|
||||
raise KeyError(key)
|
||||
|
|
|
@ -111,20 +111,23 @@ class Connection(SqlConnection):
|
|||
|
||||
def put_config(self, key: str, value: Any) -> Any:
|
||||
field = ConfigData.FIELD(key)
|
||||
key = field.name.replace('_', '-')
|
||||
|
||||
if key == 'private-key':
|
||||
match field.name:
|
||||
case "private_key":
|
||||
self.app.signer = value
|
||||
|
||||
elif key == 'log-level':
|
||||
case "log_level":
|
||||
value = logging.LogLevel.parse(value)
|
||||
logging.set_level(value)
|
||||
self.app['workers'].set_log_level(value)
|
||||
|
||||
elif key in {'approval-required', 'whitelist-enabled'}:
|
||||
case "approval_required":
|
||||
value = convert_to_boolean(value)
|
||||
|
||||
elif key == 'theme':
|
||||
case "whitelist_enabled":
|
||||
value = convert_to_boolean(value)
|
||||
|
||||
case "theme":
|
||||
if value not in THEMES:
|
||||
raise ValueError(f'"{value}" is not a valid theme')
|
||||
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
let a = `; ${document.cookie}`.match(";\\s*user-token=([^;]+)");
|
||||
const token = a ? a[1] : null;
|
||||
|
||||
|
||||
// toast notifications
|
||||
|
||||
const notifications = document.querySelector("#notifications")
|
||||
|
@ -60,6 +64,7 @@ for (const elem of document.querySelectorAll("#menu-open div")) {
|
|||
|
||||
// misc
|
||||
|
||||
|
||||
function get_date_string(date) {
|
||||
var year = date.getUTCFullYear().toString();
|
||||
var month = (date.getUTCMonth() + 1).toString().padStart(2, "0");
|
||||
|
@ -94,9 +99,13 @@ async function request(method, path, body = null) {
|
|||
"Accept": "application/json"
|
||||
}
|
||||
|
||||
if (token !== null) {
|
||||
headers["Authorization"] = `Bearer ${token}`;
|
||||
}
|
||||
|
||||
if (body !== null) {
|
||||
headers["Content-Type"] = "application/json"
|
||||
body = JSON.stringify(body)
|
||||
headers["Content-Type"] = "application/json";
|
||||
body = JSON.stringify(body);
|
||||
}
|
||||
|
||||
const response = await fetch("/api/" + path, {
|
||||
|
|
|
@ -8,7 +8,7 @@ import platform
|
|||
from aiohttp.web import Request, Response as AiohttpResponse
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, TypedDict, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, TypedDict, TypeVar, overload
|
||||
from uuid import uuid4
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -85,6 +85,40 @@ def get_app() -> Application:
|
|||
return Application.DEFAULT
|
||||
|
||||
|
||||
@overload
|
||||
def idna_to_utf(string: str) -> str:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def idna_to_utf(string: None) -> None:
|
||||
...
|
||||
|
||||
|
||||
def idna_to_utf(string: str | None) -> str | None:
|
||||
if string is None:
|
||||
return None
|
||||
|
||||
return string.encode("idna").decode("utf-8")
|
||||
|
||||
|
||||
@overload
|
||||
def utf_to_idna(string: str) -> str:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def utf_to_idna(string: None) -> None:
|
||||
...
|
||||
|
||||
|
||||
def utf_to_idna(string: str | None) -> str | None:
|
||||
if string is None:
|
||||
return None
|
||||
|
||||
return string.encode("utf-8").decode("idna")
|
||||
|
||||
|
||||
class JsonEncoder(json.JSONEncoder):
|
||||
def default(self, o: Any) -> str:
|
||||
if isinstance(o, datetime):
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from . import activitypub, api, frontend, misc
|
||||
from .base import ROUTES, VIEWS, View
|
||||
from .base import ROUTES
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,47 +1,42 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from aiohttp.abc import AbstractView
|
||||
from aiohttp.hdrs import METH_ALL as METHODS
|
||||
from aiohttp.web import Request
|
||||
from aiohttp.web import Request, StreamResponse
|
||||
from blib import HttpError, HttpMethod
|
||||
from bsql import Database
|
||||
from collections.abc import Awaitable, Callable, Generator, Sequence, Mapping
|
||||
from functools import cached_property
|
||||
from collections.abc import Awaitable, Callable, Mapping
|
||||
from json.decoder import JSONDecodeError
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, overload
|
||||
|
||||
from ..cache import Cache
|
||||
from ..config import Config
|
||||
from ..database import Connection
|
||||
from ..http_client import HttpClient
|
||||
from ..api_objects import ApiObject
|
||||
from ..misc import Response, get_app
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Self
|
||||
from ..application import Application
|
||||
from ..template import Template
|
||||
|
||||
try:
|
||||
from typing import Self
|
||||
|
||||
except ImportError:
|
||||
from typing_extensions import Self
|
||||
|
||||
ApiRouteHandler = Callable[..., Awaitable[ApiObject | list[Any] | StreamResponse]]
|
||||
RouteHandler = Callable[[Application, Request], Awaitable[Response]]
|
||||
HandlerCallback = Callable[[Request], Awaitable[Response]]
|
||||
|
||||
|
||||
VIEWS: list[tuple[str, type[View]]] = []
|
||||
ROUTES: list[tuple[str, str, HandlerCallback]] = []
|
||||
|
||||
DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob'
|
||||
ALLOWED_HEADERS: set[str] = {
|
||||
'accept',
|
||||
'authorization',
|
||||
'content-type'
|
||||
}
|
||||
|
||||
|
||||
def convert_data(data: Mapping[str, Any]) -> dict[str, str]:
|
||||
return {key: str(value) for key, value in data.items()}
|
||||
|
||||
|
||||
def register_view(*paths: str) -> Callable[[type[View]], type[View]]:
|
||||
def wrapper(view: type[View]) -> type[View]:
|
||||
for path in paths:
|
||||
VIEWS.append((path, view))
|
||||
|
||||
return view
|
||||
return wrapper
|
||||
|
||||
|
||||
def register_route(
|
||||
method: HttpMethod | str, *paths: str) -> Callable[[RouteHandler], HandlerCallback]:
|
||||
|
||||
|
@ -56,108 +51,107 @@ def register_route(
|
|||
return wrapper
|
||||
|
||||
|
||||
class View(AbstractView):
|
||||
def __await__(self) -> Generator[Any, None, Response]:
|
||||
if self.request.method not in METHODS:
|
||||
raise HttpError(405, f'"{self.request.method}" method not allowed')
|
||||
class Route:
|
||||
handler: ApiRouteHandler
|
||||
|
||||
if not (handler := self.handlers.get(self.request.method)):
|
||||
raise HttpError(405, f'"{self.request.method}" method not allowed')
|
||||
def __init__(self,
|
||||
method: HttpMethod,
|
||||
path: str,
|
||||
category: str,
|
||||
require_token: bool) -> None:
|
||||
|
||||
return self._run_handler(handler).__await__()
|
||||
self.method: HttpMethod = HttpMethod.parse(method)
|
||||
self.path: str = path
|
||||
self.category: str = category
|
||||
self.require_token: bool = require_token
|
||||
|
||||
ROUTES.append((self.method, self.path, self)) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@classmethod
|
||||
async def run(cls: type[Self], method: str, request: Request, **kwargs: Any) -> Response:
|
||||
view = cls(request)
|
||||
return await view.handlers[method](request, **kwargs)
|
||||
@overload
|
||||
def __call__(self, obj: Request) -> Awaitable[StreamResponse]:
|
||||
...
|
||||
|
||||
|
||||
async def _run_handler(self, handler: HandlerCallback, **kwargs: Any) -> Response:
|
||||
return await handler(self.request, **self.request.match_info, **kwargs)
|
||||
@overload
|
||||
def __call__(self, obj: ApiRouteHandler) -> Self:
|
||||
...
|
||||
|
||||
|
||||
async def options(self, request: Request) -> Response:
|
||||
return Response.new()
|
||||
def __call__(self, obj: Request | ApiRouteHandler) -> Self | Awaitable[StreamResponse]:
|
||||
if isinstance(obj, Request):
|
||||
return self.handle_request(obj)
|
||||
|
||||
self.handler = obj
|
||||
return self
|
||||
|
||||
|
||||
@cached_property
|
||||
def allowed_methods(self) -> Sequence[str]:
|
||||
return tuple(self.handlers.keys())
|
||||
async def handle_request(self, request: Request) -> StreamResponse:
|
||||
request["application"] = None
|
||||
|
||||
if request.method != "OPTIONS" and self.require_token:
|
||||
if (auth := request.headers.getone("Authorization", None)) is None:
|
||||
raise HttpError(401, 'Missing token')
|
||||
|
||||
@cached_property
|
||||
def handlers(self) -> dict[str, HandlerCallback]:
|
||||
data = {}
|
||||
|
||||
for method in METHODS:
|
||||
try:
|
||||
data[method] = getattr(self, method.lower())
|
||||
authtype, code = auth.split(" ", 1)
|
||||
|
||||
except AttributeError:
|
||||
continue
|
||||
except IndexError:
|
||||
raise HttpError(401, "Invalid authorization heder format")
|
||||
|
||||
return data
|
||||
if authtype != "Bearer":
|
||||
raise HttpError(401, f"Invalid authorization type: {authtype}")
|
||||
|
||||
if not code:
|
||||
raise HttpError(401, "Missing token")
|
||||
|
||||
@property
|
||||
def app(self) -> Application:
|
||||
return get_app()
|
||||
with get_app().database.session(False) as s:
|
||||
if (application := s.get_app_by_token(code)) is None:
|
||||
raise HttpError(401, "Invalid token")
|
||||
|
||||
if application.auth_code is not None:
|
||||
raise HttpError(401, "Invalid token")
|
||||
|
||||
@property
|
||||
def cache(self) -> Cache:
|
||||
return self.app.cache
|
||||
request["application"] = application
|
||||
|
||||
if request.content_type in {'application/x-www-form-urlencoded', 'multipart/form-data'}:
|
||||
post_data = {key: value for key, value in (await request.post()).items()}
|
||||
|
||||
@property
|
||||
def client(self) -> HttpClient:
|
||||
return self.app.client
|
||||
|
||||
|
||||
@property
|
||||
def config(self) -> Config:
|
||||
return self.app.config
|
||||
|
||||
|
||||
@property
|
||||
def database(self) -> Database[Connection]:
|
||||
return self.app.database
|
||||
|
||||
|
||||
@property
|
||||
def template(self) -> Template:
|
||||
return self.app['template'] # type: ignore[no-any-return]
|
||||
|
||||
|
||||
async def get_api_data(self,
|
||||
required: list[str],
|
||||
optional: list[str]) -> dict[str, str]:
|
||||
|
||||
if self.request.content_type in {'application/x-www-form-urlencoded', 'multipart/form-data'}:
|
||||
post_data = convert_data(await self.request.post())
|
||||
# post_data = {key: value for key, value in parse_qsl(await self.request.text())}
|
||||
|
||||
elif self.request.content_type == 'application/json':
|
||||
elif request.content_type == 'application/json':
|
||||
try:
|
||||
post_data = convert_data(await self.request.json())
|
||||
post_data = await request.json()
|
||||
|
||||
except JSONDecodeError:
|
||||
raise HttpError(400, 'Invalid JSON data')
|
||||
|
||||
else:
|
||||
post_data = convert_data(self.request.query)
|
||||
|
||||
data = {}
|
||||
post_data = {key: str(value) for key, value in request.query.items()}
|
||||
|
||||
try:
|
||||
for key in required:
|
||||
data[key] = post_data[key]
|
||||
response = await self.handler(get_app(), request, **post_data)
|
||||
|
||||
except KeyError as e:
|
||||
raise HttpError(400, f'Missing {str(e)} pararmeter') from None
|
||||
except HttpError as error:
|
||||
return Response.new({'error': error.message}, error.status, ctype = "json")
|
||||
|
||||
for key in optional:
|
||||
data[key] = post_data.get(key) # type: ignore[assignment]
|
||||
headers = {
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Headers": ", ".join(ALLOWED_HEADERS)
|
||||
}
|
||||
|
||||
return data
|
||||
if isinstance(response, StreamResponse):
|
||||
response.headers.update(headers)
|
||||
return response
|
||||
|
||||
if isinstance(response, ApiObject):
|
||||
return Response.new(response.to_json(), headers = headers, ctype = "json")
|
||||
|
||||
if isinstance(response, list):
|
||||
data = []
|
||||
|
||||
for item in response:
|
||||
if isinstance(item, ApiObject):
|
||||
data.append(item.to_dict())
|
||||
|
||||
response = data
|
||||
|
||||
return Response.new(response, headers = headers, ctype = "json")
|
||||
|
|
Loading…
Reference in a new issue