From 9a3e3768e75ac8e756bb7157a4429982bcf18087 Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Wed, 19 Jun 2024 12:59:08 -0400 Subject: [PATCH] modify workers * move all worker-related classes and functions to workers.py * change the log level in worker processes * create QueueItem and PostItem classes --- relay/application.py | 17 ++-- relay/config.py | 2 +- relay/database/connection.py | 1 + relay/workers.py | 150 +++++++++++++++++++++++++++++++++++ 4 files changed, 157 insertions(+), 13 deletions(-) create mode 100644 relay/workers.py diff --git a/relay/application.py b/relay/application.py index 6c8c1e7..a3d9925 100644 --- a/relay/application.py +++ b/relay/application.py @@ -22,7 +22,7 @@ from threading import Event, Thread from typing import Any from urllib.parse import urlparse -from . import logger as logging +from . import logger as logging, workers from .cache import Cache, get_cache from .config import Config from .database import Connection, get_database @@ -78,7 +78,7 @@ class Application(web.Application): self['cache'].setup() self['template'] = Template(self) self['push_queue'] = multiprocessing.Queue() - self['workers'] = [] + self['workers'] = workers.PushWorkers(self.config.workers) self.cache.setup() self.on_cleanup.append(handle_cleanup) # type: ignore @@ -143,7 +143,7 @@ class Application(web.Application): def push_message(self, inbox: str, message: Message, instance: Row) -> None: - self['push_queue'].put((inbox, message, instance)) + self['workers'].push_message(inbox, message, instance) def register_static_routes(self) -> None: @@ -198,12 +198,7 @@ class Application(web.Application): self['cache'].setup() self['cleanup_thread'] = CacheCleanupThread(self) self['cleanup_thread'].start() - - for _ in range(self.config.workers): - worker = PushWorker(self['push_queue']) - worker.start() - - self['workers'].append(worker) + self['workers'].start() runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"') await runner.setup() @@ -223,15 +218,13 @@ class Application(web.Application): await site.stop() - for worker in self['workers']: - worker.stop() + self['workers'].stop() self.set_signal_handler(False) self['starttime'] = None self['running'] = False self['cleanup_thread'].stop() - self['workers'].clear() self['database'].disconnect() self['cache'].close() diff --git a/relay/config.py b/relay/config.py index ac2bbb6..7e86ef7 100644 --- a/relay/config.py +++ b/relay/config.py @@ -61,7 +61,7 @@ class Config: def __init__(self, path: Path | None = None, load: bool = False): - self.path = Config.get_config_dir(path) + self.path: Path = Config.get_config_dir(path) self.reset() if load: diff --git a/relay/database/connection.py b/relay/database/connection.py index 614f307..864ad27 100644 --- a/relay/database/connection.py +++ b/relay/database/connection.py @@ -75,6 +75,7 @@ class Connection(SqlConnection): elif key == 'log-level': value = logging.LogLevel.parse(value) logging.set_level(value) + self.app['workers'].set_log_level(value) elif key in {'approval-required', 'whitelist-enabled'}: value = boolean(value) diff --git a/relay/workers.py b/relay/workers.py new file mode 100644 index 0000000..8d88ad7 --- /dev/null +++ b/relay/workers.py @@ -0,0 +1,150 @@ +import asyncio +import traceback +import typing + +from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError +from asyncio.exceptions import TimeoutError as AsyncTimeoutError +from bsql import Row +from dataclasses import dataclass +from multiprocessing import Event, Process, Queue, Value +from multiprocessing.synchronize import Event as EventType +from pathlib import Path +from queue import Empty, Queue as QueueType +from urllib.parse import urlparse + +from . import application, logger as logging +from .http_client import HttpClient +from .misc import IS_WINDOWS, Message, get_app + +if typing.TYPE_CHECKING: + from .multiprocessing.synchronize import Syncronized + + +@dataclass +class QueueItem: + pass + + +@dataclass +class PostItem(QueueItem): + inbox: str + message: Message + instance: Row | None + + @property + def domain(self) -> str: + return urlparse(self.inbox).netloc + + +class PushWorker(Process): + client: HttpClient + + + def __init__(self, queue: QueueType[QueueItem], log_level: "Syncronized[str]") -> None: + Process.__init__(self) + + self.queue: QueueType[QueueItem] = queue + self.shutdown: EventType = Event() + self.path: Path = get_app().config.path + self.log_level: "Syncronized[str]" = log_level + self._log_level_changed: EventType = Event() + + + def stop(self) -> None: + self.shutdown.set() + + + def run(self) -> None: + asyncio.run(self.handle_queue()) + + + async def handle_queue(self) -> None: + if IS_WINDOWS: + app = application.Application(self.path) + self.client = app.client + + self.client.open() + app.database.connect() + app.cache.setup() + + else: + self.client = HttpClient() + self.client.open() + + logging.verbose("[%i] Starting worker", self.pid) + + while not self.shutdown.is_set(): + try: + if self._log_level_changed.is_set(): + logging.set_level(logging.LogLevel.parse(self.log_level.value)) + self._log_level_changed.clear() + + item = self.queue.get(block=True, timeout=0.1) + + if isinstance(item, PostItem): + asyncio.create_task(self.handle_post(item)) + + except Empty: + await asyncio.sleep(0) + + except Exception: + traceback.print_exc() + + if IS_WINDOWS: + app.database.disconnect() + app.cache.close() + + await self.client.close() + + + async def handle_post(self, item: PostItem) -> None: + try: + await self.client.post(item.inbox, item.message, item.instance) + + except AsyncTimeoutError: + logging.error('Timeout when pushing to %s', item.domain) + + except ClientConnectionError as e: + logging.error('Failed to connect to %s for message push: %s', item.domain, str(e)) + + except ClientSSLError as e: + logging.error('SSL error when pushing to %s: %s', item.domain, str(e)) + + +class PushWorkers(list[PushWorker]): + def __init__(self, count: int) -> None: + self.queue: QueueType[QueueItem] = Queue() # type: ignore[assignment] + self._log_level: "Syncronized[str]" = Value("i", logging.get_level()) + self._count: int = count + + + def push_item(self, item: QueueItem) -> None: + self.queue.put(item) + + + def push_message(self, inbox: str, message: Message, instance: Row) -> None: + self.queue.put(PostItem(inbox, message, instance)) + + + def set_log_level(self, value: logging.LogLevel) -> None: + self._log_level.value = value + + for worker in self: + worker._log_level_changed.set() + + + def start(self) -> None: + if len(self) > 0: + return + + for _ in range(self._count): + worker = PushWorker(self.queue, self._log_level) + worker.start() + self.append(worker) + + + def stop(self) -> None: + for worker in self: + worker.stop() + + self.clear()