raise exceptions instead of returning None from HttpClient methods

This commit is contained in:
Izalia Mae 2024-06-18 23:14:21 -04:00
parent b308b03546
commit c508257981
4 changed files with 75 additions and 79 deletions

View file

@ -7,9 +7,11 @@ import time
import traceback import traceback
from aiohttp import web from aiohttp import web
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
from aiohttp.web import StaticResource from aiohttp.web import StaticResource
from aiohttp_swagger import setup_swagger from aiohttp_swagger import setup_swagger
from aputils.signer import Signer from aputils.signer import Signer
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
from bsql import Database, Row from bsql import Database, Row
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -18,6 +20,7 @@ from pathlib import Path
from queue import Empty from queue import Empty
from threading import Event, Thread from threading import Event, Thread
from typing import Any from typing import Any
from urllib.parse import urlparse
from . import logger as logging from . import logger as logging
from .cache import Cache, get_cache from .cache import Cache, get_cache
@ -331,6 +334,15 @@ class PushWorker(multiprocessing.Process):
except Empty: except Empty:
await asyncio.sleep(0) await asyncio.sleep(0)
except ClientSSLError as e:
logging.error('SSL error when pushing to %s: %s', urlparse(inbox).netloc, str(e))
except (AsyncTimeoutError, ClientConnectionError) as e:
logging.error(
'Failed to connect to %s for message push: %s',
urlparse(inbox).netloc, str(e)
)
# make sure an exception doesn't bring down the worker # make sure an exception doesn't bring down the worker
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()

View file

