use postgresql/sqlite for database backend

This commit is contained in:
Izalia Mae 2024-01-22 05:32:16 -05:00
parent 9808674b98
commit 5f6aef1871
16 changed files with 1574 additions and 813 deletions

10
.gitignore vendored
View file

@ -94,9 +94,7 @@ ENV/
# Rope project settings # Rope project settings
.ropeproject .ropeproject
viera.yaml # config and database
viera.jsonld *.yaml
*.jsonld
# config file *.sqlite3
relay.yaml
relay.jsonld

View file

@ -1,43 +1,35 @@
# this is the path that the object graph will get dumped to (in JSON-LD format), # [string] Domain the relay will be hosted on
# you probably shouldn't change it, but you can if you want. domain: relay.example.com
db: relay.jsonld
# Listener # [string] Address the relay will listen on
listen: 0.0.0.0 listen: 0.0.0.0
# [integer] Port the relay will listen on
port: 8080 port: 8080
# Note # [integer] Number of push workers to start (will get removed in a future update)
note: "Make a note about your instance here." workers: 8
# Number of worker threads to start. If 0, use asyncio futures instead of threads. # [string] Database backend to use. Valid values: sqlite, postgres
workers: 0 database_type: sqlite
# Maximum number of inbox posts to do at once # [string] Path to the sqlite database file if the sqlite backend is in use
# If workers is set to 1 or above, this is the max for each worker sqlite_path: relay.sqlite3
push_limit: 512
# The amount of json objects to cache from GET requests # settings for the postgresql backend
json_cache: 1024 postgres:
ap: # [string] hostname or unix socket to connect to
# This is used for generating activitypub messages, as well as instructions for host: /var/run/postgresql
# linking AP identities. It should be an SSL-enabled domain reachable by https.
host: 'relay.example.com'
blocked_instances: # [integer] port of the server
- 'bad-instance.example.com' port: 5432
- 'another-bad-instance.example.com'
whitelist_enabled: false # [string] username to use when logging into the server (default is the current system username)
user: null
whitelist: # [string] password of the user
- 'good-instance.example.com' pass: null
- 'another.good-instance.example.com'
# uncomment the lines below to prevent certain activitypub software from posting # [string] name of the database to use
# to the relay (all known relays by default). this uses the software name in nodeinfo name: activityrelay
#blocked_software:
#- 'activityrelay'
#- 'aoderelay'
#- 'social.seattle.wa.us-relay'
#- 'unciarelay'

View file

@ -8,52 +8,41 @@ import traceback
import typing import typing
from aiohttp import web from aiohttp import web
from aputils.signer import Signer
from datetime import datetime, timedelta from datetime import datetime, timedelta
from . import logger as logging from . import logger as logging
from .config import RelayConfig from .config import Config
from .database import RelayDatabase from .database import get_database
from .http_client import HttpClient from .http_client import HttpClient
from .misc import check_open_port from .misc import check_open_port
from .views import VIEWS from .views import VIEWS
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from tinysql import Database
from typing import Any from typing import Any
from .misc import Message from .misc import Message
# pylint: disable=unsubscriptable-object # pylint: disable=unsubscriptable-object
class Application(web.Application): class Application(web.Application):
DEFAULT: Application = None
def __init__(self, cfgpath: str): def __init__(self, cfgpath: str):
web.Application.__init__(self) web.Application.__init__(self)
Application.DEFAULT = self
self['signer'] = None
self['config'] = Config(cfgpath, load = True)
self['database'] = get_database(self.config)
self['client'] = HttpClient()
self['workers'] = [] self['workers'] = []
self['last_worker'] = 0 self['last_worker'] = 0
self['start_time'] = None self['start_time'] = None
self['running'] = False self['running'] = False
self['config'] = RelayConfig(cfgpath)
if not self.config.load():
self.config.save()
if self.config.is_docker:
self.config.update({
'db': '/data/relay.jsonld',
'listen': '0.0.0.0',
'port': 8080
})
self['database'] = RelayDatabase(self.config)
self.database.load()
self['client'] = HttpClient(
database = self.database,
limit = self.config.push_limit,
timeout = self.config.timeout,
cache_size = self.config.json_cache
)
for path, view in VIEWS: for path, view in VIEWS:
self.router.add_view(path, view) self.router.add_view(path, view)
@ -65,15 +54,29 @@ class Application(web.Application):
@property @property
def config(self) -> RelayConfig: def config(self) -> Config:
return self['config'] return self['config']
@property @property
def database(self) -> RelayDatabase: def database(self) -> Database:
return self['database'] return self['database']
@property
def signer(self) -> Signer:
return self['signer']
@signer.setter
def signer(self, value: Signer | str) -> None:
if isinstance(value, Signer):
self['signer'] = value
return
self['signer'] = Signer(value, self.config.keyid)
@property @property
def uptime(self) -> timedelta: def uptime(self) -> timedelta:
if not self['start_time']: if not self['start_time']:
@ -118,7 +121,7 @@ class Application(web.Application):
logging.info( logging.info(
'Starting webserver at %s (%s:%i)', 'Starting webserver at %s (%s:%i)',
self.config.host, self.config.domain,
self.config.listen, self.config.listen,
self.config.port self.config.port
) )
@ -179,12 +182,7 @@ class PushWorker(threading.Thread):
async def handle_queue(self) -> None: async def handle_queue(self) -> None:
self.client = HttpClient( self.client = HttpClient()
database = self.app.database,
limit = self.app.config.push_limit,
timeout = self.app.config.timeout,
cache_size = self.app.config.json_cache
)
while self.app['running']: while self.app['running']:
try: try:

View file

@ -1,17 +1,128 @@
from __future__ import annotations from __future__ import annotations
import json import json
import os
import typing import typing
import yaml
from aputils.signer import Signer from functools import cached_property
from pathlib import Path
from urllib.parse import urlparse from urllib.parse import urlparse
from . import logger as logging from . import logger as logging
from .misc import Message, boolean
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from typing import Iterator, Optional from typing import Any, Iterator, Optional
from .config import RelayConfig
from .misc import Message
# pylint: disable=duplicate-code
class RelayConfig(dict):
def __init__(self, path: str):
dict.__init__(self, {})
if self.is_docker:
path = '/data/config.yaml'
self._path = Path(path).expanduser().resolve()
self.reset()
def __setitem__(self, key: str, value: Any) -> None:
if key in ['blocked_instances', 'blocked_software', 'whitelist']:
assert isinstance(value, (list, set, tuple))
elif key in ['port', 'workers', 'json_cache', 'timeout']:
if not isinstance(value, int):
value = int(value)
elif key == 'whitelist_enabled':
if not isinstance(value, bool):
value = boolean(value)
super().__setitem__(key, value)
@property
def db(self) -> RelayDatabase:
return Path(self['db']).expanduser().resolve()
@property
def actor(self) -> str:
return f'https://{self["host"]}/actor'
@property
def inbox(self) -> str:
return f'https://{self["host"]}/inbox'
@property
def keyid(self) -> str:
return f'{self.actor}#main-key'
@cached_property
def is_docker(self) -> bool:
return bool(os.environ.get('DOCKER_RUNNING'))
def reset(self) -> None:
self.clear()
self.update({
'db': str(self._path.parent.joinpath(f'{self._path.stem}.jsonld')),
'listen': '0.0.0.0',
'port': 8080,
'note': 'Make a note about your instance here.',
'push_limit': 512,
'json_cache': 1024,
'timeout': 10,
'workers': 0,
'host': 'relay.example.com',
'whitelist_enabled': False,
'blocked_software': [],
'blocked_instances': [],
'whitelist': []
})
def load(self) -> None:
self.reset()
options = {}
try:
options['Loader'] = yaml.FullLoader
except AttributeError:
pass
try:
with self._path.open('r', encoding = 'UTF-8') as fd:
config = yaml.load(fd, **options)
except FileNotFoundError:
return
if not config:
return
for key, value in config.items():
if key in ['ap']:
for k, v in value.items():
if k not in self:
continue
self[k] = v
continue
if key not in self:
continue
self[key] = value
class RelayDatabase(dict): class RelayDatabase(dict):
@ -37,9 +148,7 @@ class RelayDatabase(dict):
return tuple(data['inbox'] for data in self['relay-list'].values()) return tuple(data['inbox'] for data in self['relay-list'].values())
def load(self) -> bool: def load(self) -> None:
new_db = True
try: try:
with self.config.db.open() as fd: with self.config.db.open() as fd:
data = json.load(fd) data = json.load(fd)
@ -65,17 +174,9 @@ class RelayDatabase(dict):
self['relay-list'] = data.get('relay-list', {}) self['relay-list'] = data.get('relay-list', {})
for domain, instance in self['relay-list'].items(): for domain, instance in self['relay-list'].items():
if self.config.is_banned(domain) or \
(self.config.whitelist_enabled and not self.config.is_whitelisted(domain)):
self.del_inbox(domain)
continue
if not instance.get('domain'): if not instance.get('domain'):
instance['domain'] = domain instance['domain'] = domain
new_db = False
except FileNotFoundError: except FileNotFoundError:
pass pass
@ -83,17 +184,6 @@ class RelayDatabase(dict):
if self.config.db.stat().st_size > 0: if self.config.db.stat().st_size > 0:
raise e from None raise e from None
if not self['private-key']:
logging.info('No actor keys present, generating 4096-bit RSA keypair.')
self.signer = Signer.new(self.config.keyid, size=4096)
self['private-key'] = self.signer.export()
else:
self.signer = Signer(self['private-key'], self.config.keyid)
self.save()
return not new_db
def save(self) -> None: def save(self) -> None:
with self.config.db.open('w', encoding = 'UTF-8') as fd: with self.config.db.open('w', encoding = 'UTF-8') as fd:

View file

