mirror of
https://git.pleroma.social/pleroma/relay.git
synced 2024-11-08 17:48:00 +00:00
modify workers
* move all worker-related classes and functions to workers.py * change the log level in worker processes * create QueueItem and PostItem classes
This commit is contained in:
parent
c508257981
commit
9a3e3768e7
|
@ -22,7 +22,7 @@ from threading import Event, Thread
|
|||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from . import logger as logging
|
||||
from . import logger as logging, workers
|
||||
from .cache import Cache, get_cache
|
||||
from .config import Config
|
||||
from .database import Connection, get_database
|
||||
|
@ -78,7 +78,7 @@ class Application(web.Application):
|
|||
self['cache'].setup()
|
||||
self['template'] = Template(self)
|
||||
self['push_queue'] = multiprocessing.Queue()
|
||||
self['workers'] = []
|
||||
self['workers'] = workers.PushWorkers(self.config.workers)
|
||||
|
||||
self.cache.setup()
|
||||
self.on_cleanup.append(handle_cleanup) # type: ignore
|
||||
|
@ -143,7 +143,7 @@ class Application(web.Application):
|
|||
|
||||
|
||||
def push_message(self, inbox: str, message: Message, instance: Row) -> None:
|
||||
self['push_queue'].put((inbox, message, instance))
|
||||
self['workers'].push_message(inbox, message, instance)
|
||||
|
||||
|
||||
def register_static_routes(self) -> None:
|
||||
|
@ -198,12 +198,7 @@ class Application(web.Application):
|
|||
self['cache'].setup()
|
||||
self['cleanup_thread'] = CacheCleanupThread(self)
|
||||
self['cleanup_thread'].start()
|
||||
|
||||
for _ in range(self.config.workers):
|
||||
worker = PushWorker(self['push_queue'])
|
||||
worker.start()
|
||||
|
||||
self['workers'].append(worker)
|
||||
self['workers'].start()
|
||||
|
||||
runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"')
|
||||
await runner.setup()
|
||||
|
@ -223,15 +218,13 @@ class Application(web.Application):
|
|||
|
||||
await site.stop()
|
||||
|
||||
for worker in self['workers']:
|
||||
worker.stop()
|
||||
self['workers'].stop()
|
||||
|
||||
self.set_signal_handler(False)
|
||||
|
||||
self['starttime'] = None
|
||||
self['running'] = False
|
||||
self['cleanup_thread'].stop()
|
||||
self['workers'].clear()
|
||||
self['database'].disconnect()
|
||||
self['cache'].close()
|
||||
|
||||
|
|
|
@ -61,7 +61,7 @@ class Config:
|
|||
|
||||
|
||||
def __init__(self, path: Path | None = None, load: bool = False):
|
||||
self.path = Config.get_config_dir(path)
|
||||
self.path: Path = Config.get_config_dir(path)
|
||||
self.reset()
|
||||
|
||||
if load:
|
||||
|
|
|
@ -75,6 +75,7 @@ class Connection(SqlConnection):
|
|||
elif key == 'log-level':
|
||||
value = logging.LogLevel.parse(value)
|
||||
logging.set_level(value)
|
||||
self.app['workers'].set_log_level(value)
|
||||
|
||||
elif key in {'approval-required', 'whitelist-enabled'}:
|
||||
value = boolean(value)
|
||||
|
|
150
relay/workers.py
Normal file
150
relay/workers.py
Normal file
|
@ -0,0 +1,150 @@
|
|||
import asyncio
|
||||
import traceback
|
||||
import typing
|
||||
|
||||
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
|
||||
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
|
||||
from bsql import Row
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import Event, Process, Queue, Value
|
||||
from multiprocessing.synchronize import Event as EventType
|
||||
from pathlib import Path
|
||||
from queue import Empty, Queue as QueueType
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from . import application, logger as logging
|
||||
from .http_client import HttpClient
|
||||
from .misc import IS_WINDOWS, Message, get_app
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from .multiprocessing.synchronize import Syncronized
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueueItem:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class PostItem(QueueItem):
|
||||
inbox: str
|
||||
message: Message
|
||||
instance: Row | None
|
||||
|
||||
@property
|
||||
def domain(self) -> str:
|
||||
return urlparse(self.inbox).netloc
|
||||
|
||||
|
||||
class PushWorker(Process):
|
||||
client: HttpClient
|
||||
|
||||
|
||||
def __init__(self, queue: QueueType[QueueItem], log_level: "Syncronized[str]") -> None:
|
||||
Process.__init__(self)
|
||||
|
||||
self.queue: QueueType[QueueItem] = queue
|
||||
self.shutdown: EventType = Event()
|
||||
self.path: Path = get_app().config.path
|
||||
self.log_level: "Syncronized[str]" = log_level
|
||||
self._log_level_changed: EventType = Event()
|
||||
|
||||
|
||||
def stop(self) -> None:
|
||||
self.shutdown.set()
|
||||
|
||||
|
||||
def run(self) -> None:
|
||||
asyncio.run(self.handle_queue())
|
||||
|
||||
|
||||
async def handle_queue(self) -> None:
|
||||
if IS_WINDOWS:
|
||||
app = application.Application(self.path)
|
||||
self.client = app.client
|
||||
|
||||
self.client.open()
|
||||
app.database.connect()
|
||||
app.cache.setup()
|
||||
|
||||
else:
|
||||
self.client = HttpClient()
|
||||
self.client.open()
|
||||
|
||||
logging.verbose("[%i] Starting worker", self.pid)
|
||||
|
||||
while not self.shutdown.is_set():
|
||||
try:
|
||||
if self._log_level_changed.is_set():
|
||||
logging.set_level(logging.LogLevel.parse(self.log_level.value))
|
||||
self._log_level_changed.clear()
|
||||
|
||||
item = self.queue.get(block=True, timeout=0.1)
|
||||
|
||||
if isinstance(item, PostItem):
|
||||
asyncio.create_task(self.handle_post(item))
|
||||
|
||||
except Empty:
|
||||
await asyncio.sleep(0)
|
||||
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
if IS_WINDOWS:
|
||||
app.database.disconnect()
|
||||
app.cache.close()
|
||||
|
||||
await self.client.close()
|
||||
|
||||
|
||||
async def handle_post(self, item: PostItem) -> None:
|
||||
try:
|
||||
await self.client.post(item.inbox, item.message, item.instance)
|
||||
|
||||
except AsyncTimeoutError:
|
||||
logging.error('Timeout when pushing to %s', item.domain)
|
||||
|
||||
except ClientConnectionError as e:
|
||||
logging.error('Failed to connect to %s for message push: %s', item.domain, str(e))
|
||||
|
||||
except ClientSSLError as e:
|
||||
logging.error('SSL error when pushing to %s: %s', item.domain, str(e))
|
||||
|
||||
|
||||
class PushWorkers(list[PushWorker]):
|
||||
def __init__(self, count: int) -> None:
|
||||
self.queue: QueueType[QueueItem] = Queue() # type: ignore[assignment]
|
||||
self._log_level: "Syncronized[str]" = Value("i", logging.get_level())
|
||||
self._count: int = count
|
||||
|
||||
|
||||
def push_item(self, item: QueueItem) -> None:
|
||||
self.queue.put(item)
|
||||
|
||||
|
||||
def push_message(self, inbox: str, message: Message, instance: Row) -> None:
|
||||
self.queue.put(PostItem(inbox, message, instance))
|
||||
|
||||
|
||||
def set_log_level(self, value: logging.LogLevel) -> None:
|
||||
self._log_level.value = value
|
||||
|
||||
for worker in self:
|
||||
worker._log_level_changed.set()
|
||||
|
||||
|
||||
def start(self) -> None:
|
||||
if len(self) > 0:
|
||||
return
|
||||
|
||||
for _ in range(self._count):
|
||||
worker = PushWorker(self.queue, self._log_level)
|
||||
worker.start()
|
||||
self.append(worker)
|
||||
|
||||
|
||||
def stop(self) -> None:
|
||||
for worker in self:
|
||||
worker.stop()
|
||||
|
||||
self.clear()
|
Loading…
Reference in a new issue