@ -1,17 +1,12 @@
from __future__ import annotations from __future__ import annotations
import json import json
import traceback
from aiohttp import ClientSession, ClientTimeout, TCPConnector from aiohttp import ClientSession, ClientTimeout, TCPConnector
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
from aputils import AlgorithmType, Nodeinfo, ObjectType, Signer, WellKnownNodeinfo from aputils import AlgorithmType, Nodeinfo, ObjectType, Signer, WellKnownNodeinfo
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
from blib import JsonBase from blib import JsonBase
from bsql import Row from bsql import Row
from json.decoder import JSONDecodeError from typing import TYPE_CHECKING, Any, TypeVar, overload
from typing import TYPE_CHECKING, Any, TypeVar
from urllib.parse import urlparse
from . import __version__, logger as logging from . import __version__, logger as logging
from .cache import Cache from .cache import Cache
@ -107,7 +102,7 @@ class HttpClient:
url: str, url: str,
sign_headers: bool, sign_headers: bool,
force: bool, force: bool,
old_algo: bool) -> dict[str, Any] | None: old_algo: bool) -> str | None:
if not self._session: if not self._session:
raise RuntimeError('Client not open') raise RuntimeError('Client not open')
@ -121,7 +116,7 @@ class HttpClient:
if not force: if not force:
try: try:
if not (item := self.cache.get('request', url)).older_than(48): if not (item := self.cache.get('request', url)).older_than(48):
return json.loads(item.value) # type: ignore[no-any-return] return item.value # type: ignore [no-any-return]
except KeyError: except KeyError:
logging.verbose('No cached data for url: %s', url) logging.verbose('No cached data for url: %s', url)
@ -132,59 +127,61 @@ class HttpClient:
algo = AlgorithmType.RSASHA256 if old_algo else AlgorithmType.HS2019 algo = AlgorithmType.RSASHA256 if old_algo else AlgorithmType.HS2019
headers = self.signer.sign_headers('GET', url, algorithm = algo) headers = self.signer.sign_headers('GET', url, algorithm = algo)
try: logging.debug('Fetching resource: %s', url)
logging.debug('Fetching resource: %s', url)
async with self._session.get(url, headers = headers) as resp: async with self._session.get(url, headers = headers) as resp:
# Not expecting a response with 202s, so just return # Not expecting a response with 202s, so just return
if resp.status == 202: if resp.status == 202:
return None
data = await resp.text()
if resp.status != 200:
logging.verbose('Received error when requesting %s: %i', url, resp.status)
logging.debug(data)
return None return None
self.cache.set('request', url, data, 'str') data = await resp.text()
logging.debug('%s >> resp %s', url, json.dumps(json.loads(data), indent = 4))
return json.loads(data) # type: ignore [no-any-return] if resp.status != 200:
logging.verbose('Received error when requesting %s: %i', url, resp.status)
except JSONDecodeError:
logging.verbose('Failed to parse JSON')
logging.debug(data) logging.debug(data)
return None return None
except ClientSSLError as e: self.cache.set('request', url, data, 'str')
logging.verbose('SSL error when connecting to %s', urlparse(url).netloc) return data
logging.warning(str(e))
except (AsyncTimeoutError, ClientConnectionError) as e:
logging.verbose('Failed to connect to %s', urlparse(url).netloc)
logging.warning(str(e))
except Exception: @overload
traceback.print_exc() async def get(self, # type: ignore[overload-overlap]
url: str,
sign_headers: bool,
cls: None = None,
force: bool = False,
old_algo: bool = True) -> None: ...
return None
@overload
async def get(self,
url: str,
sign_headers: bool,
cls: type[T] = JsonBase, # type: ignore[assignment]
force: bool = False,
old_algo: bool = True) -> T: ...
async def get(self, async def get(self,
url: str, url: str,
sign_headers: bool, sign_headers: bool,
cls: type[T], cls: type[T] | None = None,
force: bool = False, force: bool = False,
old_algo: bool = True) -> T | None: old_algo: bool = True) -> T | None:
if not issubclass(cls, JsonBase): if cls is not None and not issubclass(cls, JsonBase):
raise TypeError('cls must be a sub-class of "blib.JsonBase"') raise TypeError('cls must be a sub-class of "blib.JsonBase"')
if (data := (await self._get(url, sign_headers, force, old_algo))) is None: data = await self._get(url, sign_headers, force, old_algo)
return None
return cls.parse(data) if cls is not None:
if data is None:
raise ValueError("Empty response")
return cls.parse(data)
return None
async def post(self, url: str, data: Message | bytes, instance: Row | None = None) -> None: async def post(self, url: str, data: Message | bytes, instance: Row | None = None) -> None:
@ -218,35 +215,22 @@ class HttpClient:
algorithm = algorithm algorithm = algorithm
) )
try: logging.verbose('Sending "%s" to %s', mtype, url)
logging.verbose('Sending "%s" to %s', mtype, url)
async with self._session.post(url, headers = headers, data = body) as resp: async with self._session.post(url, headers = headers, data = body) as resp:
# Not expecting a response, so just return # Not expecting a response, so just return
if resp.status in {200, 202}: if resp.status in {200, 202}:
logging.verbose('Successfully sent "%s" to %s', mtype, url) logging.verbose('Successfully sent "%s" to %s', mtype, url)
return
logging.verbose('Received error when pushing to %s: %i', url, resp.status)
logging.debug(await resp.read())
logging.debug("message: %s", body.decode("utf-8"))
logging.debug("headers: %s", json.dumps(headers, indent = 4))
return return
except ClientSSLError as e: logging.error('Received error when pushing to %s: %i', url, resp.status)
logging.warning('SSL error when pushing to %s', urlparse(url).netloc) logging.debug(await resp.read())
logging.warning(str(e)) logging.debug("message: %s", body.decode("utf-8"))
logging.debug("headers: %s", json.dumps(headers, indent = 4))
except (AsyncTimeoutError, ClientConnectionError) as e: return
logging.warning('Failed to connect to %s for message push', urlparse(url).netloc)
logging.warning(str(e))
# prevent workers from being brought down
except Exception:
traceback.print_exc()
async def fetch_nodeinfo(self, domain: str) -> Nodeinfo | None: async def fetch_nodeinfo(self, domain: str) -> Nodeinfo:
nodeinfo_url = None nodeinfo_url = None
wk_nodeinfo = await self.get( wk_nodeinfo = await self.get(
f'https://{domain}/.well-known/nodeinfo', f'https://{domain}/.well-known/nodeinfo',
@ -254,10 +238,6 @@ class HttpClient:
WellKnownNodeinfo WellKnownNodeinfo
) )
if wk_nodeinfo is None:
logging.verbose('Failed to fetch well-known nodeinfo url for %s', domain)
return None
for version in ('20', '21'): for version in ('20', '21'):
try: try:
nodeinfo_url = wk_nodeinfo.get_url(version) nodeinfo_url = wk_nodeinfo.get_url(version)
@ -266,8 +246,7 @@ class HttpClient:
pass pass
if nodeinfo_url is None: if nodeinfo_url is None:
logging.verbose('Failed to fetch nodeinfo url for %s', domain) raise ValueError(f'Failed to fetch nodeinfo url for {domain}')
return None
return await self.get(nodeinfo_url, False, Nodeinfo) return await self.get(nodeinfo_url, False, Nodeinfo)

