create HttpClient class to avoid creating a new session every request

This commit is contained in:
Izalia Mae 2022-11-26 18:56:34 -05:00
parent 32764a1f93
commit b85b4ab80b
8 changed files with 272 additions and 224 deletions

View file

@ -6,12 +6,12 @@ import signal
import threading
from aiohttp import web
from cachetools import LRUCache
from datetime import datetime, timedelta
from .config import RelayConfig
from .database import RelayDatabase
from .misc import DotDict, check_open_port, request, set_app
from .http_client import HttpClient
from .misc import DotDict, check_open_port, set_app
from .views import routes
@ -27,8 +27,6 @@ class Application(web.Application):
if not self['config'].load():
self['config'].save()
self['cache'] = DotDict({key: Cache(maxsize=self['config'][key]) for key in self['config'].cachekeys})
self['semaphore'] = asyncio.Semaphore(self['config'].push_limit)
self['workers'] = []
self['last_worker'] = 0
@ -37,12 +35,18 @@ class Application(web.Application):
self['database'] = RelayDatabase(self['config'])
self['database'].load()
self['client'] = HttpClient(
limit = self.config.push_limit,
timeout = self.config.timeout,
cache_size = self.config.json_cache
)
self.set_signal_handler()
@property
def cache(self):
return self['cache']
def client(self):
return self['client']
@property
@ -76,6 +80,9 @@ class Application(web.Application):
def push_message(self, inbox, message):
if self.config.workers <= 0:
return asyncio.ensure_future(self.client.post(inbox, message))
worker = self['workers'][self['last_worker']]
worker.queue.put((inbox, message))
@ -145,11 +152,6 @@ class Application(web.Application):
self['workers'].clear()
class Cache(LRUCache):
def set_maxsize(self, value):
self.__maxsize = int(value)
class PushWorker(threading.Thread):
def __init__(self, app):
threading.Thread.__init__(self)
@ -158,6 +160,12 @@ class PushWorker(threading.Thread):
def run(self):
self.client = HttpClient(
limit = self.app.config.push_limit,
timeout = self.app.config.timeout,
cache_size = self.app.config.json_cache
)
asyncio.run(self.handle_queue())
@ -166,13 +174,14 @@ class PushWorker(threading.Thread):
try:
inbox, message = self.queue.get(block=True, timeout=0.25)
self.queue.task_done()
await request(inbox, message)
logging.verbose(f'New push from Thread-{threading.get_ident()}')
await self.client.post(inbox, message)
except queue.Empty:
pass
await self.client.close()
## Can't sub-class web.Request, so let's just add some properties
def request_actor(self):
@ -203,7 +212,6 @@ setattr(web.Request, 'instance', property(request_instance))
setattr(web.Request, 'message', property(request_message))
setattr(web.Request, 'signature', property(request_signature))
setattr(web.Request, 'cache', property(lambda self: self.app.cache))
setattr(web.Request, 'config', property(lambda self: self.app.config))
setattr(web.Request, 'database', property(lambda self: self.app.database))
setattr(web.Request, 'semaphore', property(lambda self: self.app.semaphore))

View file

@ -24,12 +24,6 @@ class RelayConfig(DotDict):
'whitelist'
}
cachekeys = {
'json',
'objects',
'digests'
}
def __init__(self, path, is_docker):
DotDict.__init__(self, {})
@ -50,7 +44,7 @@ class RelayConfig(DotDict):
if key in ['blocked_instances', 'blocked_software', 'whitelist']:
assert isinstance(value, (list, set, tuple))
elif key in ['port', 'workers', 'json', 'objects', 'digests']:
elif key in ['port', 'workers', 'json_cache', 'timeout']:
if not isinstance(value, int):
value = int(value)
@ -94,15 +88,14 @@ class RelayConfig(DotDict):
'port': 8080,
'note': 'Make a note about your instance here.',
'push_limit': 512,
'json_cache': 1024,
'timeout': 10,
'workers': 0,
'host': 'relay.example.com',
'whitelist_enabled': False,
'blocked_software': [],
'blocked_instances': [],
'whitelist': [],
'whitelist_enabled': False,
'json': 1024,
'objects': 1024,
'digests': 1024
'whitelist': []
})
def ban_instance(self, instance):
@ -211,7 +204,7 @@ class RelayConfig(DotDict):
return False
for key, value in config.items():
if key in ['ap', 'cache']:
if key in ['ap']:
for k, v in value.items():
if k not in self:
continue
@ -239,8 +232,9 @@ class RelayConfig(DotDict):
'note': self.note,
'push_limit': self.push_limit,
'workers': self.workers,
'ap': {key: self[key] for key in self.apkeys},
'cache': {key: self[key] for key in self.cachekeys}
'json_cache': self.json_cache,
'timeout': self.timeout,
'ap': {key: self[key] for key in self.apkeys}
}
with open(self._path, 'w') as fd:

