From c508257981fa562be40e276f5f6f57f72588b75b Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Tue, 18 Jun 2024 23:14:21 -0400 Subject: [PATCH] raise exceptions instead of returning None from HttpClient methods --- relay/application.py | 12 ++++ relay/http_client.py | 119 +++++++++++++++---------------------- relay/views/activitypub.py | 7 +-- relay/views/api.py | 16 +++-- 4 files changed, 75 insertions(+), 79 deletions(-) diff --git a/relay/application.py b/relay/application.py index b12c64f..6c8c1e7 100644 --- a/relay/application.py +++ b/relay/application.py @@ -7,9 +7,11 @@ import time import traceback from aiohttp import web +from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError from aiohttp.web import StaticResource from aiohttp_swagger import setup_swagger from aputils.signer import Signer +from asyncio.exceptions import TimeoutError as AsyncTimeoutError from bsql import Database, Row from collections.abc import Awaitable, Callable from datetime import datetime, timedelta @@ -18,6 +20,7 @@ from pathlib import Path from queue import Empty from threading import Event, Thread from typing import Any +from urllib.parse import urlparse from . import logger as logging from .cache import Cache, get_cache @@ -331,6 +334,15 @@ class PushWorker(multiprocessing.Process): except Empty: await asyncio.sleep(0) + except ClientSSLError as e: + logging.error('SSL error when pushing to %s: %s', urlparse(inbox).netloc, str(e)) + + except (AsyncTimeoutError, ClientConnectionError) as e: + logging.error( + 'Failed to connect to %s for message push: %s', + urlparse(inbox).netloc, str(e) + ) + # make sure an exception doesn't bring down the worker except Exception: traceback.print_exc() diff --git a/relay/http_client.py b/relay/http_client.py index 54cea3c..610b8a9 100644 --- a/relay/http_client.py +++ b/relay/http_client.py @@ -1,17 +1,12 @@ from __future__ import annotations import json -import traceback from aiohttp import ClientSession, ClientTimeout, TCPConnector -from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError from aputils import AlgorithmType, Nodeinfo, ObjectType, Signer, WellKnownNodeinfo -from asyncio.exceptions import TimeoutError as AsyncTimeoutError from blib import JsonBase from bsql import Row -from json.decoder import JSONDecodeError -from typing import TYPE_CHECKING, Any, TypeVar -from urllib.parse import urlparse +from typing import TYPE_CHECKING, Any, TypeVar, overload from . import __version__, logger as logging from .cache import Cache @@ -107,7 +102,7 @@ class HttpClient: url: str, sign_headers: bool, force: bool, - old_algo: bool) -> dict[str, Any] | None: + old_algo: bool) -> str | None: if not self._session: raise RuntimeError('Client not open') @@ -121,7 +116,7 @@ class HttpClient: if not force: try: if not (item := self.cache.get('request', url)).older_than(48): - return json.loads(item.value) # type: ignore[no-any-return] + return item.value # type: ignore [no-any-return] except KeyError: logging.verbose('No cached data for url: %s', url) @@ -132,59 +127,61 @@ class HttpClient: algo = AlgorithmType.RSASHA256 if old_algo else AlgorithmType.HS2019 headers = self.signer.sign_headers('GET', url, algorithm = algo) - try: - logging.debug('Fetching resource: %s', url) + logging.debug('Fetching resource: %s', url) - async with self._session.get(url, headers = headers) as resp: - # Not expecting a response with 202s, so just return - if resp.status == 202: - return None - - data = await resp.text() - - if resp.status != 200: - logging.verbose('Received error when requesting %s: %i', url, resp.status) - logging.debug(data) + async with self._session.get(url, headers = headers) as resp: + # Not expecting a response with 202s, so just return + if resp.status == 202: return None - self.cache.set('request', url, data, 'str') - logging.debug('%s >> resp %s', url, json.dumps(json.loads(data), indent = 4)) + data = await resp.text() - return json.loads(data) # type: ignore [no-any-return] - - except JSONDecodeError: - logging.verbose('Failed to parse JSON') + if resp.status != 200: + logging.verbose('Received error when requesting %s: %i', url, resp.status) logging.debug(data) return None - except ClientSSLError as e: - logging.verbose('SSL error when connecting to %s', urlparse(url).netloc) - logging.warning(str(e)) + self.cache.set('request', url, data, 'str') + return data - except (AsyncTimeoutError, ClientConnectionError) as e: - logging.verbose('Failed to connect to %s', urlparse(url).netloc) - logging.warning(str(e)) - except Exception: - traceback.print_exc() + @overload + async def get(self, # type: ignore[overload-overlap] + url: str, + sign_headers: bool, + cls: None = None, + force: bool = False, + old_algo: bool = True) -> None: ... - return None + + @overload + async def get(self, + url: str, + sign_headers: bool, + cls: type[T] = JsonBase, # type: ignore[assignment] + force: bool = False, + old_algo: bool = True) -> T: ... async def get(self, url: str, sign_headers: bool, - cls: type[T], + cls: type[T] | None = None, force: bool = False, old_algo: bool = True) -> T | None: - if not issubclass(cls, JsonBase): + if cls is not None and not issubclass(cls, JsonBase): raise TypeError('cls must be a sub-class of "blib.JsonBase"') - if (data := (await self._get(url, sign_headers, force, old_algo))) is None: - return None + data = await self._get(url, sign_headers, force, old_algo) - return cls.parse(data) + if cls is not None: + if data is None: + raise ValueError("Empty response") + + return cls.parse(data) + + return None async def post(self, url: str, data: Message | bytes, instance: Row | None = None) -> None: @@ -218,35 +215,22 @@ class HttpClient: algorithm = algorithm ) - try: - logging.verbose('Sending "%s" to %s', mtype, url) + logging.verbose('Sending "%s" to %s', mtype, url) - async with self._session.post(url, headers = headers, data = body) as resp: - # Not expecting a response, so just return - if resp.status in {200, 202}: - logging.verbose('Successfully sent "%s" to %s', mtype, url) - return - - logging.verbose('Received error when pushing to %s: %i', url, resp.status) - logging.debug(await resp.read()) - logging.debug("message: %s", body.decode("utf-8")) - logging.debug("headers: %s", json.dumps(headers, indent = 4)) + async with self._session.post(url, headers = headers, data = body) as resp: + # Not expecting a response, so just return + if resp.status in {200, 202}: + logging.verbose('Successfully sent "%s" to %s', mtype, url) return - except ClientSSLError as e: - logging.warning('SSL error when pushing to %s', urlparse(url).netloc) - logging.warning(str(e)) - - except (AsyncTimeoutError, ClientConnectionError) as e: - logging.warning('Failed to connect to %s for message push', urlparse(url).netloc) - logging.warning(str(e)) - - # prevent workers from being brought down - except Exception: - traceback.print_exc() + logging.error('Received error when pushing to %s: %i', url, resp.status) + logging.debug(await resp.read()) + logging.debug("message: %s", body.decode("utf-8")) + logging.debug("headers: %s", json.dumps(headers, indent = 4)) + return - async def fetch_nodeinfo(self, domain: str) -> Nodeinfo | None: + async def fetch_nodeinfo(self, domain: str) -> Nodeinfo: nodeinfo_url = None wk_nodeinfo = await self.get( f'https://{domain}/.well-known/nodeinfo', @@ -254,10 +238,6 @@ class HttpClient: WellKnownNodeinfo ) - if wk_nodeinfo is None: - logging.verbose('Failed to fetch well-known nodeinfo url for %s', domain) - return None - for version in ('20', '21'): try: nodeinfo_url = wk_nodeinfo.get_url(version) @@ -266,8 +246,7 @@ class HttpClient: pass if nodeinfo_url is None: - logging.verbose('Failed to fetch nodeinfo url for %s', domain) - return None + raise ValueError(f'Failed to fetch nodeinfo url for {domain}') return await self.get(nodeinfo_url, False, Nodeinfo) diff --git a/relay/views/activitypub.py b/relay/views/activitypub.py index b19b7e1..f568d17 100644 --- a/relay/views/activitypub.py +++ b/relay/views/activitypub.py @@ -95,9 +95,10 @@ class ActorView(View): logging.verbose('actor not in message') return Response.new_error(400, 'no actor in message', 'json') - actor: Message | None = await self.client.get(self.signature.keyid, True, Message) + try: + self.actor = await self.client.get(self.signature.keyid, True, Message) - if actor is None: + except Exception: # ld signatures aren't handled atm, so just ignore it if self.message.type == 'Delete': logging.verbose('Instance sent a delete which cannot be handled') @@ -106,8 +107,6 @@ class ActorView(View): logging.verbose(f'Failed to fetch actor: {self.signature.keyid}') return Response.new_error(400, 'failed to fetch actor', 'json') - self.actor = actor - try: self.signer = self.actor.signer diff --git a/relay/views/api.py b/relay/views/api.py index 70a9f0e..074dc04 100644 --- a/relay/views/api.py +++ b/relay/views/api.py @@ -1,3 +1,5 @@ +import traceback + from aiohttp.web import Request, middleware from argon2.exceptions import VerifyMismatchError from collections.abc import Awaitable, Callable, Sequence @@ -206,19 +208,23 @@ class Inbox(View): data['domain'] = data['domain'].encode('idna').decode() if not data.get('inbox'): - actor_data: Message | None = await self.client.get(data['actor'], True, Message) + try: + actor_data = await self.client.get(data['actor'], True, Message) - if actor_data is None: + except Exception: + traceback.print_exc() return Response.new_error(500, 'Failed to fetch actor', 'json') data['inbox'] = actor_data.shared_inbox if not data.get('software'): - nodeinfo = await self.client.fetch_nodeinfo(data['domain']) - - if nodeinfo is not None: + try: + nodeinfo = await self.client.fetch_nodeinfo(data['domain']) data['software'] = nodeinfo.sw_name + except Exception: + pass + row = conn.put_inbox(**data) # type: ignore[arg-type] return Response.new(row, ctype = 'json')