From c41cd6e015e789ea81079511820cf64f2ee6dfd0 Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Tue, 13 Dec 2022 08:27:09 -0500 Subject: [PATCH] first draft --- relay.yaml.example | 66 +++--- relay/application.py | 111 ++++++---- relay/config.py | 297 +++++++++++--------------- relay/database.py | 483 +++++++++++++++++++++++++++---------------- relay/http_client.py | 63 ++++-- relay/manage.py | 449 +++++++++++++++++++++++++++------------- relay/misc.py | 59 +++--- relay/processors.py | 102 +++++---- relay/views.py | 59 +++--- requirements.txt | 2 + 10 files changed, 1007 insertions(+), 684 deletions(-) diff --git a/relay.yaml.example b/relay.yaml.example index 4e35697..7607cab 100644 --- a/relay.yaml.example +++ b/relay.yaml.example @@ -1,43 +1,29 @@ -# this is the path that the object graph will get dumped to (in JSON-LD format), -# you probably shouldn't change it, but you can if you want. -db: relay.jsonld +general: + # Address the relay will listen on. Set to "0.0.0.0" for any address + listen: 0.0.0.0 + # Port the relay will listen on + port: 3621 + # Domain the relay will advertise as + host: relay.example.com -# Listener -listen: 0.0.0.0 -port: 8080 +database: + # SQL backend to use. Available options: "sqlite", "postgresql", "mysql". + type: sqlite + # Minimum number of database connections to keep open + min_connections: 0 + # Maximum number of database connections to open + max_connections: 10 -# Note -note: "Make a note about your instance here." +postgres: + database: activityrelay + hostname: null + port: null + username: null + password: null -# Number of worker threads to start. If 0, use asyncio futures instead of threads. -workers: 0 - -# Maximum number of inbox posts to do at once -# If workers is set to 1 or above, this is the max for each worker -push_limit: 512 - -# The amount of json objects to cache from GET requests -json_cache: 1024 - -ap: - # This is used for generating activitypub messages, as well as instructions for - # linking AP identities. It should be an SSL-enabled domain reachable by https. - host: 'relay.example.com' - - blocked_instances: - - 'bad-instance.example.com' - - 'another-bad-instance.example.com' - - whitelist_enabled: false - - whitelist: - - 'good-instance.example.com' - - 'another.good-instance.example.com' - - # uncomment the lines below to prevent certain activitypub software from posting - # to the relay (all known relays by default). this uses the software name in nodeinfo - #blocked_software: - #- 'activityrelay' - #- 'aoderelay' - #- 'social.seattle.wa.us-relay' - #- 'unciarelay' +mysql: + database: activityrelay + hostname: null + port: null + username: null + password: null diff --git a/relay/application.py b/relay/application.py index dbe464f..5395dca 100644 --- a/relay/application.py +++ b/relay/application.py @@ -1,4 +1,5 @@ import asyncio +import inspect import logging import os import queue @@ -7,10 +8,11 @@ import threading import traceback from aiohttp import web +from aputils import Signer from datetime import datetime, timedelta -from .config import RelayConfig -from .database import RelayDatabase +from .config import Config +from .database import Database from .http_client import HttpClient from .misc import DotDict, check_open_port, set_app from .views import routes @@ -18,37 +20,25 @@ from .views import routes class Application(web.Application): def __init__(self, cfgpath): - web.Application.__init__(self) - - self['starttime'] = None - self['running'] = False - self['config'] = RelayConfig(cfgpath) - - if not self['config'].load(): - self['config'].save() - - if self.config.is_docker: - self.config.update({ - 'db': '/data/relay.jsonld', - 'listen': '0.0.0.0', - 'port': 8080 - }) - - self['workers'] = [] - self['last_worker'] = 0 + web.Application.__init__(self, + middlewares = [ + server_middleware + ] + ) set_app(self) - self['database'] = RelayDatabase(self['config']) - self['database'].load() + self['config'] = Config(cfgpath) + self['database'] = Database(**self.config.dbconfig) + self['client'] = HttpClient() - self['client'] = HttpClient( - database = self.database, - limit = self.config.push_limit, - timeout = self.config.timeout, - cache_size = self.config.json_cache - ) + self['starttime'] = None + self['signer'] = None + self['running'] = False + self['workers'] = [] + self['last_worker'] = 0 + self.database.create() self.set_signal_handler() @@ -67,18 +57,32 @@ class Application(web.Application): return self['database'] + @property + def signer(self): + if not self['signer']: + with self.database.session as s: + privkey = s.get_config('privkey') + + if not privkey: + self['signer'] = Signer.new(self.config.keyid) + s.put_config('privkey', self['signer'].export()) + + else: + self['signer'] = Signer(privkey, self.config.keyid) + + return self['signer'] + + @property def uptime(self): if not self['starttime']: return timedelta(seconds=0) - uptime = datetime.now() - self['starttime'] - - return timedelta(seconds=uptime.seconds) + return datetime.now() - self['starttime'] def push_message(self, inbox, message): - if self.config.workers <= 0: + if len(self['workers']) <= 0: return asyncio.ensure_future(self.client.post(inbox, message)) worker = self['workers'][self['last_worker']] @@ -115,15 +119,22 @@ class Application(web.Application): self['running'] = False + def setup(self): + self.client.setup() + + async def handle_run(self): self['running'] = True - if self.config.workers > 0: - for i in range(self.config.workers): - worker = PushWorker(self) - worker.start() + with self.database.session as s: + workers = s.get_config('workers') - self['workers'].append(worker) + if workers > 0: + for i in range(workers): + worker = PushWorker(self) + worker.start() + + self['workers'].append(worker) runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"') await runner.setup() @@ -155,12 +166,8 @@ class PushWorker(threading.Thread): def run(self): - self.client = HttpClient( - database = self.app.database, - limit = self.app.config.push_limit, - timeout = self.app.config.timeout, - cache_size = self.app.config.json_cache - ) + self.client = HttpClient() + self.client.setup() asyncio.run(self.handle_queue()) @@ -183,6 +190,24 @@ class PushWorker(threading.Thread): await self.client.close() +@web.middleware +async def server_middleware(request, handler): + if len(inspect.signature(handler).parameters) == 1: + response = await handler(request) + + else: + with request.database.session as s: + response = await handler(request, s) + + ## make sure there's some sort of response + if response == None: + logging.error(f'No response for handler: {handler}') + response = Response.new_error(500, 'No response') + + response.headers['Server'] = 'ActivityRelay' + return response + + ## Can't sub-class web.Request, so let's just add some properties def request_actor(self): try: return self['actor'] diff --git a/relay/config.py b/relay/config.py index 996fa9f..400e8e2 100644 --- a/relay/config.py +++ b/relay/config.py @@ -1,58 +1,113 @@ -import json +import appdirs import os +import sys import yaml from functools import cached_property from pathlib import Path -from urllib.parse import urlparse -from .misc import DotDict, boolean +from .misc import AppBase, DotDict -RELAY_SOFTWARE = [ - 'activityrelay', # https://git.pleroma.social/pleroma/relay - 'aoderelay', # https://git.asonix.dog/asonix/relay - 'feditools-relay' # https://git.ptzo.gdn/feditools/relay +DEFAULTS = { + 'general_listen': '0.0.0.0', + 'general_port': 8080, + 'general_host': 'relay.example.com', + 'database_type': 'sqlite', + 'database_min_connections': 0, + 'database_max_connections': 10, + 'postgres_database': 'activityrelay', + 'postgres_hostname': None, + 'postgres_port': None, + 'postgres_username': None, + 'postgres_password': None, + 'mysql_database': 'activityrelay', + 'mysql_hostname': None, + 'mysql_port': None, + 'mysql_username': None, + 'mysql_password': None +} + +CATEGORY_NAMES = [ + 'general', + 'database', + 'postgres', + 'mysql' ] -APKEYS = [ - 'host', - 'whitelist_enabled', - 'blocked_software', - 'blocked_instances', - 'whitelist' +CONFIG_DIRS = [ + Path.cwd(), + Path(appdirs.user_config_dir('activityrelay')) ] -class RelayConfig(DotDict): - def __init__(self, path): - DotDict.__init__(self, {}) +def get_config_dir(): + for path in CONFIG_DIRS: + cfgpath = path.joinpath('config.yaml') + + if cfgpath.exists(): + return cfgpath + + if sys.platform == 'linux': + etcpath = Path('/etc/activityrelay/config.yaml') + + if etcpath.exists(): + return cfgpath + + return Path.cwd().joinpath('config.yaml') + + +class Config(AppBase, dict): + def __init__(self, path=None): + DotDict.__init__(self, DEFAULTS) if self.is_docker: - path = '/data/config.yaml' + path = Path('/data/config.yaml') - self._path = Path(path).expanduser() - self.reset() + elif not path: + path = get_config_dir() + + else: + path = Path(path).expanduser().resolve() + + self._path = path + self.load() def __setitem__(self, key, value): - if key in ['blocked_instances', 'blocked_software', 'whitelist']: - assert isinstance(value, (list, set, tuple)) + if (self.is_docker and key in {'general_host', 'general_port'}) or value == '__DEFAULT__': + value = DEFAULTS[key] - elif key in ['port', 'workers', 'json_cache', 'timeout']: - if not isinstance(value, int): - value = int(value) + elif key in {'general_port', 'database_min_connections', 'database_max_connections'}: + value = int(value) - elif key == 'whitelist_enabled': - if not isinstance(value, bool): - value = boolean(value) - - super().__setitem__(key, value) + dict.__setitem__(self, key, value) @property - def db(self): - return Path(self['db']).expanduser().resolve() + def dbconfig(self): + config = { + 'type': self['database_type'], + 'min_conn': self['database_min_connections'], + 'max_conn': self['database_max_connections'] + } + + if self.dbtype == 'sqlite': + config['database'] = self.path.with_name('relay.sqlite3') + + else: + for key, value in self.items(): + cat, name = key.split('_', 1) + + if self.dbtype == cat: + config[name] = value + + return config + + + @cached_property + def is_docker(self): + return bool(os.getenv('DOCKER_RUNNING')) @property @@ -60,6 +115,29 @@ class RelayConfig(DotDict): return self._path + ## General config + @property + def host(self): + return self['general_host'] + + + @property + def listen(self): + return self['general_listen'] + + + @property + def port(self): + return self['general_port'] + + + ## Database config + @property + def dbtype(self): + return self['database_type'] + + + ## AP URLs @property def actor(self): return f'https://{self.host}/actor' @@ -75,117 +153,12 @@ class RelayConfig(DotDict): return f'{self.actor}#main-key' - @cached_property - def is_docker(self): - return bool(os.environ.get('DOCKER_RUNNING')) - - def reset(self): self.clear() - self.update({ - 'db': str(self._path.parent.joinpath(f'{self._path.stem}.jsonld')), - 'listen': '0.0.0.0', - 'port': 8080, - 'note': 'Make a note about your instance here.', - 'push_limit': 512, - 'json_cache': 1024, - 'timeout': 10, - 'workers': 0, - 'host': 'relay.example.com', - 'whitelist_enabled': False, - 'blocked_software': [], - 'blocked_instances': [], - 'whitelist': [] - }) - - - def ban_instance(self, instance): - if instance.startswith('http'): - instance = urlparse(instance).hostname - - if self.is_banned(instance): - return False - - self.blocked_instances.append(instance) - return True - - - def unban_instance(self, instance): - if instance.startswith('http'): - instance = urlparse(instance).hostname - - try: - self.blocked_instances.remove(instance) - return True - - except: - return False - - - def ban_software(self, software): - if self.is_banned_software(software): - return False - - self.blocked_software.append(software) - return True - - - def unban_software(self, software): - try: - self.blocked_software.remove(software) - return True - - except: - return False - - - def add_whitelist(self, instance): - if instance.startswith('http'): - instance = urlparse(instance).hostname - - if self.is_whitelisted(instance): - return False - - self.whitelist.append(instance) - return True - - - def del_whitelist(self, instance): - if instance.startswith('http'): - instance = urlparse(instance).hostname - - try: - self.whitelist.remove(instance) - return True - - except: - return False - - - def is_banned(self, instance): - if instance.startswith('http'): - instance = urlparse(instance).hostname - - return instance in self.blocked_instances - - - def is_banned_software(self, software): - if not software: - return False - - return software.lower() in self.blocked_software - - - def is_whitelisted(self, instance): - if instance.startswith('http'): - instance = urlparse(instance).hostname - - return instance in self.whitelist + self.update(DEFAULTS) def load(self): - self.reset() - options = {} try: @@ -201,45 +174,21 @@ class RelayConfig(DotDict): except FileNotFoundError: return False - if not config: - return False - - for key, value in config.items(): - if key in ['ap']: - for k, v in value.items(): - if k not in self: - continue - - self[k] = v - - continue - - elif key not in self: - continue - - self[key] = value - - if self.host.endswith('example.com'): - return False - - return True + for key, value in DEFAULTS.items(): + cat, name = key.split('_', 1) + self[key] = config.get(cat, {}).get(name, DEFAULTS[key]) def save(self): - config = { - # just turning config.db into a string is good enough for now - 'db': str(self.db), - 'listen': self.listen, - 'port': self.port, - 'note': self.note, - 'push_limit': self.push_limit, - 'workers': self.workers, - 'json_cache': self.json_cache, - 'timeout': self.timeout, - 'ap': {key: self[key] for key in APKEYS} - } + config = {key: {} for key in CATEGORY_NAMES} - with open(self._path, 'w') as fd: + for key, value in self.items(): + cat, name = key.split('_', 1) + + if isinstance(value, Path): + value = str(value) + + config[cat][name] = value + + with open(self.path, 'w') as fd: yaml.dump(config, fd, sort_keys=False) - - return config diff --git a/relay/database.py b/relay/database.py index ad093cd..7bfdd93 100644 --- a/relay/database.py +++ b/relay/database.py @@ -1,189 +1,102 @@ -import aputils -import asyncio -import json -import logging -import traceback +import tinysql +from datetime import datetime +from tinysql import Column, Table from urllib.parse import urlparse - -class RelayDatabase(dict): - def __init__(self, config): - dict.__init__(self, { - 'relay-list': {}, - 'private-key': None, - 'follow-requests': {}, - 'version': 1 - }) - - self.config = config - self.signer = None +from .misc import AppBase, DotDict, boolean - @property - def hostnames(self): - return tuple(self['relay-list'].keys()) +TABLES = [ + Table('config', + Column('key', 'text', unique=True, nullable=False, primary_key=True), + Column('value', 'text') + ), + Table('instances', + Column('id', 'serial'), + Column('domain', 'text', unique=True, nullable=False), + Column('actor', 'text'), + Column('inbox', 'text', nullable=False), + Column('followid', 'text'), + Column('software', 'text'), + Column('actor_data', 'json'), + Column('note', 'text'), + Column('joined', 'datetime', nullable=False), + Column('updated', 'datetime') + ), + Table('whitelist', + Column('id', 'serial'), + Column('domain', 'text', unique=True), + Column('created', 'datetime', nullable=False) + ), + Table('bans', + Column('id', 'serial'), + Column('name', 'text', unique=True), + Column('note', 'text'), + Column('type', 'text', nullable=False), + Column('created', 'datetime', nullable=False) + ), + Table('users', + Column('id', 'serial'), + Column('handle', 'text', unique=True, nullable=False), + Column('domain', 'text', nullable=False), + Column('api_token', 'text'), + Column('created', 'datetime', nullable=False), + Column('updated', 'datetime') + ), + Table('tokens', + Column('id', 'text', unique=True, nullable=False, primary_key=True), + Column('userid', 'integer', nullable=False), + Column('created', 'datetime', nullable=False), + Column('updated', 'datetime') + ) +] + +DEFAULT_CONFIG = { + 'description': ('str', 'Make a note about your relay here'), + 'http_timeout': ('int', 10), + 'json_cache': ('int', 1024), + 'log_level': ('str', 'INFO'), + 'name': ('str', 'ActivityRelay'), + 'privkey': ('str', ''), + 'push_limit': ('int', 512), + 'require_approval': ('bool', False), + 'version': ('int', 20221211), + 'whitelist': ('bool', False), + 'workers': ('int', 8) +} + +RELAY_SOFTWARE = [ + 'activityrelay', # https://git.pleroma.social/pleroma/relay + 'aoderelay', # https://git.asonix.dog/asonix/relay + 'feditools-relay' # https://git.ptzo.gdn/feditools/relay +] - @property - def inboxes(self): - return tuple(data['inbox'] for data in self['relay-list'].values()) +class Database(AppBase, tinysql.Database): + def __init__(self, **config): + tinysql.Database.__init__(self, **config, + connection_class = Connection, + row_classes = [ + ConfigRow + ] + ) - def load(self): - new_db = True - - try: - with self.config.db.open() as fd: - data = json.load(fd) - - self['version'] = data.get('version', None) - self['private-key'] = data.get('private-key') - - if self['version'] == None: - self['version'] = 1 - - if 'actorKeys' in data: - self['private-key'] = data['actorKeys']['privateKey'] - - for item in data.get('relay-list', []): - domain = urlparse(item).hostname - self['relay-list'][domain] = { - 'domain': domain, - 'inbox': item, - 'followid': None - } - - else: - self['relay-list'] = data.get('relay-list', {}) - - for domain, instance in self['relay-list'].items(): - if self.config.is_banned(domain) or (self.config.whitelist_enabled and not self.config.is_whitelisted(domain)): - self.del_inbox(domain) - continue - - if not instance.get('domain'): - instance['domain'] = domain - - new_db = False - - except FileNotFoundError: - pass - - except json.decoder.JSONDecodeError as e: - if self.config.db.stat().st_size > 0: - raise e from None - - if not self['private-key']: - logging.info("No actor keys present, generating 4096-bit RSA keypair.") - self.signer = aputils.Signer.new(self.config.keyid, size=4096) - self['private-key'] = self.signer.export() - - else: - self.signer = aputils.Signer(self['private-key'], self.config.keyid) - - self.save() - return not new_db + def create(self): + self.create_database(TABLES) - def save(self): - with self.config.db.open('w') as fd: - json.dump(self, fd, indent=4) +class Connection(tinysql.ConnectionMixin): + ## Misc methods + def accept_request(self, domain): + row = self.get_request(domain) - - def get_inbox(self, domain, fail=False): - if domain.startswith('http'): - domain = urlparse(domain).hostname - - inbox = self['relay-list'].get(domain) - - if inbox: - return inbox - - if fail: + if not row: raise KeyError(domain) - - def add_inbox(self, inbox, followid=None, software=None): - assert inbox.startswith('https'), 'Inbox must be a url' - domain = urlparse(inbox).hostname - instance = self.get_inbox(domain) - - if instance: - if followid: - instance['followid'] = followid - - if software: - instance['software'] = software - - return instance - - self['relay-list'][domain] = { - 'domain': domain, - 'inbox': inbox, - 'followid': followid, - 'software': software - } - - logging.verbose(f'Added inbox to database: {inbox}') - return self['relay-list'][domain] - - - def del_inbox(self, domain, followid=None, fail=False): - data = self.get_inbox(domain, fail=False) - - if not data: - if fail: - raise KeyError(domain) - - return False - - if not data['followid'] or not followid or data['followid'] == followid: - del self['relay-list'][data['domain']] - logging.verbose(f'Removed inbox from database: {data["inbox"]}') - return True - - if fail: - raise ValueError('Follow IDs do not match') - - logging.debug(f'Follow ID does not match: db = {data["followid"]}, object = {followid}') - return False - - - def get_request(self, domain, fail=True): - if domain.startswith('http'): - domain = urlparse(domain).hostname - - try: - return self['follow-requests'][domain] - - except KeyError as e: - if fail: - raise e - - - def add_request(self, actor, inbox, followid): - domain = urlparse(inbox).hostname - - try: - request = self.get_request(domain) - request['followid'] = followid - - except KeyError: - pass - - self['follow-requests'][domain] = { - 'actor': actor, - 'inbox': inbox, - 'followid': followid - } - - - def del_request(self, domain): - if domain.startswith('http'): - domain = urlparse(inbox).hostname - - del self['follow-requests'][domain] + data = {'joined': datetime.now()} + self.update('instances', data, id=row.id) def distill_inboxes(self, message): @@ -192,6 +105,226 @@ class RelayDatabase(dict): urlparse(message.objectid).netloc } - for domain, instance in self['relay-list'].items(): - if domain not in src_domains: - yield instance['inbox'] + for instance in self.get_instances(): + if instance.domain not in src_domains: + yield instance.inbox + + + ## Delete methods + def delete_ban(self, type, name): + row = self.get_ban(type, name) + + if not row: + raise KeyError(name) + + self.delete('bans', id=row.id) + + + def delete_instance(self, domain): + row = self.get_instance(domain) + + if not row: + raise KeyError(name) + + self.delete('instances', id=row.id) + + + def delete_whitelist(self, domain): + row = self.get_whitelist_domain(domain) + + if not row: + raise KeyError(domain) + + self.delete('whitelist', id=row.id) + + + ## Get methods + def get_ban(self, type, name): + if type not in {'software', 'domain'}: + raise ValueError('Ban type must be "software" or "domain"') + + return self.select('bans', name=name, type=type).one() + + + def get_bans(self, type): + if type not in {'software', 'domain'}: + raise ValueError('Ban type must be "software" or "domain"') + + return self.select('bans', type=type).all() + + + def get_config(self, key): + if key not in DEFAULT_CONFIG: + raise KeyError(key) + + row = self.select('config', key=key).one() + + if not row: + return DEFAULT_CONFIG[key][1] + + return row.get_value() + + + def get_config_all(self): + rows = self.select('config').all() + config = DotDict({row.key: row.get_value() for row in rows}) + + for key, data in DEFAULT_CONFIG.items(): + if key not in config: + config[key] = data[1] + + return config + + + def get_hostnames(self): + return tuple(row.domain for row in self.select('instances')) + + + def get_instance(self, data): + if data.startswith('http') and '#' in data: + data = data.split('#', 1)[0] + + query = 'SELECT * FROM instances WHERE domain = :data OR actor = :data OR inbox = :data' + return self.execute(query, dict(data=data), table='instances').one() + + + def get_instances(self): + query = 'SELECT * FROM instances WHERE joined IS NOT NULL' + return self.execute(query, table='instances').all() + + + def get_request(self, domain): + return self.select('instances', domain=domain, joined=None).one() + + + def get_requests(self): + self.select('instances', joined=None).all() + + + def get_whitelist(self): + return self.select('whitelist').all() + + + def get_whitelist_domain(self, domain): + return self.select('whitelist', domain=domain).one() + + + ## Put methods + def put_ban(self, type, name, note=None): + if type not in {'software', 'domain'}: + raise ValueError('Ban type must be "software" or "domain"') + + row = self.select('bans', name=name, type=type).one() + + if row: + if note == None: + raise KeyError(name) + + data = {'note': note} + self.update('bans', data, id=row.id) + return + + self.insert('bans', { + 'name': name, + 'type': type, + 'note': note, + 'created': datetime.now() + }) + + + def put_config(self, key, value='__DEFAULT__'): + if key not in DEFAULT_CONFIG: + raise KeyError(key) + + if value == '__DEFAULT__': + value = DEFAULT_CONFIG[key][1] + + row = self.select('config', key=key).one() + + if row: + self.update('config', {'value': value}, key=key) + return + + self.insert('config', { + 'key': key, + 'value': value + }) + + + def put_instance(self, domain, actor=None, inbox=None, followid=None, software=None, actor_data=None, note=None, accept=True): + new_data = { + 'actor': actor, + 'inbox': inbox, + 'followid': followid, + 'software': software, + 'note': note + } + + if actor_data: + new_data['actor_data'] = dict(actor_data) + + new_data = {key: value for key, value in new_data.items() if value != None} + instance = self.get_instance(domain) + + if instance: + if not new_data: + raise KeyError(domain) + + instance.update(new_data) + self.update('instances', new_data, id=instance.id) + return instance + + if not inbox: + raise ValueError('Inbox must be included in instance data') + + if accept: + new_data['joined'] = datetime.now() + + new_data['domain'] = domain + + self.insert('instances', new_data) + return self.get_instance(domain) + + + def put_instance_actor(self, actor, nodeinfo=None, accept=True): + data = { + 'domain': actor.domain, + 'actor': actor.id, + 'inbox': actor.shared_inbox, + 'actor_data': actor, + 'accept': accept, + 'software': nodeinfo.sw_name if nodeinfo else None + } + + return self.put_instance(**data) + + + def put_whitelist(self, domain): + if self.get_whitelist_domain(domain): + raise KeyError(domain) + + self.insert('whitelist', { + 'domain': domain, + 'created': datetime.now() + }) + + +class ConfigRow(tinysql.Row): + __table__ = 'config' + + def get_value(self): + type = DEFAULT_CONFIG[self.key][0] + + if type == 'int': + return int(self.value) + + elif type == 'bool': + return boolean(self.value.encode('utf-8')) + + elif type == 'list': + return json.loads(value) + + elif type == 'json': + return DotDict.parse(value) + + return self.value diff --git a/relay/http_client.py b/relay/http_client.py index 8802471..797eec4 100644 --- a/relay/http_client.py +++ b/relay/http_client.py @@ -7,11 +7,14 @@ from aputils import Nodeinfo, WellKnownNodeinfo from datetime import datetime from cachetools import LRUCache from json.decoder import JSONDecodeError +from urllib.error import HTTPError from urllib.parse import urlparse +from urllib.request import Request, urlopen from . import __version__ from .misc import ( MIMETYPES, + AppBase, DotDict, Message ) @@ -28,15 +31,23 @@ class Cache(LRUCache): self.__maxsize = int(value) -class HttpClient: - def __init__(self, database, limit=100, timeout=10, cache_size=1024): - self.database = database +class HttpClient(AppBase): + def __init__(self, limit=100, timeout=10, cache_size=1024): self.cache = Cache(cache_size) self.cfg = {'limit': limit, 'timeout': timeout} self._conn = None self._session = None + async def __aenter__(self): + await self.open() + return self + + + async def __aexit__(self, *_): + await self.close() + + @property def limit(self): return self.cfg['limit'] @@ -47,8 +58,16 @@ class HttpClient: return self.cfg['timeout'] + def setup(self): + with self.database.session as s: + config = s.get_config_all() + self.client.cfg['limit'] = config.push_limit + self.client.cfg['timeout'] = config.http_timeout + self.client.cache.set_maxsize(config.json_cache) + + async def open(self): - if self._session: + if self._session and self._session._loop.is_running(): return self._conn = TCPConnector( @@ -87,7 +106,7 @@ class HttpClient: headers = {} if sign_headers: - headers.update(self.database.signer.sign_headers('GET', url, algorithm='original')) + headers.update(self.signer.sign_headers('GET', url, algorithm='original')) try: logging.verbose(f'Fetching resource: {url}') @@ -132,34 +151,35 @@ class HttpClient: raise e - async def post(self, url, message): + async def post(self, inbox, message): await self.open() - instance = self.database.get_inbox(url) + with self.database.session as s: + instance = s.get_instance(inbox) ## Using the old algo by default is probably a better idea right now - if instance and instance.get('software') in {'mastodon'}: + if instance and instance['software'] in {'mastodon'}: algorithm = 'hs2019' else: algorithm = 'original' headers = {'Content-Type': 'application/activity+json'} - headers.update(self.database.signer.sign_headers('POST', url, message, algorithm=algorithm)) + headers.update(self.signer.sign_headers('POST', inbox, message, algorithm=algorithm)) try: - logging.verbose(f'Sending "{message.type}" to {url}') + logging.verbose(f'Sending "{message.type}" to {inbox}') - async with self._session.post(url, headers=headers, data=message.to_json()) as resp: + async with self._session.post(inbox, headers=headers, data=message.to_json()) as resp: ## Not expecting a response, so just return if resp.status in {200, 202}: - return logging.verbose(f'Successfully sent "{message.type}" to {url}') + return logging.verbose(f'Successfully sent "{message.type}" to {inbox}') - logging.verbose(f'Received error when pushing to {url}: {resp.status}') + logging.verbose(f'Received error when pushing to {inbox}: {resp.status}') return logging.verbose(await resp.read()) # change this to debug except (ClientConnectorError, ServerTimeoutError): - logging.verbose(f'Failed to connect to {url}') + logging.verbose(f'Failed to connect to {inbox}') ## prevent workers from being brought down except Exception as e: @@ -190,3 +210,18 @@ class HttpClient: return False return await self.get(nodeinfo_url, loads=Nodeinfo.new_from_json) or False + + +async def get(*args, **kwargs): + async with HttpClient() as client: + return await client.get(*args, **kwargs) + + +async def post(*args, **kwargs): + async with HttpClient() as client: + return await client.post(*args, **kwargs) + + +async def fetch_nodeinfo(*args, **kwargs): + async with HttpClient() as client: + return await client.fetch_nodeinfo(*args, **kwargs) diff --git a/relay/manage.py b/relay/manage.py index 0d7decc..851b658 100644 --- a/relay/manage.py +++ b/relay/manage.py @@ -1,22 +1,29 @@ import Crypto import asyncio import click +import json import logging import platform +import yaml from urllib.parse import urlparse -from . import misc, __version__ +from . import __version__ from .application import Application -from .config import RELAY_SOFTWARE +from .database import DEFAULT_CONFIG, RELAY_SOFTWARE +from .http_client import get, post, fetch_nodeinfo +from .misc import Message, boolean, check_open_port app = None -CONFIG_IGNORE = {'blocked_software', 'blocked_instances', 'whitelist'} +CONFIG_IGNORE = { + 'privkey', + 'version' +} @click.group('cli', context_settings={'show_default': True}, invoke_without_command=True) -@click.option('--config', '-c', default='relay.yaml', help='path to the relay\'s config') +@click.option('--config', '-c', help='path to the relay\'s config') @click.version_option(version=__version__, prog_name='ActivityRelay') @click.pass_context def cli(ctx, config): @@ -30,13 +37,75 @@ def cli(ctx, config): else: cli_run.callback() + if ctx.invoked_subcommand != 'convert': + app.setup() + + +@cli.command('convert') +@click.option('--old-config', '-o', help='path to the old relay config') +def cli_convert(old_config): + with open(old_config or 'relay.yaml') as fd: + config = yaml.load(fd.read(), Loader=yaml.SafeLoader) + ap = config.get('ap', {}) + + with open(config.get('db', 'relay.jsonld')) as fd: + db = json.load(fd) + + app.config['general_host'] = ap.get('host', '__DEFAULT__') + app.config['general_listen'] = config.get('listen', '__DEFAULT__') + app.config['general_port'] = config.get('port', '__DEFAULT__') + + with app.database.session as s: + s.put_config('description', config.get('note', '__DEFAULT__')) + s.put_config('push_limit', config.get('push_limit', '__DEFAULT__')) + s.put_config('json_cache', config.get('json_cache', '__DEFAULT__')) + s.put_config('workers', config.get('workers', '__DEFAULT__')) + s.put_config('http_timeout', config.get('timeout', '__DEFAULT__')) + s.put_config('privkey', db.get('private-key')) + + for name in ap.get('blocked_software', []): + try: s.put_ban('software', name) + except KeyError: print(f'Already banned software: {name}') + + for name in ap.get('blocked_instances', []): + try: s.put_ban('domain', name) + except KeyError: print(f'Already banned instance: {name}') + + for name in ap.get('whitelist', []): + try: s.put_whitelist(name) + except KeyError: print(f'Already whitelisted domain: {name}') + + for instance in db.get('relay-list', {}).values(): + domain = instance['domain'] + software = instance.get('software') + actor = None + + if software == 'mastodon': + actor = f'https://{domain}/actor' + + elif software in {'pleroma', 'akkoma'}: + actor = f'https://{domain}/relay' + + s.put_instance( + domain = domain, + inbox = instance.get('inbox'), + software = software, + actor = actor, + followid = instance.get('followid'), + accept = True + ) + + app.config.save() + + print('Config and database converted :3') + @cli.command('setup') def cli_setup(): 'Generate a new config' while True: - app.config.host = click.prompt('What domain will the relay be hosted on?', default=app.config.host) + app.config['general_host'] = click.prompt('What domain will the relay be hosted on?', default=app.config.host) if not app.config.host.endswith('example.com'): break @@ -44,14 +113,37 @@ def cli_setup(): click.echo('The domain must not be example.com') if not app.config.is_docker: - app.config.listen = click.prompt('Which address should the relay listen on?', default=app.config.listen) + app.config['general_listen'] = click.prompt('Which address should the relay listen on?', default=app.config.listen) while True: - app.config.port = click.prompt('What TCP port should the relay listen on?', default=app.config.port, type=int) + app.config['general_port'] = click.prompt('What TCP port should the relay listen on?', default=app.config.port, type=int) break app.config.save() + with app.database.session as s: + s.put_config('name', click.prompt( + 'What do you want to name your relay?', + default = s.get_config('name') + )) + + s.put_config('description', click.prompt( + 'Provide a small description of your relay. This will be on the front page', + default = s.get_config('description') + )) + + s.put_config('whitelist', click.prompt( + 'Enable the whitelist?', + default = s.get_config('whitelist'), + type = boolean + )) + + s.put_config('require_approval', click.prompt( + 'Require instances to be approved when following?', + default = s.get_config('require_approval'), + type = boolean + )) + if not app.config.is_docker and click.confirm('Relay all setup! Would you like to run it now?'): cli_run.callback() @@ -60,8 +152,9 @@ def cli_setup(): def cli_run(): 'Run the relay' - if app.config.host.endswith('example.com'): - return click.echo('Relay is not set up. Please edit your relay config or run "activityrelay setup".') + with app.database.session as s: + if not s.get_config('privkey') or app.config.host.endswith('example.com'): + return click.echo('Relay is not set up. Please run "activityrelay setup".') vers_split = platform.python_version().split('.') pip_command = 'pip3 uninstall pycrypto && pip3 install pycryptodome' @@ -75,7 +168,7 @@ def cli_run(): click.echo('Warning: PyCrypto is old and should be replaced with pycryptodome') return click.echo(pip_command) - if not misc.check_open_port(app.config.listen, app.config.port): + if not check_open_port(app.config.listen, app.config.port): return click.echo(f'Error: A server is already running on port {app.config.port}') app.run() @@ -94,22 +187,28 @@ def cli_config_list(): click.echo('Relay Config:') - for key, value in app.config.items(): - if key not in CONFIG_IGNORE: - key = f'{key}:'.ljust(20) - click.echo(f'- {key} {value}') + with app.database.session as s: + config = s.get_config_all() + + for key in DEFAULT_CONFIG.keys(): + if key in CONFIG_IGNORE: + continue + + keystr = f'{key}:'.ljust(20) + click.echo(f'- {keystr} {config[key]}') @cli_config.command('set') @click.argument('key') -@click.argument('value') +@click.argument('value', nargs=-1) def cli_config_set(key, value): 'Set a config value' - app.config[key] = value - app.config.save() + with app.database.session as s: + s.put_config(key, ' '.join(value)) + value = s.get_config(key) - print(f'{key}: {app.config[key]}') + print(f'{key}: {value}') @cli.group('inbox') @@ -124,8 +223,9 @@ def cli_inbox_list(): click.echo('Connected to the following instances or relays:') - for inbox in app.database.inboxes: - click.echo(f'- {inbox}') + with app.database.session as s: + for instance in s.get_instances(): + click.echo(f'- {instance.inbox}') @cli_inbox.command('follow') @@ -133,9 +233,6 @@ def cli_inbox_list(): def cli_inbox_follow(actor): 'Follow an actor (Relay must be running)' - if app.config.is_banned(actor): - return click.echo(f'Error: Refusing to follow banned actor: {actor}') - if not actor.startswith('http'): domain = actor actor = f'https://{actor}/actor' @@ -143,24 +240,32 @@ def cli_inbox_follow(actor): else: domain = urlparse(actor).hostname - try: - inbox_data = app.database['relay-list'][domain] - inbox = inbox_data['inbox'] + with app.database.session as s: + if s.get_ban('domain', domain): + return click.echo(f'Error: Refusing to follow banned actor: {actor}') - except KeyError: - actor_data = asyncio.run(app.client.get(actor, sign_headers=True)) + instance = s.get_instance(domain) - if not actor_data: - return click.echo(f'Failed to fetch actor: {actor}') + if not instance: + actor_data = asyncio.run(get(actor, sign_headers=True)) - inbox = actor_data.shared_inbox + if not actor_data: + return click.echo(f'Failed to fetch actor: {actor}') - message = misc.Message.new_follow( + inbox = actor_data.shared_inbox + + else: + inbox = instance.inbox + + if instance.actor: + actor = instance.actor + + message = Message.new_follow( host = app.config.host, actor = actor ) - asyncio.run(app.client.post(inbox, message)) + asyncio.run(post(inbox, message)) click.echo(f'Sent follow message to actor: {actor}') @@ -169,6 +274,8 @@ def cli_inbox_follow(actor): def cli_inbox_unfollow(actor): 'Unfollow an actor (Relay must be running)' + followid = None + if not actor.startswith('http'): domain = actor actor = f'https://{actor}/actor' @@ -176,68 +283,105 @@ def cli_inbox_unfollow(actor): else: domain = urlparse(actor).hostname - try: - inbox_data = app.database['relay-list'][domain] - inbox = inbox_data['inbox'] - message = misc.Message.new_unfollow( + with app.database.session as s: + instance = s.get_instance(domain) + + if not instance: + actor_data = asyncio.run(get(actor, sign_headers=True)) + + if not actor_data: + return click.echo(f'Failed to fetch actor: {actor}') + + inbox = actor_data.shared_inbox + + else: + inbox = instance.inbox + followid = instance.followid + + if instance.actor: + actor = instance.actor + + if followid: + message = Message.new_unfollow( host = app.config.host, actor = actor, - follow = inbox_data['followid'] + follow = followid ) - except KeyError: - actor_data = asyncio.run(app.client.get(actor, sign_headers=True)) - inbox = actor_data.shared_inbox + else: message = misc.Message.new_unfollow( host = app.config.host, actor = actor, follow = { 'type': 'Follow', 'object': actor, - 'actor': f'https://{app.config.host}/actor' + 'actor': app.config.actor } ) - asyncio.run(app.client.post(inbox, message)) + asyncio.run(post(inbox, message)) click.echo(f'Sent unfollow message to: {actor}') @cli_inbox.command('add') -@click.argument('inbox') -def cli_inbox_add(inbox): - 'Add an inbox to the database' +@click.argument('actor') +def cli_inbox_add(actor): + 'Add an instance to the database' - if not inbox.startswith('http'): - inbox = f'https://{inbox}/inbox' + if not actor.startswith('http'): + domain = actor + actor = f'https://{actor}/inbox' - if app.config.is_banned(inbox): - return click.echo(f'Error: Refusing to add banned inbox: {inbox}') + else: + domain = urlparse(actor).hostname - if app.database.get_inbox(inbox): - return click.echo(f'Error: Inbox already in database: {inbox}') + with app.database.session as s: + data = { + 'domain': domain, + 'actor': actor, + 'inbox': f'https://{domain}/inbox' + } - app.database.add_inbox(inbox) - app.database.save() + if s.get_instance(domain): + return click.echo(f'Error: Instance already in database: {domain}') - click.echo(f'Added inbox to the database: {inbox}') + if s.get_ban('domain', domain): + return click.echo(f'Error: Refusing to add banned domain: {domain}') + + nodeinfo = asyncio.run(fetch_nodeinfo(domain)) + + if nodeinfo: + if s.get_ban('software', nodeinfo.sw_name): + return click.echo(f'Error: Refusing to add banned software: {nodeinfo.sw_name}') + + data['software'] = nodeinfo.sw_name + + actor_data = asyncio.run(get(actor, sign_headers=True)) + + if actor_data: + instance = s.put_instance_actor(actor, nodeinfo) + + else: + instance = s.put_instance(**data) + + click.echo(f'Added instance to the database: {instance.domain}') @cli_inbox.command('remove') -@click.argument('inbox') -def cli_inbox_remove(inbox): +@click.argument('domain') +def cli_inbox_remove(domain): 'Remove an inbox from the database' - try: - dbinbox = app.database.get_inbox(inbox, fail=True) + if domain.startswith('http'): + domain = urlparse(domain).hostname - except KeyError: - click.echo(f'Error: Inbox does not exist: {inbox}') - return + with app.database.session as s: + try: + s.delete_instance(domain) + click.echo(f'Removed inbox from the database: {domain}') - app.database.del_inbox(dbinbox['domain']) - app.database.save() - - click.echo(f'Removed inbox from the database: {inbox}') + except KeyError: + return click.echo(f'Error: Inbox does not exist: {domain}') @cli.group('instance') @@ -252,42 +396,50 @@ def cli_instance_list(): click.echo('Banned instances or relays:') - for domain in app.config.blocked_instances: - click.echo(f'- {domain}') + with app.database.session as s: + for row in s.get_bans('domain'): + click.echo(f'- {row.name}') @cli_instance.command('ban') -@click.argument('target') -def cli_instance_ban(target): +@click.argument('domain') +def cli_instance_ban(domain): 'Ban an instance and remove the associated inbox if it exists' - if target.startswith('http'): - target = urlparse(target).hostname + if domain.startswith('http'): + domain = urlparse(domain).hostname - if app.config.ban_instance(target): - app.config.save() + with app.database.session as s: + try: + s.put_ban('domain', domain) - if app.database.del_inbox(target): - app.database.save() + except KeyError: + return click.echo(f'Instance already banned: {domain}') - click.echo(f'Banned instance: {target}') - return + try: + s.delete_instance(domain) - click.echo(f'Instance already banned: {target}') + except KeyError: + pass + + click.echo(f'Banned instance: {domain}') @cli_instance.command('unban') -@click.argument('target') -def cli_instance_unban(target): +@click.argument('domain') +def cli_instance_unban(domain): 'Unban an instance' - if app.config.unban_instance(target): - app.config.save() + if domain.startswith('http'): + domain = urlparse(domain).hostname - click.echo(f'Unbanned instance: {target}') - return + with app.database.session as s: + try: + s.delete_ban('domain', domain) + click.echo(f'Unbanned instance: {domain}') - click.echo(f'Instance wasn\'t banned: {target}') + except KeyError: + click.echo(f'Instance wasn\'t banned: {domain}') @cli.group('software') @@ -302,8 +454,9 @@ def cli_software_list(): click.echo('Banned software:') - for software in app.config.blocked_software: - click.echo(f'- {software}') + with app.database.session as s: + for row in s.get_bans('software'): + click.echo(f'- {row.name}') @cli_software.command('ban') @@ -314,26 +467,27 @@ def cli_software_list(): def cli_software_ban(name, fetch_nodeinfo): 'Ban software. Use RELAYS for NAME to ban relays' - if name == 'RELAYS': - for name in RELAY_SOFTWARE: - app.config.ban_software(name) + with app.database.session as s: + if name == 'RELAYS': + for name in RELAY_SOFTWARE: + s.put_ban('software', name) - app.config.save() - return click.echo('Banned all relay software') + return click.echo('Banned all relay software') - if fetch_nodeinfo: - nodeinfo = asyncio.run(app.client.fetch_nodeinfo(name)) + if fetch_nodeinfo: + nodeinfo = asyncio.run(fetch_nodeinfo(name)) - if not nodeinfo: - click.echo(f'Failed to fetch software name from domain: {name}') + if not nodeinfo: + return click.echo(f'Failed to fetch software name from domain: {name}') - name = nodeinfo.sw_name + name = nodeinfo.sw_name - if app.config.ban_software(name): - app.config.save() - return click.echo(f'Banned software: {name}') + try: + s.put_ban('software', name) + click.echo(f'Banned software: {name}') - click.echo(f'Software already banned: {name}') + except KeyError: + click.echo(f'Software already banned: {name}') @cli_software.command('unban') @@ -344,26 +498,27 @@ def cli_software_ban(name, fetch_nodeinfo): def cli_software_unban(name, fetch_nodeinfo): 'Ban software. Use RELAYS for NAME to unban relays' - if name == 'RELAYS': - for name in RELAY_SOFTWARE: - app.config.unban_software(name) + with app.database.session as s: + if name == 'RELAYS': + for name in RELAY_SOFTWARE: + s.put_ban('software', name) - app.config.save() - return click.echo('Unbanned all relay software') + return click.echo('Unbanned all relay software') - if fetch_nodeinfo: - nodeinfo = asyncio.run(app.client.fetch_nodeinfo(name)) + if fetch_nodeinfo: + nodeinfo = asyncio.run(fetch_nodeinfo(name)) - if not nodeinfo: - click.echo(f'Failed to fetch software name from domain: {name}') + if not nodeinfo: + return click.echo(f'Failed to fetch software name from domain: {name}') - name = nodeinfo.sw_name + name = nodeinfo.sw_name - if app.config.unban_software(name): - app.config.save() - return click.echo(f'Unbanned software: {name}') + try: + s.put_ban('software', name) + click.echo(f'Unbanned software: {name}') - click.echo(f'Software wasn\'t banned: {name}') + except KeyError: + click.echo(f'Software wasn\'t banned: {name}') @cli.group('whitelist') @@ -376,47 +531,61 @@ def cli_whitelist(): def cli_whitelist_list(): 'List all the instances in the whitelist' - click.echo('Current whitelisted domains') + click.echo('Current whitelisted domains:') - for domain in app.config.whitelist: - click.echo(f'- {domain}') + with app.database.session as s: + for row in s.get_whitelist(): + click.echo(f'- {row.domain}') @cli_whitelist.command('add') -@click.argument('instance') -def cli_whitelist_add(instance): - 'Add an instance to the whitelist' +@click.argument('domain') +def cli_whitelist_add(domain): + 'Add a domain to the whitelist' - if not app.config.add_whitelist(instance): - return click.echo(f'Instance already in the whitelist: {instance}') + with app.database.session as s: + try: + s.put_whitelist(domain) + click.echo(f'Instance added to the whitelist: {domain}') - app.config.save() - click.echo(f'Instance added to the whitelist: {instance}') + except KeyError: + return click.echo(f'Instance already in the whitelist: {domain}') @cli_whitelist.command('remove') -@click.argument('instance') -def cli_whitelist_remove(instance): - 'Remove an instance from the whitelist' +@click.argument('domain') +def cli_whitelist_remove(domain): + 'Remove a domain from the whitelist' - if not app.config.del_whitelist(instance): - return click.echo(f'Instance not in the whitelist: {instance}') + with app.database.session as s: + try: + s.delete_whitelist(domain) + click.echo(f'Removed instance from the whitelist: {domain}') - app.config.save() - - if app.config.whitelist_enabled: - if app.database.del_inbox(instance): - app.database.save() - - click.echo(f'Removed instance from the whitelist: {instance}') + except KeyError: + click.echo(f'Instance not in the whitelist: {domain}') @cli_whitelist.command('import') def cli_whitelist_import(): 'Add all current inboxes to the whitelist' - for domain in app.database.hostnames: - cli_whitelist_add.callback(domain) + with app.database.session as s: + for row in s.get_instances(): + try: + s.put_whitelist(row.domain) + click.echo(f'Instance added to the whitelist: {row.domain}') + + except KeyError: + click.echo(f'Instance already in the whitelist: {row.domain}') + + +@cli_whitelist.command('clear') +def cli_whitelist_clear(): + 'Clear all items out of the whitelist' + + with app.database.session as s: + s.delete('whitelist') def main(): diff --git a/relay/misc.py b/relay/misc.py index a98088f..9f5a96b 100644 --- a/relay/misc.py +++ b/relay/misc.py @@ -36,6 +36,9 @@ def set_app(new_app): def boolean(value): + if isinstance(value, bytes): + value = str(value, 'utf-8') + if isinstance(value, str): if value.lower() in ['on', 'y', 'yes', 'true', 'enable', 'enabled', '1']: return True @@ -63,7 +66,7 @@ def boolean(value): return value.__bool__() except AttributeError: - raise TypeError(f'Cannot convert object of type "{clsname(value)}"') + raise TypeError(f'Cannot convert object of type "{type(value).__name__}"') def check_open_port(host, port): @@ -78,6 +81,32 @@ def check_open_port(host, port): return False +class AppBase: + @property + def app(self): + return app + + + @property + def client(self): + return app.client + + + @property + def config(self): + return app.config + + + @property + def database(self): + return app.database + + + @property + def signer(self): + return app.signer + + class DotDict(dict): def __init__(self, _data, **kwargs): dict.__init__(self) @@ -310,31 +339,3 @@ class Response(AiohttpResponse): @location.setter def location(self, value): self.headers['Location'] = value - - -class View(AiohttpView): - async def _iter(self): - if self.request.method not in METHODS: - self._raise_allowed_methods() - - method = getattr(self, self.request.method.lower(), None) - - if method is None: - self._raise_allowed_methods() - - return await method(**self.request.match_info) - - - @property - def app(self): - return self._request.app - - - @property - def config(self): - return self.app.config - - - @property - def database(self): - return self.app.database diff --git a/relay/processors.py b/relay/processors.py index 2d34246..7184171 100644 --- a/relay/processors.py +++ b/relay/processors.py @@ -4,6 +4,7 @@ import logging from cachetools import LRUCache from uuid import uuid4 +from .database import RELAY_SOFTWARE from .misc import Message @@ -20,7 +21,7 @@ def person_check(actor, software): return True -async def handle_relay(request): +async def handle_relay(request, s): if request.message.objectid in cache: logging.verbose(f'already relayed {request.message.objectid}') return @@ -33,13 +34,13 @@ async def handle_relay(request): cache[request.message.objectid] = message.id logging.debug(f'>> relay: {message}') - inboxes = request.database.distill_inboxes(request.message) + inboxes = s.distill_inboxes(request.message) for inbox in inboxes: request.app.push_message(inbox, message) -async def handle_forward(request): +async def handle_forward(request, s): if request.message.id in cache: logging.verbose(f'already forwarded {request.message.id}') return @@ -52,46 +53,53 @@ async def handle_forward(request): cache[request.message.id] = message.id logging.debug(f'>> forward: {message}') - inboxes = request.database.distill_inboxes(request.message) + inboxes = s.distill_inboxes(request.message) for inbox in inboxes: request.app.push_message(inbox, message) -async def handle_follow(request): +async def handle_follow(request, s): + approve = True + nodeinfo = await request.app.client.fetch_nodeinfo(request.actor.domain) software = nodeinfo.sw_name if nodeinfo else None - ## reject if software used by actor is banned - if request.config.is_banned_software(software): - request.app.push_message( - request.actor.shared_inbox, - Message.new_response( - host = request.config.host, - actor = request.actor.id, - followid = request.message.id, - accept = False - ) - ) + ## reject if the actor isn't whitelisted while the whiltelist is enabled + if s.get_config('whitelist') and not s.get_whitelist(request.actor.domain): + logging.verbose(f'Rejected actor for not being in the whitelist: {request.actor.id}') + accept = False - return logging.verbose(f'Rejected follow from actor for using specific software: actor={request.actor.id}, software={software}') + ## reject if software used by actor is banned + if s.get_banned_software(software): + logging.verbose(f'Rejected follow from actor for using specific software: actor={request.actor.id}, software={software}') + accept = False ## reject if the actor is not an instance actor if person_check(request.actor, software): - request.app.push_message( - request.actor.shared_inbox, - Message.new_response( - host = request.config.host, + logging.verbose(f'Non-application actor tried to follow: {request.actor.id}') + accept = False + + if approve: + if not request.instance: + s.put_instance( + domain = request.actor.domain, actor = request.actor.id, + inbox = request.actor.shared_inbox, + actor_data = request.actor, + software = software, followid = request.message.id, - accept = False + accept = s.get_config('require_approval') ) - ) - return logging.verbose(f'Non-application actor tried to follow: {request.actor.id}') + if s.get_config('require_approval'): + return - request.database.add_inbox(request.actor.shared_inbox, request.message.id, software) - request.database.save() + else: + s.put_instance( + domain = request.actor.domain, + followid = request.message.id + ) request.app.push_message( request.actor.shared_inbox, @@ -99,10 +107,18 @@ async def handle_follow(request): host = request.config.host, actor = request.actor.id, followid = request.message.id, - accept = True + accept = approve ) ) + ## Don't send a follow if the the follow has been rejected + if not approve: + return + + ## Make sure two relays aren't continuously following each other + if software in RELAY_SOFTWARE and not request.instance: + return + # Are Akkoma and Pleroma the only two that expect a follow back? # Ignoring only Mastodon for now if software != 'mastodon': @@ -115,15 +131,12 @@ async def handle_follow(request): ) -async def handle_undo(request): +async def handle_undo(request, s): ## If the object is not a Follow, forward it if request.message.object.type != 'Follow': return await handle_forward(request) - if not request.database.del_inbox(request.actor.domain, request.message.id): - return - - request.database.save() + s.delete('instances', id=request.instance.id) request.app.push_message( request.actor.shared_inbox, @@ -149,12 +162,23 @@ async def run_processor(request): if request.message.type not in processors: return - if request.instance and not request.instance.get('software'): - nodeinfo = await request.app.client.fetch_nodeinfo(request.instance['domain']) + with request.database.session as s: + new_data = {} - if nodeinfo: - request.instance['software'] = nodeinfo.sw_name - request.database.save() + if request.instance and not request.instance.software: + nodeinfo = await request.app.client.fetch_nodeinfo(request.instance.domain) - logging.verbose(f'New "{request.message.type}" from actor: {request.actor.id}') - return await processors[request.message.type](request) + if nodeinfo: + new_data['software'] = nodeinfo.sw_name + + if not request.instance.actor: + new_data['actor'] = request.signature.keyid.split('#', 1)[0] + + if not request.instance.actor_data: + new_data['actor_data'] = request.actor + + if new_data: + s.put_instance(request.actor.domain, **new_data) + + logging.verbose(f'New "{request.message.type}" from actor: {request.actor.id}') + return await processors[request.message.type](request, s) diff --git a/relay/views.py b/relay/views.py index 9cea1ef..15c4317 100644 --- a/relay/views.py +++ b/relay/views.py @@ -5,6 +5,7 @@ import subprocess import traceback from pathlib import Path +from urllib.parse import urlparse from . import __version__, misc from .misc import DotDict, Message, Response @@ -24,19 +25,24 @@ if Path(__file__).parent.parent.joinpath('.git').exists(): pass -def register_route(method, path): +def register_route(method, *paths): def wrapper(func): - routes.append([method, path, func]) + for path in paths: + routes.append([method, path, func]) + return func return wrapper @register_route('GET', '/') -async def home(request): - targets = '
'.join(request.database.hostnames) - note = request.config.note - count = len(request.database.hostnames) +async def home(request, s): + hostnames = s.get_hostnames() + config = s.get_config_all() + + targets = '
'.join(hostnames) + note = config.description + count = len(hostnames) host = request.config.host text = f""" @@ -61,28 +67,30 @@ a:hover {{ color: #8AF; }} return Response.new(text, ctype='html') -@register_route('GET', '/inbox') -@register_route('GET', '/actor') +@register_route('GET', '/actor', '/inbox') async def actor(request): data = Message.new_actor( host = request.config.host, - pubkey = request.database.signer.pubkey + pubkey = request.app.signer.pubkey ) return Response.new(data, ctype='activity') -@register_route('POST', '/inbox') -@register_route('POST', '/actor') -async def inbox(request): - config = request.config - database = request.database - +@register_route('POST', '/actor', '/inbox') +async def inbox(request, s): ## reject if missing signature header if not request.signature: logging.verbose('Actor missing signature header') raise HTTPUnauthorized(body='missing signature') + domain = urlparse(request.signature.keyid).hostname + + ## reject if actor is banned + if s.get_ban('domain', domain): + logging.verbose(f'Ignored request from banned actor: {domain}') + return Response.new_error(403, 'access denied', 'json') + try: request['message'] = await request.json(loads=Message.new_from_json) @@ -114,17 +122,8 @@ async def inbox(request): logging.verbose(f'Failed to fetch actor: {request.signature.keyid}') return Response.new_error(400, 'failed to fetch actor', 'json') - request['instance'] = request.database.get_inbox(request['actor'].inbox) - - ## reject if the actor isn't whitelisted while the whiltelist is enabled - if config.whitelist_enabled and not config.is_whitelisted(request.actor.domain): - logging.verbose(f'Rejected actor for not being in the whitelist: {request.actor.id}') - return Response.new_error(403, 'access denied', 'json') - - ## reject if actor is banned - if request.config.is_banned(request.actor.domain): - logging.verbose(f'Ignored request from banned actor: {actor.id}') - return Response.new_error(403, 'access denied', 'json') + request['instance'] = s.get_instance(request.actor.shared_inbox) + config = s.get_config_all() ## reject if the signature is invalid try: @@ -136,7 +135,7 @@ async def inbox(request): return Response.new_error(401, str(e), 'json') ## reject if activity type isn't 'Follow' and the actor isn't following - if request.message.type != 'Follow' and not database.get_inbox(request.actor.domain): + if request.message.type != 'Follow' and not request.instance: logging.verbose(f'Rejected actor for trying to post while not following: {request.actor.id}') return Response.new_error(401, 'access denied', 'json') @@ -167,16 +166,16 @@ async def webfinger(request): @register_route('GET', '/nodeinfo/{version:\d.\d\.json}') -async def nodeinfo(request): +async def nodeinfo(request, s): niversion = request.match_info['version'][:3] data = dict( name = 'activityrelay', version = version, protocols = ['activitypub'], - open_regs = not request.config.whitelist_enabled, + open_regs = not s.get_config('whitelist'), users = 1, - metadata = {'peers': request.database.hostnames} + metadata = {'peers': s.get_hostnames()} ) if niversion == '2.1': diff --git a/requirements.txt b/requirements.txt index 9199741..382d935 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,7 @@ aiohttp>=3.8.0 +appdirs>=1.4.4 aputils@https://git.barkshark.xyz/barkshark/aputils/archive/0.1.3.tar.gz cachetools>=5.2.0 click>=8.1.2 pyyaml>=6.0 +tinysql[postgres,mysql]@https:/git.barkshark.xyz/barkshark/tinysql/archive/0.1.0.tar.gz