@ -1,76 +1,73 @@
from __future__ import annotations from __future__ import annotations
import getpass
import os import os
import typing import typing
import yaml import yaml
from functools import cached_property
from pathlib import Path from pathlib import Path
from urllib.parse import urlparse
from .misc import DotDict, boolean from .misc import IS_DOCKER
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from typing import Any from typing import Any, Optional
from .database import RelayDatabase
RELAY_SOFTWARE = [ DEFAULTS: dict[str, Any] = {
'activityrelay', # https://git.pleroma.social/pleroma/relay 'listen': '0.0.0.0',
'aoderelay', # https://git.asonix.dog/asonix/relay 'port': 8080,
'feditools-relay' # https://git.ptzo.gdn/feditools/relay 'domain': 'relay.example.com',
] 'workers': len(os.sched_getaffinity(0)),
'db_type': 'sqlite',
'sq_path': 'relay.sqlite3',
'pg_host': '/var/run/postgresql',
'pg_port': 5432,
'pg_user': getpass.getuser(),
'pg_pass': None,
'pg_name': 'activityrelay'
}
APKEYS = [ if IS_DOCKER:
'host', DEFAULTS['sq_path'] = '/data/relay.jsonld'
'whitelist_enabled',
'blocked_software',
'blocked_instances',
'whitelist'
]
class RelayConfig(DotDict): class Config:
__slots__ = ('path', ) def __init__(self, path: str, load: Optional[bool] = False):
self.path = Path(path).expanduser().resolve()
def __init__(self, path: str | Path): self.listen = None
DotDict.__init__(self, {}) self.port = None
self.domain = None
self.workers = None
self.db_type = None
self.sq_path = None
self.pg_host = None
self.pg_port = None
self.pg_user = None
self.pg_pass = None
self.pg_name = None
if self.is_docker: if load:
path = '/data/config.yaml' try:
self.load()
self._path = Path(path).expanduser().resolve() except FileNotFoundError:
self.reset() self.save()
def __setitem__(self, key: str, value: Any) -> None:
if key in ['blocked_instances', 'blocked_software', 'whitelist']:
assert isinstance(value, (list, set, tuple))
elif key in ['port', 'workers', 'json_cache', 'timeout']:
if not isinstance(value, int):
value = int(value)
elif key == 'whitelist_enabled':
if not isinstance(value, bool):
value = boolean(value)
super().__setitem__(key, value)
@property @property
def db(self) -> RelayDatabase: def sqlite_path(self) -> Path:
return Path(self['db']).expanduser().resolve() return Path(self.sq_path).expanduser().resolve()
@property @property
def actor(self) -> str: def actor(self) -> str:
return f'https://{self.host}/actor' return f'https://{self.domain}/actor'
@property @property
def inbox(self) -> str: def inbox(self) -> str:
return f'https://{self.host}/inbox' return f'https://{self.domain}/inbox'
@property @property
@ -78,115 +75,7 @@ class RelayConfig(DotDict):
return f'{self.actor}#main-key' return f'{self.actor}#main-key'
@cached_property def load(self) -> None:
def is_docker(self) -> bool:
return bool(os.environ.get('DOCKER_RUNNING'))
def reset(self) -> None:
self.clear()
self.update({
'db': str(self._path.parent.joinpath(f'{self._path.stem}.jsonld')),
'listen': '0.0.0.0',
'port': 8080,
'note': 'Make a note about your instance here.',
'push_limit': 512,
'json_cache': 1024,
'timeout': 10,
'workers': 0,
'host': 'relay.example.com',
'whitelist_enabled': False,
'blocked_software': [],
'blocked_instances': [],
'whitelist': []
})
def ban_instance(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
if self.is_banned(instance):
return False
self.blocked_instances.append(instance)
return True
def unban_instance(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
try:
self.blocked_instances.remove(instance)
return True
except ValueError:
return False
def ban_software(self, software: str) -> bool:
if self.is_banned_software(software):
return False
self.blocked_software.append(software)
return True
def unban_software(self, software: str) -> bool:
try:
self.blocked_software.remove(software)
return True
except ValueError:
return False
def add_whitelist(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
if self.is_whitelisted(instance):
return False
self.whitelist.append(instance)
return True
def del_whitelist(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
try:
self.whitelist.remove(instance)
return True
except ValueError:
return False
def is_banned(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
return instance in self.blocked_instances
def is_banned_software(self, software: str) -> bool:
if not software:
return False
return software.lower() in self.blocked_software
def is_whitelisted(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
return instance in self.whitelist
def load(self) -> bool:
self.reset() self.reset()
options = {} options = {}
@ -197,50 +86,69 @@ class RelayConfig(DotDict):
except AttributeError: except AttributeError:
pass pass
try: with self.path.open('r', encoding = 'UTF-8') as fd:
with self._path.open('r', encoding = 'UTF-8') as fd:
config = yaml.load(fd, **options) config = yaml.load(fd, **options)
pgcfg = config.get('postgresql', {})
except FileNotFoundError:
return False
if not config: if not config:
return False raise ValueError('Config is empty')
for key, value in config.items(): if IS_DOCKER:
if key in ['ap']: self.listen = '0.0.0.0'
for k, v in value.items(): self.port = 8080
if k not in self: self.sq_path = '/data/relay.jsonld'
else:
self.set('listen', config.get('listen', DEFAULTS['listen']))
self.set('port', config.get('port', DEFAULTS['port']))
self.set('sq_path', config.get('sqlite_path', DEFAULTS['sq_path']))
self.set('domain', config.get('domain', DEFAULTS['domain']))
self.set('db_type', config.get('database_type', DEFAULTS['db_type']))
for key in DEFAULTS:
if not key.startswith('pg'):
continue continue
self[k] = v try:
self.set(key, pgcfg[key[3:]])
except KeyError:
continue continue
if key not in self:
continue
self[key] = value def reset(self) -> None:
for key, value in DEFAULTS.items():
if self.host.endswith('example.com'): setattr(self, key, value)
return False
return True
def save(self) -> None: def save(self) -> None:
self.path.parent.mkdir(exist_ok = True, parents = True)
config = { config = {
# just turning config.db into a string is good enough for now
'db': str(self.db),
'listen': self.listen, 'listen': self.listen,
'port': self.port, 'port': self.port,
'note': self.note, 'domain': self.domain,
'push_limit': self.push_limit, 'database_type': self.db_type,
'workers': self.workers, 'sqlite_path': self.sq_path,
'json_cache': self.json_cache, 'postgres': {
'timeout': self.timeout, 'host': self.pg_host,
'ap': {key: self[key] for key in APKEYS} 'port': self.pg_port,
'user': self.pg_user,
'pass': self.pg_pass,
'name': self.pg_name
}
} }
with self._path.open('w', encoding = 'utf-8') as fd: with self.path.open('w', encoding = 'utf-8') as fd:
yaml.dump(config, fd, sort_keys = False) yaml.dump(config, fd, sort_keys = False)
def set(self, key: str, value: Any) -> None:
if key not in DEFAULTS:
raise KeyError(key)
if key in ('port', 'pg_port', 'workers') and not isinstance(value, int):
value = int(value)
setattr(self, key, value)

View file

@ -0,0 +1,63 @@
from __future__ import annotations
import tinysql
import typing
from importlib.resources import files as pkgfiles
from .config import get_default_value
from .connection import Connection
from .schema import VERSIONS, migrate_0
from .. import logger as logging
if typing.TYPE_CHECKING:
from typing import Optional
from .config import Config
def get_database(config: Config, migrate: Optional[bool] = True) -> tinysql.Database:
if config.db_type == "sqlite":
db = tinysql.Database.sqlite(config.sq_path, connection_class = Connection)
elif config.db_type == "postgres":
db = tinysql.Database.postgres(
config.pg_name,
config.pg_host,
config.pg_port,
config.pg_user,
config.pg_pass,
connection_class = Connection
)
db.load_prepared_statements(pkgfiles("relay").joinpath("database", "statements.sql"))
if not migrate:
return db
with db.connection() as conn:
if 'config' not in conn.get_tables():
logging.info("Creating database tables")
migrate_0(conn)
return db
schema_ver = conn.get_config('schema-version')
if schema_ver < get_default_value('schema-version'):
logging.info("Migrating database from version '%i'", schema_ver)
for ver, func in VERSIONS:
if schema_ver < ver:
conn.begin()
func(conn)
conn.put_config('schema-version', ver)
conn.commit()
if (privkey := conn.get_config('private-key')):
conn.app.signer = privkey
logging.set_level(conn.get_config('log-level'))
return db

44
relay/database/config.py Normal file
View file

@ -0,0 +1,44 @@
from __future__ import annotations
import typing
from .. import logger as logging
from ..misc import boolean
if typing.TYPE_CHECKING:
from typing import Any, Callable
CONFIG_DEFAULTS: dict[str, tuple[str, Any]] = {
'schema-version': ('int', 20240119),
'log-level': ('loglevel', logging.LogLevel.INFO),
'note': ('str', 'Make a note about your instance here.'),
'private-key': ('str', None),
'whitelist-enabled': ('bool', False)
}
# serializer | deserializer
CONFIG_CONVERT: dict[str, tuple[Callable, Callable]] = {
'str': (str, str),
'int': (str, int),
'bool': (str, boolean),
'loglevel': (lambda x: x.name, logging.LogLevel.parse)
}
def get_default_value(key: str) -> Any:
return CONFIG_DEFAULTS[key][1]
def get_default_type(key: str) -> str:
return CONFIG_DEFAULTS[key][0]
def serialize(key: str, value: Any) -> str:
type_name = get_default_type(key)
return CONFIG_CONVERT[type_name][0](value)
def deserialize(key: str, value: str) -> Any:
type_name = get_default_type(key)
return CONFIG_CONVERT[type_name][1](value)

View file

@ -0,0 +1,295 @@
from __future__ import annotations
import tinysql
import typing
from datetime import datetime, timezone
from urllib.parse import urlparse
from .config import CONFIG_DEFAULTS, get_default_type, get_default_value, serialize, deserialize
from .. import logger as logging
from ..misc import get_app
if typing.TYPE_CHECKING:
from tinysql import Cursor, Row
from typing import Any, Iterator, Optional
from .application import Application
from ..misc import Message
RELAY_SOFTWARE = [
'activityrelay', # https://git.pleroma.social/pleroma/relay
'activity-relay', # https://github.com/yukimochi/Activity-Relay
'aoderelay', # https://git.asonix.dog/asonix/relay
'feditools-relay' # https://git.ptzo.gdn/feditools/relay
]
class Connection(tinysql.Connection):
@property
def app(self) -> Application:
return get_app()
def distill_inboxes(self, message: Message) -> Iterator[str]:
src_domains = {
message.domain,
urlparse(message.object_id).netloc
}
for inbox in self.execute('SELECT * FROM inboxes'):
if inbox['domain'] not in src_domains:
yield inbox['inbox']
def exec_statement(self, name: str, params: Optional[dict[str, Any]] = None) -> Cursor:
return self.execute(self.database.prepared_statements[name], params)
def get_config(self, key: str) -> Any:
if key not in CONFIG_DEFAULTS:
raise KeyError(key)
with self.exec_statement('get-config', {'key': key}) as cur:
if not (row := cur.one()):
return get_default_value(key)
if row['value']:
return deserialize(row['key'], row['value'])
return None
def get_config_all(self) -> dict[str, Any]:
with self.exec_statement('get-config-all') as cur:
db_config = {row['key']: row['value'] for row in cur}
config = {}
for key, data in CONFIG_DEFAULTS.items():
try:
config[key] = deserialize(key, db_config[key])
except KeyError:
if key == 'schema-version':
config[key] = 0
else:
config[key] = data[1]
return config
def put_config(self, key: str, value: Any) -> Any:
if key not in CONFIG_DEFAULTS:
raise KeyError(key)
if key == 'private-key':
self.app.signer = value
elif key == 'log-level':
value = logging.LogLevel.parse(value)
logging.set_level(value)
params = {
'key': key,
'value': serialize(key, value) if value is not None else None,
'type': get_default_type(key)
}
with self.exec_statement('put-config', params):
return value
def get_inbox(self, value: str) -> Row:
with self.exec_statement('get-inbox', {'value': value}) as cur:
return cur.one()
def put_inbox(self,
domain: str,
inbox: str,
actor: Optional[str] = None,
followid: Optional[str] = None,
software: Optional[str] = None) -> Row:
params = {
'domain': domain,
'inbox': inbox,
'actor': actor,
'followid': followid,
'software': software,
'created': datetime.now(tz = timezone.utc)
}
with self.exec_statement('put-inbox', params) as cur:
return cur.one()
def update_inbox(self,
inbox: str,
actor: Optional[str] = None,
followid: Optional[str] = None,
software: Optional[str] = None) -> Row:
if not (actor or followid or software):
raise ValueError('Missing "actor", "followid", and/or "software"')
data = {}
if actor:
data['actor'] = actor
if followid:
data['followid'] = followid
if software:
data['software'] = software
statement = tinysql.Update('inboxes', data, inbox = inbox)
with self.query(statement):
return self.get_inbox(inbox)
def del_inbox(self, value: str) -> bool:
with self.exec_statement('del-inbox', {'value': value}) as cur:
if cur.modified_row_count > 1:
raise ValueError('More than one row was modified')
return cur.modified_row_count == 1
def get_domain_ban(self, domain: str) -> Row:
if domain.startswith('http'):
domain = urlparse(domain).netloc
with self.exec_statement('get-domain-ban', {'domain': domain}) as cur:
return cur.one()
def put_domain_ban(self,
domain: str,
reason: Optional[str] = None,
note: Optional[str] = None) -> Row:
params = {
'domain': domain,
'reason': reason,
'note': note,
'created': datetime.now(tz = timezone.utc)
}
with self.exec_statement('put-domain-ban', params) as cur:
return cur.one()
def update_domain_ban(self,
domain: str,
reason: Optional[str] = None,
note: Optional[str] = None) -> tinysql.Row:
if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified')
params = {}
if reason:
params['reason'] = reason
if note:
params['note'] = note
statement = tinysql.Update('domain_bans', params, domain = domain)
with self.query(statement) as cur:
if cur.modified_row_count > 1:
raise ValueError('More than one row was modified')
return self.get_domain_ban(domain)
def del_domain_ban(self, domain: str) -> bool:
with self.exec_statement('del-domain-ban', {'domain': domain}) as cur:
if cur.modified_row_count > 1:
raise ValueError('More than one row was modified')
return cur.modified_row_count == 1
def get_software_ban(self, name: str) -> Row:
with self.exec_statement('get-software-ban', {'name': name}) as cur:
return cur.one()
def put_software_ban(self,
name: str,
reason: Optional[str] = None,
note: Optional[str] = None) -> Row:
params = {
'name': name,
'reason': reason,
'note': note,
'created': datetime.now(tz = timezone.utc)
}
with self.exec_statement('put-software-ban', params) as cur:
return cur.one()
def update_software_ban(self,
name: str,
reason: Optional[str] = None,
note: Optional[str] = None) -> tinysql.Row:
if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified')
params = {}
if reason:
params['reason'] = reason
if note:
params['note'] = note
statement = tinysql.Update('software_bans', params, name = name)
with self.query(statement) as cur:
if cur.modified_row_count > 1:
raise ValueError('More than one row was modified')
return self.get_software_ban(name)
def del_software_ban(self, name: str) -> bool:
with self.exec_statement('del-software-ban', {'name': name}) as cur:
if cur.modified_row_count > 1:
raise ValueError('More than one row was modified')
return cur.modified_row_count == 1
def get_domain_whitelist(self, domain: str) -> Row:
with self.exec_statement('get-domain-whitelist', {'domain': domain}) as cur:
return cur.one()
def put_domain_whitelist(self, domain: str) -> Row:
params = {
'domain': domain,
'created': datetime.now(tz = timezone.utc)
}
with self.exec_statement('put-domain-whitelist', params) as cur:
return cur.one()
def del_domain_whitelist(self, domain: str) -> bool:
with self.exec_statement('del-domain-whitelist', {'domain': domain}) as cur:
if cur.modified_row_count > 1:
raise ValueError('More than one row was modified')
return cur.modified_row_count == 1

60
relay/database/schema.py Normal file
View file

@ -0,0 +1,60 @@
from __future__ import annotations
import typing
from tinysql import Column, Connection, Table
from .config import get_default_value
if typing.TYPE_CHECKING:
from typing import Callable
VERSIONS: list[Callable] = []
TABLES: list[Table] = [
Table(
'config',
Column('key', 'text', primary_key = True, unique = True, nullable = False),
Column('value', 'text'),
Column('type', 'text', default = 'str')
),
Table(
'inboxes',
Column('domain', 'text', primary_key = True, unique = True, nullable = False),
Column('actor', 'text', unique = True),
Column('inbox', 'text', unique = True, nullable = False),
Column('followid', 'text'),
Column('software', 'text'),
Column('created', 'timestamp', nullable = False)
),
Table(
'whitelist',
Column('domain', 'text', primary_key = True, unique = True, nullable = True),
Column('created', 'timestamp')
),
Table(
'instance_bans',
Column('domain', 'text', primary_key = True, unique = True, nullable = True),
Column('reason', 'text'),
Column('note', 'text'),
Column('created', 'timestamp', nullable = False)
),
Table(
'software_bans',
Column('name', 'text', primary_key = True, unique = True, nullable = True),
Column('reason', 'text'),
Column('note', 'text'),
Column('created', 'timestamp', nullable = False)
)
]
def version(func: Callable) -> Callable:
ver = int(func.replace('migrate_', ''))
VERSIONS[ver] = func
return func
def migrate_0(conn: Connection) -> None:
conn.create_tables(TABLES)
conn.put_config('schema-version', get_default_value('schema-version'))

View file

@ -0,0 +1,79 @@
-- name: get-config
SELECT * FROM config WHERE key = :key
-- name: get-config-all
SELECT * FROM config
-- name: put-config
INSERT INTO config (key, value, type)
VALUES (:key, :value, :type)
ON CONFLICT (key) DO UPDATE SET value = :value
RETURNING *
-- name: del-config
DELETE FROM config
WHERE key = :key
-- name: get-inbox
SELECT * FROM inboxes WHERE domain = :value or inbox = :value or actor = :value
-- name: put-inbox
INSERT INTO inboxes (domain, actor, inbox, followid, software, created)
VALUES (:domain, :actor, :inbox, :followid, :software, :created)
ON CONFLICT (domain) DO UPDATE SET followid = :followid
RETURNING *
-- name: del-inbox
DELETE FROM inboxes
WHERE domain = :value or inbox = :value or actor = :value
-- name: get-software-ban
SELECT * FROM software_bans WHERE name = :name
-- name: put-software-ban
INSERT INTO software_bans (name, reason, note, created)
VALUES (:name, :reason, :note, :created)
RETURNING *
-- name: del-software-ban
DELETE FROM software_bans
WHERE name = :name
-- name: get-domain-ban
SELECT * FROM domain_bans WHERE domain = :domain
-- name: put-domain-ban
INSERT INTO domain_bans (domain, reason, note, created)
VALUES (:domain, :reason, :note, :created)
RETURNING *
-- name: del-domain-ban
DELETE FROM domain_bans
WHERE domain = :domain
-- name: get-domain-whitelist
SELECT * FROM whitelist WHERE domain = :domain
-- name: put-domain-whitelist
INSERT INTO whitelist (domain, created)
VALUES (:domain, :created)
RETURNING *
-- name: del-domain-whitelist
DELETE FROM whitelist
WHERE domain = :domain

View file

@ -13,11 +13,10 @@ from urllib.parse import urlparse
from . import __version__ from . import __version__
from . import logger as logging from . import logger as logging
from .misc import MIMETYPES, Message from .misc import MIMETYPES, Message, get_app
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
from .database import RelayDatabase
HEADERS = { HEADERS = {
@ -28,12 +27,10 @@ HEADERS = {
class HttpClient: class HttpClient:
def __init__(self, def __init__(self,
database: RelayDatabase,
limit: Optional[int] = 100, limit: Optional[int] = 100,
timeout: Optional[int] = 10, timeout: Optional[int] = 10,
cache_size: Optional[int] = 1024): cache_size: Optional[int] = 1024):
self.database = database
self.cache = LRUCache(cache_size) self.cache = LRUCache(cache_size)
self.limit = limit self.limit = limit
self.timeout = timeout self.timeout = timeout
@ -98,7 +95,7 @@ class HttpClient:
headers = {} headers = {}
if sign_headers: if sign_headers:
headers.update(self.database.signer.sign_headers('GET', url, algorithm='original')) get_app().signer.sign_headers('GET', url, algorithm = 'original')
try: try:
logging.debug('Fetching resource: %s', url) logging.debug('Fetching resource: %s', url)
@ -150,23 +147,24 @@ class HttpClient:
async def post(self, url: str, message: Message) -> None: async def post(self, url: str, message: Message) -> None:
await self.open() await self.open()
instance = self.database.get_inbox(url) with get_app().database.connection() as conn:
instance = conn.get_inbox(url)
## Using the old algo by default is probably a better idea right now ## Using the old algo by default is probably a better idea right now
if instance and instance.get('software') in {'mastodon'}: if instance and instance['software'] in {'mastodon'}:
algorithm = 'hs2019' algorithm = 'hs2019'
else: else:
algorithm = 'original' algorithm = 'original'
headers = {'Content-Type': 'application/activity+json'} headers = {'Content-Type': 'application/activity+json'}
headers.update(self.database.signer.sign_headers('POST', url, message, algorithm=algorithm)) headers.update(get_app().signer.sign_headers('POST', url, message, algorithm=algorithm))
try: try:
logging.verbose('Sending "%s" to %s', message.type, url) logging.verbose('Sending "%s" to %s', message.type, url)
async with self._session.post(url, headers=headers, data=message.to_json()) as resp: async with self._session.post(url, headers=headers, data=message.to_json()) as resp:
## Not expecting a response, so just return # Not expecting a response, so just return
if resp.status in {200, 202}: if resp.status in {200, 202}:
logging.verbose('Successfully sent "%s" to %s', message.type, url) logging.verbose('Successfully sent "%s" to %s', message.type, url)
return return
@ -181,7 +179,7 @@ class HttpClient:
except (AsyncTimeoutError, ClientConnectionError): except (AsyncTimeoutError, ClientConnectionError):
logging.warning('Failed to connect to %s for message push', urlparse(url).netloc) logging.warning('Failed to connect to %s for message push', urlparse(url).netloc)
## prevent workers from being brought down # prevent workers from being brought down
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
@ -211,16 +209,16 @@ class HttpClient:
return await self.get(nodeinfo_url, loads = Nodeinfo.parse) or None return await self.get(nodeinfo_url, loads = Nodeinfo.parse) or None
async def get(database: RelayDatabase, *args: Any, **kwargs: Any) -> Message | dict | None: async def get(*args: Any, **kwargs: Any) -> Message | dict | None:
async with HttpClient(database) as client: async with HttpClient() as client:
return await client.get(*args, **kwargs) return await client.get(*args, **kwargs)
async def post(database: RelayDatabase, *args: Any, **kwargs: Any) -> None: async def post(*args: Any, **kwargs: Any) -> None:
async with HttpClient(database) as client: async with HttpClient() as client:
return await client.post(*args, **kwargs) return await client.post(*args, **kwargs)
async def fetch_nodeinfo(database: RelayDatabase, *args: Any, **kwargs: Any) -> Nodeinfo | None: async def fetch_nodeinfo(*args: Any, **kwargs: Any) -> Nodeinfo | None:
async with HttpClient(database) as client: async with HttpClient() as client:
return await client.fetch_nodeinfo(*args, **kwargs) return await client.fetch_nodeinfo(*args, **kwargs)

View file

@ -6,22 +6,49 @@ import click
import platform import platform
import typing import typing
from aputils.signer import Signer
from pathlib import Path
from shutil import copyfile
from urllib.parse import urlparse from urllib.parse import urlparse
from . import misc, __version__ from . import __version__
from . import http_client as http from . import http_client as http
from . import logger as logging
from .application import Application from .application import Application
from .config import RELAY_SOFTWARE from .compat import RelayConfig, RelayDatabase
from .database import get_database
from .database.connection import RELAY_SOFTWARE
from .misc import IS_DOCKER, Message, check_open_port
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from typing import Any from tinysql import Row
from typing import Any, Optional
# pylint: disable=unsubscriptable-object,unsupported-assignment-operation # pylint: disable=unsubscriptable-object,unsupported-assignment-operation
app = None CONFIG_IGNORE = (
CONFIG_IGNORE = {'blocked_software', 'blocked_instances', 'whitelist'} 'schema-version',
'private-key'
)
ACTOR_FORMATS = {
'mastodon': 'https://{domain}/actor',
'akkoma': 'https://{domain}/relay',
'pleroma': 'https://{domain}/relay'
}
SOFTWARE = (
'mastodon',
'akkoma',
'pleroma',
'misskey',
'friendica',
'hubzilla',
'firefish',
'gotosocial'
)
@click.group('cli', context_settings={'show_default': True}, invoke_without_command=True) @click.group('cli', context_settings={'show_default': True}, invoke_without_command=True)
@ -29,11 +56,10 @@ CONFIG_IGNORE = {'blocked_software', 'blocked_instances', 'whitelist'}
@click.version_option(version=__version__, prog_name='ActivityRelay') @click.version_option(version=__version__, prog_name='ActivityRelay')
@click.pass_context @click.pass_context
def cli(ctx: click.Context, config: str) -> None: def cli(ctx: click.Context, config: str) -> None:
global app ctx.obj = Application(config)
app = Application(config)
if not ctx.invoked_subcommand: if not ctx.invoked_subcommand:
if app.config.host.endswith('example.com'): if ctx.obj.config.domain.endswith('example.com'):
cli_setup.callback() cli_setup.callback()
else: else:
@ -41,46 +67,92 @@ def cli(ctx: click.Context, config: str) -> None:
@cli.command('setup') @cli.command('setup')
def cli_setup() -> None: @click.pass_context
def cli_setup(ctx: click.Context) -> None:
'Generate a new config' 'Generate a new config'
while True: while True:
app.config.host = click.prompt( ctx.obj.config.domain = click.prompt(
'What domain will the relay be hosted on?', 'What domain will the relay be hosted on?',
default = app.config.host default = ctx.obj.config.domain
) )
if not app.config.host.endswith('example.com'): if not ctx.obj.config.domain.endswith('example.com'):
break break
click.echo('The domain must not be example.com') click.echo('The domain must not end with "example.com"')
if not app.config.is_docker: if not IS_DOCKER:
app.config.listen = click.prompt( ctx.obj.config.listen = click.prompt(
'Which address should the relay listen on?', 'Which address should the relay listen on?',
default = app.config.listen default = ctx.obj.config.listen
) )
while True: ctx.obj.config.port = click.prompt(
app.config.port = click.prompt(
'What TCP port should the relay listen on?', 'What TCP port should the relay listen on?',
default = app.config.port, default = ctx.obj.config.port,
type = int type = int
) )
break ctx.obj.config.db_type = click.prompt(
'Which database backend will be used?',
default = ctx.obj.config.db_type,
type = click.Choice(['postgres', 'sqlite'], case_sensitive = False)
)
app.config.save() if ctx.obj.config.db_type == 'sqlite':
ctx.obj.config.sq_path = click.prompt(
'Where should the database be stored?',
default = ctx.obj.config.sq_path
)
if not app.config.is_docker and click.confirm('Relay all setup! Would you like to run it now?'): elif ctx.obj.config.db_type == 'postgres':
ctx.obj.config.pg_name = click.prompt(
'What is the name of the database?',
default = ctx.obj.config.pg_name
)
ctx.obj.config.pg_host = click.prompt(
'What IP address or hostname does the server listen on?',
default = ctx.obj.config.pg_host
)
ctx.obj.config.pg_port = click.prompt(
'What port does the server listen on?',
default = ctx.obj.config.pg_port,
type = int
)
ctx.obj.config.pg_user = click.prompt(
'Which user will authenticate with the server?',
default = ctx.obj.config.pg_user
)
ctx.obj.config.pg_pass = click.prompt(
'User password: ',
hide_input = True
) or None
ctx.obj.config.save()
config = {
'private-key': Signer.new('n/a').export()
}
with ctx.obj.database.connection() as conn:
for key, value in config.items():
conn.put_config(key, value)
if not IS_DOCKER and click.confirm('Relay all setup! Would you like to run it now?'):
cli_run.callback() cli_run.callback()
@cli.command('run') @cli.command('run')
def cli_run() -> None: @click.pass_context
def cli_run(ctx: click.Context) -> None:
'Run the relay' 'Run the relay'
if app.config.host.endswith('example.com'): if ctx.obj.config.domain.endswith('example.com') or not ctx.obj.signer:
click.echo( click.echo(
'Relay is not set up. Please edit your relay config or run "activityrelay setup".' 'Relay is not set up. Please edit your relay config or run "activityrelay setup".'
) )
@ -104,25 +176,126 @@ def cli_run() -> None:
click.echo(pip_command) click.echo(pip_command)
return return
if not misc.check_open_port(app.config.listen, app.config.port): if not check_open_port(ctx.obj.config.listen, ctx.obj.config.port):
click.echo(f'Error: A server is already running on port {app.config.port}') click.echo(f'Error: A server is already running on port {ctx.obj.config.port}')
return return
app.run() ctx.obj.run()
@cli.command('convert')
@click.option('--old-config', '-o', help = 'Path to the new config file')
@click.pass_context
def cli_convert(ctx: click.Context, old_config: str) -> None:
'Convert an old config and jsonld database to the new format.'
old_config = Path(old_config).expanduser().resolve()
backup = ctx.obj.config.path.parent.joinpath(f'{ctx.obj.config.path.stem}.backup.yaml')
if str(old_config) == str(ctx.obj.config.path) and not backup.exists():
logging.info('Created backup config @ %s', backup)
copyfile(ctx.obj.config.path, backup)
config = RelayConfig(old_config)
config.load()
database = RelayDatabase(config)
database.load()
ctx.obj.config.set('listen', config['listen'])
ctx.obj.config.set('port', config['port'])
ctx.obj.config.set('workers', config['workers'])
ctx.obj.config.set('sq_path', config['db'].replace('jsonld', 'sqlite3'))
with get_database(ctx.obj.config) as db:
with db.connection() as conn:
conn.put_config('private-key', database['private-key'])
conn.put_config('note', config['note'])
conn.put_config('whitelist-enabled', config['whitelist_enabled'])
with click.progressbar(
database['relay-list'].values(),
label = 'Inboxes'.ljust(15),
width = 0
) as inboxes:
for inbox in inboxes:
if inbox['software'] in ('akkoma', 'pleroma'):
actor = f'https://{inbox["domain"]}/relay'
elif inbox['software'] == 'mastodon':
actor = f'https://{inbox["domain"]}/actor'
else:
actor = None
conn.put_inbox(
inbox['domain'],
inbox['inbox'],
actor = actor,
followid = inbox['followid'],
software = inbox['software']
)
with click.progressbar(
config['blocked_software'],
label = 'Banned software'.ljust(15),
width = 0
) as banned_software:
for software in banned_software:
conn.put_software_ban(
software,
reason = 'relay' if software in RELAY_SOFTWARE else None
)
with click.progressbar(
config['blocked_instances'],
label = 'Banned domains'.ljust(15),
width = 0
) as banned_software:
for domain in banned_software:
conn.put_domain_ban(domain)
with click.progressbar(
config['whitelist'],
label = 'Whitelist'.ljust(15),
width = 0
) as whitelist:
for instance in whitelist:
conn.put_domain_whitelist(instance)
click.echo('Finished converting old config and database :3')
@cli.command('edit-config')
@click.option('--editor', '-e', help = 'Text editor to use')
@click.pass_context
def cli_editconfig(ctx: click.Context, editor: str) -> None:
'Edit the config file'
click.edit(
editor = editor,
filename = str(ctx.obj.config.path)
)
@cli.group('config') @cli.group('config')
def cli_config() -> None: def cli_config() -> None:
'Manage the relay config' 'Manage the relay settings stored in the database'
@cli_config.command('list') @cli_config.command('list')
def cli_config_list() -> None: @click.pass_context
def cli_config_list(ctx: click.Context) -> None:
'List the current relay config' 'List the current relay config'
click.echo('Relay Config:') click.echo('Relay Config:')
for key, value in app.config.items(): with ctx.obj.database.connection() as conn:
for key, value in conn.get_config_all().items():
if key not in CONFIG_IGNORE: if key not in CONFIG_IGNORE:
key = f'{key}:'.ljust(20) key = f'{key}:'.ljust(20)
click.echo(f'- {key} {value}') click.echo(f'- {key} {value}')
@ -131,13 +304,14 @@ def cli_config_list() -> None:
@cli_config.command('set') @cli_config.command('set')
@click.argument('key') @click.argument('key')
@click.argument('value') @click.argument('value')
def cli_config_set(key: str, value: Any) -> None: @click.pass_context
def cli_config_set(ctx: click.Context, key: str, value: Any) -> None:
'Set a config value' 'Set a config value'
app.config[key] = value with ctx.obj.database.connection() as conn:
app.config.save() new_value = conn.put_config(key, value)
print(f'{key}: {app.config[key]}') print(f'{key}: {repr(new_value)}')
@cli.group('inbox') @cli.group('inbox')
@ -146,37 +320,36 @@ def cli_inbox() -> None:
@cli_inbox.command('list') @cli_inbox.command('list')
def cli_inbox_list() -> None: @click.pass_context
def cli_inbox_list(ctx: click.Context) -> None:
'List the connected instances or relays' 'List the connected instances or relays'
click.echo('Connected to the following instances or relays:') click.echo('Connected to the following instances or relays:')
for inbox in app.database.inboxes: with ctx.obj.database.connection() as conn:
click.echo(f'- {inbox}') for inbox in conn.execute('SELECT * FROM inboxes'):
click.echo(f'- {inbox["inbox"]}')
@cli_inbox.command('follow') @cli_inbox.command('follow')
@click.argument('actor') @click.argument('actor')
def cli_inbox_follow(actor: str) -> None: @click.pass_context
def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
'Follow an actor (Relay must be running)' 'Follow an actor (Relay must be running)'
if app.config.is_banned(actor): with ctx.obj.database.connection() as conn:
if conn.get_domain_ban(actor):
click.echo(f'Error: Refusing to follow banned actor: {actor}') click.echo(f'Error: Refusing to follow banned actor: {actor}')
return return
if not actor.startswith('http'): if (inbox_data := conn.get_inbox(actor)):
domain = actor
actor = f'https://{actor}/actor'
else:
domain = urlparse(actor).hostname
try:
inbox_data = app.database['relay-list'][domain]
inbox = inbox_data['inbox'] inbox = inbox_data['inbox']
except KeyError: else:
actor_data = asyncio.run(http.get(app.database, actor, sign_headers=True)) if not actor.startswith('http'):
actor = f'https://{actor}/actor'
actor_data = asyncio.run(http.get(actor, sign_headers = True))
if not actor_data: if not actor_data:
click.echo(f'Failed to fetch actor: {actor}') click.echo(f'Failed to fetch actor: {actor}')
@ -184,90 +357,109 @@ def cli_inbox_follow(actor: str) -> None:
inbox = actor_data.shared_inbox inbox = actor_data.shared_inbox
message = misc.Message.new_follow( message = Message.new_follow(
host = app.config.host, host = ctx.obj.config.domain,
actor = actor actor = actor
) )
asyncio.run(http.post(app.database, inbox, message)) asyncio.run(http.post(inbox, message))
click.echo(f'Sent follow message to actor: {actor}') click.echo(f'Sent follow message to actor: {actor}')
@cli_inbox.command('unfollow') @cli_inbox.command('unfollow')
@click.argument('actor') @click.argument('actor')
def cli_inbox_unfollow(actor: str) -> None: @click.pass_context
def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
'Unfollow an actor (Relay must be running)' 'Unfollow an actor (Relay must be running)'
if not actor.startswith('http'): inbox_data: Row = None
domain = actor
actor = f'https://{actor}/actor'
else: with ctx.obj.database.connection() as conn:
domain = urlparse(actor).hostname if conn.get_domain_ban(actor):
click.echo(f'Error: Refusing to follow banned actor: {actor}')
return
try: if (inbox_data := conn.get_inbox(actor)):
inbox_data = app.database['relay-list'][domain]
inbox = inbox_data['inbox'] inbox = inbox_data['inbox']
message = misc.Message.new_unfollow( message = Message.new_unfollow(
host = app.config.host, host = ctx.obj.config.domain,
actor = actor, actor = actor,
follow = inbox_data['followid'] follow = inbox_data['followid']
) )
except KeyError: else:
actor_data = asyncio.run(http.get(app.database, actor, sign_headers=True)) if not actor.startswith('http'):
actor = f'https://{actor}/actor'
actor_data = asyncio.run(http.get(actor, sign_headers = True))
inbox = actor_data.shared_inbox inbox = actor_data.shared_inbox
message = misc.Message.new_unfollow( message = Message.new_unfollow(
host = app.config.host, host = ctx.obj.config.domain,
actor = actor, actor = actor,
follow = { follow = {
'type': 'Follow', 'type': 'Follow',
'object': actor, 'object': actor,
'actor': f'https://{app.config.host}/actor' 'actor': f'https://{ctx.obj.config.domain}/actor'
} }
) )
asyncio.run(http.post(app.database, inbox, message)) asyncio.run(http.post(inbox, message))
click.echo(f'Sent unfollow message to: {actor}') click.echo(f'Sent unfollow message to: {actor}')
@cli_inbox.command('add') @cli_inbox.command('add')
@click.argument('inbox') @click.argument('inbox')
def cli_inbox_add(inbox: str) -> None: @click.option('--actor', '-a', help = 'Actor url for the inbox')
@click.option('--followid', '-f', help = 'Url for the follow activity')
@click.option('--software', '-s', type = click.Choice(SOFTWARE))
@click.pass_context
def cli_inbox_add(
ctx: click.Context,
inbox: str,
actor: Optional[str] = None,
followid: Optional[str] = None,
software: Optional[str] = None) -> None:
'Add an inbox to the database' 'Add an inbox to the database'
if not inbox.startswith('http'): if not inbox.startswith('http'):
domain = inbox
inbox = f'https://{inbox}/inbox' inbox = f'https://{inbox}/inbox'
if app.config.is_banned(inbox): else:
click.echo(f'Error: Refusing to add banned inbox: {inbox}') domain = urlparse(inbox).netloc
if not actor and software:
try:
actor = ACTOR_FORMATS[software].format(domain = domain)
except KeyError:
pass
with ctx.obj.database.connection() as conn:
if conn.get_domain_ban(domain):
click.echo(f'Refusing to add banned inbox: {inbox}')
return return
if app.database.get_inbox(inbox): if conn.get_inbox(inbox):
click.echo(f'Error: Inbox already in database: {inbox}') click.echo(f'Error: Inbox already in database: {inbox}')
return return
app.database.add_inbox(inbox) conn.put_inbox(domain, inbox, actor, followid, software)
app.database.save()
click.echo(f'Added inbox to the database: {inbox}') click.echo(f'Added inbox to the database: {inbox}')
@cli_inbox.command('remove') @cli_inbox.command('remove')
@click.argument('inbox') @click.argument('inbox')
def cli_inbox_remove(inbox: str) -> None: @click.pass_context
def cli_inbox_remove(ctx: click.Context, inbox: str) -> None:
'Remove an inbox from the database' 'Remove an inbox from the database'
try: with ctx.obj.database.connection() as conn:
dbinbox = app.database.get_inbox(inbox, fail=True) if not conn.del_inbox(inbox):
click.echo(f'Inbox not in database: {inbox}')
except KeyError:
click.echo(f'Error: Inbox does not exist: {inbox}')
return return
app.database.del_inbox(dbinbox['domain'])
app.database.save()
click.echo(f'Removed inbox from the database: {inbox}') click.echo(f'Removed inbox from the database: {inbox}')
@ -277,47 +469,76 @@ def cli_instance() -> None:
@cli_instance.command('list') @cli_instance.command('list')
def cli_instance_list() -> None: @click.pass_context
def cli_instance_list(ctx: click.Context) -> None:
'List all banned instances' 'List all banned instances'
click.echo('Banned instances or relays:') click.echo('Banned domains:')
for domain in app.config.blocked_instances: with ctx.obj.database.connection() as conn:
click.echo(f'- {domain}') for instance in conn.execute('SELECT * FROM domain_bans'):
if instance['reason']:
click.echo(f'- {instance["domain"]} ({instance["reason"]})')
else:
click.echo(f'- {instance["domain"]}')
@cli_instance.command('ban') @cli_instance.command('ban')
@click.argument('target') @click.argument('domain')
def cli_instance_ban(target: str) -> None: @click.option('--reason', '-r', help = 'Public note about why the domain is banned')
@click.option('--note', '-n', help = 'Internal note that will only be seen by admins and mods')
@click.pass_context
def cli_instance_ban(ctx: click.Context, domain: str, reason: str, note: str) -> None:
'Ban an instance and remove the associated inbox if it exists' 'Ban an instance and remove the associated inbox if it exists'
if target.startswith('http'): with ctx.obj.database.connection() as conn:
target = urlparse(target).hostname if conn.get_domain_ban(domain):
click.echo(f'Domain already banned: {domain}')
if app.config.ban_instance(target):
app.config.save()
if app.database.del_inbox(target):
app.database.save()
click.echo(f'Banned instance: {target}')
return return
click.echo(f'Instance already banned: {target}') conn.put_domain_ban(domain, reason, note)
conn.del_inbox(domain)
click.echo(f'Banned instance: {domain}')
@cli_instance.command('unban') @cli_instance.command('unban')
@click.argument('target') @click.argument('domain')
def cli_instance_unban(target: str) -> None: @click.pass_context
def cli_instance_unban(ctx: click.Context, domain: str) -> None:
'Unban an instance' 'Unban an instance'
if app.config.unban_instance(target): with ctx.obj.database.connection() as conn:
app.config.save() if not conn.del_domain_ban(domain):
click.echo(f'Instance wasn\'t banned: {domain}')
click.echo(f'Unbanned instance: {target}')
return return
click.echo(f'Instance wasn\'t banned: {target}') click.echo(f'Unbanned instance: {domain}')
@cli_instance.command('update')
@click.argument('domain')
@click.option('--reason', '-r')
@click.option('--note', '-n')
@click.pass_context
def cli_instance_update(ctx: click.Context, domain: str, reason: str, note: str) -> None:
'Update the public reason or internal note for a domain ban'
if not (reason or note):
ctx.fail('Must pass --reason or --note')
with ctx.obj.database.connection() as conn:
if not (row := conn.update_domain_ban(domain, reason, note)):
click.echo(f'Failed to update domain ban: {domain}')
return
click.echo(f'Updated domain ban: {domain}')
if row['reason']:
click.echo(f'- {row["domain"]} ({row["reason"]})')
else:
click.echo(f'- {row["domain"]}')
@cli.group('software') @cli.group('software')
@ -326,79 +547,131 @@ def cli_software() -> None:
@cli_software.command('list') @cli_software.command('list')
def cli_software_list() -> None: @click.pass_context
def cli_software_list(ctx: click.Context) -> None:
'List all banned software' 'List all banned software'
click.echo('Banned software:') click.echo('Banned software:')
for software in app.config.blocked_software: with ctx.obj.database.connection() as conn:
click.echo(f'- {software}') for software in conn.execute('SELECT * FROM software_bans'):
if software['reason']:
click.echo(f'- {software["name"]} ({software["reason"]})')
else:
click.echo(f'- {software["name"]}')
@cli_software.command('ban') @cli_software.command('ban')
@click.option(
'--fetch-nodeinfo/--ignore-nodeinfo', '-f', 'fetch_nodeinfo', default = False,
help = 'Treat NAME like a domain and try to fet the software name from nodeinfo'
)
@click.argument('name') @click.argument('name')
def cli_software_ban(name: str, fetch_nodeinfo: bool) -> None: @click.option('--reason', '-r')
@click.option('--note', '-n')
@click.option(
'--fetch-nodeinfo', '-f',
is_flag = True,
help = 'Treat NAME like a domain and try to fetch the software name from nodeinfo'
)
@click.pass_context
def cli_software_ban(ctx: click.Context,
name: str,
reason: str,
note: str,
fetch_nodeinfo: bool) -> None:
'Ban software. Use RELAYS for NAME to ban relays' 'Ban software. Use RELAYS for NAME to ban relays'
with ctx.obj.database.connection() as conn:
if name == 'RELAYS': if name == 'RELAYS':
for software in RELAY_SOFTWARE: for software in RELAY_SOFTWARE:
app.config.ban_software(software) if conn.get_software_ban(software):
click.echo(f'Relay already banned: {software}')
continue
conn.put_software_ban(software, reason or 'relay', note)
app.config.save()
click.echo('Banned all relay software') click.echo('Banned all relay software')
return return
if fetch_nodeinfo: if fetch_nodeinfo:
nodeinfo = asyncio.run(http.fetch_nodeinfo(app.database, name)) nodeinfo = asyncio.run(http.fetch_nodeinfo(name))
if not nodeinfo: if not nodeinfo:
click.echo(f'Failed to fetch software name from domain: {name}') click.echo(f'Failed to fetch software name from domain: {name}')
return
name = nodeinfo.sw_name name = nodeinfo.sw_name
if app.config.ban_software(name): if conn.get_software_ban(name):
app.config.save() click.echo(f'Software already banned: {name}')
click.echo(f'Banned software: {name}')
return return
click.echo(f'Software already banned: {name}') if not conn.put_software_ban(name, reason, note):
click.echo(f'Failed to ban software: {name}')
return
click.echo(f'Banned software: {name}')
@cli_software.command('unban') @cli_software.command('unban')
@click.option(
'--fetch-nodeinfo/--ignore-nodeinfo', '-f', 'fetch_nodeinfo', default = False,
help = 'Treat NAME like a domain and try to fet the software name from nodeinfo'
)
@click.argument('name') @click.argument('name')
def cli_software_unban(name: str, fetch_nodeinfo: bool) -> None: @click.option('--reason', '-r')
@click.option('--note', '-n')
@click.option(
'--fetch-nodeinfo', '-f',
is_flag = True,
help = 'Treat NAME like a domain and try to fetch the software name from nodeinfo'
)
@click.pass_context
def cli_software_unban(ctx: click.Context, name: str, fetch_nodeinfo: bool) -> None:
'Ban software. Use RELAYS for NAME to unban relays' 'Ban software. Use RELAYS for NAME to unban relays'
with ctx.obj.database.connection() as conn:
if name == 'RELAYS': if name == 'RELAYS':
for software in RELAY_SOFTWARE: for software in RELAY_SOFTWARE:
app.config.unban_software(software) if not conn.del_software_ban(software):
click.echo(f'Relay was not banned: {software}')
app.config.save()
click.echo('Unbanned all relay software') click.echo('Unbanned all relay software')
return return
if fetch_nodeinfo: if fetch_nodeinfo:
nodeinfo = asyncio.run(http.fetch_nodeinfo(app.database, name)) nodeinfo = asyncio.run(http.fetch_nodeinfo(name))
if not nodeinfo: if not nodeinfo:
click.echo(f'Failed to fetch software name from domain: {name}') click.echo(f'Failed to fetch software name from domain: {name}')
return
name = nodeinfo.sw_name name = nodeinfo.sw_name
if app.config.unban_software(name): if not conn.del_software_ban(name):
app.config.save() click.echo(f'Software was not banned: {name}')
click.echo(f'Unbanned software: {name}')
return return
click.echo(f'Software wasn\'t banned: {name}') click.echo(f'Unbanned software: {name}')
@cli_software.command('update')
@click.argument('name')
@click.option('--reason', '-r')
@click.option('--note', '-n')
@click.pass_context
def cli_software_update(ctx: click.Context, name: str, reason: str, note: str) -> None:
'Update the public reason or internal note for a software ban'
if not (reason or note):
ctx.fail('Must pass --reason or --note')
with ctx.obj.database.connection() as conn:
if not (row := conn.update_software_ban(name, reason, note)):
click.echo(f'Failed to update software ban: {name}')
return
click.echo(f'Updated software ban: {name}')
if row['reason']:
click.echo(f'- {row["name"]} ({row["reason"]})')
else:
click.echo(f'- {row["name"]}')
@cli.group('whitelist') @cli.group('whitelist')
@ -407,52 +680,64 @@ def cli_whitelist() -> None:
@cli_whitelist.command('list') @cli_whitelist.command('list')
def cli_whitelist_list() -> None: @click.pass_context
def cli_whitelist_list(ctx: click.Context) -> None:
'List all the instances in the whitelist' 'List all the instances in the whitelist'
click.echo('Current whitelisted domains') click.echo('Current whitelisted domains:')
for domain in app.config.whitelist: with ctx.obj.database.connection() as conn:
click.echo(f'- {domain}') for domain in conn.execute('SELECT * FROM whitelist'):
click.echo(f'- {domain["domain"]}')
@cli_whitelist.command('add') @cli_whitelist.command('add')
@click.argument('instance') @click.argument('domain')
def cli_whitelist_add(instance: str) -> None: @click.pass_context
'Add an instance to the whitelist' def cli_whitelist_add(ctx: click.Context, domain: str) -> None:
'Add a domain to the whitelist'
if not app.config.add_whitelist(instance): with ctx.obj.database.connection() as conn:
click.echo(f'Instance already in the whitelist: {instance}') if conn.get_domain_whitelist(domain):
click.echo(f'Instance already in the whitelist: {domain}')
return return
app.config.save() conn.put_domain_whitelist(domain)
click.echo(f'Instance added to the whitelist: {instance}') click.echo(f'Instance added to the whitelist: {domain}')
@cli_whitelist.command('remove') @cli_whitelist.command('remove')
@click.argument('instance') @click.argument('domain')
def cli_whitelist_remove(instance: str) -> None: @click.pass_context
def cli_whitelist_remove(ctx: click.Context, domain: str) -> None:
'Remove an instance from the whitelist' 'Remove an instance from the whitelist'
if not app.config.del_whitelist(instance): with ctx.obj.database.connection() as conn:
click.echo(f'Instance not in the whitelist: {instance}') if not conn.del_domain_whitelist(domain):
click.echo(f'Domain not in the whitelist: {domain}')
return return
app.config.save() if conn.get_config('whitelist-enabled'):
if conn.del_inbox(domain):
click.echo(f'Removed inbox for domain: {domain}')
if app.config.whitelist_enabled: click.echo(f'Removed domain from the whitelist: {domain}')
if app.database.del_inbox(instance):
app.database.save()
click.echo(f'Removed instance from the whitelist: {instance}')
@cli_whitelist.command('import') @cli_whitelist.command('import')
def cli_whitelist_import() -> None: @click.pass_context
def cli_whitelist_import(ctx: click.Context) -> None:
'Add all current inboxes to the whitelist' 'Add all current inboxes to the whitelist'
for domain in app.database.hostnames: with ctx.obj.database.connection() as conn:
cli_whitelist_add.callback(domain) for inbox in conn.execute('SELECT * FROM inboxes').all():
if conn.get_domain_whitelist(inbox['domain']):
click.echo(f'Domain already in whitelist: {inbox["domain"]}')
continue
conn.put_domain_whitelist(inbox['domain'])
click.echo('Imported whitelist from inboxes')
def main() -> None: def main() -> None:

View file

@ -1,32 +1,27 @@
from __future__ import annotations from __future__ import annotations
import json import json
import os
import socket import socket
import traceback
import typing import typing
from aiohttp.abc import AbstractView from aiohttp.abc import AbstractView
from aiohttp.hdrs import METH_ALL as METHODS from aiohttp.hdrs import METH_ALL as METHODS
from aiohttp.web import Request as AiohttpRequest, Response as AiohttpResponse from aiohttp.web import Response as AiohttpResponse
from aiohttp.web_exceptions import HTTPMethodNotAllowed from aiohttp.web_exceptions import HTTPMethodNotAllowed
from aputils.errors import SignatureFailureError
from aputils.misc import Digest, HttpDate, Signature
from aputils.message import Message as ApMessage from aputils.message import Message as ApMessage
from functools import cached_property from functools import cached_property
from json.decoder import JSONDecodeError
from uuid import uuid4 from uuid import uuid4
from . import logger as logging
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from typing import Any, Coroutine, Generator, Optional, Type from typing import Any, Coroutine, Generator, Optional, Type
from aputils.signer import Signer
from .application import Application from .application import Application
from .config import RelayConfig from .config import Config
from .database import RelayDatabase from .database import Database
from .http_client import HttpClient from .http_client import HttpClient
IS_DOCKER = bool(os.environ.get('DOCKER_RUNNING'))
MIMETYPES = { MIMETYPES = {
'activity': 'application/activity+json', 'activity': 'application/activity+json',
'html': 'text/html', 'html': 'text/html',
@ -77,91 +72,13 @@ def check_open_port(host: str, port: int) -> bool:
return False return False
class DotDict(dict): def get_app() -> Application:
def __init__(self, _data: dict[str, Any], **kwargs: Any): from .application import Application # pylint: disable=import-outside-toplevel
dict.__init__(self)
self.update(_data, **kwargs) if not Application.DEFAULT:
raise ValueError('No default application set')
return Application.DEFAULT
def __getattr__(self, key: str) -> str:
try:
return self[key]
except KeyError:
raise AttributeError(
f'{self.__class__.__name__} object has no attribute {key}'
) from None
def __setattr__(self, key: str, value: Any) -> None:
if key.startswith('_'):
super().__setattr__(key, value)
else:
self[key] = value
def __setitem__(self, key: str, value: Any) -> None:
if type(value) is dict: # pylint: disable=unidiomatic-typecheck
value = DotDict(value)
super().__setitem__(key, value)
def __delattr__(self, key: str) -> None:
try:
dict.__delitem__(self, key)
except KeyError:
raise AttributeError(
f'{self.__class__.__name__} object has no attribute {key}'
) from None
@classmethod
def new_from_json(cls: Type[DotDict], data: dict[str, Any]) -> DotDict[str, Any]:
if not data:
raise JSONDecodeError('Empty body', data, 1)
try:
return cls(json.loads(data))
except ValueError:
raise JSONDecodeError('Invalid body', data, 1) from None
@classmethod
def new_from_signature(cls: Type[DotDict], sig: str) -> DotDict[str, Any]:
data = cls({})
for chunk in sig.strip().split(','):
key, value = chunk.split('=', 1)
value = value.strip('\"')
if key == 'headers':
value = value.split()
data[key.lower()] = value
return data
def to_json(self, indent: Optional[int | str] = None) -> str:
return json.dumps(self, indent=indent)
def update(self, _data: dict[str, Any], **kwargs: Any) -> None:
if isinstance(_data, dict):
for key, value in _data.items():
self[key] = value
elif isinstance(_data, (list, tuple, set)):
for key, value in _data:
self[key] = value
for key, value in kwargs.items():
self[key] = value
class Message(ApMessage): class Message(ApMessage):
@ -181,7 +98,7 @@ class Message(ApMessage):
'followers': f'https://{host}/followers', 'followers': f'https://{host}/followers',
'following': f'https://{host}/following', 'following': f'https://{host}/following',
'inbox': f'https://{host}/inbox', 'inbox': f'https://{host}/inbox',
'url': f'https://{host}/inbox', 'url': f'https://{host}/',
'endpoints': { 'endpoints': {
'sharedInbox': f'https://{host}/inbox' 'sharedInbox': f'https://{host}/inbox'
}, },
@ -310,16 +227,6 @@ class Response(AiohttpResponse):
class View(AbstractView): class View(AbstractView):
def __init__(self, request: AiohttpRequest):
AbstractView.__init__(self, request)
self.signature: Signature = None
self.message: Message = None
self.actor: Message = None
self.instance: dict[str, str] = None
self.signer: Signer = None
def __await__(self) -> Generator[Response]: def __await__(self) -> Generator[Response]:
method = self.request.method.upper() method = self.request.method.upper()
@ -363,94 +270,10 @@ class View(AbstractView):
@property @property
def config(self) -> RelayConfig: def config(self) -> Config:
return self.app.config return self.app.config
@property @property
def database(self) -> RelayDatabase: def database(self) -> Database:
return self.app.database return self.app.database
# todo: move to views.ActorView
async def get_post_data(self) -> Response | None:
try:
self.signature = Signature.new_from_signature(self.request.headers['signature'])
except KeyError:
logging.verbose('Missing signature header')
return Response.new_error(400, 'missing signature header', 'json')
try:
self.message = await self.request.json(loads = Message.parse)
except Exception:
traceback.print_exc()
logging.verbose('Failed to parse inbox message')
return Response.new_error(400, 'failed to parse message', 'json')
if self.message is None:
logging.verbose('empty message')
return Response.new_error(400, 'missing message', 'json')
if 'actor' not in self.message:
logging.verbose('actor not in message')
return Response.new_error(400, 'no actor in message', 'json')
self.actor = await self.client.get(self.signature.keyid, sign_headers = True)
if self.actor is None:
# ld signatures aren't handled atm, so just ignore it
if self.message.type == 'Delete':
logging.verbose('Instance sent a delete which cannot be handled')
return Response.new(status=202)
logging.verbose(f'Failed to fetch actor: {self.signature.keyid}')
return Response.new_error(400, 'failed to fetch actor', 'json')
try:
self.signer = self.actor.signer
except KeyError:
logging.verbose('Actor missing public key: %s', self.signature.keyid)
return Response.new_error(400, 'actor missing public key', 'json')
try:
self.validate_signature(await self.request.read())
except SignatureFailureError as e:
logging.verbose('signature validation failed for "%s": %s', self.actor.id, e)
return Response.new_error(401, str(e), 'json')
self.instance = self.database.get_inbox(self.actor.inbox)
def validate_signature(self, body: bytes) -> None:
headers = {key.lower(): value for key, value in self.request.headers.items()}
headers["(request-target)"] = " ".join([self.request.method.lower(), self.request.path])
if (digest := Digest.new_from_digest(headers.get("digest"))):
if not body:
raise SignatureFailureError("Missing body for digest verification")
if not digest.validate(body):
raise SignatureFailureError("Body digest does not match")
if self.signature.algorithm_type == "hs2019":
if "(created)" not in self.signature.headers:
raise SignatureFailureError("'(created)' header not used")
current_timestamp = HttpDate.new_utc().timestamp()
if self.signature.created > current_timestamp:
raise SignatureFailureError("Creation date after current date")
if current_timestamp > self.signature.expires:
raise SignatureFailureError("Expiration date before current date")
headers["(created)"] = self.signature.created
headers["(expires)"] = self.signature.expires
# pylint: disable=protected-access
if not self.signer._validate_signature(headers, self.signature):
raise SignatureFailureError("Signature does not match")

View file

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import tinysql
import typing import typing
from cachetools import LRUCache from cachetools import LRUCache
@ -8,7 +9,7 @@ from . import logger as logging
from .misc import Message from .misc import Message
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from .misc import View from .views import ActorView
cache = LRUCache(1024) cache = LRUCache(1024)
@ -16,74 +17,76 @@ cache = LRUCache(1024)
def person_check(actor: str, software: str) -> bool: def person_check(actor: str, software: str) -> bool:
# pleroma and akkoma may use Person for the actor type for some reason # pleroma and akkoma may use Person for the actor type for some reason
# akkoma changed this in a 3.6.0
if software in {'akkoma', 'pleroma'} and actor.id == f'https://{actor.domain}/relay': if software in {'akkoma', 'pleroma'} and actor.id == f'https://{actor.domain}/relay':
return False return False
## make sure the actor is an application # make sure the actor is an application
if actor.type != 'Application': if actor.type != 'Application':
return True return True
return False return False
async def handle_relay(view: View) -> None: async def handle_relay(view: ActorView) -> None:
if view.message.object_id in cache: if view.message.object_id in cache:
logging.verbose('already relayed %s', view.message.object_id) logging.verbose('already relayed %s', view.message.object_id)
return return
message = Message.new_announce(view.config.host, view.message.object_id) message = Message.new_announce(view.config.domain, view.message.object_id)
cache[view.message.object_id] = message.id cache[view.message.object_id] = message.id
logging.debug('>> relay: %s', message) logging.debug('>> relay: %s', message)
inboxes = view.database.distill_inboxes(view.message) with view.database.connection() as conn:
for inbox in conn.distill_inboxes(view.message):
for inbox in inboxes:
view.app.push_message(inbox, message) view.app.push_message(inbox, message)
async def handle_forward(view: View) -> None: async def handle_forward(view: ActorView) -> None:
if view.message.id in cache: if view.message.id in cache:
logging.verbose('already forwarded %s', view.message.id) logging.verbose('already forwarded %s', view.message.id)
return return
message = Message.new_announce(view.config.host, view.message) message = Message.new_announce(view.config.domain, view.message)
cache[view.message.id] = message.id cache[view.message.id] = message.id
logging.debug('>> forward: %s', message) logging.debug('>> forward: %s', message)
inboxes = view.database.distill_inboxes(view.message) with view.database.connection() as conn:
for inbox in conn.distill_inboxes(view.message):
for inbox in inboxes:
view.app.push_message(inbox, message) view.app.push_message(inbox, message)
async def handle_follow(view: View) -> None: async def handle_follow(view: ActorView) -> None:
nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain) nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain)
software = nodeinfo.sw_name if nodeinfo else None software = nodeinfo.sw_name if nodeinfo else None
## reject if software used by actor is banned with view.database.connection() as conn:
# reject if software used by actor is banned
if view.config.is_banned_software(software): if view.config.is_banned_software(software):
view.app.push_message( view.app.push_message(
view.actor.shared_inbox, view.actor.shared_inbox,
Message.new_response( Message.new_response(
host = view.config.host, host = view.config.domain,
actor = view.actor.id, actor = view.actor.id,
followid = view.message.id, followid = view.message.id,
accept = False accept = False
) )
) )
return logging.verbose( logging.verbose(
'Rejected follow from actor for using specific software: actor=%s, software=%s', 'Rejected follow from actor for using specific software: actor=%s, software=%s',
view.actor.id, view.actor.id,
software software
) )
return
## reject if the actor is not an instance actor ## reject if the actor is not an instance actor
if person_check(view.actor, software): if person_check(view.actor, software):
view.app.push_message( view.app.push_message(
view.actor.shared_inbox, view.actor.shared_inbox,
Message.new_response( Message.new_response(
host = view.config.host, host = view.config.domain,
actor = view.actor.id, actor = view.actor.id,
followid = view.message.id, followid = view.message.id,
accept = False accept = False
@ -93,13 +96,26 @@ async def handle_follow(view: View) -> None:
logging.verbose('Non-application actor tried to follow: %s', view.actor.id) logging.verbose('Non-application actor tried to follow: %s', view.actor.id)
return return
view.database.add_inbox(view.actor.shared_inbox, view.message.id, software) if conn.get_inbox(view.actor.shared_inbox):
view.database.save() data = {'followid': view.message.id}
statement = tinysql.Update('inboxes', data, inbox = view.actor.shared_inbox)
with conn.query(statement):
pass
else:
conn.put_inbox(
view.actor.domain,
view.actor.shared_inbox,
view.actor.id,
view.message.id,
software
)
view.app.push_message( view.app.push_message(
view.actor.shared_inbox, view.actor.shared_inbox,
Message.new_response( Message.new_response(
host = view.config.host, host = view.config.domain,
actor = view.actor.id, actor = view.actor.id,
followid = view.message.id, followid = view.message.id,
accept = True accept = True
@ -112,32 +128,30 @@ async def handle_follow(view: View) -> None:
view.app.push_message( view.app.push_message(
view.actor.shared_inbox, view.actor.shared_inbox,
Message.new_follow( Message.new_follow(
host = view.config.host, host = view.config.domain,
actor = view.actor.id actor = view.actor.id
) )
) )
async def handle_undo(view: View) -> None: async def handle_undo(view: ActorView) -> None:
## If the object is not a Follow, forward it ## If the object is not a Follow, forward it
if view.message.object['type'] != 'Follow': if view.message.object['type'] != 'Follow':
return await handle_forward(view) await handle_forward(view)
return
if not view.database.del_inbox(view.actor.domain, view.message.object['id']): with view.database.connection() as conn:
if not conn.del_inbox(view.actor.inbox):
logging.verbose( logging.verbose(
'Failed to delete "%s" with follow ID "%s"', 'Failed to delete "%s" with follow ID "%s"',
view.actor.id, view.actor.id,
view.message.object['id'] view.message.object['id']
) )
return
view.database.save()
view.app.push_message( view.app.push_message(
view.actor.shared_inbox, view.actor.shared_inbox,
Message.new_unfollow( Message.new_unfollow(
host = view.config.host, host = view.config.domain,
actor = view.actor.id, actor = view.actor.id,
follow = view.message follow = view.message
) )
@ -154,7 +168,7 @@ processors = {
} }
async def run_processor(view: View) -> None: async def run_processor(view: ActorView) -> None:
if view.message.type not in processors: if view.message.type not in processors:
logging.verbose( logging.verbose(
'Message type "%s" from actor cannot be handled: %s', 'Message type "%s" from actor cannot be handled: %s',
@ -164,12 +178,21 @@ async def run_processor(view: View) -> None:
return return
if view.instance and not view.instance.get('software'): if view.instance:
nodeinfo = await view.client.fetch_nodeinfo(view.instance['domain']) if not view.instance['software']:
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])):
with view.database.connection() as conn:
view.instance = conn.update_inbox(
view.instance['inbox'],
software = nodeinfo.sw_name
)
if nodeinfo: if not view.instance['actor']:
view.instance['software'] = nodeinfo.sw_name with view.database.connection() as conn:
view.database.save() view.instance = conn.update_inbox(
view.instance['inbox'],
actor = view.actor.id
)
logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id) logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id)
await processors[view.message.type](view) await processors[view.message.type](view)

View file

@ -2,8 +2,11 @@ from __future__ import annotations
import asyncio import asyncio
import subprocess import subprocess
import traceback
import typing import typing
from aputils.errors import SignatureFailureError
from aputils.misc import Digest, HttpDate, Signature
from aputils.objects import Nodeinfo, Webfinger, WellKnownNodeinfo from aputils.objects import Nodeinfo, Webfinger, WellKnownNodeinfo
from pathlib import Path from pathlib import Path
@ -14,6 +17,7 @@ from .processors import run_processor
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from aiohttp.web import Request from aiohttp.web import Request
from aputils.signer import Signer
from typing import Callable from typing import Callable
@ -71,11 +75,15 @@ def register_route(*paths: str) -> Callable:
@register_route('/') @register_route('/')
class HomeView(View): class HomeView(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.connection() as conn:
config = conn.get_config_all()
inboxes = conn.execute('SELECT * FROM inboxes').all()
text = HOME_TEMPLATE.format( text = HOME_TEMPLATE.format(
host = self.config.host, host = self.config.domain,
note = self.config.note, note = config['note'],
count = len(self.database.hostnames), count = len(inboxes),
targets = '<br>'.join(self.database.hostnames) targets = '<br>'.join(inbox['domain'] for inbox in inboxes)
) )
return Response.new(text, ctype='html') return Response.new(text, ctype='html')
@ -84,33 +92,45 @@ class HomeView(View):
@register_route('/actor', '/inbox') @register_route('/actor', '/inbox')
class ActorView(View): class ActorView(View):
def __init__(self, request: Request):
View.__init__(self, request)
self.signature: Signature = None
self.message: Message = None
self.actor: Message = None
self.instance: dict[str, str] = None
self.signer: Signer = None
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
data = Message.new_actor( data = Message.new_actor(
host = self.config.host, host = self.config.domain,
pubkey = self.database.signer.pubkey pubkey = self.app.signer.pubkey
) )
return Response.new(data, ctype='activity') return Response.new(data, ctype='activity')
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
response = await self.get_post_data() if (response := await self.get_post_data()):
if response is not None:
return response return response
with self.database.connection() as conn:
self.instance = conn.get_inbox(self.actor.inbox)
config = conn.get_config_all()
## reject if the actor isn't whitelisted while the whiltelist is enabled ## reject if the actor isn't whitelisted while the whiltelist is enabled
if self.config.whitelist_enabled and not self.config.is_whitelisted(self.actor.domain): if config['whitelist-enabled'] and not conn.get_domain_whitelist(self.actor.domain):
logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id) logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id)
return Response.new_error(403, 'access denied', 'json') return Response.new_error(403, 'access denied', 'json')
## reject if actor is banned ## reject if actor is banned
if self.config.is_banned(self.actor.domain): if conn.get_domain_ban(self.actor.domain):
logging.verbose('Ignored request from banned actor: %s', self.actor.id) logging.verbose('Ignored request from banned actor: %s', self.actor.id)
return Response.new_error(403, 'access denied', 'json') return Response.new_error(403, 'access denied', 'json')
## reject if activity type isn't 'Follow' and the actor isn't following ## reject if activity type isn't 'Follow' and the actor isn't following
if self.message.type != 'Follow' and not self.database.get_inbox(self.actor.domain): if self.message.type != 'Follow' and not self.instance:
logging.verbose( logging.verbose(
'Rejected actor for trying to post while not following: %s', 'Rejected actor for trying to post while not following: %s',
self.actor.id self.actor.id
@ -124,6 +144,87 @@ class ActorView(View):
return Response.new(status = 202) return Response.new(status = 202)
async def get_post_data(self) -> Response | None:
try:
self.signature = Signature.new_from_signature(self.request.headers['signature'])
except KeyError:
logging.verbose('Missing signature header')
return Response.new_error(400, 'missing signature header', 'json')
try:
self.message = await self.request.json(loads = Message.parse)
except Exception:
traceback.print_exc()
logging.verbose('Failed to parse inbox message')
return Response.new_error(400, 'failed to parse message', 'json')
if self.message is None:
logging.verbose('empty message')
return Response.new_error(400, 'missing message', 'json')
if 'actor' not in self.message:
logging.verbose('actor not in message')
return Response.new_error(400, 'no actor in message', 'json')
self.actor = await self.client.get(self.signature.keyid, sign_headers = True)
if self.actor is None:
# ld signatures aren't handled atm, so just ignore it
if self.message.type == 'Delete':
logging.verbose('Instance sent a delete which cannot be handled')
return Response.new(status=202)
logging.verbose(f'Failed to fetch actor: {self.signature.keyid}')
return Response.new_error(400, 'failed to fetch actor', 'json')
try:
self.signer = self.actor.signer
except KeyError:
logging.verbose('Actor missing public key: %s', self.signature.keyid)
return Response.new_error(400, 'actor missing public key', 'json')
try:
self.validate_signature(await self.request.read())
except SignatureFailureError as e:
logging.verbose('signature validation failed for "%s": %s', self.actor.id, e)
return Response.new_error(401, str(e), 'json')
def validate_signature(self, body: bytes) -> None:
headers = {key.lower(): value for key, value in self.request.headers.items()}
headers["(request-target)"] = " ".join([self.request.method.lower(), self.request.path])
if (digest := Digest.new_from_digest(headers.get("digest"))):
if not body:
raise SignatureFailureError("Missing body for digest verification")
if not digest.validate(body):
raise SignatureFailureError("Body digest does not match")
if self.signature.algorithm_type == "hs2019":
if "(created)" not in self.signature.headers:
raise SignatureFailureError("'(created)' header not used")
current_timestamp = HttpDate.new_utc().timestamp()
if self.signature.created > current_timestamp:
raise SignatureFailureError("Creation date after current date")
if current_timestamp > self.signature.expires:
raise SignatureFailureError("Expiration date before current date")
headers["(created)"] = self.signature.created
headers["(expires)"] = self.signature.expires
# pylint: disable=protected-access
if not self.signer._validate_signature(headers, self.signature):
raise SignatureFailureError("Signature does not match")
@register_route('/.well-known/webfinger') @register_route('/.well-known/webfinger')
class WebfingerView(View): class WebfingerView(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
@ -133,12 +234,12 @@ class WebfingerView(View):
except KeyError: except KeyError:
return Response.new_error(400, 'missing "resource" query key', 'json') return Response.new_error(400, 'missing "resource" query key', 'json')
if subject != f'acct:relay@{self.config.host}': if subject != f'acct:relay@{self.config.domain}':
return Response.new_error(404, 'user not found', 'json') return Response.new_error(404, 'user not found', 'json')
data = Webfinger.new( data = Webfinger.new(
handle = 'relay', handle = 'relay',
domain = self.config.host, domain = self.config.domain,
actor = self.config.actor actor = self.config.actor
) )
@ -148,13 +249,16 @@ class WebfingerView(View):
@register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}') @register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}')
class NodeinfoView(View): class NodeinfoView(View):
async def get(self, request: Request, niversion: str) -> Response: async def get(self, request: Request, niversion: str) -> Response:
with self.database.connection() as conn:
inboxes = conn.execute('SELECT * FROM inboxes').all()
data = { data = {
'name': 'activityrelay', 'name': 'activityrelay',
'version': VERSION, 'version': VERSION,
'protocols': ['activitypub'], 'protocols': ['activitypub'],
'open_regs': not self.config.whitelist_enabled, 'open_regs': not conn.get_config('whitelist-enabled'),
'users': 1, 'users': 1,
'metadata': {'peers': self.database.hostnames} 'metadata': {'peers': [inbox['domain'] for inbox in inboxes]}
} }
if niversion == '2.1': if niversion == '2.1':
@ -166,5 +270,5 @@ class NodeinfoView(View):
@register_route('/.well-known/nodeinfo') @register_route('/.well-known/nodeinfo')
class WellknownNodeinfoView(View): class WellknownNodeinfoView(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
data = WellKnownNodeinfo.new_template(self.config.host) data = WellKnownNodeinfo.new_template(self.config.domain)
return Response.new(data, ctype = 'json') return Response.new(data, ctype = 'json')

View file

@ -3,3 +3,4 @@ aputils@https://git.barkshark.xyz/barkshark/aputils/archive/0.1.6a.tar.gz
cachetools>=5.2.0 cachetools>=5.2.0
click>=8.1.2 click>=8.1.2
pyyaml>=6.0 pyyaml>=6.0
tinysql@https://git.barkshark.xyz/barkshark/tinysql/archive/0.2.1.tar.gz