From b85b4ab80b1e7f6d0e0966749eb89c5f709f9c2d Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Sat, 26 Nov 2022 18:56:34 -0500 Subject: [PATCH] create HttpClient class to avoid creating a new session every request --- relay/application.py | 36 +++++--- relay/config.py | 24 ++--- relay/database.py | 2 - relay/http_client.py | 203 +++++++++++++++++++++++++++++++++++++++++++ relay/manage.py | 12 +-- relay/misc.py | 148 ------------------------------- relay/processors.py | 69 +++++++-------- relay/views.py | 2 +- 8 files changed, 272 insertions(+), 224 deletions(-) create mode 100644 relay/http_client.py diff --git a/relay/application.py b/relay/application.py index a216584..8133d7c 100644 --- a/relay/application.py +++ b/relay/application.py @@ -6,12 +6,12 @@ import signal import threading from aiohttp import web -from cachetools import LRUCache from datetime import datetime, timedelta from .config import RelayConfig from .database import RelayDatabase -from .misc import DotDict, check_open_port, request, set_app +from .http_client import HttpClient +from .misc import DotDict, check_open_port, set_app from .views import routes @@ -27,8 +27,6 @@ class Application(web.Application): if not self['config'].load(): self['config'].save() - self['cache'] = DotDict({key: Cache(maxsize=self['config'][key]) for key in self['config'].cachekeys}) - self['semaphore'] = asyncio.Semaphore(self['config'].push_limit) self['workers'] = [] self['last_worker'] = 0 @@ -37,12 +35,18 @@ class Application(web.Application): self['database'] = RelayDatabase(self['config']) self['database'].load() + self['client'] = HttpClient( + limit = self.config.push_limit, + timeout = self.config.timeout, + cache_size = self.config.json_cache + ) + self.set_signal_handler() @property - def cache(self): - return self['cache'] + def client(self): + return self['client'] @property @@ -76,6 +80,9 @@ class Application(web.Application): def push_message(self, inbox, message): + if self.config.workers <= 0: + return asyncio.ensure_future(self.client.post(inbox, message)) + worker = self['workers'][self['last_worker']] worker.queue.put((inbox, message)) @@ -145,11 +152,6 @@ class Application(web.Application): self['workers'].clear() -class Cache(LRUCache): - def set_maxsize(self, value): - self.__maxsize = int(value) - - class PushWorker(threading.Thread): def __init__(self, app): threading.Thread.__init__(self) @@ -158,6 +160,12 @@ class PushWorker(threading.Thread): def run(self): + self.client = HttpClient( + limit = self.app.config.push_limit, + timeout = self.app.config.timeout, + cache_size = self.app.config.json_cache + ) + asyncio.run(self.handle_queue()) @@ -166,13 +174,14 @@ class PushWorker(threading.Thread): try: inbox, message = self.queue.get(block=True, timeout=0.25) self.queue.task_done() - await request(inbox, message) - logging.verbose(f'New push from Thread-{threading.get_ident()}') + await self.client.post(inbox, message) except queue.Empty: pass + await self.client.close() + ## Can't sub-class web.Request, so let's just add some properties def request_actor(self): @@ -203,7 +212,6 @@ setattr(web.Request, 'instance', property(request_instance)) setattr(web.Request, 'message', property(request_message)) setattr(web.Request, 'signature', property(request_signature)) -setattr(web.Request, 'cache', property(lambda self: self.app.cache)) setattr(web.Request, 'config', property(lambda self: self.app.config)) setattr(web.Request, 'database', property(lambda self: self.app.database)) setattr(web.Request, 'semaphore', property(lambda self: self.app.semaphore)) diff --git a/relay/config.py b/relay/config.py index e4ee5f8..090b63c 100644 --- a/relay/config.py +++ b/relay/config.py @@ -24,12 +24,6 @@ class RelayConfig(DotDict): 'whitelist' } - cachekeys = { - 'json', - 'objects', - 'digests' - } - def __init__(self, path, is_docker): DotDict.__init__(self, {}) @@ -50,7 +44,7 @@ class RelayConfig(DotDict): if key in ['blocked_instances', 'blocked_software', 'whitelist']: assert isinstance(value, (list, set, tuple)) - elif key in ['port', 'workers', 'json', 'objects', 'digests']: + elif key in ['port', 'workers', 'json_cache', 'timeout']: if not isinstance(value, int): value = int(value) @@ -94,15 +88,14 @@ class RelayConfig(DotDict): 'port': 8080, 'note': 'Make a note about your instance here.', 'push_limit': 512, + 'json_cache': 1024, + 'timeout': 10, 'workers': 0, 'host': 'relay.example.com', + 'whitelist_enabled': False, 'blocked_software': [], 'blocked_instances': [], - 'whitelist': [], - 'whitelist_enabled': False, - 'json': 1024, - 'objects': 1024, - 'digests': 1024 + 'whitelist': [] }) def ban_instance(self, instance): @@ -211,7 +204,7 @@ class RelayConfig(DotDict): return False for key, value in config.items(): - if key in ['ap', 'cache']: + if key in ['ap']: for k, v in value.items(): if k not in self: continue @@ -239,8 +232,9 @@ class RelayConfig(DotDict): 'note': self.note, 'push_limit': self.push_limit, 'workers': self.workers, - 'ap': {key: self[key] for key in self.apkeys}, - 'cache': {key: self[key] for key in self.cachekeys} + 'json_cache': self.json_cache, + 'timeout': self.timeout, + 'ap': {key: self[key] for key in self.apkeys} } with open(self._path, 'w') as fd: diff --git a/relay/database.py b/relay/database.py index 85daf6b..82adce4 100644 --- a/relay/database.py +++ b/relay/database.py @@ -6,8 +6,6 @@ import traceback from Crypto.PublicKey import RSA from urllib.parse import urlparse -from .misc import fetch_nodeinfo - class RelayDatabase(dict): def __init__(self, config): diff --git a/relay/http_client.py b/relay/http_client.py new file mode 100644 index 0000000..d664a88 --- /dev/null +++ b/relay/http_client.py @@ -0,0 +1,203 @@ +import logging +import traceback + +from aiohttp import ClientSession, ClientTimeout, TCPConnector +from aiohttp.client_exceptions import ClientConnectorError, ServerTimeoutError +from datetime import datetime +from cachetools import LRUCache +from json.decoder import JSONDecodeError +from urllib.parse import urlparse + +from . import __version__ +from .misc import ( + MIMETYPES, + DotDict, + Message, + create_signature_header, + generate_body_digest +) + + +HEADERS = { + 'Accept': f'{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9', + 'User-Agent': f'ActivityRelay/{__version__}' +} + + +class Cache(LRUCache): + def set_maxsize(self, value): + self.__maxsize = int(value) + + +class HttpClient: + def __init__(self, limit=100, timeout=10, cache_size=1024): + self.cache = Cache(cache_size) + self.cfg = {'limit': limit, 'timeout': timeout} + self._conn = None + self._session = None + + + @property + def limit(self): + return self.cfg['limit'] + + + @property + def timeout(self): + return self.cfg['timeout'] + + + def sign_headers(self, method, url, message=None): + parsed = urlparse(url) + headers = { + '(request-target)': f'{method.lower()} {parsed.path}', + 'Date': datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT'), + 'Host': parsed.netloc + } + + if message: + data = message.to_json() + headers.update({ + 'Digest': f'SHA-256={generate_body_digest(data)}', + 'Content-Length': str(len(data.encode('utf-8'))) + }) + + headers['Signature'] = create_signature_header(headers) + + del headers['(request-target)'] + del headers['Host'] + + return headers + + + async def open(self): + if self._session: + return + + self._conn = TCPConnector( + limit = self.limit, + ttl_dns_cache = 300, + ) + + self._session = ClientSession( + connector = self._conn, + headers = HEADERS, + connector_owner = True, + timeout = ClientTimeout(total=self.timeout) + ) + + + async def close(self): + if not self._session: + return + + await self._session.close() + await self._conn.close() + + self._conn = None + self._session = None + + + async def get(self, url, sign_headers=False, loads=None, force=False): + await self.open() + + try: url, _ = url.split('#', 1) + except: pass + + if not force and url in self.cache: + return self.cache[url] + + headers = {} + + if sign_headers: + headers.update(self.sign_headers('GET', url)) + + try: + logging.verbose(f'Fetching resource: {url}') + + async with self._session.get(url, headers=headers) as resp: + ## Not expecting a response with 202s, so just return + if resp.status == 202: + return + + elif resp.status != 200: + logging.verbose(f'Received error when requesting {url}: {resp.status}') + logging.verbose(await resp.read()) # change this to debug + return + + if loads: + if issubclass(loads, DotDict): + message = await resp.json(loads=loads.new_from_json) + + else: + message = await resp.json(loads=loads) + + elif resp.content_type == MIMETYPES['activity']: + message = await resp.json(loads=Message.new_from_json) + + elif resp.content_type == MIMETYPES['json']: + message = await resp.json(loads=DotDict.new_from_json) + + else: + # todo: raise TypeError or something + logging.verbose(f'Invalid Content-Type for "{url}": {resp.content_type}') + return logging.debug(f'Response: {resp.read()}') + + logging.debug(f'{url} >> resp {message.to_json(4)}') + + self.cache[url] = message + return message + + except JSONDecodeError: + logging.verbose(f'Failed to parse JSON') + + except (ClientConnectorError, ServerTimeoutError): + logging.verbose(f'Failed to connect to {urlparse(url).netloc}') + + except Exception as e: + traceback.print_exc() + raise e + + + async def post(self, url, message): + await self.open() + + headers = {'Content-Type': 'application/activity+json'} + headers.update(self.sign_headers('POST', url, message)) + + try: + logging.verbose(f'Sending "{message.type}" to {url}') + + async with self._session.post(url, headers=headers, data=message.to_json()) as resp: + ## Not expecting a response, so just return + if resp.status in {200, 202}: + return logging.verbose(f'Successfully sent "{message.type}" to {url}') + + logging.verbose(f'Received error when pushing to {url}: {resp.status}') + return logging.verbose(await resp.read()) # change this to debug + + except (ClientConnectorError, ServerTimeoutError): + logging.verbose(f'Failed to connect to {url.netloc}') + + ## prevent workers from being brought down + except Exception as e: + traceback.print_exc() + + + ## Additional methods ## + async def fetch_nodeinfo(domain): + nodeinfo_url = None + wk_nodeinfo = await self.get(f'https://{domain}/.well-known/nodeinfo', loads=WKNodeinfo) + + for version in ['20', '21']: + try: + nodeinfo_url = wk_nodeinfo.get_url(version) + + except KeyError: + pass + + if not nodeinfo_url: + logging.verbose(f'Failed to fetch nodeinfo url for domain: {domain}') + return False + + return await request(nodeinfo_url, loads=Nodeinfo) or False diff --git a/relay/manage.py b/relay/manage.py index 4fb9614..48f3700 100644 --- a/relay/manage.py +++ b/relay/manage.py @@ -145,7 +145,7 @@ def cli_inbox_follow(actor): inbox = inbox_data['inbox'] except KeyError: - actor_data = asyncio.run(misc.request(actor)) + actor_data = asyncio.run(app.client.get(actor, sign_headers=True)) if not actor_data: return click.echo(f'Failed to fetch actor: {actor}') @@ -157,7 +157,7 @@ def cli_inbox_follow(actor): actor = actor ) - asyncio.run(misc.request(inbox, message)) + asyncio.run(app.client.post(inbox, message)) click.echo(f'Sent follow message to actor: {actor}') @@ -183,7 +183,7 @@ def cli_inbox_unfollow(actor): ) except KeyError: - actor_data = asyncio.run(misc.request(actor)) + actor_data = asyncio.run(app.client.get(actor, sign_headers=True)) inbox = actor_data.shared_inbox message = misc.Message.new_unfollow( host = app.config.host, @@ -195,7 +195,7 @@ def cli_inbox_unfollow(actor): } ) - asyncio.run(misc.request(inbox, message)) + asyncio.run(app.client.post(inbox, message)) click.echo(f'Sent unfollow message to: {actor}') @@ -319,7 +319,7 @@ def cli_software_ban(name, fetch_nodeinfo): return click.echo('Banned all relay software') if fetch_nodeinfo: - nodeinfo = asyncio.run(misc.fetch_nodeinfo(name)) + nodeinfo = asyncio.run(app.client.fetch_nodeinfo(name)) if not software: click.echo(f'Failed to fetch software name from domain: {name}') @@ -347,7 +347,7 @@ def cli_software_unban(name, fetch_nodeinfo): return click.echo('Unbanned all relay software') if fetch_nodeinfo: - nodeinfo = asyncio.run(misc.fetch_nodeinfo(name)) + nodeinfo = asyncio.run(app.client.fetch_nodeinfo(name)) if not nodeinfo: click.echo(f'Failed to fetch software name from domain: {name}') diff --git a/relay/misc.py b/relay/misc.py index 7f2bb56..628800d 100644 --- a/relay/misc.py +++ b/relay/misc.py @@ -9,8 +9,6 @@ import uuid from Crypto.Hash import SHA, SHA256, SHA512 from Crypto.PublicKey import RSA from Crypto.Signature import PKCS1_v1_5 -from aiohttp import ClientSession, ClientTimeout -from aiohttp.client_exceptions import ClientConnectorError, ServerTimeoutError from aiohttp.hdrs import METH_ALL as METHODS from aiohttp.web import Response as AiohttpResponse, View as AiohttpView from datetime import datetime @@ -117,14 +115,8 @@ def distill_inboxes(actor, object_id): def generate_body_digest(body): - bodyhash = app.cache.digests.get(body) - - if bodyhash: - return bodyhash - h = SHA256.new(body.encode('utf-8')) bodyhash = base64.b64encode(h.digest()).decode('utf-8') - app.cache.digests[body] = bodyhash return bodyhash @@ -138,141 +130,6 @@ def sign_signing_string(sigstring, key): return base64.b64encode(sigdata).decode('utf-8') -async def fetch_actor_key(actor): - actor_data = await request(actor) - - if not actor_data: - return None - - try: - return RSA.importKey(actor_data['publicKey']['publicKeyPem']) - - except Exception as e: - logging.debug(f'Exception occured while fetching actor key: {e}') - - -async def fetch_nodeinfo(domain): - nodeinfo_url = None - wk_nodeinfo = await request(f'https://{domain}/.well-known/nodeinfo', sign_headers=False, activity=False) - - if not wk_nodeinfo: - return - - wk_nodeinfo = WKNodeinfo(wk_nodeinfo) - - for version in ['20', '21']: - try: - nodeinfo_url = wk_nodeinfo.get_url(version) - - except KeyError: - pass - - if not nodeinfo_url: - logging.verbose(f'Failed to fetch nodeinfo url for domain: {domain}') - return False - - nodeinfo = await request(nodeinfo_url, sign_headers=False, activity=False) - - if not nodeinfo: - return False - - return Nodeinfo(nodeinfo) - - -async def request(uri, data=None, force=False, sign_headers=True, activity=True, timeout=10): - ## If a get request and not force, try to use the cache first - if not data and not force: - try: - return app.cache.json[uri] - - except KeyError: - pass - - url = urlparse(uri) - method = 'POST' if data else 'GET' - action = data.get('type') if data else None - headers = { - 'Accept': f'{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9', - 'User-Agent': 'ActivityRelay', - } - - if data: - headers['Content-Type'] = MIMETYPES['activity' if activity else 'json'] - - if sign_headers: - signing_headers = { - '(request-target)': f'{method.lower()} {url.path}', - 'Date': datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT'), - 'Host': url.netloc - } - - if data: - assert isinstance(data, dict) - - data = json.dumps(data) - signing_headers.update({ - 'Digest': f'SHA-256={generate_body_digest(data)}', - 'Content-Length': str(len(data.encode('utf-8'))) - }) - - signing_headers['Signature'] = create_signature_header(signing_headers) - - del signing_headers['(request-target)'] - del signing_headers['Host'] - - headers.update(signing_headers) - - try: - if data: - logging.verbose(f'Sending "{action}" to inbox: {uri}') - - else: - logging.verbose(f'Sending GET request to url: {uri}') - - timeout_cfg = ClientTimeout(connect=timeout) - async with ClientSession(trace_configs=http_debug(), timeout=timeout_cfg) as session, app.semaphore: - async with session.request(method, uri, headers=headers, data=data) as resp: - ## aiohttp has been known to leak if the response hasn't been read, - ## so we're just gonna read the request no matter what - resp_data = await resp.read() - - ## Not expecting a response, so just return - if resp.status == 202: - return - - elif resp.status != 200: - if not resp_data: - return logging.verbose(f'Received error when requesting {uri}: {resp.status} {resp_data}') - - return logging.verbose(f'Received error when sending {action} to {uri}: {resp.status} {resp_data}') - - if resp.content_type == MIMETYPES['activity']: - resp_data = await resp.json(loads=Message.new_from_json) - - elif resp.content_type == MIMETYPES['json']: - resp_data = await resp.json(loads=DotDict.new_from_json) - - else: - logging.verbose(f'Invalid Content-Type for "{url}": {resp.content_type}') - return logging.debug(f'Response: {resp_data}') - - logging.debug(f'{uri} >> resp {resp_data}') - - app.cache.json[uri] = resp_data - return resp_data - - except JSONDecodeError: - logging.verbose(f'Failed to parse JSON') - return - - except (ClientConnectorError, ServerTimeoutError): - logging.verbose(f'Failed to connect to {url.netloc}') - return - - except Exception: - traceback.print_exc() - - async def validate_signature(actor, signature, http_request): headers = {key.lower(): value for key, value in http_request.headers.items()} headers['(request-target)'] = ' '.join([http_request.method.lower(), http_request.path]) @@ -559,11 +416,6 @@ class View(AiohttpView): return self._request.app - @property - def cache(self): - return self.app.cache - - @property def config(self): return self.app.config diff --git a/relay/processors.py b/relay/processors.py index 0b575b4..5b76485 100644 --- a/relay/processors.py +++ b/relay/processors.py @@ -1,63 +1,55 @@ import asyncio import logging +from cachetools import LRUCache from uuid import uuid4 -from . import misc +from .misc import Message, distill_inboxes + + +cache = LRUCache(1024) async def handle_relay(request): - if request.message.objectid in request.cache.objects: + if request.message.objectid in cache: logging.verbose(f'already relayed {request.message.objectid}') return - message = misc.Message.new_announce( + message = Message.new_announce( host = request.config.host, object = request.message.objectid ) - request.cache.objects[request.message.objectid] = message.id - logging.verbose(f'Relaying post from {request.message.actorid}') + cache[request.message.objectid] = message.id logging.debug(f'>> relay: {message}') - inboxes = misc.distill_inboxes(request.actor, request.message.objectid) + inboxes = distill_inboxes(request.actor, request.message.objectid) - if request.config.workers > 0: - for inbox in inboxes: - request.app.push_message(inbox, message) - - else: - futures = [misc.request(inbox, data=message) for inbox in inboxes] - asyncio.ensure_future(asyncio.gather(*futures)) + for inbox in inboxes: + request.app.push_message(inbox, message) async def handle_forward(request): - if request.message.id in request.cache.objects: + if request.message.id in cache: logging.verbose(f'already forwarded {request.message.id}') return - message = misc.Message.new_announce( + message = Message.new_announce( host = request.config.host, object = request.message ) - request.cache.objects[request.message.id] = message.id - logging.verbose(f'Forwarding post from {request.actor.id}') - logging.debug(f'>> Relay {request.message}') + cache[request.message.id] = message.id + logging.debug(f'>> forward: {message}') - inboxes = misc.distill_inboxes(request.actor, request.message.objectid) + inboxes = distill_inboxes(request.actor, request.message.objectid) - if request.config.workers > 0: - for inbox in inboxes: - request.app.push_message(inbox, message) - - else: - futures = [misc.request(inbox, data=message) for inbox in inboxes] - asyncio.ensure_future(asyncio.gather(*futures)) + for inbox in inboxes: + request.app.push_message(inbox, message) async def handle_follow(request): - nodeinfo = await misc.fetch_nodeinfo(request.actor.domain) + nodeinfo = await request.app.client.fetch_nodeinfo(request.actor.domain) software = nodeinfo.swname if nodeinfo else None ## reject if software used by actor is banned @@ -67,9 +59,9 @@ async def handle_follow(request): request.database.add_inbox(request.actor.shared_inbox, request.message.id, software) request.database.save() - await misc.request( + await request.app.push_message( request.actor.shared_inbox, - misc.Message.new_response( + Message.new_response( host = request.config.host, actor = request.actor.id, followid = request.message.id, @@ -80,9 +72,9 @@ async def handle_follow(request): # Are Akkoma and Pleroma the only two that expect a follow back? # Ignoring only Mastodon for now if software != 'mastodon': - await misc.request( + await request.app.push_message( request.actor.shared_inbox, - misc.Message.new_follow( + Message.new_follow( host = request.config.host, actor = request.actor.id ) @@ -99,14 +91,15 @@ async def handle_undo(request): request.database.save() - message = misc.Message.new_unfollow( - host = request.config.host, - actor = request.actor.id, - follow = request.message + await request.app.push_message( + request.actor.shared_inbox, + Message.new_unfollow( + host = request.config.host, + actor = request.actor.id, + follow = request.message + ) ) - await misc.request(request.actor.shared_inbox, message) - processors = { 'Announce': handle_relay, @@ -123,7 +116,7 @@ async def run_processor(request): return if request.instance and not request.instance.get('software'): - nodeinfo = await misc.fetch_nodeinfo(request.instance['domain']) + nodeinfo = await request.app.client.fetch_nodeinfo(request.instance['domain']) if nodeinfo: request.instance['software'] = nodeinfo.swname diff --git a/relay/views.py b/relay/views.py index a91f078..76cafec 100644 --- a/relay/views.py +++ b/relay/views.py @@ -102,7 +102,7 @@ async def inbox(request): logging.verbose('Failed to parse inbox message') return Response.new_error(400, 'failed to parse message', 'json') - request['actor'] = await misc.request(request.signature.keyid) + request['actor'] = await request.app.client.get(request.signature.keyid, sign_headers=True) ## reject if actor is empty if not request.actor: