Compare commits

..

No commits in common. "965ac73c6d84caf0fff7d316b8ae67c4a1473140" and "eea7dc81eaba43aeb90b9304e9d4cb12e6b6ef1b" have entirely different histories.

16 changed files with 630 additions and 839 deletions

1
.gitignore vendored
View file

@ -99,4 +99,3 @@ viera.jsonld
# config file # config file
relay.yaml relay.yaml
relay.jsonld

View file

@ -15,7 +15,7 @@ the [official pipx docs](https://pypa.github.io/pipx/installation/) for more in-
Now simply install ActivityRelay directly from git Now simply install ActivityRelay directly from git
pipx install git+https://git.pleroma.social/pleroma/relay@0.2.5 pipx install git+https://git.pleroma.social/pleroma/relay@0.2.4
Or from a cloned git repo. Or from a cloned git repo.
@ -39,7 +39,7 @@ be installed via [pyenv](https://github.com/pyenv/pyenv).
The instructions for installation via pip are very similar to pipx. Installation can be done from The instructions for installation via pip are very similar to pipx. Installation can be done from
git git
python3 -m pip install git+https://git.pleroma.social/pleroma/relay@0.2.5 python3 -m pip install git+https://git.pleroma.social/pleroma/relay@0.2.4
or a cloned git repo. or a cloned git repo.

View file

@ -1,37 +1,3 @@
[build-system] [build-system]
requires = ["setuptools","wheel"] requires = ["setuptools","wheel"]
build-backend = 'setuptools.build_meta' build-backend = 'setuptools.build_meta'
[tool.pylint.main]
jobs = 0
persistent = true
[tool.pylint.design]
max-args = 10
max-attributes = 100
[tool.pylint.format]
indent-str = "\t"
indent-after-paren = 1
max-line-length = 100
single-line-if-stmt = true
[tool.pylint.messages_control]
disable = [
"broad-exception-caught",
"cyclic-import",
"global-statement",
"invalid-name",
"missing-module-docstring",
"too-few-public-methods",
"too-many-public-methods",
"too-many-return-statements",
"wrong-import-order",
"wrong-import-position",
"missing-function-docstring",
"missing-class-docstring"
]

View file

@ -9,7 +9,13 @@ a = Analysis(
pathex=[], pathex=[],
binaries=[], binaries=[],
datas=[], datas=[],
hiddenimports=[], hiddenimports=[
'aputils.enums',
'aputils.errors',
'aputils.misc',
'aputils.objects',
'aputils.signer'
],
hookspath=[], hookspath=[],
hooksconfig={}, hooksconfig={},
runtime_hooks=[], runtime_hooks=[],

View file

@ -1 +1,3 @@
__version__ = '0.2.5' __version__ = '0.2.4'
from . import logger

View file

@ -1,42 +1,31 @@
from __future__ import annotations
import asyncio import asyncio
import logging
import os
import queue import queue
import signal import signal
import threading import threading
import traceback import traceback
import typing
from aiohttp import web from aiohttp import web
from datetime import datetime, timedelta from datetime import datetime, timedelta
from . import logger as logging
from .config import RelayConfig from .config import RelayConfig
from .database import RelayDatabase from .database import RelayDatabase
from .http_client import HttpClient from .http_client import HttpClient
from .misc import check_open_port from .misc import DotDict, check_open_port, set_app
from .views import VIEWS from .views import routes
if typing.TYPE_CHECKING:
from typing import Any
from .misc import Message
# pylint: disable=unsubscriptable-object
class Application(web.Application): class Application(web.Application):
def __init__(self, cfgpath: str): def __init__(self, cfgpath):
web.Application.__init__(self) web.Application.__init__(self)
self['workers'] = [] self['starttime'] = None
self['last_worker'] = 0
self['start_time'] = None
self['running'] = False self['running'] = False
self['config'] = RelayConfig(cfgpath) self['config'] = RelayConfig(cfgpath)
if not self.config.load(): if not self['config'].load():
self.config.save() self['config'].save()
if self.config.is_docker: if self.config.is_docker:
self.config.update({ self.config.update({
@ -45,8 +34,13 @@ class Application(web.Application):
'port': 8080 'port': 8080
}) })
self['database'] = RelayDatabase(self.config) self['workers'] = []
self.database.load() self['last_worker'] = 0
set_app(self)
self['database'] = RelayDatabase(self['config'])
self['database'].load()
self['client'] = HttpClient( self['client'] = HttpClient(
database = self.database, database = self.database,
@ -55,39 +49,37 @@ class Application(web.Application):
cache_size = self.config.json_cache cache_size = self.config.json_cache
) )
for path, view in VIEWS: self.set_signal_handler()
self.router.add_view(path, view)
@property @property
def client(self) -> HttpClient: def client(self):
return self['client'] return self['client']
@property @property
def config(self) -> RelayConfig: def config(self):
return self['config'] return self['config']
@property @property
def database(self) -> RelayDatabase: def database(self):
return self['database'] return self['database']
@property @property
def uptime(self) -> timedelta: def uptime(self):
if not self['start_time']: if not self['starttime']:
return timedelta(seconds=0) return timedelta(seconds=0)
uptime = datetime.now() - self['start_time'] uptime = datetime.now() - self['starttime']
return timedelta(seconds=uptime.seconds) return timedelta(seconds=uptime.seconds)
def push_message(self, inbox: str, message: Message) -> None: def push_message(self, inbox, message):
if self.config.workers <= 0: if self.config.workers <= 0:
asyncio.ensure_future(self.client.post(inbox, message)) return asyncio.ensure_future(self.client.post(inbox, message))
return
worker = self['workers'][self['last_worker']] worker = self['workers'][self['last_worker']]
worker.queue.put((inbox, message)) worker.queue.put((inbox, message))
@ -98,45 +90,36 @@ class Application(web.Application):
self['last_worker'] = 0 self['last_worker'] = 0
def set_signal_handler(self, startup: bool) -> None: def set_signal_handler(self):
for sig in ('SIGHUP', 'SIGINT', 'SIGQUIT', 'SIGTERM'): for sig in {'SIGHUP', 'SIGINT', 'SIGQUIT', 'SIGTERM'}:
try: try:
signal.signal(getattr(signal, sig), self.stop if startup else signal.SIG_DFL) signal.signal(getattr(signal, sig), self.stop)
# some signals don't exist in windows, so skip them # some signals don't exist in windows, so skip them
except AttributeError: except AttributeError:
pass pass
def run(self) -> None: def run(self):
if not check_open_port(self.config.listen, self.config.port): if not check_open_port(self.config.listen, self.config.port):
logging.error('A server is already running on port %i', self.config.port) return logging.error(f'A server is already running on port {self.config.port}')
return
for view in VIEWS: for route in routes:
self.router.add_view(*view) self.router.add_route(*route)
logging.info(
'Starting webserver at %s (%s:%i)',
self.config.host,
self.config.listen,
self.config.port
)
logging.info(f'Starting webserver at {self.config.host} ({self.config.listen}:{self.config.port})')
asyncio.run(self.handle_run()) asyncio.run(self.handle_run())
def stop(self, *_: Any) -> None: def stop(self, *_):
self['running'] = False self['running'] = False
async def handle_run(self) -> None: async def handle_run(self):
self['running'] = True self['running'] = True
self.set_signal_handler(True)
if self.config.workers > 0: if self.config.workers > 0:
for _ in range(self.config.workers): for i in range(self.config.workers):
worker = PushWorker(self) worker = PushWorker(self)
worker.start() worker.start()
@ -145,40 +128,33 @@ class Application(web.Application):
runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"') runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"')
await runner.setup() await runner.setup()
site = web.TCPSite( site = web.TCPSite(runner,
runner,
host = self.config.listen, host = self.config.listen,
port = self.config.port, port = self.config.port,
reuse_address = True reuse_address = True
) )
await site.start() await site.start()
self['start_time'] = datetime.now() self['starttime'] = datetime.now()
while self['running']: while self['running']:
await asyncio.sleep(0.25) await asyncio.sleep(0.25)
await site.stop() await site.stop()
await self.client.close()
self['start_time'] = None self['starttime'] = None
self['running'] = False self['running'] = False
self['workers'].clear() self['workers'].clear()
class PushWorker(threading.Thread): class PushWorker(threading.Thread):
def __init__(self, app: Application): def __init__(self, app):
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.app = app self.app = app
self.queue = queue.Queue() self.queue = queue.Queue()
self.client = None
def run(self) -> None: def run(self):
asyncio.run(self.handle_queue())
async def handle_queue(self) -> None:
self.client = HttpClient( self.client = HttpClient(
database = self.app.database, database = self.app.database,
limit = self.app.config.push_limit, limit = self.app.config.push_limit,
@ -186,11 +162,15 @@ class PushWorker(threading.Thread):
cache_size = self.app.config.json_cache cache_size = self.app.config.json_cache
) )
asyncio.run(self.handle_queue())
async def handle_queue(self):
while self.app['running']: while self.app['running']:
try: try:
inbox, message = self.queue.get(block=True, timeout=0.25) inbox, message = self.queue.get(block=True, timeout=0.25)
self.queue.task_done() self.queue.task_done()
logging.verbose('New push from Thread-%i', threading.get_ident()) logging.verbose(f'New push from Thread-{threading.get_ident()}')
await self.client.post(inbox, message) await self.client.post(inbox, message)
except queue.Empty: except queue.Empty:
@ -201,3 +181,36 @@ class PushWorker(threading.Thread):
traceback.print_exc() traceback.print_exc()
await self.client.close() await self.client.close()
## Can't sub-class web.Request, so let's just add some properties
def request_actor(self):
try: return self['actor']
except KeyError: pass
def request_instance(self):
try: return self['instance']
except KeyError: pass
def request_message(self):
try: return self['message']
except KeyError: pass
def request_signature(self):
if 'signature' not in self._state:
try: self['signature'] = DotDict.new_from_signature(self.headers['signature'])
except KeyError: return
return self['signature']
setattr(web.Request, 'actor', property(request_actor))
setattr(web.Request, 'instance', property(request_instance))
setattr(web.Request, 'message', property(request_message))
setattr(web.Request, 'signature', property(request_signature))
setattr(web.Request, 'config', property(lambda self: self.app.config))
setattr(web.Request, 'database', property(lambda self: self.app.database))

View file

@ -1,7 +1,5 @@
from __future__ import annotations import json
import os import os
import typing
import yaml import yaml
from functools import cached_property from functools import cached_property
@ -10,10 +8,6 @@ from urllib.parse import urlparse
from .misc import DotDict, boolean from .misc import DotDict, boolean
if typing.TYPE_CHECKING:
from typing import Any
from .database import RelayDatabase
RELAY_SOFTWARE = [ RELAY_SOFTWARE = [
'activityrelay', # https://git.pleroma.social/pleroma/relay 'activityrelay', # https://git.pleroma.social/pleroma/relay
@ -31,19 +25,17 @@ APKEYS = [
class RelayConfig(DotDict): class RelayConfig(DotDict):
__slots__ = ('path', ) def __init__(self, path):
def __init__(self, path: str | Path):
DotDict.__init__(self, {}) DotDict.__init__(self, {})
if self.is_docker: if self.is_docker:
path = '/data/config.yaml' path = '/data/config.yaml'
self._path = Path(path).expanduser().resolve() self._path = Path(path).expanduser()
self.reset() self.reset()
def __setitem__(self, key: str, value: Any) -> None: def __setitem__(self, key, value):
if key in ['blocked_instances', 'blocked_software', 'whitelist']: if key in ['blocked_instances', 'blocked_software', 'whitelist']:
assert isinstance(value, (list, set, tuple)) assert isinstance(value, (list, set, tuple))
@ -59,31 +51,36 @@ class RelayConfig(DotDict):
@property @property
def db(self) -> RelayDatabase: def db(self):
return Path(self['db']).expanduser().resolve() return Path(self['db']).expanduser().resolve()
@property @property
def actor(self) -> str: def path(self):
return self._path
@property
def actor(self):
return f'https://{self.host}/actor' return f'https://{self.host}/actor'
@property @property
def inbox(self) -> str: def inbox(self):
return f'https://{self.host}/inbox' return f'https://{self.host}/inbox'
@property @property
def keyid(self) -> str: def keyid(self):
return f'{self.actor}#main-key' return f'{self.actor}#main-key'
@cached_property @cached_property
def is_docker(self) -> bool: def is_docker(self):
return bool(os.environ.get('DOCKER_RUNNING')) return bool(os.environ.get('DOCKER_RUNNING'))
def reset(self) -> None: def reset(self):
self.clear() self.clear()
self.update({ self.update({
'db': str(self._path.parent.joinpath(f'{self._path.stem}.jsonld')), 'db': str(self._path.parent.joinpath(f'{self._path.stem}.jsonld')),
@ -102,7 +99,7 @@ class RelayConfig(DotDict):
}) })
def ban_instance(self, instance: str) -> bool: def ban_instance(self, instance):
if instance.startswith('http'): if instance.startswith('http'):
instance = urlparse(instance).hostname instance = urlparse(instance).hostname
@ -113,7 +110,7 @@ class RelayConfig(DotDict):
return True return True
def unban_instance(self, instance: str) -> bool: def unban_instance(self, instance):
if instance.startswith('http'): if instance.startswith('http'):
instance = urlparse(instance).hostname instance = urlparse(instance).hostname
@ -121,11 +118,11 @@ class RelayConfig(DotDict):
self.blocked_instances.remove(instance) self.blocked_instances.remove(instance)
return True return True
except ValueError: except:
return False return False
def ban_software(self, software: str) -> bool: def ban_software(self, software):
if self.is_banned_software(software): if self.is_banned_software(software):
return False return False
@ -133,16 +130,16 @@ class RelayConfig(DotDict):
return True return True
def unban_software(self, software: str) -> bool: def unban_software(self, software):
try: try:
self.blocked_software.remove(software) self.blocked_software.remove(software)
return True return True
except ValueError: except:
return False return False
def add_whitelist(self, instance: str) -> bool: def add_whitelist(self, instance):
if instance.startswith('http'): if instance.startswith('http'):
instance = urlparse(instance).hostname instance = urlparse(instance).hostname
@ -153,7 +150,7 @@ class RelayConfig(DotDict):
return True return True
def del_whitelist(self, instance: str) -> bool: def del_whitelist(self, instance):
if instance.startswith('http'): if instance.startswith('http'):
instance = urlparse(instance).hostname instance = urlparse(instance).hostname
@ -161,32 +158,32 @@ class RelayConfig(DotDict):
self.whitelist.remove(instance) self.whitelist.remove(instance)
return True return True
except ValueError: except:
return False return False
def is_banned(self, instance: str) -> bool: def is_banned(self, instance):
if instance.startswith('http'): if instance.startswith('http'):
instance = urlparse(instance).hostname instance = urlparse(instance).hostname
return instance in self.blocked_instances return instance in self.blocked_instances
def is_banned_software(self, software: str) -> bool: def is_banned_software(self, software):
if not software: if not software:
return False return False
return software.lower() in self.blocked_software return software.lower() in self.blocked_software
def is_whitelisted(self, instance: str) -> bool: def is_whitelisted(self, instance):
if instance.startswith('http'): if instance.startswith('http'):
instance = urlparse(instance).hostname instance = urlparse(instance).hostname
return instance in self.whitelist return instance in self.whitelist
def load(self) -> bool: def load(self):
self.reset() self.reset()
options = {} options = {}
@ -198,7 +195,7 @@ class RelayConfig(DotDict):
pass pass
try: try:
with self._path.open('r', encoding = 'UTF-8') as fd: with open(self.path) as fd:
config = yaml.load(fd, **options) config = yaml.load(fd, **options)
except FileNotFoundError: except FileNotFoundError:
@ -217,7 +214,7 @@ class RelayConfig(DotDict):
continue continue
if key not in self: elif key not in self:
continue continue
self[key] = value self[key] = value
@ -228,7 +225,7 @@ class RelayConfig(DotDict):
return True return True
def save(self) -> None: def save(self):
config = { config = {
# just turning config.db into a string is good enough for now # just turning config.db into a string is good enough for now
'db': str(self.db), 'db': str(self.db),
@ -242,5 +239,7 @@ class RelayConfig(DotDict):
'ap': {key: self[key] for key in APKEYS} 'ap': {key: self[key] for key in APKEYS}
} }
with self._path.open('w', encoding = 'utf-8') as fd: with open(self._path, 'w') as fd:
yaml.dump(config, fd, sort_keys=False) yaml.dump(config, fd, sort_keys=False)
return config

View file

@ -1,21 +1,14 @@
from __future__ import annotations import aputils
import asyncio
import json import json
import typing import logging
import traceback
from aputils.signer import Signer
from urllib.parse import urlparse from urllib.parse import urlparse
from . import logger as logging
if typing.TYPE_CHECKING:
from typing import Iterator, Optional
from .config import RelayConfig
from .misc import Message
class RelayDatabase(dict): class RelayDatabase(dict):
def __init__(self, config: RelayConfig): def __init__(self, config):
dict.__init__(self, { dict.__init__(self, {
'relay-list': {}, 'relay-list': {},
'private-key': None, 'private-key': None,
@ -28,16 +21,16 @@ class RelayDatabase(dict):
@property @property
def hostnames(self) -> tuple[str]: def hostnames(self):
return tuple(self['relay-list'].keys()) return tuple(self['relay-list'].keys())
@property @property
def inboxes(self) -> tuple[dict[str, str]]: def inboxes(self):
return tuple(data['inbox'] for data in self['relay-list'].values()) return tuple(data['inbox'] for data in self['relay-list'].values())
def load(self) -> bool: def load(self):
new_db = True new_db = True
try: try:
@ -47,7 +40,7 @@ class RelayDatabase(dict):
self['version'] = data.get('version', None) self['version'] = data.get('version', None)
self['private-key'] = data.get('private-key') self['private-key'] = data.get('private-key')
if self['version'] is None: if self['version'] == None:
self['version'] = 1 self['version'] = 1
if 'actorKeys' in data: if 'actorKeys' in data:
@ -65,9 +58,7 @@ class RelayDatabase(dict):
self['relay-list'] = data.get('relay-list', {}) self['relay-list'] = data.get('relay-list', {})
for domain, instance in self['relay-list'].items(): for domain, instance in self['relay-list'].items():
if self.config.is_banned(domain) or \ if self.config.is_banned(domain) or (self.config.whitelist_enabled and not self.config.is_whitelisted(domain)):
(self.config.whitelist_enabled and not self.config.is_whitelisted(domain)):
self.del_inbox(domain) self.del_inbox(domain)
continue continue
@ -84,40 +75,36 @@ class RelayDatabase(dict):
raise e from None raise e from None
if not self['private-key']: if not self['private-key']:
logging.info('No actor keys present, generating 4096-bit RSA keypair.') logging.info("No actor keys present, generating 4096-bit RSA keypair.")
self.signer = Signer.new(self.config.keyid, size=4096) self.signer = aputils.Signer.new(self.config.keyid, size=4096)
self['private-key'] = self.signer.export() self['private-key'] = self.signer.export()
else: else:
self.signer = Signer(self['private-key'], self.config.keyid) self.signer = aputils.Signer(self['private-key'], self.config.keyid)
self.save() self.save()
return not new_db return not new_db
def save(self) -> None: def save(self):
with self.config.db.open('w', encoding = 'UTF-8') as fd: with self.config.db.open('w') as fd:
json.dump(self, fd, indent=4) json.dump(self, fd, indent=4)
def get_inbox(self, domain: str, fail: Optional[bool] = False) -> dict[str, str] | None: def get_inbox(self, domain, fail=False):
if domain.startswith('http'): if domain.startswith('http'):
domain = urlparse(domain).hostname domain = urlparse(domain).hostname
if (inbox := self['relay-list'].get(domain)): inbox = self['relay-list'].get(domain)
if inbox:
return inbox return inbox
if fail: if fail:
raise KeyError(domain) raise KeyError(domain)
return None
def add_inbox(self,
inbox: str,
followid: Optional[str] = None,
software: Optional[str] = None) -> dict[str, str]:
def add_inbox(self, inbox, followid=None, software=None):
assert inbox.startswith('https'), 'Inbox must be a url' assert inbox.startswith('https'), 'Inbox must be a url'
domain = urlparse(inbox).hostname domain = urlparse(inbox).hostname
instance = self.get_inbox(domain) instance = self.get_inbox(domain)
@ -138,15 +125,11 @@ class RelayDatabase(dict):
'software': software 'software': software
} }
logging.verbose('Added inbox to database: %s', inbox) logging.verbose(f'Added inbox to database: {inbox}')
return self['relay-list'][domain] return self['relay-list'][domain]
def del_inbox(self, def del_inbox(self, domain, followid=None, fail=False):
domain: str,
followid: Optional[str] = None,
fail: Optional[bool] = False) -> bool:
data = self.get_inbox(domain, fail=False) data = self.get_inbox(domain, fail=False)
if not data: if not data:
@ -157,17 +140,17 @@ class RelayDatabase(dict):
if not data['followid'] or not followid or data['followid'] == followid: if not data['followid'] or not followid or data['followid'] == followid:
del self['relay-list'][data['domain']] del self['relay-list'][data['domain']]
logging.verbose('Removed inbox from database: %s', data['inbox']) logging.verbose(f'Removed inbox from database: {data["inbox"]}')
return True return True
if fail: if fail:
raise ValueError('Follow IDs do not match') raise ValueError('Follow IDs do not match')
logging.debug('Follow ID does not match: db = %s, object = %s', data['followid'], followid) logging.debug(f'Follow ID does not match: db = {data["followid"]}, object = {followid}')
return False return False
def get_request(self, domain: str, fail: bool = True) -> dict[str, str] | None: def get_request(self, domain, fail=True):
if domain.startswith('http'): if domain.startswith('http'):
domain = urlparse(domain).hostname domain = urlparse(domain).hostname
@ -178,10 +161,8 @@ class RelayDatabase(dict):
if fail: if fail:
raise e raise e
return None
def add_request(self, actor, inbox, followid):
def add_request(self, actor: str, inbox: str, followid: str) -> None:
domain = urlparse(inbox).hostname domain = urlparse(inbox).hostname
try: try:
@ -198,17 +179,17 @@ class RelayDatabase(dict):
} }
def del_request(self, domain: str) -> None: def del_request(self, domain):
if domain.startswith('http'): if domain.startswith('http'):
domain = urlparse(domain).hostname domain = urlparse(inbox).hostname
del self['follow-requests'][domain] del self['follow-requests'][domain]
def distill_inboxes(self, message: Message) -> Iterator[str]: def distill_inboxes(self, message):
src_domains = { src_domains = {
message.domain, message.domain,
urlparse(message.object_id).netloc urlparse(message.objectid).netloc
} }
for domain, instance in self['relay-list'].items(): for domain, instance in self['relay-list'].items():

View file

@ -1,23 +1,21 @@
from __future__ import annotations import logging
import traceback import traceback
import typing
from aiohttp import ClientSession, ClientTimeout, TCPConnector from aiohttp import ClientSession, ClientTimeout, TCPConnector
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
from asyncio.exceptions import TimeoutError as AsyncTimeoutError from asyncio.exceptions import TimeoutError as AsyncTimeoutError
from aputils.objects import Nodeinfo, WellKnownNodeinfo from aputils import Nodeinfo, WellKnownNodeinfo
from datetime import datetime
from cachetools import LRUCache from cachetools import LRUCache
from json.decoder import JSONDecodeError from json.decoder import JSONDecodeError
from urllib.parse import urlparse from urllib.parse import urlparse
from . import __version__ from . import __version__
from . import logger as logging from .misc import (
from .misc import MIMETYPES, Message MIMETYPES,
DotDict,
if typing.TYPE_CHECKING: Message
from typing import Any, Callable, Optional )
from .database import RelayDatabase
HEADERS = { HEADERS = {
@ -26,31 +24,40 @@ HEADERS = {
} }
class HttpClient: class Cache(LRUCache):
def __init__(self, def set_maxsize(self, value):
database: RelayDatabase, self.__maxsize = int(value)
limit: Optional[int] = 100,
timeout: Optional[int] = 10,
cache_size: Optional[int] = 1024):
class HttpClient:
def __init__(self, database, limit=100, timeout=10, cache_size=1024):
self.database = database self.database = database
self.cache = LRUCache(cache_size) self.cache = Cache(cache_size)
self.limit = limit self.cfg = {'limit': limit, 'timeout': timeout}
self.timeout = timeout
self._conn = None self._conn = None
self._session = None self._session = None
async def __aenter__(self) -> HttpClient: async def __aenter__(self):
await self.open() await self.open()
return self return self
async def __aexit__(self, *_: Any) -> None: async def __aexit__(self, *_):
await self.close() await self.close()
async def open(self) -> None: @property
def limit(self):
return self.cfg['limit']
@property
def timeout(self):
return self.cfg['timeout']
async def open(self):
if self._session: if self._session:
return return
@ -67,7 +74,7 @@ class HttpClient:
) )
async def close(self) -> None: async def close(self):
if not self._session: if not self._session:
return return
@ -78,19 +85,11 @@ class HttpClient:
self._session = None self._session = None
async def get(self, # pylint: disable=too-many-branches async def get(self, url, sign_headers=False, loads=None, force=False):
url: str,
sign_headers: Optional[bool] = False,
loads: Optional[Callable] = None,
force: Optional[bool] = False) -> Message | dict | None:
await self.open() await self.open()
try: try: url, _ = url.split('#', 1)
url, _ = url.split('#', 1) except: pass
except ValueError:
pass
if not force and url in self.cache: if not force and url in self.cache:
return self.cache[url] return self.cache[url]
@ -101,53 +100,51 @@ class HttpClient:
headers.update(self.database.signer.sign_headers('GET', url, algorithm='original')) headers.update(self.database.signer.sign_headers('GET', url, algorithm='original'))
try: try:
logging.debug('Fetching resource: %s', url) logging.verbose(f'Fetching resource: {url}')
async with self._session.get(url, headers=headers) as resp: async with self._session.get(url, headers=headers) as resp:
## Not expecting a response with 202s, so just return ## Not expecting a response with 202s, so just return
if resp.status == 202: if resp.status == 202:
return None return
if resp.status != 200: elif resp.status != 200:
logging.verbose('Received error when requesting %s: %i', url, resp.status) logging.verbose(f'Received error when requesting {url}: {resp.status}')
logging.debug(await resp.read()) logging.verbose(await resp.read()) # change this to debug
return None return
if loads: if loads:
message = await resp.json(loads=loads) message = await resp.json(loads=loads)
elif resp.content_type == MIMETYPES['activity']: elif resp.content_type == MIMETYPES['activity']:
message = await resp.json(loads = Message.parse) message = await resp.json(loads=Message.new_from_json)
elif resp.content_type == MIMETYPES['json']: elif resp.content_type == MIMETYPES['json']:
message = await resp.json() message = await resp.json(loads=DotDict.new_from_json)
else: else:
logging.verbose('Invalid Content-Type for "%s": %s', url, resp.content_type) # todo: raise TypeError or something
logging.debug('Response: %s', await resp.read()) logging.verbose(f'Invalid Content-Type for "{url}": {resp.content_type}')
return None return logging.debug(f'Response: {resp.read()}')
logging.debug('%s >> resp %s', url, message.to_json(4)) logging.debug(f'{url} >> resp {message.to_json(4)}')
self.cache[url] = message self.cache[url] = message
return message return message
except JSONDecodeError: except JSONDecodeError:
logging.verbose('Failed to parse JSON') logging.verbose(f'Failed to parse JSON')
except ClientSSLError: except ClientSSLError:
logging.verbose('SSL error when connecting to %s', urlparse(url).netloc) logging.verbose(f'SSL error when connecting to {urlparse(url).netloc}')
except (AsyncTimeoutError, ClientConnectionError): except (AsyncTimeoutError, ClientConnectionError):
logging.verbose('Failed to connect to %s', urlparse(url).netloc) logging.verbose(f'Failed to connect to {urlparse(url).netloc}')
except Exception: except Exception as e:
traceback.print_exc() traceback.print_exc()
return None
async def post(self, url, message):
async def post(self, url: str, message: Message) -> None:
await self.open() await self.open()
instance = self.database.get_inbox(url) instance = self.database.get_inbox(url)
@ -163,39 +160,38 @@ class HttpClient:
headers.update(self.database.signer.sign_headers('POST', url, message, algorithm=algorithm)) headers.update(self.database.signer.sign_headers('POST', url, message, algorithm=algorithm))
try: try:
logging.verbose('Sending "%s" to %s', message.type, url) logging.verbose(f'Sending "{message.type}" to {url}')
async with self._session.post(url, headers=headers, data=message.to_json()) as resp: async with self._session.post(url, headers=headers, data=message.to_json()) as resp:
## Not expecting a response, so just return ## Not expecting a response, so just return
if resp.status in {200, 202}: if resp.status in {200, 202}:
logging.verbose('Successfully sent "%s" to %s', message.type, url) return logging.verbose(f'Successfully sent "{message.type}" to {url}')
return
logging.verbose('Received error when pushing to %s: %i', url, resp.status) logging.verbose(f'Received error when pushing to {url}: {resp.status}')
logging.debug(await resp.read()) return logging.verbose(await resp.read()) # change this to debug
return
except ClientSSLError: except ClientSSLError:
logging.warning('SSL error when pushing to %s', urlparse(url).netloc) logging.warning(f'SSL error when pushing to {urlparse(url).netloc}')
except (AsyncTimeoutError, ClientConnectionError): except (AsyncTimeoutError, ClientConnectionError):
logging.warning('Failed to connect to %s for message push', urlparse(url).netloc) logging.warning(f'Failed to connect to {urlparse(url).netloc} for message push')
## prevent workers from being brought down ## prevent workers from being brought down
except Exception: except Exception as e:
traceback.print_exc() traceback.print_exc()
async def fetch_nodeinfo(self, domain: str) -> Nodeinfo | None: ## Additional methods ##
async def fetch_nodeinfo(self, domain):
nodeinfo_url = None nodeinfo_url = None
wk_nodeinfo = await self.get( wk_nodeinfo = await self.get(
f'https://{domain}/.well-known/nodeinfo', f'https://{domain}/.well-known/nodeinfo',
loads = WellKnownNodeinfo.parse loads = WellKnownNodeinfo.new_from_json
) )
if not wk_nodeinfo: if not wk_nodeinfo:
logging.verbose('Failed to fetch well-known nodeinfo url for %s', domain) logging.verbose(f'Failed to fetch well-known nodeinfo url for domain: {domain}')
return None return False
for version in ['20', '21']: for version in ['20', '21']:
try: try:
@ -205,22 +201,22 @@ class HttpClient:
pass pass
if not nodeinfo_url: if not nodeinfo_url:
logging.verbose('Failed to fetch nodeinfo url for %s', domain) logging.verbose(f'Failed to fetch nodeinfo url for domain: {domain}')
return None return False
return await self.get(nodeinfo_url, loads = Nodeinfo.parse) or None return await self.get(nodeinfo_url, loads=Nodeinfo.new_from_json) or False
async def get(database: RelayDatabase, *args: Any, **kwargs: Any) -> Message | dict | None: async def get(database, *args, **kwargs):
async with HttpClient(database) as client: async with HttpClient(database) as client:
return await client.get(*args, **kwargs) return await client.get(*args, **kwargs)
async def post(database: RelayDatabase, *args: Any, **kwargs: Any) -> None: async def post(database, *args, **kwargs):
async with HttpClient(database) as client: async with HttpClient(database) as client:
return await client.post(*args, **kwargs) return await client.post(*args, **kwargs)
async def fetch_nodeinfo(database: RelayDatabase, *args: Any, **kwargs: Any) -> Nodeinfo | None: async def fetch_nodeinfo(database, *args, **kwargs):
async with HttpClient(database) as client: async with HttpClient(database) as client:
return await client.fetch_nodeinfo(*args, **kwargs) return await client.fetch_nodeinfo(*args, **kwargs)

View file

@ -1,57 +1,40 @@
from __future__ import annotations
import logging import logging
import os import os
import typing
from pathlib import Path from pathlib import Path
if typing.TYPE_CHECKING:
from typing import Any, Callable
## Add the verbose logging level
LOG_LEVELS: dict[str, int] = { def verbose(message, *args, **kwargs):
'DEBUG': logging.DEBUG, if not logging.root.isEnabledFor(logging.VERBOSE):
'VERBOSE': 15,
'INFO': logging.INFO,
'WARNING': logging.WARNING,
'ERROR': logging.ERROR,
'CRITICAL': logging.CRITICAL
}
debug: Callable = logging.debug
info: Callable = logging.info
warning: Callable = logging.warning
error: Callable = logging.error
critical: Callable = logging.critical
def verbose(message: str, *args: Any, **kwargs: Any) -> None:
if not logging.root.isEnabledFor(LOG_LEVELS['VERBOSE']):
return return
logging.log(LOG_LEVELS['VERBOSE'], message, *args, **kwargs) logging.log(logging.VERBOSE, message, *args, **kwargs)
setattr(logging, 'verbose', verbose)
setattr(logging, 'VERBOSE', 15)
logging.addLevelName(15, 'VERBOSE')
logging.addLevelName(LOG_LEVELS['VERBOSE'], 'VERBOSE') ## Get log level and file from environment if possible
env_log_level = os.environ.get('LOG_LEVEL', 'INFO').upper() env_log_level = os.environ.get('LOG_LEVEL', 'INFO').upper()
try: try:
env_log_file = Path(os.environ['LOG_FILE']).expanduser().resolve() env_log_file = Path(os.environ.get('LOG_FILE')).expanduser().resolve()
except KeyError: except TypeError:
env_log_file = None env_log_file = None
## Make sure the level from the environment is valid
try: try:
log_level = LOG_LEVELS[env_log_level] log_level = getattr(logging, env_log_level)
except KeyError: except AttributeError:
logging.warning('Invalid log level: %s', env_log_level)
log_level = logging.INFO log_level = logging.INFO
## Set logging config
handlers = [logging.StreamHandler()] handlers = [logging.StreamHandler()]
if env_log_file: if env_log_file:
@ -59,6 +42,6 @@ if env_log_file:
logging.basicConfig( logging.basicConfig(
level = log_level, level = log_level,
format = '[%(asctime)s] %(levelname)s: %(message)s', format = "[%(asctime)s] %(levelname)s: %(message)s",
handlers = handlers handlers = handlers
) )

View file

@ -1,10 +1,8 @@
from __future__ import annotations
import Crypto import Crypto
import asyncio import asyncio
import click import click
import logging
import platform import platform
import typing
from urllib.parse import urlparse from urllib.parse import urlparse
@ -13,12 +11,6 @@ from . import http_client as http
from .application import Application from .application import Application
from .config import RELAY_SOFTWARE from .config import RELAY_SOFTWARE
if typing.TYPE_CHECKING:
from typing import Any
# pylint: disable=unsubscriptable-object,unsupported-assignment-operation
app = None app = None
CONFIG_IGNORE = {'blocked_software', 'blocked_instances', 'whitelist'} CONFIG_IGNORE = {'blocked_software', 'blocked_instances', 'whitelist'}
@ -28,7 +20,7 @@ CONFIG_IGNORE = {'blocked_software', 'blocked_instances', 'whitelist'}
@click.option('--config', '-c', default='relay.yaml', help='path to the relay\'s config') @click.option('--config', '-c', default='relay.yaml', help='path to the relay\'s config')
@click.version_option(version=__version__, prog_name='ActivityRelay') @click.version_option(version=__version__, prog_name='ActivityRelay')
@click.pass_context @click.pass_context
def cli(ctx: click.Context, config: str) -> None: def cli(ctx, config):
global app global app
app = Application(config) app = Application(config)
@ -41,14 +33,11 @@ def cli(ctx: click.Context, config: str) -> None:
@cli.command('setup') @cli.command('setup')
def cli_setup() -> None: def cli_setup():
'Generate a new config' 'Generate a new config'
while True: while True:
app.config.host = click.prompt( app.config.host = click.prompt('What domain will the relay be hosted on?', default=app.config.host)
'What domain will the relay be hosted on?',
default = app.config.host
)
if not app.config.host.endswith('example.com'): if not app.config.host.endswith('example.com'):
break break
@ -56,18 +45,10 @@ def cli_setup() -> None:
click.echo('The domain must not be example.com') click.echo('The domain must not be example.com')
if not app.config.is_docker: if not app.config.is_docker:
app.config.listen = click.prompt( app.config.listen = click.prompt('Which address should the relay listen on?', default=app.config.listen)
'Which address should the relay listen on?',
default = app.config.listen
)
while True: while True:
app.config.port = click.prompt( app.config.port = click.prompt('What TCP port should the relay listen on?', default=app.config.port, type=int)
'What TCP port should the relay listen on?',
default = app.config.port,
type = int
)
break break
app.config.save() app.config.save()
@ -77,47 +58,39 @@ def cli_setup() -> None:
@cli.command('run') @cli.command('run')
def cli_run() -> None: def cli_run():
'Run the relay' 'Run the relay'
if app.config.host.endswith('example.com'): if app.config.host.endswith('example.com'):
click.echo( return click.echo('Relay is not set up. Please edit your relay config or run "activityrelay setup".')
'Relay is not set up. Please edit your relay config or run "activityrelay setup".'
)
return
vers_split = platform.python_version().split('.') vers_split = platform.python_version().split('.')
pip_command = 'pip3 uninstall pycrypto && pip3 install pycryptodome' pip_command = 'pip3 uninstall pycrypto && pip3 install pycryptodome'
if Crypto.__version__ == '2.6.1': if Crypto.__version__ == '2.6.1':
if int(vers_split[1]) > 7: if int(vers_split[1]) > 7:
click.echo( click.echo('Error: PyCrypto is broken on Python 3.8+. Please replace it with pycryptodome before running again. Exiting...')
'Error: PyCrypto is broken on Python 3.8+. Please replace it with pycryptodome ' + return click.echo(pip_command)
'before running again. Exiting...'
)
click.echo(pip_command)
return
else:
click.echo('Warning: PyCrypto is old and should be replaced with pycryptodome') click.echo('Warning: PyCrypto is old and should be replaced with pycryptodome')
click.echo(pip_command) return click.echo(pip_command)
return
if not misc.check_open_port(app.config.listen, app.config.port): if not misc.check_open_port(app.config.listen, app.config.port):
click.echo(f'Error: A server is already running on port {app.config.port}') return click.echo(f'Error: A server is already running on port {app.config.port}')
return
app.run() app.run()
# todo: add config default command for resetting config key
@cli.group('config') @cli.group('config')
def cli_config() -> None: def cli_config():
'Manage the relay config' 'Manage the relay config'
pass
@cli_config.command('list') @cli_config.command('list')
def cli_config_list() -> None: def cli_config_list():
'List the current relay config' 'List the current relay config'
click.echo('Relay Config:') click.echo('Relay Config:')
@ -131,7 +104,7 @@ def cli_config_list() -> None:
@cli_config.command('set') @cli_config.command('set')
@click.argument('key') @click.argument('key')
@click.argument('value') @click.argument('value')
def cli_config_set(key: str, value: Any) -> None: def cli_config_set(key, value):
'Set a config value' 'Set a config value'
app.config[key] = value app.config[key] = value
@ -141,12 +114,13 @@ def cli_config_set(key: str, value: Any) -> None:
@cli.group('inbox') @cli.group('inbox')
def cli_inbox() -> None: def cli_inbox():
'Manage the inboxes in the database' 'Manage the inboxes in the database'
pass
@cli_inbox.command('list') @cli_inbox.command('list')
def cli_inbox_list() -> None: def cli_inbox_list():
'List the connected instances or relays' 'List the connected instances or relays'
click.echo('Connected to the following instances or relays:') click.echo('Connected to the following instances or relays:')
@ -157,12 +131,11 @@ def cli_inbox_list() -> None:
@cli_inbox.command('follow') @cli_inbox.command('follow')
@click.argument('actor') @click.argument('actor')
def cli_inbox_follow(actor: str) -> None: def cli_inbox_follow(actor):
'Follow an actor (Relay must be running)' 'Follow an actor (Relay must be running)'
if app.config.is_banned(actor): if app.config.is_banned(actor):
click.echo(f'Error: Refusing to follow banned actor: {actor}') return click.echo(f'Error: Refusing to follow banned actor: {actor}')
return
if not actor.startswith('http'): if not actor.startswith('http'):
domain = actor domain = actor
@ -179,8 +152,7 @@ def cli_inbox_follow(actor: str) -> None:
actor_data = asyncio.run(http.get(app.database, actor, sign_headers=True)) actor_data = asyncio.run(http.get(app.database, actor, sign_headers=True))
if not actor_data: if not actor_data:
click.echo(f'Failed to fetch actor: {actor}') return click.echo(f'Failed to fetch actor: {actor}')
return
inbox = actor_data.shared_inbox inbox = actor_data.shared_inbox
@ -195,7 +167,7 @@ def cli_inbox_follow(actor: str) -> None:
@cli_inbox.command('unfollow') @cli_inbox.command('unfollow')
@click.argument('actor') @click.argument('actor')
def cli_inbox_unfollow(actor: str) -> None: def cli_inbox_unfollow(actor):
'Unfollow an actor (Relay must be running)' 'Unfollow an actor (Relay must be running)'
if not actor.startswith('http'): if not actor.startswith('http'):
@ -233,19 +205,17 @@ def cli_inbox_unfollow(actor: str) -> None:
@cli_inbox.command('add') @cli_inbox.command('add')
@click.argument('inbox') @click.argument('inbox')
def cli_inbox_add(inbox: str) -> None: def cli_inbox_add(inbox):
'Add an inbox to the database' 'Add an inbox to the database'
if not inbox.startswith('http'): if not inbox.startswith('http'):
inbox = f'https://{inbox}/inbox' inbox = f'https://{inbox}/inbox'
if app.config.is_banned(inbox): if app.config.is_banned(inbox):
click.echo(f'Error: Refusing to add banned inbox: {inbox}') return click.echo(f'Error: Refusing to add banned inbox: {inbox}')
return
if app.database.get_inbox(inbox): if app.database.get_inbox(inbox):
click.echo(f'Error: Inbox already in database: {inbox}') return click.echo(f'Error: Inbox already in database: {inbox}')
return
app.database.add_inbox(inbox) app.database.add_inbox(inbox)
app.database.save() app.database.save()
@ -255,7 +225,7 @@ def cli_inbox_add(inbox: str) -> None:
@cli_inbox.command('remove') @cli_inbox.command('remove')
@click.argument('inbox') @click.argument('inbox')
def cli_inbox_remove(inbox: str) -> None: def cli_inbox_remove(inbox):
'Remove an inbox from the database' 'Remove an inbox from the database'
try: try:
@ -272,12 +242,13 @@ def cli_inbox_remove(inbox: str) -> None:
@cli.group('instance') @cli.group('instance')
def cli_instance() -> None: def cli_instance():
'Manage instance bans' 'Manage instance bans'
pass
@cli_instance.command('list') @cli_instance.command('list')
def cli_instance_list() -> None: def cli_instance_list():
'List all banned instances' 'List all banned instances'
click.echo('Banned instances or relays:') click.echo('Banned instances or relays:')
@ -288,7 +259,7 @@ def cli_instance_list() -> None:
@cli_instance.command('ban') @cli_instance.command('ban')
@click.argument('target') @click.argument('target')
def cli_instance_ban(target: str) -> None: def cli_instance_ban(target):
'Ban an instance and remove the associated inbox if it exists' 'Ban an instance and remove the associated inbox if it exists'
if target.startswith('http'): if target.startswith('http'):
@ -308,7 +279,7 @@ def cli_instance_ban(target: str) -> None:
@cli_instance.command('unban') @cli_instance.command('unban')
@click.argument('target') @click.argument('target')
def cli_instance_unban(target: str) -> None: def cli_instance_unban(target):
'Unban an instance' 'Unban an instance'
if app.config.unban_instance(target): if app.config.unban_instance(target):
@ -321,12 +292,13 @@ def cli_instance_unban(target: str) -> None:
@cli.group('software') @cli.group('software')
def cli_software() -> None: def cli_software():
'Manage banned software' 'Manage banned software'
pass
@cli_software.command('list') @cli_software.command('list')
def cli_software_list() -> None: def cli_software_list():
'List all banned software' 'List all banned software'
click.echo('Banned software:') click.echo('Banned software:')
@ -336,21 +308,19 @@ def cli_software_list() -> None:
@cli_software.command('ban') @cli_software.command('ban')
@click.option( @click.option('--fetch-nodeinfo/--ignore-nodeinfo', '-f', 'fetch_nodeinfo', default=False,
'--fetch-nodeinfo/--ignore-nodeinfo', '-f', 'fetch_nodeinfo', default = False, help='Treat NAME like a domain and try to fet the software name from nodeinfo'
help = 'Treat NAME like a domain and try to fet the software name from nodeinfo'
) )
@click.argument('name') @click.argument('name')
def cli_software_ban(name: str, fetch_nodeinfo: bool) -> None: def cli_software_ban(name, fetch_nodeinfo):
'Ban software. Use RELAYS for NAME to ban relays' 'Ban software. Use RELAYS for NAME to ban relays'
if name == 'RELAYS': if name == 'RELAYS':
for software in RELAY_SOFTWARE: for name in RELAY_SOFTWARE:
app.config.ban_software(software) app.config.ban_software(name)
app.config.save() app.config.save()
click.echo('Banned all relay software') return click.echo('Banned all relay software')
return
if fetch_nodeinfo: if fetch_nodeinfo:
nodeinfo = asyncio.run(http.fetch_nodeinfo(app.database, name)) nodeinfo = asyncio.run(http.fetch_nodeinfo(app.database, name))
@ -362,28 +332,25 @@ def cli_software_ban(name: str, fetch_nodeinfo: bool) -> None:
if app.config.ban_software(name): if app.config.ban_software(name):
app.config.save() app.config.save()
click.echo(f'Banned software: {name}') return click.echo(f'Banned software: {name}')
return
click.echo(f'Software already banned: {name}') click.echo(f'Software already banned: {name}')
@cli_software.command('unban') @cli_software.command('unban')
@click.option( @click.option('--fetch-nodeinfo/--ignore-nodeinfo', '-f', 'fetch_nodeinfo', default=False,
'--fetch-nodeinfo/--ignore-nodeinfo', '-f', 'fetch_nodeinfo', default = False, help='Treat NAME like a domain and try to fet the software name from nodeinfo'
help = 'Treat NAME like a domain and try to fet the software name from nodeinfo'
) )
@click.argument('name') @click.argument('name')
def cli_software_unban(name: str, fetch_nodeinfo: bool) -> None: def cli_software_unban(name, fetch_nodeinfo):
'Ban software. Use RELAYS for NAME to unban relays' 'Ban software. Use RELAYS for NAME to unban relays'
if name == 'RELAYS': if name == 'RELAYS':
for software in RELAY_SOFTWARE: for name in RELAY_SOFTWARE:
app.config.unban_software(software) app.config.unban_software(name)
app.config.save() app.config.save()
click.echo('Unbanned all relay software') return click.echo('Unbanned all relay software')
return
if fetch_nodeinfo: if fetch_nodeinfo:
nodeinfo = asyncio.run(http.fetch_nodeinfo(app.database, name)) nodeinfo = asyncio.run(http.fetch_nodeinfo(app.database, name))
@ -395,19 +362,19 @@ def cli_software_unban(name: str, fetch_nodeinfo: bool) -> None:
if app.config.unban_software(name): if app.config.unban_software(name):
app.config.save() app.config.save()
click.echo(f'Unbanned software: {name}') return click.echo(f'Unbanned software: {name}')
return
click.echo(f'Software wasn\'t banned: {name}') click.echo(f'Software wasn\'t banned: {name}')
@cli.group('whitelist') @cli.group('whitelist')
def cli_whitelist() -> None: def cli_whitelist():
'Manage the instance whitelist' 'Manage the instance whitelist'
pass
@cli_whitelist.command('list') @cli_whitelist.command('list')
def cli_whitelist_list() -> None: def cli_whitelist_list():
'List all the instances in the whitelist' 'List all the instances in the whitelist'
click.echo('Current whitelisted domains') click.echo('Current whitelisted domains')
@ -418,12 +385,11 @@ def cli_whitelist_list() -> None:
@cli_whitelist.command('add') @cli_whitelist.command('add')
@click.argument('instance') @click.argument('instance')
def cli_whitelist_add(instance: str) -> None: def cli_whitelist_add(instance):
'Add an instance to the whitelist' 'Add an instance to the whitelist'
if not app.config.add_whitelist(instance): if not app.config.add_whitelist(instance):
click.echo(f'Instance already in the whitelist: {instance}') return click.echo(f'Instance already in the whitelist: {instance}')
return
app.config.save() app.config.save()
click.echo(f'Instance added to the whitelist: {instance}') click.echo(f'Instance added to the whitelist: {instance}')
@ -431,12 +397,11 @@ def cli_whitelist_add(instance: str) -> None:
@cli_whitelist.command('remove') @cli_whitelist.command('remove')
@click.argument('instance') @click.argument('instance')
def cli_whitelist_remove(instance: str) -> None: def cli_whitelist_remove(instance):
'Remove an instance from the whitelist' 'Remove an instance from the whitelist'
if not app.config.del_whitelist(instance): if not app.config.del_whitelist(instance):
click.echo(f'Instance not in the whitelist: {instance}') return click.echo(f'Instance not in the whitelist: {instance}')
return
app.config.save() app.config.save()
@ -448,15 +413,14 @@ def cli_whitelist_remove(instance: str) -> None:
@cli_whitelist.command('import') @cli_whitelist.command('import')
def cli_whitelist_import() -> None: def cli_whitelist_import():
'Add all current inboxes to the whitelist' 'Add all current inboxes to the whitelist'
for domain in app.database.hostnames: for domain in app.database.hostnames:
cli_whitelist_add.callback(domain) cli_whitelist_add.callback(domain)
def main() -> None: def main():
# pylint: disable=no-value-for-parameter
cli(prog_name='relay') cli(prog_name='relay')

View file

@ -1,31 +1,21 @@
from __future__ import annotations import aputils
import asyncio
import base64
import json import json
import logging
import socket import socket
import traceback import traceback
import typing import uuid
from aiohttp.abc import AbstractView
from aiohttp.hdrs import METH_ALL as METHODS from aiohttp.hdrs import METH_ALL as METHODS
from aiohttp.web import Request as AiohttpRequest, Response as AiohttpResponse from aiohttp.web import Response as AiohttpResponse, View as AiohttpView
from aiohttp.web_exceptions import HTTPMethodNotAllowed from datetime import datetime
from aputils.errors import SignatureFailureError
from aputils.misc import Digest, HttpDate, Signature
from aputils.message import Message as ApMessage
from functools import cached_property
from json.decoder import JSONDecodeError from json.decoder import JSONDecodeError
from urllib.parse import urlparse
from uuid import uuid4 from uuid import uuid4
from . import logger as logging
if typing.TYPE_CHECKING:
from typing import Any, Coroutine, Generator, Optional, Type
from aputils.signer import Signer
from .application import Application
from .config import RelayConfig
from .database import RelayDatabase
from .http_client import HttpClient
app = None
MIMETYPES = { MIMETYPES = {
'activity': 'application/activity+json', 'activity': 'application/activity+json',
@ -40,87 +30,94 @@ NODEINFO_NS = {
} }
def boolean(value: Any) -> bool: def set_app(new_app):
global app
app = new_app
def boolean(value):
if isinstance(value, str): if isinstance(value, str):
if value.lower() in ['on', 'y', 'yes', 'true', 'enable', 'enabled', '1']: if value.lower() in ['on', 'y', 'yes', 'true', 'enable', 'enabled', '1']:
return True return True
if value.lower() in ['off', 'n', 'no', 'false', 'disable', 'disable', '0']: elif value.lower() in ['off', 'n', 'no', 'false', 'disable', 'disable', '0']:
return False return False
else:
raise TypeError(f'Cannot parse string "{value}" as a boolean') raise TypeError(f'Cannot parse string "{value}" as a boolean')
if isinstance(value, int): elif isinstance(value, int):
if value == 1: if value == 1:
return True return True
if value == 0: elif value == 0:
return False return False
else:
raise ValueError('Integer value must be 1 or 0') raise ValueError('Integer value must be 1 or 0')
if value is None: elif value == None:
return False return False
return bool(value) try:
return value.__bool__()
except AttributeError:
raise TypeError(f'Cannot convert object of type "{clsname(value)}"')
def check_open_port(host: str, port: int) -> bool: def check_open_port(host, port):
if host == '0.0.0.0': if host == '0.0.0.0':
host = '127.0.0.1' host = '127.0.0.1'
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try: try:
return s.connect_ex((host, port)) != 0 return s.connect_ex((host , port)) != 0
except socket.error: except socket.error as e:
return False return False
class DotDict(dict): class DotDict(dict):
def __init__(self, _data: dict[str, Any], **kwargs: Any): def __init__(self, _data, **kwargs):
dict.__init__(self) dict.__init__(self)
self.update(_data, **kwargs) self.update(_data, **kwargs)
def __getattr__(self, key: str) -> str: def __getattr__(self, k):
try: try:
return self[key] return self[k]
except KeyError: except KeyError:
raise AttributeError( raise AttributeError(f'{self.__class__.__name__} object has no attribute {k}') from None
f'{self.__class__.__name__} object has no attribute {key}'
) from None
def __setattr__(self, key: str, value: Any) -> None: def __setattr__(self, k, v):
if key.startswith('_'): if k.startswith('_'):
super().__setattr__(key, value) super().__setattr__(k, v)
else: else:
self[key] = value self[k] = v
def __setitem__(self, key: str, value: Any) -> None: def __setitem__(self, k, v):
if type(value) is dict: # pylint: disable=unidiomatic-typecheck if type(v) == dict:
value = DotDict(value) v = DotDict(v)
super().__setitem__(key, value) super().__setitem__(k, v)
def __delattr__(self, key: str) -> None: def __delattr__(self, k):
try: try:
dict.__delitem__(self, key) dict.__delitem__(self, k)
except KeyError: except KeyError:
raise AttributeError( raise AttributeError(f'{self.__class__.__name__} object has no attribute {k}') from None
f'{self.__class__.__name__} object has no attribute {key}'
) from None
@classmethod @classmethod
def new_from_json(cls: Type[DotDict], data: dict[str, Any]) -> DotDict[str, Any]: def new_from_json(cls, data):
if not data: if not data:
raise JSONDecodeError('Empty body', data, 1) raise JSONDecodeError('Empty body', data, 1)
@ -128,11 +125,11 @@ class DotDict(dict):
return cls(json.loads(data)) return cls(json.loads(data))
except ValueError: except ValueError:
raise JSONDecodeError('Invalid body', data, 1) from None raise JSONDecodeError('Invalid body', data, 1)
@classmethod @classmethod
def new_from_signature(cls: Type[DotDict], sig: str) -> DotDict[str, Any]: def new_from_signature(cls, sig):
data = cls({}) data = cls({})
for chunk in sig.strip().split(','): for chunk in sig.strip().split(','):
@ -147,11 +144,11 @@ class DotDict(dict):
return data return data
def to_json(self, indent: Optional[int | str] = None) -> str: def to_json(self, indent=None):
return json.dumps(self, indent=indent) return json.dumps(self, indent=indent)
def update(self, _data: dict[str, Any], **kwargs: Any) -> None: def update(self, _data, **kwargs):
if isinstance(_data, dict): if isinstance(_data, dict):
for key, value in _data.items(): for key, value in _data.items():
self[key] = value self[key] = value
@ -164,13 +161,9 @@ class DotDict(dict):
self[key] = value self[key] = value
class Message(ApMessage): class Message(DotDict):
@classmethod @classmethod
def new_actor(cls: Type[Message], # pylint: disable=arguments-differ def new_actor(cls, host, pubkey, description=None):
host: str,
pubkey: str,
description: Optional[str] = None) -> Message:
return cls({ return cls({
'@context': 'https://www.w3.org/ns/activitystreams', '@context': 'https://www.w3.org/ns/activitystreams',
'id': f'https://{host}/actor', 'id': f'https://{host}/actor',
@ -194,34 +187,34 @@ class Message(ApMessage):
@classmethod @classmethod
def new_announce(cls: Type[Message], host: str, obj: str) -> Message: def new_announce(cls, host, object):
return cls({ return cls({
'@context': 'https://www.w3.org/ns/activitystreams', '@context': 'https://www.w3.org/ns/activitystreams',
'id': f'https://{host}/activities/{uuid4()}', 'id': f'https://{host}/activities/{uuid.uuid4()}',
'type': 'Announce', 'type': 'Announce',
'to': [f'https://{host}/followers'], 'to': [f'https://{host}/followers'],
'actor': f'https://{host}/actor', 'actor': f'https://{host}/actor',
'object': obj 'object': object
}) })
@classmethod @classmethod
def new_follow(cls: Type[Message], host: str, actor: str) -> Message: def new_follow(cls, host, actor):
return cls({ return cls({
'@context': 'https://www.w3.org/ns/activitystreams', '@context': 'https://www.w3.org/ns/activitystreams',
'type': 'Follow', 'type': 'Follow',
'to': [actor], 'to': [actor],
'object': actor, 'object': actor,
'id': f'https://{host}/activities/{uuid4()}', 'id': f'https://{host}/activities/{uuid.uuid4()}',
'actor': f'https://{host}/actor' 'actor': f'https://{host}/actor'
}) })
@classmethod @classmethod
def new_unfollow(cls: Type[Message], host: str, actor: str, follow: str) -> Message: def new_unfollow(cls, host, actor, follow):
return cls({ return cls({
'@context': 'https://www.w3.org/ns/activitystreams', '@context': 'https://www.w3.org/ns/activitystreams',
'id': f'https://{host}/activities/{uuid4()}', 'id': f'https://{host}/activities/{uuid.uuid4()}',
'type': 'Undo', 'type': 'Undo',
'to': [actor], 'to': [actor],
'actor': f'https://{host}/actor', 'actor': f'https://{host}/actor',
@ -230,15 +223,10 @@ class Message(ApMessage):
@classmethod @classmethod
def new_response(cls: Type[Message], def new_response(cls, host, actor, followid, accept):
host: str,
actor: str,
followid: str,
accept: bool) -> Message:
return cls({ return cls({
'@context': 'https://www.w3.org/ns/activitystreams', '@context': 'https://www.w3.org/ns/activitystreams',
'id': f'https://{host}/activities/{uuid4()}', 'id': f'https://{host}/activities/{uuid.uuid4()}',
'type': 'Accept' if accept else 'Reject', 'type': 'Accept' if accept else 'Reject',
'to': [actor], 'to': [actor],
'actor': f'https://{host}/actor', 'actor': f'https://{host}/actor',
@ -251,24 +239,43 @@ class Message(ApMessage):
}) })
# todo: remove when fixed in aputils # misc properties
@property @property
def object_id(self) -> str: def domain(self):
try: return urlparse(self.id).hostname
return self["object"]["id"]
except (KeyError, TypeError):
return self["object"] # actor properties
@property
def shared_inbox(self):
return self.get('endpoints', {}).get('sharedInbox', self.inbox)
# activity properties
@property
def actorid(self):
if isinstance(self.actor, dict):
return self.actor.id
return self.actor
@property
def objectid(self):
if isinstance(self.object, dict):
return self.object.id
return self.object
@property
def signer(self):
return aputils.Signer.new_from_actor(self)
class Response(AiohttpResponse): class Response(AiohttpResponse):
@classmethod @classmethod
def new(cls: Type[Response], def new(cls, body='', status=200, headers=None, ctype='text'):
body: Optional[str | bytes | dict] = '',
status: Optional[int] = 200,
headers: Optional[dict[str, str]] = None,
ctype: Optional[str] = 'text') -> Response:
kwargs = { kwargs = {
'status': status, 'status': status,
'headers': headers, 'headers': headers,
@ -288,11 +295,7 @@ class Response(AiohttpResponse):
@classmethod @classmethod
def new_error(cls: Type[Response], def new_error(cls, status, body, ctype='text'):
status: int,
body: str | bytes | dict,
ctype: str = 'text') -> Response:
if ctype == 'json': if ctype == 'json':
body = json.dumps({'status': status, 'error': body}) body = json.dumps({'status': status, 'error': body})
@ -300,157 +303,38 @@ class Response(AiohttpResponse):
@property @property
def location(self) -> str: def location(self):
return self.headers.get('Location') return self.headers.get('Location')
@location.setter @location.setter
def location(self, value: str) -> None: def location(self, value):
self.headers['Location'] = value self.headers['Location'] = value
class View(AbstractView): class View(AiohttpView):
def __init__(self, request: AiohttpRequest): async def _iter(self):
AbstractView.__init__(self, request) if self.request.method not in METHODS:
self._raise_allowed_methods()
self.signature: Signature = None method = getattr(self, self.request.method.lower(), None)
self.message: Message = None
self.actor: Message = None
self.instance: dict[str, str] = None
self.signer: Signer = None
if method is None:
self._raise_allowed_methods()
def __await__(self) -> Generator[Response]: return await method(**self.request.match_info)
method = self.request.method.upper()
if method not in METHODS:
raise HTTPMethodNotAllowed(method, self.allowed_methods)
if not (handler := self.handlers.get(method)):
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) from None
return handler(self.request, **self.request.match_info).__await__()
@cached_property
def allowed_methods(self) -> tuple[str]:
return tuple(self.handlers.keys())
@cached_property
def handlers(self) -> dict[str, Coroutine]:
data = {}
for method in METHODS:
try:
data[method] = getattr(self, method.lower())
except AttributeError:
continue
return data
# app components
@property
def app(self) -> Application:
return self.request.app
@property @property
def client(self) -> HttpClient: def app(self):
return self.app.client return self._request.app
@property @property
def config(self) -> RelayConfig: def config(self):
return self.app.config return self.app.config
@property @property
def database(self) -> RelayDatabase: def database(self):
return self.app.database return self.app.database
# todo: move to views.ActorView
async def get_post_data(self) -> Response | None:
try:
self.signature = Signature.new_from_signature(self.request.headers['signature'])
except KeyError:
logging.verbose('Missing signature header')
return Response.new_error(400, 'missing signature header', 'json')
try:
self.message = await self.request.json(loads = Message.parse)
except Exception:
traceback.print_exc()
logging.verbose('Failed to parse inbox message')
return Response.new_error(400, 'failed to parse message', 'json')
if self.message is None:
logging.verbose('empty message')
return Response.new_error(400, 'missing message', 'json')
if 'actor' not in self.message:
logging.verbose('actor not in message')
return Response.new_error(400, 'no actor in message', 'json')
self.actor = await self.client.get(self.signature.keyid, sign_headers = True)
if self.actor is None:
# ld signatures aren't handled atm, so just ignore it
if self.message.type == 'Delete':
logging.verbose('Instance sent a delete which cannot be handled')
return Response.new(status=202)
logging.verbose(f'Failed to fetch actor: {self.signature.keyid}')
return Response.new_error(400, 'failed to fetch actor', 'json')
try:
self.signer = self.actor.signer
except KeyError:
logging.verbose('Actor missing public key: %s', self.signature.keyid)
return Response.new_error(400, 'actor missing public key', 'json')
try:
self.validate_signature(await self.request.read())
except SignatureFailureError as e:
logging.verbose('signature validation failed for "%s": %s', self.actor.id, e)
return Response.new_error(401, str(e), 'json')
self.instance = self.database.get_inbox(self.actor.inbox)
def validate_signature(self, body: bytes) -> None:
headers = {key.lower(): value for key, value in self.request.headers.items()}
headers["(request-target)"] = " ".join([self.request.method.lower(), self.request.path])
if (digest := Digest.new_from_digest(headers.get("digest"))):
if not body:
raise SignatureFailureError("Missing body for digest verification")
if not digest.validate(body):
raise SignatureFailureError("Body digest does not match")
if self.signature.algorithm_type == "hs2019":
if "(created)" not in self.signature.headers:
raise SignatureFailureError("'(created)' header not used")
current_timestamp = HttpDate.new_utc().timestamp()
if self.signature.created > current_timestamp:
raise SignatureFailureError("Creation date after current date")
if current_timestamp > self.signature.expires:
raise SignatureFailureError("Expiration date before current date")
headers["(created)"] = self.signature.created
headers["(expires)"] = self.signature.expires
# pylint: disable=protected-access
if not self.signer._validate_signature(headers, self.signature):
raise SignatureFailureError("Signature does not match")

View file

@ -1,21 +1,17 @@
from __future__ import annotations import asyncio
import logging
import typing
from cachetools import LRUCache from cachetools import LRUCache
from uuid import uuid4
from . import logger as logging
from .misc import Message from .misc import Message
if typing.TYPE_CHECKING:
from .misc import View
cache = LRUCache(1024) cache = LRUCache(1024)
def person_check(actor: str, software: str) -> bool: def person_check(actor, software):
# pleroma and akkoma may use Person for the actor type for some reason ## pleroma and akkoma may use Person for the actor type for some reason
if software in {'akkoma', 'pleroma'} and actor.id == f'https://{actor.domain}/relay': if software in {'akkoma', 'pleroma'} and actor.id == f'https://{actor.domain}/relay':
return False return False
@ -23,85 +19,86 @@ def person_check(actor: str, software: str) -> bool:
if actor.type != 'Application': if actor.type != 'Application':
return True return True
return False
async def handle_relay(request):
async def handle_relay(view: View) -> None: if request.message.objectid in cache:
if view.message.object_id in cache: logging.verbose(f'already relayed {request.message.objectid}')
logging.verbose('already relayed %s', view.message.object_id)
return return
message = Message.new_announce(view.config.host, view.message.object_id) message = Message.new_announce(
cache[view.message.object_id] = message.id host = request.config.host,
logging.debug('>> relay: %s', message) object = request.message.objectid
)
inboxes = view.database.distill_inboxes(view.message) cache[request.message.objectid] = message.id
logging.debug(f'>> relay: {message}')
inboxes = request.database.distill_inboxes(request.message)
for inbox in inboxes: for inbox in inboxes:
view.app.push_message(inbox, message) request.app.push_message(inbox, message)
async def handle_forward(view: View) -> None: async def handle_forward(request):
if view.message.id in cache: if request.message.id in cache:
logging.verbose('already forwarded %s', view.message.id) logging.verbose(f'already forwarded {request.message.id}')
return return
message = Message.new_announce(view.config.host, view.message) message = Message.new_announce(
cache[view.message.id] = message.id host = request.config.host,
logging.debug('>> forward: %s', message) object = request.message
)
inboxes = view.database.distill_inboxes(view.message) cache[request.message.id] = message.id
logging.debug(f'>> forward: {message}')
inboxes = request.database.distill_inboxes(request.message)
for inbox in inboxes: for inbox in inboxes:
view.app.push_message(inbox, message) request.app.push_message(inbox, message)
async def handle_follow(view: View) -> None: async def handle_follow(request):
nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain) nodeinfo = await request.app.client.fetch_nodeinfo(request.actor.domain)
software = nodeinfo.sw_name if nodeinfo else None software = nodeinfo.sw_name if nodeinfo else None
## reject if software used by actor is banned ## reject if software used by actor is banned
if view.config.is_banned_software(software): if request.config.is_banned_software(software):
view.app.push_message( request.app.push_message(
view.actor.shared_inbox, request.actor.shared_inbox,
Message.new_response( Message.new_response(
host = view.config.host, host = request.config.host,
actor = view.actor.id, actor = request.actor.id,
followid = view.message.id, followid = request.message.id,
accept = False accept = False
) )
) )
return logging.verbose( return logging.verbose(f'Rejected follow from actor for using specific software: actor={request.actor.id}, software={software}')
'Rejected follow from actor for using specific software: actor=%s, software=%s',
view.actor.id,
software
)
## reject if the actor is not an instance actor ## reject if the actor is not an instance actor
if person_check(view.actor, software): if person_check(request.actor, software):
view.app.push_message( request.app.push_message(
view.actor.shared_inbox, request.actor.shared_inbox,
Message.new_response( Message.new_response(
host = view.config.host, host = request.config.host,
actor = view.actor.id, actor = request.actor.id,
followid = view.message.id, followid = request.message.id,
accept = False accept = False
) )
) )
logging.verbose('Non-application actor tried to follow: %s', view.actor.id) return logging.verbose(f'Non-application actor tried to follow: {request.actor.id}')
return
view.database.add_inbox(view.actor.shared_inbox, view.message.id, software) request.database.add_inbox(request.actor.shared_inbox, request.message.id, software)
view.database.save() request.database.save()
view.app.push_message( request.app.push_message(
view.actor.shared_inbox, request.actor.shared_inbox,
Message.new_response( Message.new_response(
host = view.config.host, host = request.config.host,
actor = view.actor.id, actor = request.actor.id,
followid = view.message.id, followid = request.message.id,
accept = True accept = True
) )
) )
@ -109,37 +106,31 @@ async def handle_follow(view: View) -> None:
# Are Akkoma and Pleroma the only two that expect a follow back? # Are Akkoma and Pleroma the only two that expect a follow back?
# Ignoring only Mastodon for now # Ignoring only Mastodon for now
if software != 'mastodon': if software != 'mastodon':
view.app.push_message( request.app.push_message(
view.actor.shared_inbox, request.actor.shared_inbox,
Message.new_follow( Message.new_follow(
host = view.config.host, host = request.config.host,
actor = view.actor.id actor = request.actor.id
) )
) )
async def handle_undo(view: View) -> None: async def handle_undo(request):
## If the object is not a Follow, forward it ## If the object is not a Follow, forward it
if view.message.object['type'] != 'Follow': if request.message.object.type != 'Follow':
return await handle_forward(view) return await handle_forward(request)
if not view.database.del_inbox(view.actor.domain, view.message.object['id']):
logging.verbose(
'Failed to delete "%s" with follow ID "%s"',
view.actor.id,
view.message.object['id']
)
if not request.database.del_inbox(request.actor.domain, request.message.id):
return return
view.database.save() request.database.save()
view.app.push_message( request.app.push_message(
view.actor.shared_inbox, request.actor.shared_inbox,
Message.new_unfollow( Message.new_unfollow(
host = view.config.host, host = request.config.host,
actor = view.actor.id, actor = request.actor.id,
follow = view.message follow = request.message
) )
) )
@ -154,22 +145,16 @@ processors = {
} }
async def run_processor(view: View) -> None: async def run_processor(request):
if view.message.type not in processors: if request.message.type not in processors:
logging.verbose(
'Message type "%s" from actor cannot be handled: %s',
view.message.type,
view.actor.id
)
return return
if view.instance and not view.instance.get('software'): if request.instance and not request.instance.get('software'):
nodeinfo = await view.client.fetch_nodeinfo(view.instance['domain']) nodeinfo = await request.app.client.fetch_nodeinfo(request.instance['domain'])
if nodeinfo: if nodeinfo:
view.instance['software'] = nodeinfo.sw_name request.instance['software'] = nodeinfo.sw_name
view.database.save() request.database.save()
logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id) logging.verbose(f'New "{request.message.type}" from actor: {request.actor.id}')
await processors[view.message.type](view) return await processors[request.message.type](request)

View file

@ -1,170 +1,191 @@
from __future__ import annotations import aputils
import asyncio import asyncio
import logging
import subprocess import subprocess
import typing import traceback
from aputils.objects import Nodeinfo, Webfinger, WellKnownNodeinfo
from pathlib import Path from pathlib import Path
from . import __version__ from . import __version__, misc
from . import logger as logging from .misc import DotDict, Message, Response
from .misc import Message, Response, View
from .processors import run_processor from .processors import run_processor
if typing.TYPE_CHECKING:
from aiohttp.web import Request
from typing import Callable
routes = []
VIEWS = [] version = __version__
VERSION = __version__
HOME_TEMPLATE = """
<html><head>
<title>ActivityPub Relay at {host}</title>
<style>
p {{ color: #FFFFFF; font-family: monospace, arial; font-size: 100%; }}
body {{ background-color: #000000; }}
a {{ color: #26F; }}
a:visited {{ color: #46C; }}
a:hover {{ color: #8AF; }}
</style>
</head>
<body>
<p>This is an Activity Relay for fediverse instances.</p>
<p>{note}</p>
<p>
You may subscribe to this relay with the address:
<a href="https://{host}/actor">https://{host}/actor</a>
</p>
<p>
To host your own relay, you may download the code at this address:
<a href="https://git.pleroma.social/pleroma/relay">
https://git.pleroma.social/pleroma/relay
</a>
</p>
<br><p>List of {count} registered instances:<br>{targets}</p>
</body></html>
"""
if Path(__file__).parent.parent.joinpath('.git').exists(): if Path(__file__).parent.parent.joinpath('.git').exists():
try: try:
commit_label = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode('ascii') commit_label = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode('ascii')
VERSION = f'{__version__} {commit_label}' version = f'{__version__} {commit_label}'
except Exception: except:
pass pass
def register_route(*paths: str) -> Callable: def register_route(method, path):
def wrapper(view: View) -> View: def wrapper(func):
for path in paths: routes.append([method, path, func])
VIEWS.append([path, view]) return func
return View
return wrapper return wrapper
# pylint: disable=unused-argument @register_route('GET', '/')
async def home(request):
targets = '<br>'.join(request.database.hostnames)
note = request.config.note
count = len(request.database.hostnames)
host = request.config.host
@register_route('/') text = f"""
class HomeView(View): <html><head>
async def get(self, request: Request) -> Response: <title>ActivityPub Relay at {host}</title>
text = HOME_TEMPLATE.format( <style>
host = self.config.host, p {{ color: #FFFFFF; font-family: monospace, arial; font-size: 100%; }}
note = self.config.note, body {{ background-color: #000000; }}
count = len(self.database.hostnames), a {{ color: #26F; }}
targets = '<br>'.join(self.database.hostnames) a:visited {{ color: #46C; }}
) a:hover {{ color: #8AF; }}
</style>
</head>
<body>
<p>This is an Activity Relay for fediverse instances.</p>
<p>{note}</p>
<p>You may subscribe to this relay with the address: <a href="https://{host}/actor">https://{host}/actor</a></p>
<p>To host your own relay, you may download the code at this address: <a href="https://git.pleroma.social/pleroma/relay">https://git.pleroma.social/pleroma/relay</a></p>
<br><p>List of {count} registered instances:<br>{targets}</p>
</body></html>"""
return Response.new(text, ctype='html') return Response.new(text, ctype='html')
@register_route('GET', '/inbox')
@register_route('/actor', '/inbox') @register_route('GET', '/actor')
class ActorView(View): async def actor(request):
async def get(self, request: Request) -> Response:
data = Message.new_actor( data = Message.new_actor(
host = self.config.host, host = request.config.host,
pubkey = self.database.signer.pubkey pubkey = request.database.signer.pubkey
) )
return Response.new(data, ctype='activity') return Response.new(data, ctype='activity')
async def post(self, request: Request) -> Response: @register_route('POST', '/inbox')
response = await self.get_post_data() @register_route('POST', '/actor')
async def inbox(request):
config = request.config
database = request.database
if response is not None: ## reject if missing signature header
return response if not request.signature:
logging.verbose('Actor missing signature header')
raise HTTPUnauthorized(body='missing signature')
try:
request['message'] = await request.json(loads=Message.new_from_json)
## reject if there is no message
if not request.message:
logging.verbose('empty message')
return Response.new_error(400, 'missing message', 'json')
## reject if there is no actor in the message
if 'actor' not in request.message:
logging.verbose('actor not in message')
return Response.new_error(400, 'no actor in message', 'json')
except:
## this code should hopefully never get called
traceback.print_exc()
logging.verbose('Failed to parse inbox message')
return Response.new_error(400, 'failed to parse message', 'json')
request['actor'] = await request.app.client.get(request.signature.keyid, sign_headers=True)
## reject if actor is empty
if not request.actor:
## ld signatures aren't handled atm, so just ignore it
if request['message'].type == 'Delete':
logging.verbose(f'Instance sent a delete which cannot be handled')
return Response.new(status=202)
logging.verbose(f'Failed to fetch actor: {request.signature.keyid}')
return Response.new_error(400, 'failed to fetch actor', 'json')
request['instance'] = request.database.get_inbox(request['actor'].inbox)
## reject if the actor isn't whitelisted while the whiltelist is enabled ## reject if the actor isn't whitelisted while the whiltelist is enabled
if self.config.whitelist_enabled and not self.config.is_whitelisted(self.actor.domain): if config.whitelist_enabled and not config.is_whitelisted(request.actor.domain):
logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id) logging.verbose(f'Rejected actor for not being in the whitelist: {request.actor.id}')
return Response.new_error(403, 'access denied', 'json') return Response.new_error(403, 'access denied', 'json')
## reject if actor is banned ## reject if actor is banned
if self.config.is_banned(self.actor.domain): if request.config.is_banned(request.actor.domain):
logging.verbose('Ignored request from banned actor: %s', self.actor.id) logging.verbose(f'Ignored request from banned actor: {actor.id}')
return Response.new_error(403, 'access denied', 'json') return Response.new_error(403, 'access denied', 'json')
## reject if activity type isn't 'Follow' and the actor isn't following ## reject if the signature is invalid
if self.message.type != 'Follow' and not self.database.get_inbox(self.actor.domain): try:
logging.verbose( await request.actor.signer.validate_aiohttp_request(request)
'Rejected actor for trying to post while not following: %s',
self.actor.id
)
except aputils.SignatureValidationError as e:
logging.verbose(f'signature validation failed for: {actor.id}')
logging.debug(str(e))
return Response.new_error(401, str(e), 'json')
## reject if activity type isn't 'Follow' and the actor isn't following
if request.message.type != 'Follow' and not database.get_inbox(request.actor.domain):
logging.verbose(f'Rejected actor for trying to post while not following: {request.actor.id}')
return Response.new_error(401, 'access denied', 'json') return Response.new_error(401, 'access denied', 'json')
logging.debug('>> payload %s', self.message.to_json(4)) logging.debug(f">> payload {request.message.to_json(4)}")
asyncio.ensure_future(run_processor(self)) asyncio.ensure_future(run_processor(request))
return Response.new(status = 202) return Response.new(status=202)
@register_route('/.well-known/webfinger') @register_route('GET', '/.well-known/webfinger')
class WebfingerView(View): async def webfinger(request):
async def get(self, request: Request) -> Response:
try: try:
subject = request.query['resource'] subject = request.query['resource']
except KeyError: except KeyError:
return Response.new_error(400, 'missing "resource" query key', 'json') return Response.new_error(400, 'missing \'resource\' query key', 'json')
if subject != f'acct:relay@{self.config.host}': if subject != f'acct:relay@{request.config.host}':
return Response.new_error(404, 'user not found', 'json') return Response.new_error(404, 'user not found', 'json')
data = Webfinger.new( data = aputils.Webfinger.new(
handle = 'relay', handle = 'relay',
domain = self.config.host, domain = request.config.host,
actor = self.config.actor actor = request.config.actor
) )
return Response.new(data, ctype = 'json') return Response.new(data, ctype='json')
@register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}') @register_route('GET', '/nodeinfo/{version:\d.\d\.json}')
class NodeinfoView(View): async def nodeinfo(request):
async def get(self, request: Request, niversion: str) -> Response: niversion = request.match_info['version'][:3]
data = {
'name': 'activityrelay', data = dict(
'version': VERSION, name = 'activityrelay',
'protocols': ['activitypub'], version = version,
'open_regs': not self.config.whitelist_enabled, protocols = ['activitypub'],
'users': 1, open_regs = not request.config.whitelist_enabled,
'metadata': {'peers': self.database.hostnames} users = 1,
} metadata = {'peers': request.database.hostnames}
)
if niversion == '2.1': if niversion == '2.1':
data['repo'] = 'https://git.pleroma.social/pleroma/relay' data['repo'] = 'https://git.pleroma.social/pleroma/relay'
return Response.new(Nodeinfo.new(**data), ctype = 'json') return Response.new(aputils.Nodeinfo.new(**data), ctype='json')
@register_route('/.well-known/nodeinfo') @register_route('GET', '/.well-known/nodeinfo')
class WellknownNodeinfoView(View): async def nodeinfo_wellknown(request):
async def get(self, request: Request) -> Response: data = aputils.WellKnownNodeinfo.new_template(request.config.host)
data = WellKnownNodeinfo.new_template(self.config.host) return Response.new(data, ctype='json')
return Response.new(data, ctype = 'json')

View file

@ -1,5 +1,5 @@
aiohttp>=3.9.1 aiohttp>=3.8.0
aputils@https://git.barkshark.xyz/barkshark/aputils/archive/0.1.6a.tar.gz aputils@https://git.barkshark.xyz/barkshark/aputils/archive/0.1.4.tar.gz
cachetools>=5.2.0 cachetools>=5.2.0
click>=8.1.2 click>=8.1.2
pyyaml>=6.0 pyyaml>=6.0

View file

@ -23,21 +23,13 @@ project_urls =
zip_safe = False zip_safe = False
packages = find: packages = find:
install_requires = file: requirements.txt install_requires = file: requirements.txt
python_requires = >=3.8 python_requires = >=3.7
[options.extras_require] [options.extras_require]
dev = dev =
flake8 = 3.1.0 pyinstaller >= 5.6.0
pyinstaller = 6.3.0
pylint = 3.0
[options.entry_points] [options.entry_points]
console_scripts = console_scripts =
activityrelay = relay.manage:main activityrelay = relay.manage:main
[flake8]
extend-ignore = ANN101,ANN204,E128,E251,E261,E266,E301,E303,W191
extend-exclude = docs, test*.py
max-line-length = 100
indent-size = 4