View file

@ -6,8 +6,6 @@ import traceback
from Crypto.PublicKey import RSA
from urllib.parse import urlparse
from .misc import fetch_nodeinfo
class RelayDatabase(dict):
def __init__(self, config):

203
relay/http_client.py Normal file
View file

@ -0,0 +1,203 @@
import logging
import traceback
from aiohttp import ClientSession, ClientTimeout, TCPConnector
from aiohttp.client_exceptions import ClientConnectorError, ServerTimeoutError
from datetime import datetime
from cachetools import LRUCache
from json.decoder import JSONDecodeError
from urllib.parse import urlparse
from . import __version__
from .misc import (
MIMETYPES,
DotDict,
Message,
create_signature_header,
generate_body_digest
)
HEADERS = {
'Accept': f'{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9',
'User-Agent': f'ActivityRelay/{__version__}'
}
class Cache(LRUCache):
def set_maxsize(self, value):
self.__maxsize = int(value)
class HttpClient:
def __init__(self, limit=100, timeout=10, cache_size=1024):
self.cache = Cache(cache_size)
self.cfg = {'limit': limit, 'timeout': timeout}
self._conn = None
self._session = None
@property
def limit(self):
return self.cfg['limit']
@property
def timeout(self):
return self.cfg['timeout']
def sign_headers(self, method, url, message=None):
parsed = urlparse(url)
headers = {
'(request-target)': f'{method.lower()} {parsed.path}',
'Date': datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT'),
'Host': parsed.netloc
}
if message:
data = message.to_json()
headers.update({
'Digest': f'SHA-256={generate_body_digest(data)}',
'Content-Length': str(len(data.encode('utf-8')))
})
headers['Signature'] = create_signature_header(headers)
del headers['(request-target)']
del headers['Host']
return headers
async def open(self):
if self._session:
return
self._conn = TCPConnector(
limit = self.limit,
ttl_dns_cache = 300,
)
self._session = ClientSession(
connector = self._conn,
headers = HEADERS,
connector_owner = True,
timeout = ClientTimeout(total=self.timeout)
)
async def close(self):
if not self._session:
return
await self._session.close()
await self._conn.close()
self._conn = None
self._session = None
async def get(self, url, sign_headers=False, loads=None, force=False):
await self.open()
try: url, _ = url.split('#', 1)
except: pass
if not force and url in self.cache:
return self.cache[url]
headers = {}
if sign_headers:
headers.update(self.sign_headers('GET', url))
try:
logging.verbose(f'Fetching resource: {url}')
async with self._session.get(url, headers=headers) as resp:
## Not expecting a response with 202s, so just return
if resp.status == 202:
return
elif resp.status != 200:
logging.verbose(f'Received error when requesting {url}: {resp.status}')
logging.verbose(await resp.read()) # change this to debug
return
if loads:
if issubclass(loads, DotDict):
message = await resp.json(loads=loads.new_from_json)
else:
message = await resp.json(loads=loads)
elif resp.content_type == MIMETYPES['activity']:
message = await resp.json(loads=Message.new_from_json)
elif resp.content_type == MIMETYPES['json']:
message = await resp.json(loads=DotDict.new_from_json)
else:
# todo: raise TypeError or something
logging.verbose(f'Invalid Content-Type for "{url}": {resp.content_type}')
return logging.debug(f'Response: {resp.read()}')
logging.debug(f'{url} >> resp {message.to_json(4)}')
self.cache[url] = message
return message
except JSONDecodeError:
logging.verbose(f'Failed to parse JSON')
except (ClientConnectorError, ServerTimeoutError):
logging.verbose(f'Failed to connect to {urlparse(url).netloc}')
except Exception as e:
traceback.print_exc()
raise e
async def post(self, url, message):
await self.open()
headers = {'Content-Type': 'application/activity+json'}
headers.update(self.sign_headers('POST', url, message))
try:
logging.verbose(f'Sending "{message.type}" to {url}')
async with self._session.post(url, headers=headers, data=message.to_json()) as resp:
## Not expecting a response, so just return
if resp.status in {200, 202}:
return logging.verbose(f'Successfully sent "{message.type}" to {url}')
logging.verbose(f'Received error when pushing to {url}: {resp.status}')
return logging.verbose(await resp.read()) # change this to debug
except (ClientConnectorError, ServerTimeoutError):
logging.verbose(f'Failed to connect to {url.netloc}')
## prevent workers from being brought down
except Exception as e:
traceback.print_exc()
## Additional methods ##
async def fetch_nodeinfo(domain):
nodeinfo_url = None
wk_nodeinfo = await self.get(f'https://{domain}/.well-known/nodeinfo', loads=WKNodeinfo)
for version in ['20', '21']:
try:
nodeinfo_url = wk_nodeinfo.get_url(version)
except KeyError:
pass
if not nodeinfo_url:
logging.verbose(f'Failed to fetch nodeinfo url for domain: {domain}')
return False
return await request(nodeinfo_url, loads=Nodeinfo) or False

