diff --git a/relay/application.py b/relay/application.py index 52e16c9..a216584 100644 --- a/relay/application.py +++ b/relay/application.py @@ -1,7 +1,9 @@ import asyncio import logging import os +import queue import signal +import threading from aiohttp import web from cachetools import LRUCache @@ -9,7 +11,7 @@ from datetime import datetime, timedelta from .config import RelayConfig from .database import RelayDatabase -from .misc import DotDict, check_open_port, fetch_nodeinfo, set_app +from .misc import DotDict, check_open_port, request, set_app from .views import routes @@ -27,6 +29,8 @@ class Application(web.Application): self['cache'] = DotDict({key: Cache(maxsize=self['config'][key]) for key in self['config'].cachekeys}) self['semaphore'] = asyncio.Semaphore(self['config'].push_limit) + self['workers'] = [] + self['last_worker'] = 0 set_app(self) @@ -71,6 +75,16 @@ class Application(web.Application): return timedelta(seconds=uptime.seconds) + def push_message(self, inbox, message): + worker = self['workers'][self['last_worker']] + worker.queue.put((inbox, message)) + + self['last_worker'] += 1 + + if self['last_worker'] >= len(self['workers']): + self['last_worker'] = 0 + + def set_signal_handler(self): for sig in {'SIGHUP', 'SIGINT', 'SIGQUIT', 'SIGTERM'}: try: @@ -102,6 +116,13 @@ class Application(web.Application): async def handle_run(self): self['running'] = True + if self.config.workers > 0: + for i in range(self.config.workers): + worker = PushWorker(self) + worker.start() + + self['workers'].append(worker) + runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"') await runner.setup() @@ -121,6 +142,7 @@ class Application(web.Application): self['starttime'] = None self['running'] = False + self['workers'].clear() class Cache(LRUCache): @@ -128,6 +150,30 @@ class Cache(LRUCache): self.__maxsize = int(value) +class PushWorker(threading.Thread): + def __init__(self, app): + threading.Thread.__init__(self) + self.app = app + self.queue = queue.Queue() + + + def run(self): + asyncio.run(self.handle_queue()) + + + async def handle_queue(self): + while self.app['running']: + try: + inbox, message = self.queue.get(block=True, timeout=0.25) + self.queue.task_done() + await request(inbox, message) + + logging.verbose(f'New push from Thread-{threading.get_ident()}') + + except queue.Empty: + pass + + ## Can't sub-class web.Request, so let's just add some properties def request_actor(self): try: return self['actor'] diff --git a/relay/config.py b/relay/config.py index fd22f22..998e5d6 100644 --- a/relay/config.py +++ b/relay/config.py @@ -50,7 +50,7 @@ class RelayConfig(DotDict): if key in ['blocked_instances', 'blocked_software', 'whitelist']: assert isinstance(value, (list, set, tuple)) - elif key in ['port', 'json', 'objects', 'digests']: + elif key in ['port', 'workers', 'json', 'objects', 'digests']: assert isinstance(value, (int)) elif key == 'whitelist_enabled': @@ -92,6 +92,7 @@ class RelayConfig(DotDict): 'port': 8080, 'note': 'Make a note about your instance here.', 'push_limit': 512, + 'workers': 0, 'host': 'relay.example.com', 'blocked_software': [], 'blocked_instances': [], @@ -233,6 +234,7 @@ class RelayConfig(DotDict): 'port': self.port, 'note': self.note, 'push_limit': self.push_limit, + 'workers': self.workers, 'ap': {key: self[key] for key in self.apkeys}, 'cache': {key: self[key] for key in self.cachekeys} } diff --git a/relay/processors.py b/relay/processors.py index 92df00d..46d23e4 100644 --- a/relay/processors.py +++ b/relay/processors.py @@ -11,20 +11,24 @@ async def handle_relay(request): logging.verbose(f'already relayed {request.message.objectid}') return - logging.verbose(f'Relaying post from {request.message.actorid}') - message = misc.Message.new_announce( host = request.config.host, object = request.message.objectid ) + request.cache.objects[request.message.objectid] = message.id + logging.verbose(f'Relaying post from {request.message.actorid}') logging.debug(f'>> relay: {message}') inboxes = misc.distill_inboxes(request.actor, request.message.objectid) - futures = [misc.request(inbox, data=message) for inbox in inboxes] - asyncio.ensure_future(asyncio.gather(*futures)) - request.cache.objects[request.message.objectid] = message.id + if request.config.workers > 0: + for inbox in inboxes: + request.app.push_message(inbox, message) + + else: + futures = [misc.request(inbox, data=message) for inbox in inboxes] + asyncio.ensure_future(asyncio.gather(*futures)) async def handle_forward(request): @@ -37,14 +41,19 @@ async def handle_forward(request): object = request.message ) + request.cache.objects[request.message.id] = message.id logging.verbose(f'Forwarding post from {request.actor.id}') logging.debug(f'>> Relay {request.message}') - inboxes = misc.distill_inboxes(request.actor, request.message.id) - futures = [misc.request(inbox, data=message) for inbox in inboxes] + inboxes = misc.distill_inboxes(request.actor, request.message.objectid) - asyncio.ensure_future(asyncio.gather(*futures)) - request.cache.objects[request.message.id] = message.id + if request.config.workers > 0: + for inbox in inboxes: + request.app.push_message(inbox, message) + + else: + futures = [misc.request(inbox, data=message) for inbox in inboxes] + asyncio.ensure_future(asyncio.gather(*futures)) async def handle_follow(request): diff --git a/relay/views.py b/relay/views.py index c360ca2..9727207 100644 --- a/relay/views.py +++ b/relay/views.py @@ -1,3 +1,4 @@ +import asyncio import logging import subprocess import traceback @@ -137,7 +138,7 @@ async def inbox(request): logging.debug(f">> payload {request.message.to_json(4)}") - await run_processor(request) + asyncio.ensure_future(run_processor(request)) return Response.new(status=202)