changes to api and update tinysql

* rename `View.get_data` to `View.get_api_data`
* normally aquire database connection on request
* rename `/api/v1/inbox` to `/api/v1/instance`
* rework `POST /api/v1/instance`
* add `Connection.transaction` calls
* add `PATCH /api/v1/instance/{domain}`
This commit is contained in:
Izalia Mae 2024-02-10 16:14:44 -05:00
parent d10c864a00
commit bd4790212e
3 changed files with 101 additions and 40 deletions

View file

@ -5,12 +5,14 @@ import typing
from aiohttp import web from aiohttp import web
from argon2.exceptions import VerifyMismatchError from argon2.exceptions import VerifyMismatchError
from datetime import datetime, timezone from datetime import datetime, timezone
from urllib.parse import urlparse
from .base import View, register_route from .base import View, register_route
from .. import __version__ from .. import __version__
from .. import logger as logging
from ..database.config import CONFIG_DEFAULTS from ..database.config import CONFIG_DEFAULTS
from ..misc import Response from ..misc import Message, Response
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from aiohttp.web import Request from aiohttp.web import Request
@ -27,13 +29,13 @@ CONFIG_VALID = {key for key in CONFIG_DEFAULTS if key not in CONFIG_IGNORE}
PUBLIC_API_PATHS: tuple[tuple[str, str]] = ( PUBLIC_API_PATHS: tuple[tuple[str, str]] = (
('GET', '/api/v1/relay'), ('GET', '/api/v1/relay'),
('GET', '/api/v1/instance'),
('POST', '/api/v1/token') ('POST', '/api/v1/token')
) )
def check_api_path(method: str, path: str) -> bool: def check_api_path(method: str, path: str) -> bool:
for m, p in PUBLIC_API_PATHS: if (method, path) in PUBLIC_API_PATHS:
if m == method and p == path:
return False return False
return path.startswith('/api') return path.startswith('/api')
@ -70,7 +72,7 @@ class Login(View):
async def post(self, request: Request, conn: Connection) -> Response: async def post(self, request: Request, conn: Connection) -> Response:
data = await self.get_data(['username', 'password'], []) data = await self.get_api_data(['username', 'password'], [])
if isinstance(data, Response): if isinstance(data, Response):
return data return data
@ -84,12 +86,16 @@ class Login(View):
except VerifyMismatchError: except VerifyMismatchError:
return Response.new_error(401, 'Invalid password', 'json') return Response.new_error(401, 'Invalid password', 'json')
with conn.transaction():
token = conn.put_token(data['username']) token = conn.put_token(data['username'])
return Response.new({'token': token['code']}, ctype = 'json') return Response.new({'token': token['code']}, ctype = 'json')
async def delete(self, request: Request, conn: Connection) -> Response: async def delete(self, request: Request, conn: Connection) -> Response:
with conn.transaction():
conn.del_token(request['token']) conn.del_token(request['token'])
return Response.new({'message': 'Token revoked'}, ctype = 'json') return Response.new({'message': 'Token revoked'}, ctype = 'json')
@ -127,7 +133,7 @@ class Config(View):
async def post(self, request: Request, conn: Connection) -> Response: async def post(self, request: Request, conn: Connection) -> Response:
data = await self.get_data(['key', 'value'], []) data = await self.get_api_data(['key', 'value'], [])
if isinstance(data, Response): if isinstance(data, Response):
return data return data
@ -135,12 +141,14 @@ class Config(View):
if data['key'] not in CONFIG_VALID: if data['key'] not in CONFIG_VALID:
return Response.new_error(400, 'Invalid key', 'json') return Response.new_error(400, 'Invalid key', 'json')
with conn.transaction():
conn.put_config(data['key'], data['value']) conn.put_config(data['key'], data['value'])
return Response.new({'message': 'Updated config'}, ctype = 'json') return Response.new({'message': 'Updated config'}, ctype = 'json')
async def delete(self, request: Request, conn: Connection) -> Response: async def delete(self, request: Request, conn: Connection) -> Response:
data = await self.get_data(['key'], []) data = await self.get_api_data(['key'], [])
if isinstance(data, Response): if isinstance(data, Response):
return data return data
@ -148,11 +156,13 @@ class Config(View):
if data['key'] not in CONFIG_VALID: if data['key'] not in CONFIG_VALID:
return Response.new_error(400, 'Invalid key', 'json') return Response.new_error(400, 'Invalid key', 'json')
with conn.transaction():
conn.put_config(data['key'], CONFIG_DEFAULTS[data['key']][1]) conn.put_config(data['key'], CONFIG_DEFAULTS[data['key']][1])
return Response.new({'message': 'Updated config'}, ctype = 'json') return Response.new({'message': 'Updated config'}, ctype = 'json')
@register_route('/api/v1/inbox') @register_route('/api/v1/instance')
class Inbox(View): class Inbox(View):
async def get(self, request: Request, conn: Connection) -> Response: async def get(self, request: Request, conn: Connection) -> Response:
data = [] data = []
@ -171,34 +181,72 @@ class Inbox(View):
async def post(self, request: Request, conn: Connection) -> Response: async def post(self, request: Request, conn: Connection) -> Response:
data = await self.get_data(['domain', 'inbox', 'actor'], ['software']) data = await self.get_api_data(['actor'], ['inbox', 'software', 'followid'])
if isinstance(data, Response): if isinstance(data, Response):
return data return data
data['domain'] = urlparse(data["actor"]).netloc
if conn.get_inbox(data['domain']): if conn.get_inbox(data['domain']):
return Response.new_error(404, 'Inbox already in database', 'json') return Response.new_error(404, 'Instance already in database', 'json')
if not data.get('inbox'):
try:
actor_data = await self.client.get(
data['actor'],
sign_headers = True,
loads = Message.parse
)
data['inbox'] = actor_data.shared_inbox
except Exception as e:
logging.error('Failed to fetch actor: %s', str(e))
return Response.new_error(500, 'Failed to fetch actor', 'json')
with conn.transaction():
row = conn.put_inbox(**data) row = conn.put_inbox(**data)
return Response.new(row.to_json(), ctype = 'json')
return Response.new(row, ctype = 'json')
@register_route('/api/v1/inbox/{domain}') @register_route('/api/v1/instance/{domain}')
class InboxSingle(View): class InboxSingle(View):
async def get(self, request: Request, conn: Connection, domain: str) -> Response: async def get(self, request: Request, conn: Connection, domain: str) -> Response:
if not (row := conn.get_inbox(domain)): if not (row := conn.get_inbox(domain)):
return Response.new_error(404, 'Inbox with domain not found', 'json') return Response.new_error(404, 'Instance with domain not found', 'json')
row['created'] = datetime.fromtimestamp(row['created'], tz = timezone.utc).isoformat() row['created'] = datetime.fromtimestamp(row['created'], tz = timezone.utc).isoformat()
return Response.new(row, ctype = 'json') return Response.new(row, ctype = 'json')
async def patch(self, request: Request, conn: Connection, domain: str) -> Response:
if not conn.get_inbox(domain):
return Response.new_error(404, 'Instance with domain not found', 'json')
data = await self.get_api_data([], ['actor', 'software', 'followid'])
if isinstance(data, Response):
return data
if not (instance := conn.get_inbox(domain)):
return Response.new_error(404, 'Instance with domain not found', 'json')
with conn.transaction():
instance = conn.update_inbox(instance['inbox'], **data)
return Response.new(instance, ctype = 'json')
async def delete(self, request: Request, conn: Connection, domain: str) -> Response: async def delete(self, request: Request, conn: Connection, domain: str) -> Response:
if not conn.get_inbox(domain): if not conn.get_inbox(domain):
return Response.new_error(404, 'Inbox with domain not found', 'json') return Response.new_error(404, 'Instance with domain not found', 'json')
with conn.transaction():
conn.del_inbox(domain) conn.del_inbox(domain)
return Response.new({'message': 'Deleted inbox'}, ctype = 'json')
return Response.new({'message': 'Deleted instance'}, ctype = 'json')
@register_route('/api/v1/domain_ban') @register_route('/api/v1/domain_ban')
@ -209,7 +257,7 @@ class DomainBan(View):
async def post(self, request: Request, conn: Connection) -> Response: async def post(self, request: Request, conn: Connection) -> Response:
data = await self.get_data(['domain'], ['note', 'reason']) data = await self.get_api_data(['domain'], ['note', 'reason'])
if isinstance(data, Response): if isinstance(data, Response):
return data return data
@ -217,7 +265,9 @@ class DomainBan(View):
if conn.get_domain_ban(data['domain']): if conn.get_domain_ban(data['domain']):
return Response.new_error(400, 'Domain already banned', 'json') return Response.new_error(400, 'Domain already banned', 'json')
with conn.transaction():
ban = conn.put_domain_ban(**data) ban = conn.put_domain_ban(**data)
return Response.new(ban, ctype = 'json') return Response.new(ban, ctype = 'json')
@ -234,12 +284,14 @@ class DomainBanSingle(View):
if not conn.get_domain_ban(domain): if not conn.get_domain_ban(domain):
return Response.new_error(404, 'Domain not banned', 'json') return Response.new_error(404, 'Domain not banned', 'json')
data = await self.get_data(['domain'], ['note', 'reason']) data = await self.get_api_data(['domain'], ['note', 'reason'])
if isinstance(data, Response): if isinstance(data, Response):
return data return data
with conn.transaction():
ban = conn.update_domain_ban(**data) ban = conn.update_domain_ban(**data)
return Response.new(ban, ctype = 'json') return Response.new(ban, ctype = 'json')
@ -247,7 +299,9 @@ class DomainBanSingle(View):
if not conn.get_domain_ban(domain): if not conn.get_domain_ban(domain):
return Response.new_error(404, 'Domain not banned', 'json') return Response.new_error(404, 'Domain not banned', 'json')
with conn.transaction():
conn.del_domain_ban(domain) conn.del_domain_ban(domain)
return Response.new({'message': 'Unbanned domain'}, ctype = 'json') return Response.new({'message': 'Unbanned domain'}, ctype = 'json')
@ -259,7 +313,7 @@ class SoftwareBan(View):
async def post(self, request: Request, conn: Connection) -> Response: async def post(self, request: Request, conn: Connection) -> Response:
data = await self.get_data(['name'], ['note', 'reason']) data = await self.get_api_data(['name'], ['note', 'reason'])
if isinstance(data, Response): if isinstance(data, Response):
return data return data
@ -267,7 +321,9 @@ class SoftwareBan(View):
if conn.get_software_ban(data['name']): if conn.get_software_ban(data['name']):
return Response.new_error(400, 'Domain already banned', 'json') return Response.new_error(400, 'Domain already banned', 'json')
with conn.transaction():
ban = conn.put_software_ban(**data) ban = conn.put_software_ban(**data)
return Response.new(ban, ctype = 'json') return Response.new(ban, ctype = 'json')
@ -284,12 +340,14 @@ class SoftwareBanSingle(View):
if not conn.get_software_ban(name): if not conn.get_software_ban(name):
return Response.new_error(404, 'Software not banned', 'json') return Response.new_error(404, 'Software not banned', 'json')
data = await self.get_data(['name'], ['note', 'reason']) data = await self.get_api_data(['name'], ['note', 'reason'])
if isinstance(data, Response): if isinstance(data, Response):
return data return data
with conn.transaction():
ban = conn.update_software_ban(**data) ban = conn.update_software_ban(**data)
return Response.new(ban, ctype = 'json') return Response.new(ban, ctype = 'json')
@ -297,7 +355,9 @@ class SoftwareBanSingle(View):
if not conn.get_software_ban(domain): if not conn.get_software_ban(domain):
return Response.new_error(404, 'Software not banned', 'json') return Response.new_error(404, 'Software not banned', 'json')
with conn.transaction():
conn.del_software_ban(domain) conn.del_software_ban(domain)
return Response.new({'message': 'Unbanned software'}, ctype = 'json') return Response.new({'message': 'Unbanned software'}, ctype = 'json')
@ -309,7 +369,7 @@ class Whitelist(View):
async def post(self, request: Request, conn: Connection) -> Response: async def post(self, request: Request, conn: Connection) -> Response:
data = await self.get_data(['domain']) data = await self.get_api_data(['domain'])
if isinstance(data, Response): if isinstance(data, Response):
return data return data
@ -317,7 +377,9 @@ class Whitelist(View):
if conn.get_domain_whitelist(data['domain']): if conn.get_domain_whitelist(data['domain']):
return Response.new_error(400, 'Domain already added to whitelist', 'json') return Response.new_error(400, 'Domain already added to whitelist', 'json')
with conn.transaction():
item = conn.put_domain_whitelist(**data) item = conn.put_domain_whitelist(**data)
return Response.new(item, ctype = 'json') return Response.new(item, ctype = 'json')
@ -334,5 +396,7 @@ class WhitelistSingle(View):
if not conn.get_domain_whitelist(domain): if not conn.get_domain_whitelist(domain):
return Response.new_error(404, 'Domain not in whitelist', 'json') return Response.new_error(404, 'Domain not in whitelist', 'json')
with conn.transaction():
conn.del_domain_whitelist(domain) conn.del_domain_whitelist(domain)
return Response.new({'message': 'Removed domain from whitelist'}, ctype = 'json') return Response.new({'message': 'Removed domain from whitelist'}, ctype = 'json')

