mirror of
https://git.pleroma.social/pleroma/relay.git
synced 2024-11-23 23:17:58 +00:00
Compare commits
1 commit
06f35541e5
...
ed5d45396f
Author | SHA1 | Date | |
---|---|---|---|
ed5d45396f |
|
@ -1,17 +1,15 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import queue
|
||||||
import signal
|
import signal
|
||||||
import subprocess
|
import threading
|
||||||
import sys
|
import traceback
|
||||||
import time
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from aputils.signer import Signer
|
from aputils.signer import Signer
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from gunicorn.app.wsgiapp import WSGIApplication
|
|
||||||
|
|
||||||
from . import logger as logging
|
from . import logger as logging
|
||||||
from .cache import get_cache
|
from .cache import get_cache
|
||||||
|
@ -22,7 +20,6 @@ from .misc import check_open_port
|
||||||
from .views import VIEWS
|
from .views import VIEWS
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if typing.TYPE_CHECKING:
|
||||||
from collections.abc import Awaitable
|
|
||||||
from tinysql import Database, Row
|
from tinysql import Database, Row
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from .cache import Cache
|
from .cache import Cache
|
||||||
|
@ -34,24 +31,21 @@ if typing.TYPE_CHECKING:
|
||||||
class Application(web.Application):
|
class Application(web.Application):
|
||||||
DEFAULT: Application = None
|
DEFAULT: Application = None
|
||||||
|
|
||||||
def __init__(self, cfgpath: str, gunicorn: bool = False):
|
def __init__(self, cfgpath: str):
|
||||||
web.Application.__init__(self)
|
web.Application.__init__(self)
|
||||||
|
|
||||||
Application.DEFAULT = self
|
Application.DEFAULT = self
|
||||||
|
|
||||||
self['proc'] = None
|
|
||||||
self['signer'] = None
|
self['signer'] = None
|
||||||
self['start_time'] = None
|
|
||||||
|
|
||||||
self['config'] = Config(cfgpath, load = True)
|
self['config'] = Config(cfgpath, load = True)
|
||||||
self['database'] = get_database(self.config)
|
self['database'] = get_database(self.config)
|
||||||
self['client'] = HttpClient()
|
self['client'] = HttpClient()
|
||||||
self['cache'] = get_cache(self)
|
self['cache'] = get_cache(self)
|
||||||
|
|
||||||
if not gunicorn:
|
self['workers'] = []
|
||||||
return
|
self['last_worker'] = 0
|
||||||
|
self['start_time'] = None
|
||||||
self.on_response_prepare.append(handle_access_log)
|
self['running'] = False
|
||||||
|
|
||||||
for path, view in VIEWS:
|
for path, view in VIEWS:
|
||||||
self.router.add_view(path, view)
|
self.router.add_view(path, view)
|
||||||
|
@ -102,16 +96,17 @@ class Application(web.Application):
|
||||||
|
|
||||||
|
|
||||||
def push_message(self, inbox: str, message: Message, instance: Row) -> None:
|
def push_message(self, inbox: str, message: Message, instance: Row) -> None:
|
||||||
|
if self.config.workers <= 0:
|
||||||
asyncio.ensure_future(self.client.post(inbox, message, instance))
|
asyncio.ensure_future(self.client.post(inbox, message, instance))
|
||||||
|
return
|
||||||
|
|
||||||
|
worker = self['workers'][self['last_worker']]
|
||||||
|
worker.queue.put((inbox, message, instance))
|
||||||
|
|
||||||
def run(self, dev: bool = False) -> None:
|
self['last_worker'] += 1
|
||||||
self.start(dev)
|
|
||||||
|
|
||||||
while self['proc'] and self['proc'].poll() is None:
|
if self['last_worker'] >= len(self['workers']):
|
||||||
time.sleep(0.1)
|
self['last_worker'] = 0
|
||||||
|
|
||||||
self.stop()
|
|
||||||
|
|
||||||
|
|
||||||
def set_signal_handler(self, startup: bool) -> None:
|
def set_signal_handler(self, startup: bool) -> None:
|
||||||
|
@ -124,101 +119,91 @@ class Application(web.Application):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
def start(self, dev: bool = False) -> None:
|
|
||||||
if self['proc']:
|
|
||||||
return
|
|
||||||
|
|
||||||
if not check_open_port(self.config.listen, self.config.port):
|
if not check_open_port(self.config.listen, self.config.port):
|
||||||
logging.error('Server already running on %s:%s', self.config.listen, self.config.port)
|
logging.error('A server is already running on port %i', self.config.port)
|
||||||
return
|
return
|
||||||
|
|
||||||
cmd = [
|
for view in VIEWS:
|
||||||
sys.executable, '-m', 'gunicorn',
|
self.router.add_view(*view)
|
||||||
'relay.application:main_gunicorn',
|
|
||||||
'--bind', f'{self.config.listen}:{self.config.port}',
|
|
||||||
'--worker-class', 'aiohttp.GunicornWebWorker',
|
|
||||||
'--workers', str(self.config.workers),
|
|
||||||
'--env', f'CONFIG_FILE={self.config.path}'
|
|
||||||
]
|
|
||||||
|
|
||||||
if dev:
|
|
||||||
cmd.append('--reload')
|
|
||||||
|
|
||||||
self.set_signal_handler(True)
|
|
||||||
self['proc'] = subprocess.Popen(cmd) # pylint: disable=consider-using-with
|
|
||||||
|
|
||||||
|
|
||||||
def stop(self, *_) -> None:
|
|
||||||
if not self['proc']:
|
|
||||||
return
|
|
||||||
|
|
||||||
self['proc'].terminate()
|
|
||||||
time_wait = 0.0
|
|
||||||
|
|
||||||
while self['proc'].poll() is None:
|
|
||||||
time.sleep(0.1)
|
|
||||||
time_wait += 0.1
|
|
||||||
|
|
||||||
if time_wait >= 5.0:
|
|
||||||
self['proc'].kill()
|
|
||||||
break
|
|
||||||
|
|
||||||
self.set_signal_handler(False)
|
|
||||||
self['proc'] = None
|
|
||||||
|
|
||||||
|
|
||||||
# not used, but keeping just in case
|
|
||||||
class GunicornRunner(WSGIApplication):
|
|
||||||
def __init__(self, app: Application):
|
|
||||||
self.app = app
|
|
||||||
self.app_uri = 'relay.application:main_gunicorn'
|
|
||||||
self.options = {
|
|
||||||
'bind': f'{app.config.listen}:{app.config.port}',
|
|
||||||
'worker_class': 'aiohttp.GunicornWebWorker',
|
|
||||||
'workers': app.config.workers,
|
|
||||||
'raw_env': f'CONFIG_FILE={app.config.path}'
|
|
||||||
}
|
|
||||||
|
|
||||||
WSGIApplication.__init__(self)
|
|
||||||
|
|
||||||
|
|
||||||
def load_config(self):
|
|
||||||
for key, value in self.options.items():
|
|
||||||
self.cfg.set(key, value)
|
|
||||||
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
logging.info('Starting webserver for %s', self.app.config.domain)
|
|
||||||
WSGIApplication.run(self)
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_access_log(request: web.Request, response: web.Response) -> None:
|
|
||||||
address = request.headers.get(
|
|
||||||
'X-Forwarded-For',
|
|
||||||
request.headers.get(
|
|
||||||
'X-Real-Ip',
|
|
||||||
request.remote
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
'%s "%s %s" %i %i "%s"',
|
'Starting webserver at %s (%s:%i)',
|
||||||
address,
|
self.config.domain,
|
||||||
request.method,
|
self.config.listen,
|
||||||
request.path,
|
self.config.port
|
||||||
response.status,
|
|
||||||
len(response.body),
|
|
||||||
request.headers.get('User-Agent', 'n/a')
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
asyncio.run(self.handle_run())
|
||||||
|
|
||||||
async def main_gunicorn():
|
|
||||||
|
def stop(self, *_: Any) -> None:
|
||||||
|
self['running'] = False
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_run(self) -> None:
|
||||||
|
self['running'] = True
|
||||||
|
|
||||||
|
self.set_signal_handler(True)
|
||||||
|
|
||||||
|
if self.config.workers > 0:
|
||||||
|
for _ in range(self.config.workers):
|
||||||
|
worker = PushWorker(self)
|
||||||
|
worker.start()
|
||||||
|
|
||||||
|
self['workers'].append(worker)
|
||||||
|
|
||||||
|
runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"')
|
||||||
|
await runner.setup()
|
||||||
|
|
||||||
|
site = web.TCPSite(
|
||||||
|
runner,
|
||||||
|
host = self.config.listen,
|
||||||
|
port = self.config.port,
|
||||||
|
reuse_address = True
|
||||||
|
)
|
||||||
|
|
||||||
|
await site.start()
|
||||||
|
self['start_time'] = datetime.now()
|
||||||
|
|
||||||
|
while self['running']:
|
||||||
|
await asyncio.sleep(0.25)
|
||||||
|
|
||||||
|
await site.stop()
|
||||||
|
await self.client.close()
|
||||||
|
|
||||||
|
self['start_time'] = None
|
||||||
|
self['running'] = False
|
||||||
|
self['workers'].clear()
|
||||||
|
|
||||||
|
|
||||||
|
class PushWorker(threading.Thread):
|
||||||
|
def __init__(self, app: Application):
|
||||||
|
threading.Thread.__init__(self)
|
||||||
|
self.app = app
|
||||||
|
self.queue = queue.Queue()
|
||||||
|
self.client = None
|
||||||
|
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
|
asyncio.run(self.handle_queue())
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_queue(self) -> None:
|
||||||
|
self.client = HttpClient()
|
||||||
|
|
||||||
|
while self.app['running']:
|
||||||
try:
|
try:
|
||||||
app = Application(os.environ['CONFIG_FILE'], gunicorn = True)
|
inbox, message, instance = self.queue.get(block=True, timeout=0.25)
|
||||||
|
self.queue.task_done()
|
||||||
|
logging.verbose('New push from Thread-%i', threading.get_ident())
|
||||||
|
await self.client.post(inbox, message, instance)
|
||||||
|
|
||||||
except KeyError:
|
except queue.Empty:
|
||||||
logging.error('Failed to set "CONFIG_FILE" environment. Trying to run without gunicorn?')
|
pass
|
||||||
raise
|
|
||||||
|
|
||||||
return app
|
## make sure an exception doesn't bring down the worker
|
||||||
|
except Exception:
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
await self.client.close()
|
||||||
|
|
|
@ -70,6 +70,7 @@ error: Callable = logging.error
|
||||||
critical: Callable = logging.critical
|
critical: Callable = logging.critical
|
||||||
|
|
||||||
|
|
||||||
|
logging.addLevelName(LogLevel['VERBOSE'], 'VERBOSE')
|
||||||
env_log_level = os.environ.get('LOG_LEVEL', 'INFO').upper()
|
env_log_level = os.environ.get('LOG_LEVEL', 'INFO').upper()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -78,15 +79,22 @@ try:
|
||||||
except KeyError:
|
except KeyError:
|
||||||
env_log_file = None
|
env_log_file = None
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
log_level = LogLevel[env_log_level]
|
||||||
|
|
||||||
|
except KeyError:
|
||||||
|
print('Invalid log level:', env_log_level)
|
||||||
|
log_level = LogLevel['INFO']
|
||||||
|
|
||||||
|
|
||||||
handlers = [logging.StreamHandler()]
|
handlers = [logging.StreamHandler()]
|
||||||
|
|
||||||
if env_log_file:
|
if env_log_file:
|
||||||
handlers.append(logging.FileHandler(env_log_file))
|
handlers.append(logging.FileHandler(env_log_file))
|
||||||
|
|
||||||
logging.addLevelName(LogLevel['VERBOSE'], 'VERBOSE')
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level = LogLevel.INFO,
|
level = log_level,
|
||||||
format = '[%(asctime)s] %(levelname)s: %(message)s',
|
format = '[%(asctime)s] %(levelname)s: %(message)s',
|
||||||
datefmt = '%Y-%m-%d %H:%M:%S',
|
|
||||||
handlers = handlers
|
handlers = handlers
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,7 +18,7 @@ from .application import Application
|
||||||
from .compat import RelayConfig, RelayDatabase
|
from .compat import RelayConfig, RelayDatabase
|
||||||
from .database import get_database
|
from .database import get_database
|
||||||
from .database.connection import RELAY_SOFTWARE
|
from .database.connection import RELAY_SOFTWARE
|
||||||
from .misc import IS_DOCKER, Message
|
from .misc import IS_DOCKER, Message, check_open_port
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if typing.TYPE_CHECKING:
|
||||||
from tinysql import Row
|
from tinysql import Row
|
||||||
|
@ -70,11 +70,6 @@ def cli(ctx: click.Context, config: str) -> None:
|
||||||
cli_setup.callback()
|
cli_setup.callback()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
click.echo(
|
|
||||||
'[DEPRECATED] Running the relay without the "run" command will be removed in the ' +
|
|
||||||
'future.'
|
|
||||||
)
|
|
||||||
|
|
||||||
cli_run.callback()
|
cli_run.callback()
|
||||||
|
|
||||||
|
|
||||||
|
@ -205,9 +200,8 @@ def cli_setup(ctx: click.Context) -> None:
|
||||||
|
|
||||||
|
|
||||||
@cli.command('run')
|
@cli.command('run')
|
||||||
@click.option('--dev', '-d', is_flag = True, help = 'Enable worker reloading on code change')
|
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def cli_run(ctx: click.Context, dev: bool = False) -> None:
|
def cli_run(ctx: click.Context) -> None:
|
||||||
'Run the relay'
|
'Run the relay'
|
||||||
|
|
||||||
if ctx.obj.config.domain.endswith('example.com') or not ctx.obj.signer:
|
if ctx.obj.config.domain.endswith('example.com') or not ctx.obj.signer:
|
||||||
|
@ -234,7 +228,11 @@ def cli_run(ctx: click.Context, dev: bool = False) -> None:
|
||||||
click.echo(pip_command)
|
click.echo(pip_command)
|
||||||
return
|
return
|
||||||
|
|
||||||
ctx.obj.run(dev)
|
if not check_open_port(ctx.obj.config.listen, ctx.obj.config.port):
|
||||||
|
click.echo(f'Error: A server is already running on port {ctx.obj.config.port}')
|
||||||
|
return
|
||||||
|
|
||||||
|
ctx.obj.run()
|
||||||
|
|
||||||
|
|
||||||
@cli.command('convert')
|
@cli.command('convert')
|
||||||
|
|
|
@ -14,9 +14,9 @@ from functools import cached_property
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if typing.TYPE_CHECKING:
|
||||||
from collections.abc import Awaitable, Coroutine, Generator
|
from collections.abc import Coroutine, Generator
|
||||||
from tinysql import Connection
|
from tinysql import Connection
|
||||||
from typing import Any
|
from typing import Any, Awaitable
|
||||||
from .application import Application
|
from .application import Application
|
||||||
from .cache import Cache
|
from .cache import Cache
|
||||||
from .config import Config
|
from .config import Config
|
||||||
|
@ -236,7 +236,7 @@ class Response(AiohttpResponse):
|
||||||
|
|
||||||
class View(AbstractView):
|
class View(AbstractView):
|
||||||
def __await__(self) -> Generator[Response]:
|
def __await__(self) -> Generator[Response]:
|
||||||
if self.request.method not in METHODS:
|
if (self.request.method) not in METHODS:
|
||||||
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
|
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
|
||||||
|
|
||||||
if not (handler := self.handlers.get(self.request.method)):
|
if not (handler := self.handlers.get(self.request.method)):
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
aiohttp>=3.9.1
|
aiohttp>=3.9.1
|
||||||
aputils@https://git.barkshark.xyz/barkshark/aputils/archive/0.1.6a.tar.gz
|
aputils@https://git.barkshark.xyz/barkshark/aputils/archive/0.1.6a.tar.gz
|
||||||
click>=8.1.2
|
click>=8.1.2
|
||||||
gunicorn==21.1.0
|
|
||||||
hiredis==2.3.2
|
hiredis==2.3.2
|
||||||
pyyaml>=6.0
|
pyyaml>=6.0
|
||||||
redis==5.0.1
|
redis==5.0.1
|
||||||
|
|
Loading…
Reference in a new issue