diff --git a/relay/misc.py b/relay/misc.py index face943..25c0a1e 100644 --- a/relay/misc.py +++ b/relay/misc.py @@ -7,6 +7,7 @@ import typing from aiohttp.web import Response as AiohttpResponse from aputils.message import Message as ApMessage +from datetime import datetime from uuid import uuid4 if typing.TYPE_CHECKING: @@ -74,6 +75,14 @@ def get_app() -> Application: return Application.DEFAULT +class JsonEncoder(json.JSONEncoder): + def default(self, obj: Any) -> str: + if isinstance(obj, datetime): + return obj.isoformat() + + return JSONEncoder.default(self, obj) + + class Message(ApMessage): @classmethod def new_actor(cls: type[Message], # pylint: disable=arguments-differ @@ -193,8 +202,8 @@ class Response(AiohttpResponse): if isinstance(body, bytes): kwargs['body'] = body - elif isinstance(body, (dict, list, tuple, set)) and ctype in {'json', 'activity'}: - kwargs['text'] = json.dumps(body) + elif isinstance(body, (dict, list, tuple, set)) or ctype in {'json', 'activity'}: + kwargs['text'] = json.dumps(body, cls = JsonEncoder) else: kwargs['text'] = body @@ -209,7 +218,7 @@ class Response(AiohttpResponse): ctype: str = 'text') -> Response: if ctype == 'json': - body = json.dumps({'error': body}) + body = {'error': body} return cls.new(body=body, status=status, ctype=ctype) diff --git a/relay/views/api.py b/relay/views/api.py index 56a5a25..07a5c9a 100644 --- a/relay/views/api.py +++ b/relay/views/api.py @@ -166,12 +166,8 @@ class Config(View): @register_route('/api/v1/instance') class Inbox(View): async def get(self, request: Request) -> Response: - data = [] - with self.database.session() as conn: - for inbox in conn.execute('SELECT * FROM inboxes'): - inbox['created'] = inbox['created'].isoformat() - data.append(inbox) + data = tuple(conn.execute('SELECT * FROM inboxes').all()) return Response.new(data, ctype = 'json') @@ -241,7 +237,7 @@ class Inbox(View): class DomainBan(View): async def get(self, request: Request) -> Response: with self.database.session() as conn: - bans = conn.execute('SELECT * FROM domain_bans').all() + bans = tuple(conn.execute('SELECT * FROM domain_bans').all()) return Response.new(bans, ctype = 'json') @@ -298,7 +294,7 @@ class DomainBan(View): class SoftwareBan(View): async def get(self, request: Request) -> Response: with self.database.session() as conn: - bans = conn.execute('SELECT * FROM software_bans').all() + bans = tuple(conn.execute('SELECT * FROM software_bans').all()) return Response.new(bans, ctype = 'json') @@ -355,7 +351,7 @@ class SoftwareBan(View): class Whitelist(View): async def get(self, request: Request) -> Response: with self.database.session() as conn: - items = conn.execute('SELECT * FROM whitelist').all() + items = tuple(conn.execute('SELECT * FROM whitelist').all()) return Response.new(items, ctype = 'json')