diff --git a/relay/__init__.py b/relay/__init__.py
index 9489f87..5c940ed 100644
--- a/relay/__init__.py
+++ b/relay/__init__.py
@@ -1,8 +1,3 @@
__version__ = '0.2.2'
-from aiohttp.web import Application
-
from . import logger
-
-
-app = Application()
diff --git a/relay/application.py b/relay/application.py
new file mode 100644
index 0000000..9acddbf
--- /dev/null
+++ b/relay/application.py
@@ -0,0 +1,119 @@
+import asyncio
+import logging
+import os
+import signal
+
+from aiohttp import web
+from cachetools import LRUCache
+from datetime import datetime, timedelta
+
+from .config import RelayConfig
+from .database import RelayDatabase
+from .misc import DotDict, check_open_port, set_app
+from .views import routes
+
+
+class Application(web.Application):
+ def __init__(self, cfgpath):
+ web.Application.__init__(self)
+
+ self['starttime'] = None
+ self['running'] = False
+ self['is_docker'] = bool(os.environ.get('DOCKER_RUNNING'))
+ self['config'] = RelayConfig(cfgpath, self['is_docker'])
+
+ if not self['config'].load():
+ self['config'].save()
+
+ self['database'] = RelayDatabase(self['config'])
+ self['database'].load()
+
+ self['cache'] = DotDict({key: LRUCache(maxsize=self['config'][key]) for key in self['config'].cachekeys})
+ self['semaphore'] = asyncio.Semaphore(self['config'].push_limit)
+
+ self.set_signal_handler()
+ set_app(self)
+
+
+ @property
+ def cache(self):
+ return self['cache']
+
+
+ @property
+ def config(self):
+ return self['config']
+
+
+ @property
+ def database(self):
+ return self['database']
+
+
+ @property
+ def is_docker(self):
+ return self['is_docker']
+
+
+ @property
+ def semaphore(self):
+ return self['semaphore']
+
+
+ @property
+ def uptime(self):
+ if not self['starttime']:
+ return timedelta(seconds=0)
+
+ uptime = datetime.now() - self['starttime']
+
+ return timedelta(seconds=uptime.seconds)
+
+
+ def set_signal_handler(self):
+ signal.signal(signal.SIGHUP, self.stop)
+ signal.signal(signal.SIGINT, self.stop)
+ signal.signal(signal.SIGQUIT, self.stop)
+ signal.signal(signal.SIGTERM, self.stop)
+
+
+ def run(self):
+ if not check_open_port(self.config.listen, self.config.port):
+ return logging.error(f'A server is already running on port {self.config.port}')
+
+ for route in routes:
+ if route[1] == '/stats' and logging.DEBUG < logging.root.level:
+ continue
+
+ self.router.add_route(*route)
+
+ logging.info(f'Starting webserver at {self.config.host} ({self.config.listen}:{self.config.port})')
+ asyncio.run(self.handle_run())
+
+
+ def stop(self, *_):
+ self['running'] = False
+
+
+ async def handle_run(self):
+ self['running'] = True
+
+ runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"')
+ await runner.setup()
+
+ site = web.TCPSite(runner,
+ host = self.config.listen,
+ port = self.config.port,
+ reuse_address = True
+ )
+
+ await site.start()
+ self['starttime'] = datetime.now()
+
+ while self['running']:
+ await asyncio.sleep(0.25)
+
+ await site.stop()
+
+ self['starttime'] = None
+ self['running'] = False
diff --git a/relay/manage.py b/relay/manage.py
index 5dcad3a..838e2f9 100644
--- a/relay/manage.py
+++ b/relay/manage.py
@@ -1,17 +1,15 @@
import Crypto
import asyncio
import click
-import json
import logging
-import os
import platform
-from aiohttp.web import AppRunner, TCPSite
-from cachetools import LRUCache
+from . import misc, __version__
+from .application import Application
+from .config import relay_software_names
-from . import app, misc, views, __version__
-from .config import DotDict, RelayConfig, relay_software_names
-from .database import RelayDatabase
+
+app = None
@click.group('cli', context_settings={'show_default': True}, invoke_without_command=True)
@@ -19,23 +17,11 @@ from .database import RelayDatabase
@click.version_option(version=__version__, prog_name='ActivityRelay')
@click.pass_context
def cli(ctx, config):
- app['is_docker'] = bool(os.environ.get('DOCKER_RUNNING'))
- app['config'] = RelayConfig(config, app['is_docker'])
-
- if not app['config'].load():
- app['config'].save()
-
- app['database'] = RelayDatabase(app['config'])
- app['database'].load()
-
- app['cache'] = DotDict()
- app['semaphore'] = asyncio.Semaphore(app['config']['push_limit'])
-
- for key in app['config'].cachekeys:
- app['cache'][key] = LRUCache(app['config'][key])
+ global app
+ app = Application(config)
if not ctx.invoked_subcommand:
- if app['config'].host.endswith('example.com'):
+ if app.config.host.endswith('example.com'):
relay_setup.callback()
else:
@@ -55,7 +41,7 @@ def cli_inbox_list():
click.echo('Connected to the following instances or relays:')
- for inbox in app['database'].inboxes:
+ for inbox in app.database.inboxes:
click.echo(f'- {inbox}')
@@ -64,16 +50,13 @@ def cli_inbox_list():
def cli_inbox_follow(actor):
'Follow an actor (Relay must be running)'
- config = app['config']
- database = app['database']
-
- if config.is_banned(actor):
+ if app.config.is_banned(actor):
return click.echo(f'Error: Refusing to follow banned actor: {actor}')
if not actor.startswith('http'):
actor = f'https://{actor}/actor'
- if database.get_inbox(actor):
+ if app.database.get_inbox(actor):
return click.echo(f'Error: Already following actor: {actor}')
actor_data = asyncio.run(misc.request(actor, sign_headers=True))
@@ -81,8 +64,8 @@ def cli_inbox_follow(actor):
if not actor_data:
return click.echo(f'Error: Failed to fetch actor: {actor}')
- database.add_inbox(actor_data.shared_inbox)
- database.save()
+ app.database.add_inbox(actor_data.shared_inbox)
+ app.database.save()
asyncio.run(misc.follow_remote_actor(actor))
click.echo(f'Sent follow message to actor: {actor}')
@@ -93,13 +76,11 @@ def cli_inbox_follow(actor):
def cli_inbox_unfollow(actor):
'Unfollow an actor (Relay must be running)'
- database = app['database']
-
if not actor.startswith('http'):
actor = f'https://{actor}/actor'
- if database.del_inbox(actor):
- database.save()
+ if app.database.del_inbox(actor):
+ app.database.save()
asyncio.run(misc.unfollow_remote_actor(actor))
return click.echo(f'Sent unfollow message to: {actor}')
@@ -111,17 +92,14 @@ def cli_inbox_unfollow(actor):
def cli_inbox_add(inbox):
'Add an inbox to the database'
- database = app['database']
- config = app['config']
-
if not inbox.startswith('http'):
inbox = f'https://{inbox}/inbox'
- if config.is_banned(inbox):
+ if app.config.is_banned(inbox):
return click.echo(f'Error: Refusing to add banned inbox: {inbox}')
- if database.add_inbox(inbox):
- database.save()
+ if app.database.add_inbox(inbox):
+ app.database.save()
return click.echo(f'Added inbox to the database: {inbox}')
click.echo(f'Error: Inbox already in database: {inbox}')
@@ -132,17 +110,15 @@ def cli_inbox_add(inbox):
def cli_inbox_remove(inbox):
'Remove an inbox from the database'
- database = app['database']
-
try:
- dbinbox = database.get_inbox(inbox, fail=True)
+ dbinbox = app.database.get_inbox(inbox, fail=True)
except KeyError:
click.echo(f'Error: Inbox does not exist: {inbox}')
return
- database.del_inbox(dbinbox['domain'])
- database.save()
+ app.database.del_inbox(dbinbox['domain'])
+ app.database.save()
click.echo(f'Removed inbox from the database: {inbox}')
@@ -159,7 +135,7 @@ def cli_instance_list():
click.echo('Banned instances or relays:')
- for domain in app['config'].blocked_instances:
+ for domain in app.config.blocked_instances:
click.echo(f'- {domain}')
@@ -168,17 +144,14 @@ def cli_instance_list():
def cli_instance_ban(target):
'Ban an instance and remove the associated inbox if it exists'
- config = app['config']
- database = app['database']
-
if target.startswith('http'):
target = urlparse(target).hostname
- if config.ban_instance(target):
- config.save()
+ if app.config.ban_instance(target):
+ app.config.save()
- if database.del_inbox(target):
- database.save()
+ if app.database.del_inbox(target):
+ app.database.save()
click.echo(f'Banned instance: {target}')
return
@@ -191,10 +164,8 @@ def cli_instance_ban(target):
def cli_instance_unban(target):
'Unban an instance'
- config = app['config']
-
- if config.unban_instance(target):
- config.save()
+ if app.config.unban_instance(target):
+ app.config.save()
click.echo(f'Unbanned instance: {target}')
return
@@ -214,7 +185,7 @@ def cli_software_list():
click.echo('Banned software:')
- for software in app['config'].blocked_software:
+ for software in app.config.blocked_software:
click.echo(f'- {software}')
@@ -226,13 +197,11 @@ def cli_software_list():
def cli_software_ban(name, fetch_nodeinfo):
'Ban software. Use RELAYS for NAME to ban relays'
- config = app['config']
-
if name == 'RELAYS':
for name in relay_software_names:
- config.ban_software(name)
+ app.config.ban_software(name)
- config.save()
+ app.config.save()
return click.echo('Banned all relay software')
if fetch_nodeinfo:
@@ -244,7 +213,7 @@ def cli_software_ban(name, fetch_nodeinfo):
name = software
if config.ban_software(name):
- config.save()
+ app.config.save()
return click.echo(f'Banned software: {name}')
click.echo(f'Software already banned: {name}')
@@ -258,11 +227,9 @@ def cli_software_ban(name, fetch_nodeinfo):
def cli_software_unban(name, fetch_nodeinfo):
'Ban software. Use RELAYS for NAME to unban relays'
- config = app['config']
-
if name == 'RELAYS':
for name in relay_software_names:
- config.unban_software(name)
+ app.config.unban_software(name)
config.save()
return click.echo('Unbanned all relay software')
@@ -275,8 +242,8 @@ def cli_software_unban(name, fetch_nodeinfo):
name = software
- if config.unban_software(name):
- config.save()
+ if app.config.unban_software(name):
+ app.config.save()
return click.echo(f'Unbanned software: {name}')
click.echo(f'Software wasn\'t banned: {name}')
@@ -293,7 +260,7 @@ def cli_whitelist():
def cli_whitelist_list():
click.echo('Current whitelisted domains')
- for domain in app['config'].whitelist:
+ for domain in app.config.whitelist:
click.echo(f'- {domain}')
@@ -302,12 +269,10 @@ def cli_whitelist_list():
def cli_whitelist_add(instance):
'Add an instance to the whitelist'
- config = app['config']
-
- if not config.add_whitelist(instance):
+ if not app.config.add_whitelist(instance):
return click.echo(f'Instance already in the whitelist: {instance}')
- config.save()
+ app.config.save()
click.echo(f'Instance added to the whitelist: {instance}')
@@ -316,17 +281,14 @@ def cli_whitelist_add(instance):
def cli_whitelist_remove(instance):
'Remove an instance from the whitelist'
- config = app['config']
- database = app['database']
-
- if not config.del_whitelist(instance):
+ if not app.config.del_whitelist(instance):
return click.echo(f'Instance not in the whitelist: {instance}')
- config.save()
+ app.config.save()
- if config.whitelist_enabled:
- if database.del_inbox(inbox):
- database.save()
+ if app.config.whitelist_enabled:
+ if app.database.del_inbox(inbox):
+ app.database.save()
click.echo(f'Removed instance from the whitelist: {instance}')
@@ -335,23 +297,21 @@ def cli_whitelist_remove(instance):
def relay_setup():
'Generate a new config'
- config = app['config']
-
while True:
- config.host = click.prompt('What domain will the relay be hosted on?', default=config.host)
+ app.config.host = click.prompt('What domain will the relay be hosted on?', default=app.config.host)
if not config.host.endswith('example.com'):
break
click.echo('The domain must not be example.com')
- config.listen = click.prompt('Which address should the relay listen on?', default=config.listen)
+ app.config.listen = click.prompt('Which address should the relay listen on?', default=app.config.listen)
while True:
- config.port = click.prompt('What TCP port should the relay listen on?', default=config.port, type=int)
+ app.config.port = click.prompt('What TCP port should the relay listen on?', default=app.config.port, type=int)
break
- config.save()
+ app.config.save()
if not app['is_docker'] and click.confirm('Relay all setup! Would you like to run it now?'):
relay_run.callback()
@@ -361,9 +321,7 @@ def relay_setup():
def relay_run():
'Run the relay'
- config = app['config']
-
- if config.host.endswith('example.com'):
+ if app.config.host.endswith('example.com'):
return click.echo('Relay is not set up. Please edit your relay config or run "activityrelay setup".')
vers_split = platform.python_version().split('.')
@@ -378,38 +336,10 @@ def relay_run():
click.echo('Warning: PyCrypto is old and should be replaced with pycryptodome')
return click.echo(pip_command)
- if not misc.check_open_port(config.listen, config.port):
- return click.echo(f'Error: A server is already running on port {config.port}')
+ if not misc.check_open_port(app.config.listen, app.config.port):
+ return click.echo(f'Error: A server is already running on port {app.config.port}')
- # web pages
- app.router.add_get('/', views.home)
-
- # endpoints
- app.router.add_post('/actor', views.inbox)
- app.router.add_post('/inbox', views.inbox)
- app.router.add_get('/actor', views.actor)
- app.router.add_get('/nodeinfo/2.0.json', views.nodeinfo_2_0)
- app.router.add_get('/.well-known/nodeinfo', views.nodeinfo_wellknown)
- app.router.add_get('/.well-known/webfinger', views.webfinger)
-
- if logging.DEBUG >= logging.root.level:
- app.router.add_get('/stats', views.stats)
-
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- asyncio.ensure_future(handle_start_webserver(), loop=loop)
- loop.run_forever()
-
-
-async def handle_start_webserver():
- config = app['config']
- runner = AppRunner(app, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{Referer}i" "%{User-Agent}i"')
-
- logging.info(f'Starting webserver at {config.host} ({config.listen}:{config.port})')
- await runner.setup()
-
- site = TCPSite(runner, config.listen, config.port)
- await site.start()
+ app.run()
def main():
diff --git a/relay/misc.py b/relay/misc.py
index 8b58478..1fc109c 100644
--- a/relay/misc.py
+++ b/relay/misc.py
@@ -15,10 +15,10 @@ from json.decoder import JSONDecodeError
from urllib.parse import urlparse
from uuid import uuid4
-from . import app
from .http_debug import http_debug
+app = None
HASHES = {
'sha1': SHA,
'sha256': SHA256,
@@ -26,6 +26,11 @@ HASHES = {
}
+def set_app(new_app):
+ global app
+ app = new_app
+
+
def build_signing_string(headers, used_headers):
return '\n'.join(map(lambda x: ': '.join([x.lower(), headers[x]]), used_headers))
@@ -62,7 +67,7 @@ def distill_inboxes(actor, object_id):
database = app['database']
for inbox in database.inboxes:
- if inbox != actor.shared_inbox or urlparse(inbox).hostname != urlparse(object_id).hostname:
+ if inbox != actor.shared_inbox and urlparse(inbox).hostname != urlparse(object_id).hostname:
yield inbox
diff --git a/relay/processors.py b/relay/processors.py
index e38902d..b8a19aa 100644
--- a/relay/processors.py
+++ b/relay/processors.py
@@ -3,21 +3,18 @@ import logging
from uuid import uuid4
-from . import app, misc
+from . import misc
async def handle_relay(request, actor, data, software):
- cache = app['cache'].objects
- config = app['config']
-
- if data.objectid in cache:
- logging.verbose(f'already relayed {data.objectid} as {cache[data.objectid]}')
+ if data.objectid in request.app.cache.objects:
+ logging.verbose(f'already relayed {data.objectid}')
return
logging.verbose(f'Relaying post from {data.actorid}')
message = misc.Message.new_announce(
- host = config.host,
+ host = request.app.config.host,
object = data.objectid
)
@@ -27,19 +24,16 @@ async def handle_relay(request, actor, data, software):
futures = [misc.request(inbox, data=message) for inbox in inboxes]
asyncio.ensure_future(asyncio.gather(*futures))
- cache[data.objectid] = message.id
+ request.app.cache.objects[data.objectid] = message.id
async def handle_forward(request, actor, data, software):
- cache = app['cache'].objects
- config = app['config']
-
- if data.id in cache:
+ if data.id in request.app.cache.objects:
logging.verbose(f'already forwarded {data.id}')
return
message = misc.Message.new_announce(
- host = config.host,
+ host = request.app.config.host,
object = data
)
@@ -50,22 +44,19 @@ async def handle_forward(request, actor, data, software):
futures = [misc.request(inbox, data=message) for inbox in inboxes]
asyncio.ensure_future(asyncio.gather(*futures))
- cache[data.id] = message.id
+ request.app.cache.objects[data.id] = message.id
async def handle_follow(request, actor, data, software):
- config = app['config']
- database = app['database']
+ if request.app.database.add_inbox(inbox, data.id):
+ request.app.database.set_followid(actor.id, data.id)
- if database.add_inbox(inbox, data.id):
- database.set_followid(actor.id, data.id)
-
- database.save()
+ request.app.database.save()
await misc.request(
actor.shared_inbox,
misc.Message.new_response(
- host = config.host,
+ host = request.app.config.host,
actor = actor.id,
followid = data.id,
accept = True
@@ -78,7 +69,7 @@ async def handle_follow(request, actor, data, software):
misc.request(
actor.shared_inbox,
misc.Message.new_follow(
- host = config.host,
+ host = request.app.config.host,
actor = actor.id
)
)
@@ -89,15 +80,13 @@ async def handle_undo(request, actor, data, software):
if data['object']['type'] != 'Follow':
return await handle_forward(request, actor, data, software)
- database = app['database']
-
- if not database.del_inbox(actor.domain, data.id):
+ if not request.app.database.del_inbox(actor.domain, data.id):
return
- database.save()
+ request.app.database.save()
message = misc.Message.new_unfollow(
- host = config.host,
+ host = request.app.config.host,
actor = actor.id,
follow = data
)
diff --git a/relay/views.py b/relay/views.py
index b493378..676024a 100644
--- a/relay/views.py
+++ b/relay/views.py
@@ -2,14 +2,25 @@ import logging
import subprocess
import traceback
-from aiohttp.web import HTTPForbidden, HTTPUnauthorized, Response, json_response
+from aiohttp.web import HTTPForbidden, HTTPUnauthorized, Response, json_response, route
-from . import __version__, app, misc
+from . import __version__, misc
from .http_debug import STATS
from .misc import Message
from .processors import run_processor
+routes = []
+
+
+def register_route(method, path):
+ def wrapper(func):
+ routes.append([method, path, func])
+ return func
+
+ return wrapper
+
+
try:
commit_label = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode('ascii')
version = f'{__version__} {commit_label}'
@@ -18,9 +29,14 @@ except:
version = __version__
+@register_route('GET', '/')
async def home(request):
- targets = '
'.join(app['database'].hostnames)
- text = """
+ targets = '
'.join(request.app.database.hostnames)
+ note = request.app.config.note
+ count = len(request.app.database.hostnames)
+ host = request.app.config.host
+
+ text = f"""