View file

@ -145,7 +145,7 @@ def cli_inbox_follow(actor):
inbox = inbox_data['inbox']
except KeyError:
actor_data = asyncio.run(misc.request(actor))
actor_data = asyncio.run(app.client.get(actor, sign_headers=True))
if not actor_data:
return click.echo(f'Failed to fetch actor: {actor}')
@ -157,7 +157,7 @@ def cli_inbox_follow(actor):
actor = actor
)
asyncio.run(misc.request(inbox, message))
asyncio.run(app.client.post(inbox, message))
click.echo(f'Sent follow message to actor: {actor}')
@ -183,7 +183,7 @@ def cli_inbox_unfollow(actor):
)
except KeyError:
actor_data = asyncio.run(misc.request(actor))
actor_data = asyncio.run(app.client.get(actor, sign_headers=True))
inbox = actor_data.shared_inbox
message = misc.Message.new_unfollow(
host = app.config.host,
@ -195,7 +195,7 @@ def cli_inbox_unfollow(actor):
}
)
asyncio.run(misc.request(inbox, message))
asyncio.run(app.client.post(inbox, message))
click.echo(f'Sent unfollow message to: {actor}')
@ -319,7 +319,7 @@ def cli_software_ban(name, fetch_nodeinfo):
return click.echo('Banned all relay software')
if fetch_nodeinfo:
nodeinfo = asyncio.run(misc.fetch_nodeinfo(name))
nodeinfo = asyncio.run(app.client.fetch_nodeinfo(name))
if not software:
click.echo(f'Failed to fetch software name from domain: {name}')
@ -347,7 +347,7 @@ def cli_software_unban(name, fetch_nodeinfo):
return click.echo('Unbanned all relay software')
if fetch_nodeinfo:
nodeinfo = asyncio.run(misc.fetch_nodeinfo(name))
nodeinfo = asyncio.run(app.client.fetch_nodeinfo(name))
if not nodeinfo:
click.echo(f'Failed to fetch software name from domain: {name}')

View file

