modify workers

* move all worker-related classes and functions to workers.py
* change the log level in worker processes
* create QueueItem and PostItem classes
This commit is contained in:
Izalia Mae 2024-06-19 12:59:08 -04:00
parent c508257981
commit 9a3e3768e7
4 changed files with 157 additions and 13 deletions

View file

@ -22,7 +22,7 @@ from threading import Event, Thread
from typing import Any
from urllib.parse import urlparse
from . import logger as logging
from . import logger as logging, workers
from .cache import Cache, get_cache
from .config import Config
from .database import Connection, get_database
@ -78,7 +78,7 @@ class Application(web.Application):
self['cache'].setup()
self['template'] = Template(self)
self['push_queue'] = multiprocessing.Queue()
self['workers'] = []
self['workers'] = workers.PushWorkers(self.config.workers)
self.cache.setup()
self.on_cleanup.append(handle_cleanup) # type: ignore
@ -143,7 +143,7 @@ class Application(web.Application):
def push_message(self, inbox: str, message: Message, instance: Row) -> None:
self['push_queue'].put((inbox, message, instance))
self['workers'].push_message(inbox, message, instance)
def register_static_routes(self) -> None:
@ -198,12 +198,7 @@ class Application(web.Application):
self['cache'].setup()
self['cleanup_thread'] = CacheCleanupThread(self)
self['cleanup_thread'].start()
for _ in range(self.config.workers):
worker = PushWorker(self['push_queue'])
worker.start()
self['workers'].append(worker)
self['workers'].start()
runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"')
await runner.setup()
@ -223,15 +218,13 @@ class Application(web.Application):
await site.stop()
for worker in self['workers']:
worker.stop()
self['workers'].stop()
self.set_signal_handler(False)
self['starttime'] = None
self['running'] = False
self['cleanup_thread'].stop()
self['workers'].clear()
self['database'].disconnect()
self['cache'].close()

View file

@ -61,7 +61,7 @@ class Config:
def __init__(self, path: Path | None = None, load: bool = False):
self.path = Config.get_config_dir(path)
self.path: Path = Config.get_config_dir(path)
self.reset()
if load:

View file

@ -75,6 +75,7 @@ class Connection(SqlConnection):
elif key == 'log-level':
value = logging.LogLevel.parse(value)
logging.set_level(value)
self.app['workers'].set_log_level(value)
elif key in {'approval-required', 'whitelist-enabled'}:
value = boolean(value)

150
relay/workers.py Normal file
View file

@ -0,0 +1,150 @@
import asyncio
import traceback
import typing
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
from bsql import Row
from dataclasses import dataclass
from multiprocessing import Event, Process, Queue, Value
from multiprocessing.synchronize import Event as EventType
from pathlib import Path
from queue import Empty, Queue as QueueType
from urllib.parse import urlparse
from . import application, logger as logging
from .http_client import HttpClient
from .misc import IS_WINDOWS, Message, get_app
if typing.TYPE_CHECKING:
from .multiprocessing.synchronize import Syncronized
@dataclass
class QueueItem:
pass
@dataclass
class PostItem(QueueItem):
inbox: str
message: Message
instance: Row | None
@property
def domain(self) -> str:
return urlparse(self.inbox).netloc
class PushWorker(Process):
client: HttpClient
def __init__(self, queue: QueueType[QueueItem], log_level: "Syncronized[str]") -> None:
Process.__init__(self)
self.queue: QueueType[QueueItem] = queue
self.shutdown: EventType = Event()
self.path: Path = get_app().config.path
self.log_level: "Syncronized[str]" = log_level
self._log_level_changed: EventType = Event()
def stop(self) -> None:
self.shutdown.set()
def run(self) -> None:
asyncio.run(self.handle_queue())
async def handle_queue(self) -> None:
if IS_WINDOWS:
app = application.Application(self.path)
self.client = app.client
self.client.open()
app.database.connect()
app.cache.setup()
else:
self.client = HttpClient()
self.client.open()
logging.verbose("[%i] Starting worker", self.pid)
while not self.shutdown.is_set():
try:
if self._log_level_changed.is_set():
logging.set_level(logging.LogLevel.parse(self.log_level.value))
self._log_level_changed.clear()
item = self.queue.get(block=True, timeout=0.1)
if isinstance(item, PostItem):
asyncio.create_task(self.handle_post(item))
except Empty:
await asyncio.sleep(0)
except Exception:
traceback.print_exc()
if IS_WINDOWS:
app.database.disconnect()
app.cache.close()
await self.client.close()
async def handle_post(self, item: PostItem) -> None:
try:
await self.client.post(item.inbox, item.message, item.instance)
except AsyncTimeoutError:
logging.error('Timeout when pushing to %s', item.domain)
except ClientConnectionError as e:
logging.error('Failed to connect to %s for message push: %s', item.domain, str(e))
except ClientSSLError as e:
logging.error('SSL error when pushing to %s: %s', item.domain, str(e))
class PushWorkers(list[PushWorker]):
def __init__(self, count: int) -> None:
self.queue: QueueType[QueueItem] = Queue() # type: ignore[assignment]
self._log_level: "Syncronized[str]" = Value("i", logging.get_level())
self._count: int = count
def push_item(self, item: QueueItem) -> None:
self.queue.put(item)
def push_message(self, inbox: str, message: Message, instance: Row) -> None:
self.queue.put(PostItem(inbox, message, instance))
def set_log_level(self, value: logging.LogLevel) -> None:
self._log_level.value = value
for worker in self:
worker._log_level_changed.set()
def start(self) -> None:
if len(self) > 0:
return
for _ in range(self._count):
worker = PushWorker(self.queue, self._log_level)
worker.start()
self.append(worker)
def stop(self) -> None:
for worker in self:
worker.stop()
self.clear()