View file

@ -95,9 +95,10 @@ class ActorView(View):
logging.verbose('actor not in message') logging.verbose('actor not in message')
return Response.new_error(400, 'no actor in message', 'json') return Response.new_error(400, 'no actor in message', 'json')
actor: Message | None = await self.client.get(self.signature.keyid, True, Message) try:
self.actor = await self.client.get(self.signature.keyid, True, Message)
if actor is None: except Exception:
# ld signatures aren't handled atm, so just ignore it # ld signatures aren't handled atm, so just ignore it
if self.message.type == 'Delete': if self.message.type == 'Delete':
logging.verbose('Instance sent a delete which cannot be handled') logging.verbose('Instance sent a delete which cannot be handled')
@ -106,8 +107,6 @@ class ActorView(View):
logging.verbose(f'Failed to fetch actor: {self.signature.keyid}') logging.verbose(f'Failed to fetch actor: {self.signature.keyid}')
return Response.new_error(400, 'failed to fetch actor', 'json') return Response.new_error(400, 'failed to fetch actor', 'json')
self.actor = actor
try: try:
self.signer = self.actor.signer self.signer = self.actor.signer

View file

@ -1,3 +1,5 @@
import traceback
from aiohttp.web import Request, middleware from aiohttp.web import Request, middleware
from argon2.exceptions import VerifyMismatchError from argon2.exceptions import VerifyMismatchError
from collections.abc import Awaitable, Callable, Sequence from collections.abc import Awaitable, Callable, Sequence
@ -206,19 +208,23 @@ class Inbox(View):
data['domain'] = data['domain'].encode('idna').decode() data['domain'] = data['domain'].encode('idna').decode()
if not data.get('inbox'): if not data.get('inbox'):
actor_data: Message | None = await self.client.get(data['actor'], True, Message) try:
actor_data = await self.client.get(data['actor'], True, Message)
if actor_data is None: except Exception:
traceback.print_exc()
return Response.new_error(500, 'Failed to fetch actor', 'json') return Response.new_error(500, 'Failed to fetch actor', 'json')
data['inbox'] = actor_data.shared_inbox data['inbox'] = actor_data.shared_inbox
if not data.get('software'): if not data.get('software'):
nodeinfo = await self.client.fetch_nodeinfo(data['domain']) try:
nodeinfo = await self.client.fetch_nodeinfo(data['domain'])
if nodeinfo is not None:
data['software'] = nodeinfo.sw_name data['software'] = nodeinfo.sw_name
except Exception:
pass
row = conn.put_inbox(**data) # type: ignore[arg-type] row = conn.put_inbox(**data) # type: ignore[arg-type]
return Response.new(row, ctype = 'json') return Response.new(row, ctype = 'json')