From bd4790212e5b5e120f3a0323c34182961fc35caa Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Sat, 10 Feb 2024 16:14:44 -0500 Subject: [PATCH] changes to api and update tinysql * rename `View.get_data` to `View.get_api_data` * normally aquire database connection on request * rename `/api/v1/inbox` to `/api/v1/instance` * rework `POST /api/v1/instance` * add `Connection.transaction` calls * add `PATCH /api/v1/instance/{domain}` --- relay/views/api.py | 132 ++++++++++++++++++++++++++++++++------------ relay/views/base.py | 7 +-- requirements.txt | 2 +- 3 files changed, 101 insertions(+), 40 deletions(-) diff --git a/relay/views/api.py b/relay/views/api.py index f75caf4..2fcbc6c 100644 --- a/relay/views/api.py +++ b/relay/views/api.py @@ -5,12 +5,14 @@ import typing from aiohttp import web from argon2.exceptions import VerifyMismatchError from datetime import datetime, timezone +from urllib.parse import urlparse from .base import View, register_route from .. import __version__ +from .. import logger as logging from ..database.config import CONFIG_DEFAULTS -from ..misc import Response +from ..misc import Message, Response if typing.TYPE_CHECKING: from aiohttp.web import Request @@ -27,14 +29,14 @@ CONFIG_VALID = {key for key in CONFIG_DEFAULTS if key not in CONFIG_IGNORE} PUBLIC_API_PATHS: tuple[tuple[str, str]] = ( ('GET', '/api/v1/relay'), + ('GET', '/api/v1/instance'), ('POST', '/api/v1/token') ) def check_api_path(method: str, path: str) -> bool: - for m, p in PUBLIC_API_PATHS: - if m == method and p == path: - return False + if (method, path) in PUBLIC_API_PATHS: + return False return path.startswith('/api') @@ -70,7 +72,7 @@ class Login(View): async def post(self, request: Request, conn: Connection) -> Response: - data = await self.get_data(['username', 'password'], []) + data = await self.get_api_data(['username', 'password'], []) if isinstance(data, Response): return data @@ -84,12 +86,16 @@ class Login(View): except VerifyMismatchError: return Response.new_error(401, 'Invalid password', 'json') - token = conn.put_token(data['username']) + with conn.transaction(): + token = conn.put_token(data['username']) + return Response.new({'token': token['code']}, ctype = 'json') async def delete(self, request: Request, conn: Connection) -> Response: - conn.del_token(request['token']) + with conn.transaction(): + conn.del_token(request['token']) + return Response.new({'message': 'Token revoked'}, ctype = 'json') @@ -127,7 +133,7 @@ class Config(View): async def post(self, request: Request, conn: Connection) -> Response: - data = await self.get_data(['key', 'value'], []) + data = await self.get_api_data(['key', 'value'], []) if isinstance(data, Response): return data @@ -135,12 +141,14 @@ class Config(View): if data['key'] not in CONFIG_VALID: return Response.new_error(400, 'Invalid key', 'json') - conn.put_config(data['key'], data['value']) + with conn.transaction(): + conn.put_config(data['key'], data['value']) + return Response.new({'message': 'Updated config'}, ctype = 'json') async def delete(self, request: Request, conn: Connection) -> Response: - data = await self.get_data(['key'], []) + data = await self.get_api_data(['key'], []) if isinstance(data, Response): return data @@ -148,11 +156,13 @@ class Config(View): if data['key'] not in CONFIG_VALID: return Response.new_error(400, 'Invalid key', 'json') - conn.put_config(data['key'], CONFIG_DEFAULTS[data['key']][1]) + with conn.transaction(): + conn.put_config(data['key'], CONFIG_DEFAULTS[data['key']][1]) + return Response.new({'message': 'Updated config'}, ctype = 'json') -@register_route('/api/v1/inbox') +@register_route('/api/v1/instance') class Inbox(View): async def get(self, request: Request, conn: Connection) -> Response: data = [] @@ -171,34 +181,72 @@ class Inbox(View): async def post(self, request: Request, conn: Connection) -> Response: - data = await self.get_data(['domain', 'inbox', 'actor'], ['software']) + data = await self.get_api_data(['actor'], ['inbox', 'software', 'followid']) if isinstance(data, Response): return data + data['domain'] = urlparse(data["actor"]).netloc + if conn.get_inbox(data['domain']): - return Response.new_error(404, 'Inbox already in database', 'json') + return Response.new_error(404, 'Instance already in database', 'json') - row = conn.put_inbox(**data) - return Response.new(row.to_json(), ctype = 'json') + if not data.get('inbox'): + try: + actor_data = await self.client.get( + data['actor'], + sign_headers = True, + loads = Message.parse + ) + + data['inbox'] = actor_data.shared_inbox + + except Exception as e: + logging.error('Failed to fetch actor: %s', str(e)) + return Response.new_error(500, 'Failed to fetch actor', 'json') + + with conn.transaction(): + row = conn.put_inbox(**data) + + return Response.new(row, ctype = 'json') -@register_route('/api/v1/inbox/{domain}') +@register_route('/api/v1/instance/{domain}') class InboxSingle(View): async def get(self, request: Request, conn: Connection, domain: str) -> Response: if not (row := conn.get_inbox(domain)): - return Response.new_error(404, 'Inbox with domain not found', 'json') + return Response.new_error(404, 'Instance with domain not found', 'json') row['created'] = datetime.fromtimestamp(row['created'], tz = timezone.utc).isoformat() return Response.new(row, ctype = 'json') + async def patch(self, request: Request, conn: Connection, domain: str) -> Response: + if not conn.get_inbox(domain): + return Response.new_error(404, 'Instance with domain not found', 'json') + + data = await self.get_api_data([], ['actor', 'software', 'followid']) + + if isinstance(data, Response): + return data + + if not (instance := conn.get_inbox(domain)): + return Response.new_error(404, 'Instance with domain not found', 'json') + + with conn.transaction(): + instance = conn.update_inbox(instance['inbox'], **data) + + return Response.new(instance, ctype = 'json') + + async def delete(self, request: Request, conn: Connection, domain: str) -> Response: if not conn.get_inbox(domain): - return Response.new_error(404, 'Inbox with domain not found', 'json') + return Response.new_error(404, 'Instance with domain not found', 'json') - conn.del_inbox(domain) - return Response.new({'message': 'Deleted inbox'}, ctype = 'json') + with conn.transaction(): + conn.del_inbox(domain) + + return Response.new({'message': 'Deleted instance'}, ctype = 'json') @register_route('/api/v1/domain_ban') @@ -209,7 +257,7 @@ class DomainBan(View): async def post(self, request: Request, conn: Connection) -> Response: - data = await self.get_data(['domain'], ['note', 'reason']) + data = await self.get_api_data(['domain'], ['note', 'reason']) if isinstance(data, Response): return data @@ -217,7 +265,9 @@ class DomainBan(View): if conn.get_domain_ban(data['domain']): return Response.new_error(400, 'Domain already banned', 'json') - ban = conn.put_domain_ban(**data) + with conn.transaction(): + ban = conn.put_domain_ban(**data) + return Response.new(ban, ctype = 'json') @@ -234,12 +284,14 @@ class DomainBanSingle(View): if not conn.get_domain_ban(domain): return Response.new_error(404, 'Domain not banned', 'json') - data = await self.get_data(['domain'], ['note', 'reason']) + data = await self.get_api_data(['domain'], ['note', 'reason']) if isinstance(data, Response): return data - ban = conn.update_domain_ban(**data) + with conn.transaction(): + ban = conn.update_domain_ban(**data) + return Response.new(ban, ctype = 'json') @@ -247,7 +299,9 @@ class DomainBanSingle(View): if not conn.get_domain_ban(domain): return Response.new_error(404, 'Domain not banned', 'json') - conn.del_domain_ban(domain) + with conn.transaction(): + conn.del_domain_ban(domain) + return Response.new({'message': 'Unbanned domain'}, ctype = 'json') @@ -259,7 +313,7 @@ class SoftwareBan(View): async def post(self, request: Request, conn: Connection) -> Response: - data = await self.get_data(['name'], ['note', 'reason']) + data = await self.get_api_data(['name'], ['note', 'reason']) if isinstance(data, Response): return data @@ -267,7 +321,9 @@ class SoftwareBan(View): if conn.get_software_ban(data['name']): return Response.new_error(400, 'Domain already banned', 'json') - ban = conn.put_software_ban(**data) + with conn.transaction(): + ban = conn.put_software_ban(**data) + return Response.new(ban, ctype = 'json') @@ -284,12 +340,14 @@ class SoftwareBanSingle(View): if not conn.get_software_ban(name): return Response.new_error(404, 'Software not banned', 'json') - data = await self.get_data(['name'], ['note', 'reason']) + data = await self.get_api_data(['name'], ['note', 'reason']) if isinstance(data, Response): return data - ban = conn.update_software_ban(**data) + with conn.transaction(): + ban = conn.update_software_ban(**data) + return Response.new(ban, ctype = 'json') @@ -297,7 +355,9 @@ class SoftwareBanSingle(View): if not conn.get_software_ban(domain): return Response.new_error(404, 'Software not banned', 'json') - conn.del_software_ban(domain) + with conn.transaction(): + conn.del_software_ban(domain) + return Response.new({'message': 'Unbanned software'}, ctype = 'json') @@ -309,7 +369,7 @@ class Whitelist(View): async def post(self, request: Request, conn: Connection) -> Response: - data = await self.get_data(['domain']) + data = await self.get_api_data(['domain']) if isinstance(data, Response): return data @@ -317,7 +377,9 @@ class Whitelist(View): if conn.get_domain_whitelist(data['domain']): return Response.new_error(400, 'Domain already added to whitelist', 'json') - item = conn.put_domain_whitelist(**data) + with conn.transaction(): + item = conn.put_domain_whitelist(**data) + return Response.new(item, ctype = 'json') @@ -334,5 +396,7 @@ class WhitelistSingle(View): if not conn.get_domain_whitelist(domain): return Response.new_error(404, 'Domain not in whitelist', 'json') - conn.del_domain_whitelist(domain) + with conn.transaction(): + conn.del_domain_whitelist(domain) + return Response.new({'message': 'Removed domain from whitelist'}, ctype = 'json') diff --git a/relay/views/base.py b/relay/views/base.py index 53d6fd5..093fcb7 100644 --- a/relay/views/base.py +++ b/relay/views/base.py @@ -43,10 +43,7 @@ class View(AbstractView): async def _run_handler(self, handler: Coroutine) -> Response: - with self.database.config.connection_class(self.database) as conn: - # todo: remove on next tinysql release - conn.open() - + with self.database.connection(False) as conn: return await handler(self.request, conn, **self.request.match_info) @@ -95,7 +92,7 @@ class View(AbstractView): return self.app.database - async def get_data(self, required: list[str], optional: list[str]) -> dict[str, str] | Response: + async def get_api_data(self, required: list[str], optional: list[str]) -> dict[str, str] | Response: if self.request.content_type in {'x-www-form-urlencoded', 'multipart/form-data'}: post_data = await self.request.post() diff --git a/requirements.txt b/requirements.txt index e7c73fa..16bd9dc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,6 @@ gunicorn==21.1.0 hiredis==2.3.2 pyyaml>=6.0 redis==5.0.1 -tinysql[postgres]@https://git.barkshark.xyz/barkshark/tinysql/archive/0.2.4.tar.gz +tinysql[postgres]@https://git.barkshark.xyz/barkshark/tinysql/archive/f8db814084dded0a46bd3a9576e09fca860f2166.tar.gz importlib_resources==6.1.1;python_version<'3.9'