@ -9,8 +9,6 @@ import uuid
from Crypto.Hash import SHA, SHA256, SHA512
from Crypto.PublicKey import RSA
from Crypto.Signature import PKCS1_v1_5
from aiohttp import ClientSession, ClientTimeout
from aiohttp.client_exceptions import ClientConnectorError, ServerTimeoutError
from aiohttp.hdrs import METH_ALL as METHODS
from aiohttp.web import Response as AiohttpResponse, View as AiohttpView
from datetime import datetime
@ -117,14 +115,8 @@ def distill_inboxes(actor, object_id):
def generate_body_digest(body):
bodyhash = app.cache.digests.get(body)
if bodyhash:
return bodyhash
h = SHA256.new(body.encode('utf-8'))
bodyhash = base64.b64encode(h.digest()).decode('utf-8')
app.cache.digests[body] = bodyhash
return bodyhash
@ -138,141 +130,6 @@ def sign_signing_string(sigstring, key):
return base64.b64encode(sigdata).decode('utf-8')
async def fetch_actor_key(actor):
actor_data = await request(actor)
if not actor_data:
return None
try:
return RSA.importKey(actor_data['publicKey']['publicKeyPem'])
except Exception as e:
logging.debug(f'Exception occured while fetching actor key: {e}')
async def fetch_nodeinfo(domain):
nodeinfo_url = None
wk_nodeinfo = await request(f'https://{domain}/.well-known/nodeinfo', sign_headers=False, activity=False)
if not wk_nodeinfo:
return
wk_nodeinfo = WKNodeinfo(wk_nodeinfo)
for version in ['20', '21']:
try:
nodeinfo_url = wk_nodeinfo.get_url(version)
except KeyError:
pass
if not nodeinfo_url:
logging.verbose(f'Failed to fetch nodeinfo url for domain: {domain}')
return False
nodeinfo = await request(nodeinfo_url, sign_headers=False, activity=False)
if not nodeinfo:
return False
return Nodeinfo(nodeinfo)
async def request(uri, data=None, force=False, sign_headers=True, activity=True, timeout=10):
## If a get request and not force, try to use the cache first
if not data and not force:
try:
return app.cache.json[uri]
except KeyError:
pass
url = urlparse(uri)
method = 'POST' if data else 'GET'
action = data.get('type') if data else None
headers = {
'Accept': f'{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9',
'User-Agent': 'ActivityRelay',
}
if data:
headers['Content-Type'] = MIMETYPES['activity' if activity else 'json']
if sign_headers:
signing_headers = {
'(request-target)': f'{method.lower()} {url.path}',
'Date': datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT'),
'Host': url.netloc
}
if data:
assert isinstance(data, dict)
data = json.dumps(data)
signing_headers.update({
'Digest': f'SHA-256={generate_body_digest(data)}',
'Content-Length': str(len(data.encode('utf-8')))
})
signing_headers['Signature'] = create_signature_header(signing_headers)
del signing_headers['(request-target)']
del signing_headers['Host']
headers.update(signing_headers)
try:
if data:
logging.verbose(f'Sending "{action}" to inbox: {uri}')
else:
logging.verbose(f'Sending GET request to url: {uri}')
timeout_cfg = ClientTimeout(connect=timeout)
async with ClientSession(trace_configs=http_debug(), timeout=timeout_cfg) as session, app.semaphore:
async with session.request(method, uri, headers=headers, data=data) as resp:
## aiohttp has been known to leak if the response hasn't been read,
## so we're just gonna read the request no matter what
resp_data = await resp.read()
## Not expecting a response, so just return
if resp.status == 202:
return
elif resp.status != 200:
if not resp_data:
return logging.verbose(f'Received error when requesting {uri}: {resp.status} {resp_data}')
return logging.verbose(f'Received error when sending {action} to {uri}: {resp.status} {resp_data}')
if resp.content_type == MIMETYPES['activity']:
resp_data = await resp.json(loads=Message.new_from_json)
elif resp.content_type == MIMETYPES['json']:
resp_data = await resp.json(loads=DotDict.new_from_json)
else:
logging.verbose(f'Invalid Content-Type for "{url}": {resp.content_type}')
return logging.debug(f'Response: {resp_data}')
logging.debug(f'{uri} >> resp {resp_data}')
app.cache.json[uri] = resp_data
return resp_data
except JSONDecodeError:
logging.verbose(f'Failed to parse JSON')
return
except (ClientConnectorError, ServerTimeoutError):
logging.verbose(f'Failed to connect to {url.netloc}')
return
except Exception:
traceback.print_exc()
async def validate_signature(actor, signature, http_request):
headers = {key.lower(): value for key, value in http_request.headers.items()}
headers['(request-target)'] = ' '.join([http_request.method.lower(), http_request.path])
@ -559,11 +416,6 @@ class View(AiohttpView):
return self._request.app
@property
def cache(self):
return self.app.cache
@property
def config(self):
return self.app.config

