relay/relay/http_client.py
2025-02-12 15:14:46 -05:00

263 lines
5.8 KiB
Python

from __future__ import annotations
import json
from aiohttp import ClientSession, ClientTimeout, TCPConnector
from aputils import AlgorithmType, Nodeinfo, ObjectType, Signer, WellKnownNodeinfo
from blib import HttpError, JsonBase
from typing import TYPE_CHECKING, Any, TypeVar, overload
from . import __version__, logger as logging
from .cache import Cache
from .database.schema import Instance
from .errors import EmptyBodyError
from .misc import MIMETYPES, Message, get_app
if TYPE_CHECKING:
from .application import Application
T = TypeVar("T", bound = JsonBase[Any])
HEADERS = {
"Accept": f"{MIMETYPES['activity']}, {MIMETYPES['json']};q=0.9",
"User-Agent": f"ActivityRelay/{__version__}"
}
SUPPORTS_HS2019 = {
'friendica',
'gotosocial',
'hubzilla'
'mastodon',
'socialhome',
'misskey',
'catodon',
'cherrypick',
'firefish',
'foundkey',
'iceshrimp',
'sharkey'
}
class HttpClient:
def __init__(self, limit: int = 100, timeout: int = 10):
self.limit = limit
self.timeout = timeout
self._conn: TCPConnector | None = None
self._session: ClientSession | None = None
async def __aenter__(self) -> HttpClient:
self.open()
return self
async def __aexit__(self, *_: Any) -> None:
await self.close()
@property
def app(self) -> Application:
return get_app()
@property
def cache(self) -> Cache:
return self.app.cache
@property
def signer(self) -> Signer:
return self.app.signer
def open(self) -> None:
if self._session:
return
self._conn = TCPConnector(
limit = self.limit,
ttl_dns_cache = 300,
)
self._session = ClientSession(
connector = self._conn,
headers = HEADERS,
connector_owner = True,
timeout = ClientTimeout(total=self.timeout)
)
async def close(self) -> None:
if self._session:
await self._session.close()
if self._conn:
await self._conn.close()
self._conn = None
self._session = None
async def _get(self,
url: str,
sign_headers: bool,
force: bool,
old_algo: bool) -> str | None:
if not self._session:
raise RuntimeError("Client not open")
url = url.split("#", 1)[0]
if not force:
try:
if not (item := self.cache.get("request", url)).older_than(48):
return item.value # type: ignore [no-any-return]
except KeyError:
logging.verbose("No cached data for url: %s", url)
headers = {}
if sign_headers:
algo = AlgorithmType.RSASHA256 if old_algo else AlgorithmType.HS2019
headers = self.signer.sign_headers("GET", url, algorithm = algo)
logging.debug("Fetching resource: %s", url)
async with self._session.get(url, headers = headers) as resp:
# Not expecting a response with 202s, so just return
if resp.status == 202:
return None
data = await resp.text()
if resp.status not in (200, 202):
try:
error = json.loads(data)["error"]
except Exception:
error = data
raise HttpError(resp.status, error)
self.cache.set("request", url, data, "str")
return data
@overload
async def get(self,
url: str,
sign_headers: bool,
cls: None = None,
force: bool = False,
old_algo: bool = True) -> str | None: ...
@overload
async def get(self,
url: str,
sign_headers: bool,
cls: type[T] = JsonBase, # type: ignore[assignment]
force: bool = False,
old_algo: bool = True) -> T: ...
async def get(self,
url: str,
sign_headers: bool,
cls: type[T] | None = None,
force: bool = False,
old_algo: bool = True) -> T | str | None:
if cls is not None and not issubclass(cls, JsonBase):
raise TypeError("cls must be a sub-class of \"blib.JsonBase\"")
data = await self._get(url, sign_headers, force, old_algo)
if cls is not None:
if data is None:
# this shouldn"t actually get raised, but keeping just in case
raise EmptyBodyError(f"GET {url}")
return cls.parse(data)
return data
async def post(self, url: str, data: Message | bytes, instance: Instance | None = None) -> None:
if not self._session:
raise RuntimeError("Client not open")
# akkoma and pleroma do not support HS2019 and other software still needs to be tested
if instance is not None and instance.software in SUPPORTS_HS2019:
algorithm = AlgorithmType.HS2019
else:
algorithm = AlgorithmType.RSASHA256
body: bytes
message: Message
if isinstance(data, bytes):
body = data
message = Message.parse(data)
else:
body = data.to_json().encode("utf-8")
message = data
mtype = message.type.value if isinstance(message.type, ObjectType) else message.type
headers = self.signer.sign_headers(
"POST",
url,
body,
headers = {"Content-Type": "application/activity+json"},
algorithm = algorithm
)
logging.verbose("Sending \"%s\" to %s", mtype, url)
async with self._session.post(url, headers = headers, data = body) as resp:
if resp.status not in (200, 202):
raise HttpError(
resp.status,
await resp.text(),
headers = {k: v for k, v in resp.headers.items()}
)
async def fetch_nodeinfo(self, domain: str, force: bool = False) -> Nodeinfo:
nodeinfo_url = None
wk_nodeinfo = await self.get(
f"https://{domain}/.well-known/nodeinfo", False, WellKnownNodeinfo, force
)
for version in ("20", "21"):
try:
nodeinfo_url = wk_nodeinfo.get_url(version)
except KeyError:
pass
if nodeinfo_url is None:
raise ValueError(f"Failed to fetch nodeinfo url for {domain}")
return await self.get(nodeinfo_url, False, Nodeinfo, force)
async def get(*args: Any, **kwargs: Any) -> Any:
async with HttpClient() as client:
return await client.get(*args, **kwargs)
async def post(*args: Any, **kwargs: Any) -> None:
async with HttpClient() as client:
return await client.post(*args, **kwargs)
async def fetch_nodeinfo(*args: Any, **kwargs: Any) -> Nodeinfo | None:
async with HttpClient() as client:
return await client.fetch_nodeinfo(*args, **kwargs)