add annotations and fix linter warnings

This commit is contained in:
Izalia Mae 2024-01-14 14:13:06 -05:00
parent fdef2f708c
commit 9bf45a54d1
10 changed files with 391 additions and 304 deletions

View file

@ -1,3 +1 @@
__version__ = '0.2.4'
from . import logger

View file

@ -1,9 +1,11 @@
from __future__ import annotations
import asyncio
import os
import queue
import signal
import threading
import traceback
import typing
from aiohttp import web
from datetime import datetime, timedelta
@ -12,20 +14,29 @@ from . import logger as logging
from .config import RelayConfig
from .database import RelayDatabase
from .http_client import HttpClient
from .misc import DotDict, check_open_port, set_app
from .misc import check_open_port
from .views import VIEWS
if typing.TYPE_CHECKING:
from typing import Any
from .misc import Message
# pylint: disable=unsubscriptable-object
class Application(web.Application):
def __init__(self, cfgpath):
def __init__(self, cfgpath: str):
web.Application.__init__(self)
self['starttime'] = None
self['workers'] = []
self['last_worker'] = 0
self['start_time'] = None
self['running'] = False
self['config'] = RelayConfig(cfgpath)
if not self['config'].load():
self['config'].save()
if not self.config.load():
self.config.save()
if self.config.is_docker:
self.config.update({
@ -34,13 +45,8 @@ class Application(web.Application):
'port': 8080
})
self['workers'] = []
self['last_worker'] = 0
set_app(self)
self['database'] = RelayDatabase(self['config'])
self['database'].load()
self['database'] = RelayDatabase(self.config)
self.database.load()
self['client'] = HttpClient(
database = self.database,
@ -54,33 +60,34 @@ class Application(web.Application):
@property
def client(self):
def client(self) -> HttpClient:
return self['client']
@property
def config(self):
def config(self) -> RelayConfig:
return self['config']
@property
def database(self):
def database(self) -> RelayDatabase:
return self['database']
@property
def uptime(self):
if not self['starttime']:
def uptime(self) -> timedelta:
if not self['start_time']:
return timedelta(seconds=0)
uptime = datetime.now() - self['starttime']
uptime = datetime.now() - self['start_time']
return timedelta(seconds=uptime.seconds)
def push_message(self, inbox, message):
def push_message(self, inbox: str, message: Message) -> None:
if self.config.workers <= 0:
return asyncio.ensure_future(self.client.post(inbox, message))
asyncio.ensure_future(self.client.post(inbox, message))
return
worker = self['workers'][self['last_worker']]
worker.queue.put((inbox, message))
@ -91,8 +98,8 @@ class Application(web.Application):
self['last_worker'] = 0
def set_signal_handler(self, startup):
for sig in {'SIGHUP', 'SIGINT', 'SIGQUIT', 'SIGTERM'}:
def set_signal_handler(self, startup: bool) -> None:
for sig in ('SIGHUP', 'SIGINT', 'SIGQUIT', 'SIGTERM'):
try:
signal.signal(getattr(signal, sig), self.stop if startup else signal.SIG_DFL)
@ -101,9 +108,10 @@ class Application(web.Application):
pass
def run(self):
def run(self) -> None:
if not check_open_port(self.config.listen, self.config.port):
return logging.error('A server is already running on port %i', self.config.port)
logging.error('A server is already running on port %i', self.config.port)
return
for view in VIEWS:
self.router.add_view(*view)
@ -118,17 +126,17 @@ class Application(web.Application):
asyncio.run(self.handle_run())
def stop(self, *_):
def stop(self, *_: Any) -> None:
self['running'] = False
async def handle_run(self):
async def handle_run(self) -> None:
self['running'] = True
self.set_signal_handler(True)
if self.config.workers > 0:
for i in range(self.config.workers):
for _ in range(self.config.workers):
worker = PushWorker(self)
worker.start()
@ -137,14 +145,15 @@ class Application(web.Application):
runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"')
await runner.setup()
site = web.TCPSite(runner,
site = web.TCPSite(
runner,
host = self.config.listen,
port = self.config.port,
reuse_address = True
)
await site.start()
self['starttime'] = datetime.now()
self['start_time'] = datetime.now()
while self['running']:
await asyncio.sleep(0.25)
@ -152,23 +161,24 @@ class Application(web.Application):
await site.stop()
await self.client.close()
self['starttime'] = None
self['start_time'] = None
self['running'] = False
self['workers'].clear()
class PushWorker(threading.Thread):
def __init__(self, app):
def __init__(self, app: Application):
threading.Thread.__init__(self)
self.app = app
self.queue = queue.Queue()
self.client = None
def run(self):
def run(self) -> None:
asyncio.run(self.handle_queue())
async def handle_queue(self):
async def handle_queue(self) -> None:
self.client = HttpClient(
database = self.app.database,
limit = self.app.config.push_limit,

View file

@ -1,5 +1,7 @@
import json
from __future__ import annotations
import os
import typing
import yaml
from functools import cached_property
@ -8,6 +10,10 @@ from urllib.parse import urlparse
from .misc import DotDict, boolean
if typing.TYPE_CHECKING:
from typing import Any
from .database import RelayDatabase
RELAY_SOFTWARE = [
'activityrelay', # https://git.pleroma.social/pleroma/relay
@ -25,17 +31,19 @@ APKEYS = [
class RelayConfig(DotDict):
def __init__(self, path):
__slots__ = ('path', )
def __init__(self, path: str | Path):
DotDict.__init__(self, {})
if self.is_docker:
path = '/data/config.yaml'
self._path = Path(path).expanduser()
self._path = Path(path).expanduser().resolve()
self.reset()
def __setitem__(self, key, value):
def __setitem__(self, key: str, value: Any) -> None:
if key in ['blocked_instances', 'blocked_software', 'whitelist']:
assert isinstance(value, (list, set, tuple))
@ -51,36 +59,31 @@ class RelayConfig(DotDict):
@property
def db(self):
def db(self) -> RelayDatabase:
return Path(self['db']).expanduser().resolve()
@property
def path(self):
return self._path
@property
def actor(self):
def actor(self) -> str:
return f'https://{self.host}/actor'
@property
def inbox(self):
def inbox(self) -> str:
return f'https://{self.host}/inbox'
@property
def keyid(self):
def keyid(self) -> str:
return f'{self.actor}#main-key'
@cached_property
def is_docker(self):
def is_docker(self) -> bool:
return bool(os.environ.get('DOCKER_RUNNING'))
def reset(self):
def reset(self) -> None:
self.clear()
self.update({
'db': str(self._path.parent.joinpath(f'{self._path.stem}.jsonld')),
@ -99,7 +102,7 @@ class RelayConfig(DotDict):
})
def ban_instance(self, instance):
def ban_instance(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
@ -110,7 +113,7 @@ class RelayConfig(DotDict):
return True
def unban_instance(self, instance):
def unban_instance(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
@ -118,11 +121,11 @@ class RelayConfig(DotDict):
self.blocked_instances.remove(instance)
return True
except:
except ValueError:
return False
def ban_software(self, software):
def ban_software(self, software: str) -> bool:
if self.is_banned_software(software):
return False
@ -130,16 +133,16 @@ class RelayConfig(DotDict):
return True
def unban_software(self, software):
def unban_software(self, software: str) -> bool:
try:
self.blocked_software.remove(software)
return True
except:
except ValueError:
return False
def add_whitelist(self, instance):
def add_whitelist(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
@ -150,7 +153,7 @@ class RelayConfig(DotDict):
return True
def del_whitelist(self, instance):
def del_whitelist(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
@ -158,32 +161,32 @@ class RelayConfig(DotDict):
self.whitelist.remove(instance)
return True
except:
except ValueError:
return False
def is_banned(self, instance):
def is_banned(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
return instance in self.blocked_instances
def is_banned_software(self, software):
def is_banned_software(self, software: str) -> bool:
if not software:
return False
return software.lower() in self.blocked_software
def is_whitelisted(self, instance):
def is_whitelisted(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
return instance in self.whitelist
def load(self):
def load(self) -> bool:
self.reset()
options = {}
@ -195,7 +198,7 @@ class RelayConfig(DotDict):
pass
try:
with open(self.path) as fd:
with self._path.open('r', encoding = 'UTF-8') as fd:
config = yaml.load(fd, **options)
except FileNotFoundError:
@ -214,7 +217,7 @@ class RelayConfig(DotDict):
continue
elif key not in self:
if key not in self:
continue
self[key] = value
@ -225,7 +228,7 @@ class RelayConfig(DotDict):
return True
def save(self):
def save(self) -> None:
config = {
# just turning config.db into a string is good enough for now
'db': str(self.db),
@ -239,7 +242,5 @@ class RelayConfig(DotDict):
'ap': {key: self[key] for key in APKEYS}
}
with open(self._path, 'w') as fd:
with self._path.open('w', encoding = 'utf-8') as fd:
yaml.dump(config, fd, sort_keys=False)
return config

View file

@ -1,15 +1,21 @@
from __future__ import annotations
import aputils
import asyncio
import json
import traceback
import typing
from urllib.parse import urlparse
from . import logger as logging
if typing.TYPE_CHECKING:
from typing import Iterator, Optional
from .config import RelayConfig
from .misc import Message
class RelayDatabase(dict):
def __init__(self, config):
def __init__(self, config: RelayConfig):
dict.__init__(self, {
'relay-list': {},
'private-key': None,
@ -22,16 +28,16 @@ class RelayDatabase(dict):
@property
def hostnames(self):
def hostnames(self) -> tuple[str]:
return tuple(self['relay-list'].keys())
@property
def inboxes(self):
def inboxes(self) -> tuple[dict[str, str]]:
return tuple(data['inbox'] for data in self['relay-list'].values())
def load(self):
def load(self) -> bool:
new_db = True
try:
@ -41,7 +47,7 @@ class RelayDatabase(dict):
self['version'] = data.get('version', None)
self['private-key'] = data.get('private-key')
if self['version'] == None:
if self['version'] is None:
self['version'] = 1
if 'actorKeys' in data:
@ -59,7 +65,9 @@ class RelayDatabase(dict):
self['relay-list'] = data.get('relay-list', {})
for domain, instance in self['relay-list'].items():
if self.config.is_banned(domain) or (self.config.whitelist_enabled and not self.config.is_whitelisted(domain)):
if self.config.is_banned(domain) or \
(self.config.whitelist_enabled and not self.config.is_whitelisted(domain)):
self.del_inbox(domain)
continue
@ -87,25 +95,29 @@ class RelayDatabase(dict):
return not new_db
def save(self):
with self.config.db.open('w') as fd:
def save(self) -> None:
with self.config.db.open('w', encoding = 'UTF-8') as fd:
json.dump(self, fd, indent=4)
def get_inbox(self, domain, fail=False):
def get_inbox(self, domain: str, fail: Optional[bool] = False) -> dict[str, str] | None:
if domain.startswith('http'):
domain = urlparse(domain).hostname
inbox = self['relay-list'].get(domain)
if inbox:
if (inbox := self['relay-list'].get(domain)):
return inbox
if fail:
raise KeyError(domain)
return None
def add_inbox(self,
inbox: str,
followid: Optional[str] = None,
software: Optional[str] = None) -> dict[str, str]:
def add_inbox(self, inbox, followid=None, software=None):
assert inbox.startswith('https'), 'Inbox must be a url'
domain = urlparse(inbox).hostname
instance = self.get_inbox(domain)
@ -130,7 +142,11 @@ class RelayDatabase(dict):
return self['relay-list'][domain]
def del_inbox(self, domain, followid=None, fail=False):
def del_inbox(self,
domain: str,
followid: Optional[str] = None,
fail: Optional[bool] = False) -> bool:
data = self.get_inbox(domain, fail=False)
if not data:
@ -151,7 +167,7 @@ class RelayDatabase(dict):
return False
def get_request(self, domain, fail=True):
def get_request(self, domain: str, fail: bool = True) -> dict[str, str] | None:
if domain.startswith('http'):
domain = urlparse(domain).hostname
@ -162,8 +178,10 @@ class RelayDatabase(dict):
if fail:
raise e
return None
def add_request(self, actor, inbox, followid):
def add_request(self, actor: str, inbox: str, followid: str) -> None:
domain = urlparse(inbox).hostname
try:
@ -180,14 +198,14 @@ class RelayDatabase(dict):
}
def del_request(self, domain):
def del_request(self, domain: str) -> None:
if domain.startswith('http'):
domain = urlparse(inbox).hostname
domain = urlparse(domain).hostname
del self['follow-requests'][domain]
def distill_inboxes(self, message):
def distill_inboxes(self, message: Message) -> Iterator[str]:
src_domains = {
message.domain,
urlparse(message.objectid).netloc

View file

@ -1,21 +1,23 @@
from __future__ import annotations
import traceback
import typing
from aiohttp import ClientSession, ClientTimeout, TCPConnector
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
from aputils import Nodeinfo, WellKnownNodeinfo
from datetime import datetime
from aputils.objects import Nodeinfo, WellKnownNodeinfo
from cachetools import LRUCache
from json.decoder import JSONDecodeError
from urllib.parse import urlparse
from . import __version__
from . import logger as logging
from .misc import (
MIMETYPES,
DotDict,
Message
)
from .misc import MIMETYPES, Message
if typing.TYPE_CHECKING:
from typing import Any, Callable, Optional
from .database import RelayDatabase
HEADERS = {
@ -24,40 +26,31 @@ HEADERS = {
}
class Cache(LRUCache):
def set_maxsize(self, value):
self.__maxsize = int(value)
class HttpClient:
def __init__(self, database, limit=100, timeout=10, cache_size=1024):
def __init__(self,
database: RelayDatabase,
limit: Optional[int] = 100,
timeout: Optional[int] = 10,
cache_size: Optional[int] = 1024):
self.database = database
self.cache = Cache(cache_size)
self.cfg = {'limit': limit, 'timeout': timeout}
self.cache = LRUCache(cache_size)
self.limit = limit
self.timeout = timeout
self._conn = None
self._session = None
async def __aenter__(self):
async def __aenter__(self) -> HttpClient:
await self.open()
return self
async def __aexit__(self, *_):
async def __aexit__(self, *_: Any) -> None:
await self.close()
@property
def limit(self):
return self.cfg['limit']
@property
def timeout(self):
return self.cfg['timeout']
async def open(self):
async def open(self) -> None:
if self._session:
return
@ -74,7 +67,7 @@ class HttpClient:
)
async def close(self):
async def close(self) -> None:
if not self._session:
return
@ -85,11 +78,19 @@ class HttpClient:
self._session = None
async def get(self, url, sign_headers=False, loads=None, force=False):
async def get(self, # pylint: disable=too-many-branches
url: str,
sign_headers: Optional[bool] = False,
loads: Optional[Callable] = None,
force: Optional[bool] = False) -> Message | dict | None:
await self.open()
try: url, _ = url.split('#', 1)
except: pass
try:
url, _ = url.split('#', 1)
except ValueError:
pass
if not force and url in self.cache:
return self.cache[url]
@ -105,26 +106,26 @@ class HttpClient:
async with self._session.get(url, headers=headers) as resp:
## Not expecting a response with 202s, so just return
if resp.status == 202:
return
return None
elif resp.status != 200:
if resp.status != 200:
logging.verbose('Received error when requesting %s: %i', url, resp.status)
logging.debug(await resp.read())
return
return None
if loads:
message = await resp.json(loads=loads)
elif resp.content_type == MIMETYPES['activity']:
message = await resp.json(loads=Message.parse)
message = await resp.json(loads = Message.parse)
elif resp.content_type == MIMETYPES['json']:
message = await resp.json(loads=DotDict.parse)
message = await resp.json()
else:
# todo: raise TypeError or something
logging.verbose('Invalid Content-Type for "%s": %s', url, resp.content_type)
return logging.debug('Response: %s', await resp.read())
logging.debug('Response: %s', await resp.read())
return None
logging.debug('%s >> resp %s', url, message.to_json(4))
@ -140,11 +141,13 @@ class HttpClient:
except (AsyncTimeoutError, ClientConnectionError):
logging.verbose('Failed to connect to %s', urlparse(url).netloc)
except Exception as e:
except Exception:
traceback.print_exc()
return None
async def post(self, url, message):
async def post(self, url: str, message: Message) -> None:
await self.open()
instance = self.database.get_inbox(url)
@ -165,10 +168,12 @@ class HttpClient:
async with self._session.post(url, headers=headers, data=message.to_json()) as resp:
## Not expecting a response, so just return
if resp.status in {200, 202}:
return logging.verbose('Successfully sent "%s" to %s', message.type, url)
logging.verbose('Successfully sent "%s" to %s', message.type, url)
return
logging.verbose('Received error when pushing to %s: %i', url, resp.status)
return logging.verbose(await resp.read()) # change this to debug
logging.debug(await resp.read())
return
except ClientSSLError:
logging.warning('SSL error when pushing to %s', urlparse(url).netloc)
@ -177,12 +182,11 @@ class HttpClient:
logging.warning('Failed to connect to %s for message push', urlparse(url).netloc)
## prevent workers from being brought down
except Exception as e:
except Exception:
traceback.print_exc()
## Additional methods ##
async def fetch_nodeinfo(self, domain):
async def fetch_nodeinfo(self, domain: str) -> Nodeinfo | None:
nodeinfo_url = None
wk_nodeinfo = await self.get(
f'https://{domain}/.well-known/nodeinfo',
@ -191,7 +195,7 @@ class HttpClient:
if not wk_nodeinfo:
logging.verbose('Failed to fetch well-known nodeinfo url for %s', domain)
return False
return None
for version in ['20', '21']:
try:
@ -202,21 +206,21 @@ class HttpClient:
if not nodeinfo_url:
logging.verbose('Failed to fetch nodeinfo url for %s', domain)
return False
return None
return await self.get(nodeinfo_url, loads=Nodeinfo.parse) or False
return await self.get(nodeinfo_url, loads = Nodeinfo.parse) or None
async def get(database, *args, **kwargs):
async def get(database: RelayDatabase, *args: Any, **kwargs: Any) -> Message | dict | None:
async with HttpClient(database) as client:
return await client.get(*args, **kwargs)
async def post(database, *args, **kwargs):
async def post(database: RelayDatabase, *args: Any, **kwargs: Any) -> None:
async with HttpClient(database) as client:
return await client.post(*args, **kwargs)
async def fetch_nodeinfo(database, *args, **kwargs):
async def fetch_nodeinfo(database: RelayDatabase, *args: Any, **kwargs: Any) -> Nodeinfo | None:
async with HttpClient(database) as client:
return await client.fetch_nodeinfo(*args, **kwargs)

View file

@ -1,10 +1,16 @@
from __future__ import annotations
import logging
import os
import typing
from pathlib import Path
if typing.TYPE_CHECKING:
from typing import Any, Callable
LOG_LEVELS = {
LOG_LEVELS: dict[str, int] = {
'DEBUG': logging.DEBUG,
'VERBOSE': 15,
'INFO': logging.INFO,
@ -14,14 +20,14 @@ LOG_LEVELS = {
}
debug = logging.debug
info = logging.info
warning = logging.warning
error = logging.error
critical = logging.critical
debug: Callable = logging.debug
info: Callable = logging.info
warning: Callable = logging.warning
error: Callable = logging.error
critical: Callable = logging.critical
def verbose(message, *args, **kwargs):
def verbose(message: str, *args: Any, **kwargs: Any) -> None:
if not logging.root.isEnabledFor(LOG_LEVELS['VERBOSE']):
return

View file

@ -1,7 +1,10 @@
from __future__ import annotations
import Crypto
import asyncio
import click
import platform
import typing
from urllib.parse import urlparse
@ -10,6 +13,12 @@ from . import http_client as http
from .application import Application
from .config import RELAY_SOFTWARE
if typing.TYPE_CHECKING:
from typing import Any
# pylint: disable=unsubscriptable-object,unsupported-assignment-operation
app = None
CONFIG_IGNORE = {'blocked_software', 'blocked_instances', 'whitelist'}
@ -19,7 +28,7 @@ CONFIG_IGNORE = {'blocked_software', 'blocked_instances', 'whitelist'}
@click.option('--config', '-c', default='relay.yaml', help='path to the relay\'s config')
@click.version_option(version=__version__, prog_name='ActivityRelay')
@click.pass_context
def cli(ctx, config):
def cli(ctx: click.Context, config: str) -> None:
global app
app = Application(config)
@ -32,11 +41,14 @@ def cli(ctx, config):
@cli.command('setup')
def cli_setup():
def cli_setup() -> None:
'Generate a new config'
while True:
app.config.host = click.prompt('What domain will the relay be hosted on?', default=app.config.host)
app.config.host = click.prompt(
'What domain will the relay be hosted on?',
default = app.config.host
)
if not app.config.host.endswith('example.com'):
break
@ -44,10 +56,18 @@ def cli_setup():
click.echo('The domain must not be example.com')
if not app.config.is_docker:
app.config.listen = click.prompt('Which address should the relay listen on?', default=app.config.listen)
app.config.listen = click.prompt(
'Which address should the relay listen on?',
default = app.config.listen
)
while True:
app.config.port = click.prompt('What TCP port should the relay listen on?', default=app.config.port, type=int)
app.config.port = click.prompt(
'What TCP port should the relay listen on?',
default = app.config.port,
type = int
)
break
app.config.save()
@ -57,39 +77,47 @@ def cli_setup():
@cli.command('run')
def cli_run():
def cli_run() -> None:
'Run the relay'
if app.config.host.endswith('example.com'):
return click.echo('Relay is not set up. Please edit your relay config or run "activityrelay setup".')
click.echo(
'Relay is not set up. Please edit your relay config or run "activityrelay setup".'
)
return
vers_split = platform.python_version().split('.')
pip_command = 'pip3 uninstall pycrypto && pip3 install pycryptodome'
if Crypto.__version__ == '2.6.1':
if int(vers_split[1]) > 7:
click.echo('Error: PyCrypto is broken on Python 3.8+. Please replace it with pycryptodome before running again. Exiting...')
return click.echo(pip_command)
click.echo(
'Error: PyCrypto is broken on Python 3.8+. Please replace it with pycryptodome ' +
'before running again. Exiting...'
)
else:
click.echo('Warning: PyCrypto is old and should be replaced with pycryptodome')
return click.echo(pip_command)
click.echo(pip_command)
return
click.echo('Warning: PyCrypto is old and should be replaced with pycryptodome')
click.echo(pip_command)
return
if not misc.check_open_port(app.config.listen, app.config.port):
return click.echo(f'Error: A server is already running on port {app.config.port}')
click.echo(f'Error: A server is already running on port {app.config.port}')
return
app.run()
# todo: add config default command for resetting config key
@cli.group('config')
def cli_config():
def cli_config() -> None:
'Manage the relay config'
pass
@cli_config.command('list')
def cli_config_list():
def cli_config_list() -> None:
'List the current relay config'
click.echo('Relay Config:')
@ -103,7 +131,7 @@ def cli_config_list():
@cli_config.command('set')
@click.argument('key')
@click.argument('value')
def cli_config_set(key, value):
def cli_config_set(key: str, value: Any) -> None:
'Set a config value'
app.config[key] = value
@ -113,13 +141,12 @@ def cli_config_set(key, value):
@cli.group('inbox')
def cli_inbox():
def cli_inbox() -> None:
'Manage the inboxes in the database'
pass
@cli_inbox.command('list')
def cli_inbox_list():
def cli_inbox_list() -> None:
'List the connected instances or relays'
click.echo('Connected to the following instances or relays:')
@ -130,11 +157,12 @@ def cli_inbox_list():
@cli_inbox.command('follow')
@click.argument('actor')
def cli_inbox_follow(actor):
def cli_inbox_follow(actor: str) -> None:
'Follow an actor (Relay must be running)'
if app.config.is_banned(actor):
return click.echo(f'Error: Refusing to follow banned actor: {actor}')
click.echo(f'Error: Refusing to follow banned actor: {actor}')
return
if not actor.startswith('http'):
domain = actor
@ -151,7 +179,8 @@ def cli_inbox_follow(actor):
actor_data = asyncio.run(http.get(app.database, actor, sign_headers=True))
if not actor_data:
return click.echo(f'Failed to fetch actor: {actor}')
click.echo(f'Failed to fetch actor: {actor}')
return
inbox = actor_data.shared_inbox
@ -166,7 +195,7 @@ def cli_inbox_follow(actor):
@cli_inbox.command('unfollow')
@click.argument('actor')
def cli_inbox_unfollow(actor):
def cli_inbox_unfollow(actor: str) -> None:
'Unfollow an actor (Relay must be running)'
if not actor.startswith('http'):
@ -204,17 +233,19 @@ def cli_inbox_unfollow(actor):
@cli_inbox.command('add')
@click.argument('inbox')
def cli_inbox_add(inbox):
def cli_inbox_add(inbox: str) -> None:
'Add an inbox to the database'
if not inbox.startswith('http'):
inbox = f'https://{inbox}/inbox'
if app.config.is_banned(inbox):
return click.echo(f'Error: Refusing to add banned inbox: {inbox}')
click.echo(f'Error: Refusing to add banned inbox: {inbox}')
return
if app.database.get_inbox(inbox):
return click.echo(f'Error: Inbox already in database: {inbox}')
click.echo(f'Error: Inbox already in database: {inbox}')
return
app.database.add_inbox(inbox)
app.database.save()
@ -224,7 +255,7 @@ def cli_inbox_add(inbox):
@cli_inbox.command('remove')
@click.argument('inbox')
def cli_inbox_remove(inbox):
def cli_inbox_remove(inbox: str) -> None:
'Remove an inbox from the database'
try:
@ -241,13 +272,12 @@ def cli_inbox_remove(inbox):
@cli.group('instance')
def cli_instance():
def cli_instance() -> None:
'Manage instance bans'
pass
@cli_instance.command('list')
def cli_instance_list():
def cli_instance_list() -> None:
'List all banned instances'
click.echo('Banned instances or relays:')
@ -258,7 +288,7 @@ def cli_instance_list():
@cli_instance.command('ban')
@click.argument('target')
def cli_instance_ban(target):
def cli_instance_ban(target: str) -> None:
'Ban an instance and remove the associated inbox if it exists'
if target.startswith('http'):
@ -278,7 +308,7 @@ def cli_instance_ban(target):
@cli_instance.command('unban')
@click.argument('target')
def cli_instance_unban(target):
def cli_instance_unban(target: str) -> None:
'Unban an instance'
if app.config.unban_instance(target):
@ -291,13 +321,12 @@ def cli_instance_unban(target):
@cli.group('software')
def cli_software():
def cli_software() -> None:
'Manage banned software'
pass
@cli_software.command('list')
def cli_software_list():
def cli_software_list() -> None:
'List all banned software'
click.echo('Banned software:')
@ -307,19 +336,21 @@ def cli_software_list():
@cli_software.command('ban')
@click.option('--fetch-nodeinfo/--ignore-nodeinfo', '-f', 'fetch_nodeinfo', default=False,
help='Treat NAME like a domain and try to fet the software name from nodeinfo'
@click.option(
'--fetch-nodeinfo/--ignore-nodeinfo', '-f', 'fetch_nodeinfo', default = False,
help = 'Treat NAME like a domain and try to fet the software name from nodeinfo'
)
@click.argument('name')
def cli_software_ban(name, fetch_nodeinfo):
def cli_software_ban(name: str, fetch_nodeinfo: bool) -> None:
'Ban software. Use RELAYS for NAME to ban relays'
if name == 'RELAYS':
for name in RELAY_SOFTWARE:
app.config.ban_software(name)
for software in RELAY_SOFTWARE:
app.config.ban_software(software)
app.config.save()
return click.echo('Banned all relay software')
click.echo('Banned all relay software')
return
if fetch_nodeinfo:
nodeinfo = asyncio.run(http.fetch_nodeinfo(app.database, name))
@ -331,25 +362,28 @@ def cli_software_ban(name, fetch_nodeinfo):
if app.config.ban_software(name):
app.config.save()
return click.echo(f'Banned software: {name}')
click.echo(f'Banned software: {name}')
return
click.echo(f'Software already banned: {name}')
@cli_software.command('unban')
@click.option('--fetch-nodeinfo/--ignore-nodeinfo', '-f', 'fetch_nodeinfo', default=False,
help='Treat NAME like a domain and try to fet the software name from nodeinfo'
@click.option(
'--fetch-nodeinfo/--ignore-nodeinfo', '-f', 'fetch_nodeinfo', default = False,
help = 'Treat NAME like a domain and try to fet the software name from nodeinfo'
)
@click.argument('name')
def cli_software_unban(name, fetch_nodeinfo):
def cli_software_unban(name: str, fetch_nodeinfo: bool) -> None:
'Ban software. Use RELAYS for NAME to unban relays'
if name == 'RELAYS':
for name in RELAY_SOFTWARE:
app.config.unban_software(name)
for software in RELAY_SOFTWARE:
app.config.unban_software(software)
app.config.save()
return click.echo('Unbanned all relay software')
click.echo('Unbanned all relay software')
return
if fetch_nodeinfo:
nodeinfo = asyncio.run(http.fetch_nodeinfo(app.database, name))
@ -361,19 +395,19 @@ def cli_software_unban(name, fetch_nodeinfo):
if app.config.unban_software(name):
app.config.save()
return click.echo(f'Unbanned software: {name}')
click.echo(f'Unbanned software: {name}')
return
click.echo(f'Software wasn\'t banned: {name}')
@cli.group('whitelist')
def cli_whitelist():
def cli_whitelist() -> None:
'Manage the instance whitelist'
pass
@cli_whitelist.command('list')
def cli_whitelist_list():
def cli_whitelist_list() -> None:
'List all the instances in the whitelist'
click.echo('Current whitelisted domains')
@ -384,11 +418,12 @@ def cli_whitelist_list():
@cli_whitelist.command('add')
@click.argument('instance')
def cli_whitelist_add(instance):
def cli_whitelist_add(instance: str) -> None:
'Add an instance to the whitelist'
if not app.config.add_whitelist(instance):
return click.echo(f'Instance already in the whitelist: {instance}')
click.echo(f'Instance already in the whitelist: {instance}')
return
app.config.save()
click.echo(f'Instance added to the whitelist: {instance}')
@ -396,11 +431,12 @@ def cli_whitelist_add(instance):
@cli_whitelist.command('remove')
@click.argument('instance')
def cli_whitelist_remove(instance):
def cli_whitelist_remove(instance: str) -> None:
'Remove an instance from the whitelist'
if not app.config.del_whitelist(instance):
return click.echo(f'Instance not in the whitelist: {instance}')
click.echo(f'Instance not in the whitelist: {instance}')
return
app.config.save()
@ -412,14 +448,15 @@ def cli_whitelist_remove(instance):
@cli_whitelist.command('import')
def cli_whitelist_import():
def cli_whitelist_import() -> None:
'Add all current inboxes to the whitelist'
for domain in app.database.hostnames:
cli_whitelist_add.callback(domain)
def main():
def main() -> None:
# pylint: disable=no-value-for-parameter
cli(prog_name='relay')

View file

@ -12,20 +12,21 @@ from aiohttp.web_exceptions import HTTPMethodNotAllowed
from aputils.errors import SignatureFailureError
from aputils.misc import Digest, HttpDate, Signature
from aputils.message import Message as ApMessage
from datetime import datetime
from functools import cached_property
from json.decoder import JSONDecodeError
from urllib.parse import urlparse
from uuid import uuid4
from . import logger as logging
if typing.TYPE_CHECKING:
from typing import Coroutine, Generator
from typing import Any, Coroutine, Generator, Optional, Type
from aputils.signer import Signer
from .application import Application
from .config import RelayConfig
from .database import RelayDatabase
from .http_client import HttpClient
app = None
MIMETYPES = {
'activity': 'application/activity+json',
'html': 'text/html',
@ -39,94 +40,87 @@ NODEINFO_NS = {
}
def set_app(new_app):
global app
app = new_app
def boolean(value):
def boolean(value: Any) -> bool:
if isinstance(value, str):
if value.lower() in ['on', 'y', 'yes', 'true', 'enable', 'enabled', '1']:
return True
elif value.lower() in ['off', 'n', 'no', 'false', 'disable', 'disable', '0']:
if value.lower() in ['off', 'n', 'no', 'false', 'disable', 'disable', '0']:
return False
else:
raise TypeError(f'Cannot parse string "{value}" as a boolean')
raise TypeError(f'Cannot parse string "{value}" as a boolean')
elif isinstance(value, int):
if isinstance(value, int):
if value == 1:
return True
elif value == 0:
if value == 0:
return False
else:
raise ValueError('Integer value must be 1 or 0')
raise ValueError('Integer value must be 1 or 0')
elif value == None:
if value is None:
return False
try:
return value.__bool__()
except AttributeError:
raise TypeError(f'Cannot convert object of type "{clsname(value)}"')
return bool(value)
def check_open_port(host, port):
def check_open_port(host: str, port: int) -> bool:
if host == '0.0.0.0':
host = '127.0.0.1'
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
return s.connect_ex((host , port)) != 0
return s.connect_ex((host, port)) != 0
except socket.error as e:
except socket.error:
return False
class DotDict(dict):
def __init__(self, _data, **kwargs):
def __init__(self, _data: dict[str, Any], **kwargs: Any):
dict.__init__(self)
self.update(_data, **kwargs)
def __getattr__(self, k):
def __getattr__(self, key: str) -> str:
try:
return self[k]
return self[key]
except KeyError:
raise AttributeError(f'{self.__class__.__name__} object has no attribute {k}') from None
raise AttributeError(
f'{self.__class__.__name__} object has no attribute {key}'
) from None
def __setattr__(self, k, v):
if k.startswith('_'):
super().__setattr__(k, v)
def __setattr__(self, key: str, value: Any) -> None:
if key.startswith('_'):
super().__setattr__(key, value)
else:
self[k] = v
self[key] = value
def __setitem__(self, k, v):
if type(v) == dict:
v = DotDict(v)
def __setitem__(self, key: str, value: Any) -> None:
if type(value) is dict: # pylint: disable=unidiomatic-typecheck
value = DotDict(value)
super().__setitem__(k, v)
super().__setitem__(key, value)
def __delattr__(self, k):
def __delattr__(self, key: str) -> None:
try:
dict.__delitem__(self, k)
dict.__delitem__(self, key)
except KeyError:
raise AttributeError(f'{self.__class__.__name__} object has no attribute {k}') from None
raise AttributeError(
f'{self.__class__.__name__} object has no attribute {key}'
) from None
@classmethod
def new_from_json(cls, data):
def new_from_json(cls: Type[DotDict], data: dict[str, Any]) -> DotDict[str, Any]:
if not data:
raise JSONDecodeError('Empty body', data, 1)
@ -134,11 +128,11 @@ class DotDict(dict):
return cls(json.loads(data))
except ValueError:
raise JSONDecodeError('Invalid body', data, 1)
raise JSONDecodeError('Invalid body', data, 1) from None
@classmethod
def new_from_signature(cls, sig):
def new_from_signature(cls: Type[DotDict], sig: str) -> DotDict[str, Any]:
data = cls({})
for chunk in sig.strip().split(','):
@ -153,11 +147,11 @@ class DotDict(dict):
return data
def to_json(self, indent=None):
def to_json(self, indent: Optional[int | str] = None) -> str:
return json.dumps(self, indent=indent)
def update(self, _data, **kwargs):
def update(self, _data: dict[str, Any], **kwargs: Any) -> None:
if isinstance(_data, dict):
for key, value in _data.items():
self[key] = value
@ -172,7 +166,11 @@ class DotDict(dict):
class Message(ApMessage):
@classmethod
def new_actor(cls, host, pubkey, description=None):
def new_actor(cls: Type[Message], # pylint: disable=arguments-differ
host: str,
pubkey: str,
description: Optional[str] = None) -> Message:
return cls({
'@context': 'https://www.w3.org/ns/activitystreams',
'id': f'https://{host}/actor',
@ -196,19 +194,19 @@ class Message(ApMessage):
@classmethod
def new_announce(cls, host, object):
def new_announce(cls: Type[Message], host: str, obj: str) -> Message:
return cls({
'@context': 'https://www.w3.org/ns/activitystreams',
'id': f'https://{host}/activities/{uuid4()}',
'type': 'Announce',
'to': [f'https://{host}/followers'],
'actor': f'https://{host}/actor',
'object': object
'object': obj
})
@classmethod
def new_follow(cls, host, actor):
def new_follow(cls: Type[Message], host: str, actor: str) -> Message:
return cls({
'@context': 'https://www.w3.org/ns/activitystreams',
'type': 'Follow',
@ -220,7 +218,7 @@ class Message(ApMessage):
@classmethod
def new_unfollow(cls, host, actor, follow):
def new_unfollow(cls: Type[Message], host: str, actor: str, follow: str) -> Message:
return cls({
'@context': 'https://www.w3.org/ns/activitystreams',
'id': f'https://{host}/activities/{uuid4()}',
@ -232,7 +230,12 @@ class Message(ApMessage):
@classmethod
def new_response(cls, host, actor, followid, accept):
def new_response(cls: Type[Message],
host: str,
actor: str,
followid: str,
accept: bool) -> Message:
return cls({
'@context': 'https://www.w3.org/ns/activitystreams',
'id': f'https://{host}/activities/{uuid4()}',
@ -250,7 +253,12 @@ class Message(ApMessage):
class Response(AiohttpResponse):
@classmethod
def new(cls, body='', status=200, headers=None, ctype='text'):
def new(cls: Type[Response],
body: Optional[str | bytes | dict] = '',
status: Optional[int] = 200,
headers: Optional[dict[str, str]] = None,
ctype: Optional[str] = 'text') -> Response:
kwargs = {
'status': status,
'headers': headers,
@ -270,7 +278,11 @@ class Response(AiohttpResponse):
@classmethod
def new_error(cls, status, body, ctype='text'):
def new_error(cls: Type[Response],
status: int,
body: str | bytes | dict,
ctype: str = 'text') -> Response:
if ctype == 'json':
body = json.dumps({'status': status, 'error': body})
@ -278,12 +290,12 @@ class Response(AiohttpResponse):
@property
def location(self):
def location(self) -> str:
return self.headers.get('Location')
@location.setter
def location(self, value):
def location(self, value: str) -> None:
self.headers['Location'] = value
@ -295,6 +307,7 @@ class View(AbstractView):
self.message: Message = None
self.actor: Message = None
self.instance: dict[str, str] = None
self.signer: Signer = None
def __await__(self) -> Generator[Response]:
@ -335,7 +348,7 @@ class View(AbstractView):
@property
def client(self) -> Client:
def client(self) -> HttpClient:
return self.app.client
@ -377,9 +390,9 @@ class View(AbstractView):
self.actor = await self.client.get(self.signature.keyid, sign_headers = True)
if self.actor is None:
## ld signatures aren't handled atm, so just ignore it
# ld signatures aren't handled atm, so just ignore it
if self.message.type == 'Delete':
logging.verbose(f'Instance sent a delete which cannot be handled')
logging.verbose('Instance sent a delete which cannot be handled')
return Response.new(status=202)
logging.verbose(f'Failed to fetch actor: {self.signature.keyid}')
@ -429,5 +442,5 @@ class View(AbstractView):
headers["(expires)"] = self.signature.expires
# pylint: disable=protected-access
if not self.actor.signer._validate_signature(headers, self.signature):
if not self.signer._validate_signature(headers, self.signature):
raise SignatureFailureError("Signature does not match")

View file

@ -1,10 +1,8 @@
from __future__ import annotations
import asyncio
import typing
from cachetools import LRUCache
from uuid import uuid4
from . import logger as logging
from .misc import Message
@ -16,8 +14,8 @@ if typing.TYPE_CHECKING:
cache = LRUCache(1024)
def person_check(actor, software):
## pleroma and akkoma may use Person for the actor type for some reason
def person_check(actor: str, software: str) -> bool:
# pleroma and akkoma may use Person for the actor type for some reason
if software in {'akkoma', 'pleroma'} and actor.id == f'https://{actor.domain}/relay':
return False
@ -25,21 +23,19 @@ def person_check(actor, software):
if actor.type != 'Application':
return True
return False
async def handle_relay(view: View) -> None:
if view.message.objectid in cache:
logging.verbose('already relayed %s', view.message.objectid)
return
message = Message.new_announce(
host = view.config.host,
object = view.message.objectid
)
message = Message.new_announce(view.config.host, view.message.objectid)
cache[view.message.objectid] = message.id
logging.debug('>> relay: %s', message)
inboxes = view.database.distill_inboxes(message)
inboxes = view.database.distill_inboxes(view.message)
for inbox in inboxes:
view.app.push_message(inbox, message)
@ -50,15 +46,11 @@ async def handle_forward(view: View) -> None:
logging.verbose('already forwarded %s', view.message.id)
return
message = Message.new_announce(
host = view.config.host,
object = view.message
)
message = Message.new_announce(view.config.host, view.message)
cache[view.message.id] = message.id
logging.debug('>> forward: %s', message)
inboxes = view.database.distill_inboxes(message.message)
inboxes = view.database.distill_inboxes(view.message)
for inbox in inboxes:
view.app.push_message(inbox, message)
@ -162,7 +154,7 @@ processors = {
}
async def run_processor(view: View):
async def run_processor(view: View) -> None:
if view.message.type not in processors:
logging.verbose(
'Message type "%s" from actor cannot be handled: %s',
@ -180,4 +172,4 @@ async def run_processor(view: View):
view.database.save()
logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id)
return await processors[view.message.type](view)
await processors[view.message.type](view)

View file

@ -1,15 +1,13 @@
from __future__ import annotations
import aputils
import asyncio
import subprocess
import traceback
import typing
from aputils.objects import Nodeinfo, Webfinger, WellKnownNodeinfo
from pathlib import Path
from . import __version__, misc
from . import __version__
from . import logger as logging
from .misc import Message, Response, View
from .processors import run_processor
@ -35,8 +33,16 @@ HOME_TEMPLATE = """
<body>
<p>This is an Activity Relay for fediverse instances.</p>
<p>{note}</p>
<p>You may subscribe to this relay with the address: <a href="https://{host}/actor">https://{host}/actor</a></p>
<p>To host your own relay, you may download the code at this address: <a href="https://git.pleroma.social/pleroma/relay">https://git.pleroma.social/pleroma/relay</a></p>
<p>
You may subscribe to this relay with the address:
<a href="https://{host}/actor">https://{host}/actor</a>
</p>
<p>
To host your own relay, you may download the code at this address:
<a href="https://git.pleroma.social/pleroma/relay">
https://git.pleroma.social/pleroma/relay
</a>
</p>
<br><p>List of {count} registered instances:<br>{targets}</p>
</body></html>
"""
@ -60,6 +66,8 @@ def register_route(*paths: str) -> Callable:
return wrapper
# pylint: disable=unused-argument
@register_route('/')
class HomeView(View):
async def get(self, request: Request) -> Response:
@ -140,14 +148,14 @@ class WebfingerView(View):
@register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}')
class NodeinfoView(View):
async def get(self, request: Request, niversion: str) -> Response:
data = dict(
name = 'activityrelay',
version = VERSION,
protocols = ['activitypub'],
open_regs = not self.config.whitelist_enabled,
users = 1,
metadata = {'peers': self.database.hostnames}
)
data = {
'name': 'activityrelay',
'version': VERSION,
'protocols': ['activitypub'],
'open_regs': not self.config.whitelist_enabled,
'users': 1,
'metadata': {'peers': self.database.hostnames}
}
if niversion == '2.1':
data['repo'] = 'https://git.pleroma.social/pleroma/relay'