View file

@ -1,63 +1,55 @@
import asyncio
import logging
from cachetools import LRUCache
from uuid import uuid4
from . import misc
from .misc import Message, distill_inboxes
cache = LRUCache(1024)
async def handle_relay(request):
if request.message.objectid in request.cache.objects:
if request.message.objectid in cache:
logging.verbose(f'already relayed {request.message.objectid}')
return
message = misc.Message.new_announce(
message = Message.new_announce(
host = request.config.host,
object = request.message.objectid
)
request.cache.objects[request.message.objectid] = message.id
logging.verbose(f'Relaying post from {request.message.actorid}')
cache[request.message.objectid] = message.id
logging.debug(f'>> relay: {message}')
inboxes = misc.distill_inboxes(request.actor, request.message.objectid)
inboxes = distill_inboxes(request.actor, request.message.objectid)
if request.config.workers > 0:
for inbox in inboxes:
request.app.push_message(inbox, message)
else:
futures = [misc.request(inbox, data=message) for inbox in inboxes]
asyncio.ensure_future(asyncio.gather(*futures))
async def handle_forward(request):
if request.message.id in request.cache.objects:
if request.message.id in cache:
logging.verbose(f'already forwarded {request.message.id}')
return
message = misc.Message.new_announce(
message = Message.new_announce(
host = request.config.host,
object = request.message
)
request.cache.objects[request.message.id] = message.id
logging.verbose(f'Forwarding post from {request.actor.id}')
logging.debug(f'>> Relay {request.message}')
cache[request.message.id] = message.id
logging.debug(f'>> forward: {message}')
inboxes = misc.distill_inboxes(request.actor, request.message.objectid)
inboxes = distill_inboxes(request.actor, request.message.objectid)
if request.config.workers > 0:
for inbox in inboxes:
request.app.push_message(inbox, message)
else:
futures = [misc.request(inbox, data=message) for inbox in inboxes]
asyncio.ensure_future(asyncio.gather(*futures))
async def handle_follow(request):
nodeinfo = await misc.fetch_nodeinfo(request.actor.domain)
nodeinfo = await request.app.client.fetch_nodeinfo(request.actor.domain)
software = nodeinfo.swname if nodeinfo else None
## reject if software used by actor is banned
@ -67,9 +59,9 @@ async def handle_follow(request):
request.database.add_inbox(request.actor.shared_inbox, request.message.id, software)
request.database.save()
await misc.request(
await request.app.push_message(
request.actor.shared_inbox,
misc.Message.new_response(
Message.new_response(
host = request.config.host,
actor = request.actor.id,
followid = request.message.id,
@ -80,9 +72,9 @@ async def handle_follow(request):
# Are Akkoma and Pleroma the only two that expect a follow back?
# Ignoring only Mastodon for now
if software != 'mastodon':
await misc.request(
await request.app.push_message(
request.actor.shared_inbox,
misc.Message.new_follow(
Message.new_follow(
host = request.config.host,
actor = request.actor.id
)
@ -99,13 +91,14 @@ async def handle_undo(request):
request.database.save()
message = misc.Message.new_unfollow(
await request.app.push_message(
request.actor.shared_inbox,
Message.new_unfollow(
host = request.config.host,
actor = request.actor.id,
follow = request.message
)
await misc.request(request.actor.shared_inbox, message)
)
processors = {
@ -123,7 +116,7 @@ async def run_processor(request):
return
if request.instance and not request.instance.get('software'):
nodeinfo = await misc.fetch_nodeinfo(request.instance['domain'])
nodeinfo = await request.app.client.fetch_nodeinfo(request.instance['domain'])
if nodeinfo:
request.instance['software'] = nodeinfo.swname

View file

@ -102,7 +102,7 @@ async def inbox(request):
logging.verbose('Failed to parse inbox message')
return Response.new_error(400, 'failed to parse message', 'json')
request['actor'] = await misc.request(request.signature.keyid)
request['actor'] = await request.app.client.get(request.signature.keyid, sign_headers=True)
## reject if actor is empty
if not request.actor: