From e4bcbdeccb7f1968c78b938e5922f9f99980fc76 Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Wed, 14 Feb 2024 14:17:53 -0500 Subject: [PATCH] don't get a database connection at the start of every request --- relay/processors.py | 10 +- relay/views/activitypub.py | 43 +++---- relay/views/api.py | 249 +++++++++++++++++++------------------ relay/views/base.py | 8 +- relay/views/frontend.py | 7 +- relay/views/misc.py | 24 ++-- 6 files changed, 180 insertions(+), 161 deletions(-) diff --git a/relay/processors.py b/relay/processors.py index d9780d1..7fc3423 100644 --- a/relay/processors.py +++ b/relay/processors.py @@ -170,7 +170,7 @@ processors = { } -async def run_processor(view: ActorView, conn: Connection) -> None: +async def run_processor(view: ActorView) -> None: if view.message.type not in processors: logging.verbose( 'Message type "%s" from actor cannot be handled: %s', @@ -180,8 +180,8 @@ async def run_processor(view: ActorView, conn: Connection) -> None: return - if view.instance: - with conn.transaction(): + with view.database.connection(False) as conn: + if view.instance: if not view.instance['software']: if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])): view.instance = conn.update_inbox( @@ -195,5 +195,5 @@ async def run_processor(view: ActorView, conn: Connection) -> None: actor = view.actor.id ) - logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id) - await processors[view.message.type](view, conn) + logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id) + await processors[view.message.type](view, conn) diff --git a/relay/views/activitypub.py b/relay/views/activitypub.py index be51047..70f759a 100644 --- a/relay/views/activitypub.py +++ b/relay/views/activitypub.py @@ -34,7 +34,7 @@ class ActorView(View): self.signer: Signer = None - async def get(self, request: Request, conn: Connection) -> Response: + async def get(self, request: Request) -> Response: data = Message.new_actor( host = self.config.domain, pubkey = self.app.signer.pubkey @@ -43,35 +43,36 @@ class ActorView(View): return Response.new(data, ctype='activity') - async def post(self, request: Request, conn: Connection) -> Response: + async def post(self, request: Request) -> Response: if response := await self.get_post_data(): return response - self.instance = conn.get_inbox(self.actor.shared_inbox) - config = conn.get_config_all() + with self.database.connection(False) as conn: + self.instance = conn.get_inbox(self.actor.shared_inbox) + config = conn.get_config_all() - ## reject if the actor isn't whitelisted while the whiltelist is enabled - if config['whitelist-enabled'] and not conn.get_domain_whitelist(self.actor.domain): - logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id) - return Response.new_error(403, 'access denied', 'json') + ## reject if the actor isn't whitelisted while the whiltelist is enabled + if config['whitelist-enabled'] and not conn.get_domain_whitelist(self.actor.domain): + logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id) + return Response.new_error(403, 'access denied', 'json') - ## reject if actor is banned - if conn.get_domain_ban(self.actor.domain): - logging.verbose('Ignored request from banned actor: %s', self.actor.id) - return Response.new_error(403, 'access denied', 'json') + ## reject if actor is banned + if conn.get_domain_ban(self.actor.domain): + logging.verbose('Ignored request from banned actor: %s', self.actor.id) + return Response.new_error(403, 'access denied', 'json') - ## reject if activity type isn't 'Follow' and the actor isn't following - if self.message.type != 'Follow' and not self.instance: - logging.verbose( - 'Rejected actor for trying to post while not following: %s', - self.actor.id - ) + ## reject if activity type isn't 'Follow' and the actor isn't following + if self.message.type != 'Follow' and not self.instance: + logging.verbose( + 'Rejected actor for trying to post while not following: %s', + self.actor.id + ) - return Response.new_error(401, 'access denied', 'json') + return Response.new_error(401, 'access denied', 'json') logging.debug('>> payload %s', self.message.to_json(4)) - await run_processor(self, conn) + await run_processor(self) return Response.new(status = 202) @@ -162,7 +163,7 @@ class ActorView(View): @register_route('/.well-known/webfinger') class WebfingerView(View): - async def get(self, request: Request, conn: Connection) -> Response: + async def get(self, request: Request) -> Response: try: subject = request.query['resource'] diff --git a/relay/views/api.py b/relay/views/api.py index 56b5ed7..ade1d3b 100644 --- a/relay/views/api.py +++ b/relay/views/api.py @@ -67,33 +67,33 @@ async def handle_api_path(request: web.Request, handler: Coroutine) -> web.Respo @register_route('/api/v1/token') class Login(View): - async def get(self, request: Request, conn: Connection) -> Response: + async def get(self, request: Request) -> Response: return Response.new({'message': 'Token valid :3'}) - async def post(self, request: Request, conn: Connection) -> Response: + async def post(self, request: Request) -> Response: data = await self.get_api_data(['username', 'password'], []) if isinstance(data, Response): return data - if not (user := conn.get_user(data['username'])): - return Response.new_error(401, 'User not found', 'json') + with self.database.connction(True) as conn: + if not (user := conn.get_user(data['username'])): + return Response.new_error(401, 'User not found', 'json') - try: - conn.hasher.verify(user['hash'], data['password']) + try: + conn.hasher.verify(user['hash'], data['password']) - except VerifyMismatchError: - return Response.new_error(401, 'Invalid password', 'json') + except VerifyMismatchError: + return Response.new_error(401, 'Invalid password', 'json') - with conn.transaction(): token = conn.put_token(data['username']) return Response.new({'token': token['code']}, ctype = 'json') - async def delete(self, request: Request, conn: Connection) -> Response: - with conn.transaction(): + async def delete(self, request: Request) -> Response: + with self.database.connection(True) as conn: conn.del_token(request['token']) return Response.new({'message': 'Token revoked'}, ctype = 'json') @@ -101,9 +101,10 @@ class Login(View): @register_route('/api/v1/relay') class RelayInfo(View): - async def get(self, request: Request, conn: Connection) -> Response: - config = conn.get_config_all() - inboxes = [row['domain'] for row in conn.execute('SELECT * FROM inboxes')] + async def get(self, request: Request) -> Response: + with self.database.connection(False) as conn: + config = conn.get_config_all() + inboxes = [row['domain'] for row in conn.execute('SELECT * FROM inboxes')] data = { 'domain': self.config.domain, @@ -122,9 +123,10 @@ class RelayInfo(View): @register_route('/api/v1/config') class Config(View): - async def get(self, request: Request, conn: Connection) -> Response: - data = conn.get_config_all() - data['log-level'] = data['log-level'].name + async def get(self, request: Request) -> Response: + with self.database.connection(False) as conn: + data = conn.get_config_all() + data['log-level'] = data['log-level'].name for key in CONFIG_IGNORE: del data[key] @@ -132,7 +134,7 @@ class Config(View): return Response.new(data, ctype = 'json') - async def post(self, request: Request, conn: Connection) -> Response: + async def post(self, request: Request) -> Response: data = await self.get_api_data(['key', 'value'], []) if isinstance(data, Response): @@ -141,13 +143,13 @@ class Config(View): if data['key'] not in CONFIG_VALID: return Response.new_error(400, 'Invalid key', 'json') - with conn.transaction(): + with self.database.connection(True) as conn: conn.put_config(data['key'], data['value']) return Response.new({'message': 'Updated config'}, ctype = 'json') - async def delete(self, request: Request, conn: Connection) -> Response: + async def delete(self, request: Request) -> Response: data = await self.get_api_data(['key'], []) if isinstance(data, Response): @@ -156,7 +158,7 @@ class Config(View): if data['key'] not in CONFIG_VALID: return Response.new_error(400, 'Invalid key', 'json') - with conn.transaction(): + with self.database.connection(True) as conn: conn.put_config(data['key'], CONFIG_DEFAULTS[data['key']][1]) return Response.new({'message': 'Updated config'}, ctype = 'json') @@ -164,23 +166,24 @@ class Config(View): @register_route('/api/v1/instance') class Inbox(View): - async def get(self, request: Request, conn: Connection) -> Response: + async def get(self, request: Request) -> Response: data = [] - for inbox in conn.execute('SELECT * FROM inboxes'): - try: - created = datetime.fromtimestamp(inbox['created'], tz = timezone.utc) + with self.database.connection(False) as conn: + for inbox in conn.execute('SELECT * FROM inboxes'): + try: + created = datetime.fromtimestamp(inbox['created'], tz = timezone.utc) - except TypeError: - created = datetime.fromisoformat(inbox['created']) + except TypeError: + created = datetime.fromisoformat(inbox['created']) - inbox['created'] = created.isoformat() - data.append(inbox) + inbox['created'] = created.isoformat() + data.append(inbox) return Response.new(data, ctype = 'json') - async def post(self, request: Request, conn: Connection) -> Response: + async def post(self, request: Request) -> Response: data = await self.get_api_data(['actor'], ['inbox', 'software', 'followid']) if isinstance(data, Response): @@ -188,24 +191,24 @@ class Inbox(View): data['domain'] = urlparse(data["actor"]).netloc - if conn.get_inbox(data['domain']): - return Response.new_error(404, 'Instance already in database', 'json') + with self.database.connection(True) as conn: + if conn.get_inbox(data['domain']): + return Response.new_error(404, 'Instance already in database', 'json') - if not data.get('inbox'): - try: - actor_data = await self.client.get( - data['actor'], - sign_headers = True, - loads = Message.parse - ) + if not data.get('inbox'): + try: + actor_data = await self.client.get( + data['actor'], + sign_headers = True, + loads = Message.parse + ) - data['inbox'] = actor_data.shared_inbox + data['inbox'] = actor_data.shared_inbox - except Exception as e: - logging.error('Failed to fetch actor: %s', str(e)) - return Response.new_error(500, 'Failed to fetch actor', 'json') + except Exception as e: + logging.error('Failed to fetch actor: %s', str(e)) + return Response.new_error(500, 'Failed to fetch actor', 'json') - with conn.transaction(): row = conn.put_inbox(**data) return Response.new(row, ctype = 'json') @@ -213,37 +216,38 @@ class Inbox(View): @register_route('/api/v1/instance/{domain}') class InboxSingle(View): - async def get(self, request: Request, conn: Connection, domain: str) -> Response: - if not (row := conn.get_inbox(domain)): - return Response.new_error(404, 'Instance with domain not found', 'json') + async def get(self, request: Request, domain: str) -> Response: + with self.database.connection(False) as conn: + if not (row := conn.get_inbox(domain)): + return Response.new_error(404, 'Instance with domain not found', 'json') row['created'] = datetime.fromtimestamp(row['created'], tz = timezone.utc).isoformat() return Response.new(row, ctype = 'json') - async def patch(self, request: Request, conn: Connection, domain: str) -> Response: - if not conn.get_inbox(domain): - return Response.new_error(404, 'Instance with domain not found', 'json') + async def patch(self, request: Request, domain: str) -> Response: + with self.database.connection(True) as conn: + if not conn.get_inbox(domain): + return Response.new_error(404, 'Instance with domain not found', 'json') - data = await self.get_api_data([], ['actor', 'software', 'followid']) + data = await self.get_api_data([], ['actor', 'software', 'followid']) - if isinstance(data, Response): - return data + if isinstance(data, Response): + return data - if not (instance := conn.get_inbox(domain)): - return Response.new_error(404, 'Instance with domain not found', 'json') + if not (instance := conn.get_inbox(domain)): + return Response.new_error(404, 'Instance with domain not found', 'json') - with conn.transaction(): instance = conn.update_inbox(instance['inbox'], **data) return Response.new(instance, ctype = 'json') - async def delete(self, request: Request, conn: Connection, domain: str) -> Response: - if not conn.get_inbox(domain): - return Response.new_error(404, 'Instance with domain not found', 'json') + async def delete(self, request: Request, domain: str) -> Response: + with self.database.connection(True) as conn: + if not conn.get_inbox(domain): + return Response.new_error(404, 'Instance with domain not found', 'json') - with conn.transaction(): conn.del_inbox(domain) return Response.new({'message': 'Deleted instance'}, ctype = 'json') @@ -251,21 +255,23 @@ class InboxSingle(View): @register_route('/api/v1/domain_ban') class DomainBan(View): - async def get(self, request: Request, conn: Connection) -> Response: - bans = conn.execute('SELECT * FROM domain_bans').all() + async def get(self, request: Request) -> Response: + with self.database.connection(False) as conn: + bans = conn.execute('SELECT * FROM domain_bans').all() + return Response.new(bans, ctype = 'json') - async def post(self, request: Request, conn: Connection) -> Response: + async def post(self, request: Request) -> Response: data = await self.get_api_data(['domain'], ['note', 'reason']) if isinstance(data, Response): return data - if conn.get_domain_ban(data['domain']): - return Response.new_error(400, 'Domain already banned', 'json') + with self.database.connection(True) as conn: + if conn.get_domain_ban(data['domain']): + return Response.new_error(400, 'Domain already banned', 'json') - with conn.transaction(): ban = conn.put_domain_ban(**data) return Response.new(ban, ctype = 'json') @@ -273,36 +279,37 @@ class DomainBan(View): @register_route('/api/v1/domain_ban/{domain}') class DomainBanSingle(View): - async def get(self, request: Request, conn: Connection, domain: str) -> Response: - if not (ban := conn.get_domain_ban(domain)): - return Response.new_error(404, 'Domain ban not found', 'json') + async def get(self, request: Request, domain: str) -> Response: + with self.database.connection(False) as conn: + if not (ban := conn.get_domain_ban(domain)): + return Response.new_error(404, 'Domain ban not found', 'json') return Response.new(ban, ctype = 'json') - async def patch(self, request: Request, conn: Connection, domain: str) -> Response: - if not conn.get_domain_ban(domain): - return Response.new_error(404, 'Domain not banned', 'json') + async def patch(self, request: Request, domain: str) -> Response: + with self.database.connection(True) as conn: + if not conn.get_domain_ban(domain): + return Response.new_error(404, 'Domain not banned', 'json') - data = await self.get_api_data([], ['note', 'reason']) + data = await self.get_api_data([], ['note', 'reason']) - if isinstance(data, Response): - return data + if isinstance(data, Response): + return data - if not any([data.get('note'), data.get('reason')]): - return Response.new_error(400, 'Must include note and/or reason parameters', 'json') + if not any([data.get('note'), data.get('reason')]): + return Response.new_error(400, 'Must include note and/or reason parameters', 'json') - with conn.transaction(): ban = conn.update_domain_ban(domain, **data) return Response.new(ban, ctype = 'json') - async def delete(self, request: Request, conn: Connection, domain: str) -> Response: - if not conn.get_domain_ban(domain): - return Response.new_error(404, 'Domain not banned', 'json') + async def delete(self, request: Request, domain: str) -> Response: + with self.database.connection(True) as conn: + if not conn.get_domain_ban(domain): + return Response.new_error(404, 'Domain not banned', 'json') - with conn.transaction(): conn.del_domain_ban(domain) return Response.new({'message': 'Unbanned domain'}, ctype = 'json') @@ -310,21 +317,23 @@ class DomainBanSingle(View): @register_route('/api/v1/software_ban') class SoftwareBan(View): - async def get(self, request: Request, conn: Connection) -> Response: - bans = conn.execute('SELECT * FROM software_bans').all() + async def get(self, request: Request) -> Response: + with self.database.connection(False) as conn: + bans = conn.execute('SELECT * FROM software_bans').all() + return Response.new(bans, ctype = 'json') - async def post(self, request: Request, conn: Connection) -> Response: + async def post(self, request: Request) -> Response: data = await self.get_api_data(['name'], ['note', 'reason']) if isinstance(data, Response): return data - if conn.get_software_ban(data['name']): - return Response.new_error(400, 'Domain already banned', 'json') + with self.database.connection(True) as conn: + if conn.get_software_ban(data['name']): + return Response.new_error(400, 'Domain already banned', 'json') - with conn.transaction(): ban = conn.put_software_ban(**data) return Response.new(ban, ctype = 'json') @@ -332,36 +341,37 @@ class SoftwareBan(View): @register_route('/api/v1/software_ban/{name}') class SoftwareBanSingle(View): - async def get(self, request: Request, conn: Connection, name: str) -> Response: - if not (ban := conn.get_software_ban(name)): - return Response.new_error(404, 'Software ban not found', 'json') + async def get(self, request: Request, name: str) -> Response: + with self.database.connection(False) as conn: + if not (ban := conn.get_software_ban(name)): + return Response.new_error(404, 'Software ban not found', 'json') return Response.new(ban, ctype = 'json') - async def patch(self, request: Request, conn: Connection, name: str) -> Response: - if not conn.get_software_ban(name): - return Response.new_error(404, 'Software not banned', 'json') + async def patch(self, request: Request, name: str) -> Response: + with self.database.connection(True) as conn: + if not conn.get_software_ban(name): + return Response.new_error(404, 'Software not banned', 'json') - data = await self.get_api_data([], ['note', 'reason']) + data = await self.get_api_data([], ['note', 'reason']) - if isinstance(data, Response): - return data + if isinstance(data, Response): + return data - if not any([data.get('note'), data.get('reason')]): - return Response.new_error(400, 'Must include note and/or reason parameters', 'json') + if not any([data.get('note'), data.get('reason')]): + return Response.new_error(400, 'Must include note and/or reason parameters', 'json') - with conn.transaction(): ban = conn.update_software_ban(name, **data) return Response.new(ban, ctype = 'json') - async def delete(self, request: Request, conn: Connection, name: str) -> Response: - if not conn.get_software_ban(name): - return Response.new_error(404, 'Software not banned', 'json') + async def delete(self, request: Request, name: str) -> Response: + with self.database.connection(True) as conn: + if not conn.get_software_ban(name): + return Response.new_error(404, 'Software not banned', 'json') - with conn.transaction(): conn.del_software_ban(name) return Response.new({'message': 'Unbanned software'}, ctype = 'json') @@ -369,21 +379,23 @@ class SoftwareBanSingle(View): @register_route('/api/v1/whitelist') class Whitelist(View): - async def get(self, request: Request, conn: Connection) -> Response: - items = conn.execute('SELECT * FROM whitelist').all() + async def get(self, request: Request) -> Response: + with self.database.connection(False) as conn: + items = conn.execute('SELECT * FROM whitelist').all() + return Response.new(items, ctype = 'json') - async def post(self, request: Request, conn: Connection) -> Response: - data = await self.get_api_data(['domain']) + async def post(self, request: Request) -> Response: + data = await self.get_api_data(['domain'], []) if isinstance(data, Response): return data - if conn.get_domain_whitelist(data['domain']): - return Response.new_error(400, 'Domain already added to whitelist', 'json') + with self.database.connection(True) as conn: + if conn.get_domain_whitelist(data['domain']): + return Response.new_error(400, 'Domain already added to whitelist', 'json') - with conn.transaction(): item = conn.put_domain_whitelist(**data) return Response.new(item, ctype = 'json') @@ -391,18 +403,19 @@ class Whitelist(View): @register_route('/api/v1/domain/{domain}') class WhitelistSingle(View): - async def get(self, request: Request, conn: Connection, domain: str) -> Response: - if not (item := conn.get_domain_whitelist(domain)): - return Response.new_error(404, 'Domain not in whitelist', 'json') + async def get(self, request: Request, domain: str) -> Response: + with self.database.connection(False) as conn: + if not (item := conn.get_domain_whitelist(domain)): + return Response.new_error(404, 'Domain not in whitelist', 'json') return Response.new(item, ctype = 'json') - async def delete(self, request: Request, conn: Connection, domain: str) -> Response: - if not conn.get_domain_whitelist(domain): - return Response.new_error(404, 'Domain not in whitelist', 'json') + async def delete(self, request: Request, domain: str) -> Response: + with self.database.connection(False) as conn: + if not conn.get_domain_whitelist(domain): + return Response.new_error(404, 'Domain not in whitelist', 'json') - with conn.transaction(): conn.del_domain_whitelist(domain) return Response.new({'message': 'Removed domain from whitelist'}, ctype = 'json') diff --git a/relay/views/base.py b/relay/views/base.py index 093fcb7..ce72e4b 100644 --- a/relay/views/base.py +++ b/relay/views/base.py @@ -43,8 +43,7 @@ class View(AbstractView): async def _run_handler(self, handler: Coroutine) -> Response: - with self.database.connection(False) as conn: - return await handler(self.request, conn, **self.request.match_info) + return await handler(self.request, **self.request.match_info) @cached_property @@ -92,7 +91,10 @@ class View(AbstractView): return self.app.database - async def get_api_data(self, required: list[str], optional: list[str]) -> dict[str, str] | Response: + async def get_api_data(self, + required: list[str], + optional: list[str]) -> dict[str, str] | Response: + if self.request.content_type in {'x-www-form-urlencoded', 'multipart/form-data'}: post_data = await self.request.post() diff --git a/relay/views/frontend.py b/relay/views/frontend.py index 987b9b0..663edd4 100644 --- a/relay/views/frontend.py +++ b/relay/views/frontend.py @@ -48,9 +48,10 @@ HOME_TEMPLATE = """ @register_route('/') class HomeView(View): - async def get(self, request: Request, conn: Connection) -> Response: - config = conn.get_config_all() - inboxes = conn.execute('SELECT * FROM inboxes').all() + async def get(self, request: Request) -> Response: + with self.database.connection(False) as conn: + config = conn.get_config_all() + inboxes = conn.execute('SELECT * FROM inboxes').all() text = HOME_TEMPLATE.format( host = self.config.domain, diff --git a/relay/views/misc.py b/relay/views/misc.py index e41ae2b..7c2e65c 100644 --- a/relay/views/misc.py +++ b/relay/views/misc.py @@ -33,17 +33,18 @@ if Path(__file__).parent.parent.joinpath('.git').exists(): @register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}') class NodeinfoView(View): # pylint: disable=no-self-use - async def get(self, request: Request, conn: Connection, niversion: str) -> Response: - inboxes = conn.execute('SELECT * FROM inboxes').all() + async def get(self, request: Request, niversion: str) -> Response: + with self.database.connection(False) as conn: + inboxes = conn.execute('SELECT * FROM inboxes').all() - data = { - 'name': 'activityrelay', - 'version': VERSION, - 'protocols': ['activitypub'], - 'open_regs': not conn.get_config('whitelist-enabled'), - 'users': 1, - 'metadata': {'peers': [inbox['domain'] for inbox in inboxes]} - } + data = { + 'name': 'activityrelay', + 'version': VERSION, + 'protocols': ['activitypub'], + 'open_regs': not conn.get_config('whitelist-enabled'), + 'users': 1, + 'metadata': {'peers': [inbox['domain'] for inbox in inboxes]} + } if niversion == '2.1': data['repo'] = 'https://git.pleroma.social/pleroma/relay' @@ -53,6 +54,7 @@ class NodeinfoView(View): @register_route('/.well-known/nodeinfo') class WellknownNodeinfoView(View): - async def get(self, request: Request, conn: Connection) -> Response: + async def get(self, request: Request) -> Response: data = WellKnownNodeinfo.new_template(self.config.domain) + return Response.new(data, ctype = 'json')