View file

@ -43,10 +43,7 @@ class View(AbstractView):
async def _run_handler(self, handler: Coroutine) -> Response: async def _run_handler(self, handler: Coroutine) -> Response:
with self.database.config.connection_class(self.database) as conn: with self.database.connection(False) as conn:
# todo: remove on next tinysql release
conn.open()
return await handler(self.request, conn, **self.request.match_info) return await handler(self.request, conn, **self.request.match_info)
@ -95,7 +92,7 @@ class View(AbstractView):
return self.app.database return self.app.database
async def get_data(self, required: list[str], optional: list[str]) -> dict[str, str] | Response: async def get_api_data(self, required: list[str], optional: list[str]) -> dict[str, str] | Response:
if self.request.content_type in {'x-www-form-urlencoded', 'multipart/form-data'}: if self.request.content_type in {'x-www-form-urlencoded', 'multipart/form-data'}:
post_data = await self.request.post() post_data = await self.request.post()

View file

@ -6,6 +6,6 @@ gunicorn==21.1.0
hiredis==2.3.2 hiredis==2.3.2
pyyaml>=6.0 pyyaml>=6.0
redis==5.0.1 redis==5.0.1
tinysql[postgres]@https://git.barkshark.xyz/barkshark/tinysql/archive/0.2.4.tar.gz tinysql[postgres]@https://git.barkshark.xyz/barkshark/tinysql/archive/f8db814084dded0a46bd3a9576e09fca860f2166.tar.gz
importlib_resources==6.1.1;python_version<'3.9' importlib_resources==6.1.1;python_version<'3.9'