use postgresql/sqlite for database backend

This commit is contained in:
Izalia Mae 2024-01-22 05:32:16 -05:00
parent 9808674b98
commit 5f6aef1871
16 changed files with 1574 additions and 813 deletions

10
.gitignore vendored
View file

@ -94,9 +94,7 @@ ENV/
# Rope project settings
.ropeproject
viera.yaml
viera.jsonld
# config file
relay.yaml
relay.jsonld
# config and database
*.yaml
*.jsonld
*.sqlite3

View file

@ -1,43 +1,35 @@
# this is the path that the object graph will get dumped to (in JSON-LD format),
# you probably shouldn't change it, but you can if you want.
db: relay.jsonld
# [string] Domain the relay will be hosted on
domain: relay.example.com
# Listener
# [string] Address the relay will listen on
listen: 0.0.0.0
# [integer] Port the relay will listen on
port: 8080
# Note
note: "Make a note about your instance here."
# [integer] Number of push workers to start (will get removed in a future update)
workers: 8
# Number of worker threads to start. If 0, use asyncio futures instead of threads.
workers: 0
# [string] Database backend to use. Valid values: sqlite, postgres
database_type: sqlite
# Maximum number of inbox posts to do at once
# If workers is set to 1 or above, this is the max for each worker
push_limit: 512
# [string] Path to the sqlite database file if the sqlite backend is in use
sqlite_path: relay.sqlite3
# The amount of json objects to cache from GET requests
json_cache: 1024
# settings for the postgresql backend
postgres:
ap:
# This is used for generating activitypub messages, as well as instructions for
# linking AP identities. It should be an SSL-enabled domain reachable by https.
host: 'relay.example.com'
# [string] hostname or unix socket to connect to
host: /var/run/postgresql
blocked_instances:
- 'bad-instance.example.com'
- 'another-bad-instance.example.com'
# [integer] port of the server
port: 5432
whitelist_enabled: false
# [string] username to use when logging into the server (default is the current system username)
user: null
whitelist:
- 'good-instance.example.com'
- 'another.good-instance.example.com'
# [string] password of the user
pass: null
# uncomment the lines below to prevent certain activitypub software from posting
# to the relay (all known relays by default). this uses the software name in nodeinfo
#blocked_software:
#- 'activityrelay'
#- 'aoderelay'
#- 'social.seattle.wa.us-relay'
#- 'unciarelay'
# [string] name of the database to use
name: activityrelay

View file

@ -8,52 +8,41 @@ import traceback
import typing
from aiohttp import web
from aputils.signer import Signer
from datetime import datetime, timedelta
from . import logger as logging
from .config import RelayConfig
from .database import RelayDatabase
from .config import Config
from .database import get_database
from .http_client import HttpClient
from .misc import check_open_port
from .views import VIEWS
if typing.TYPE_CHECKING:
from tinysql import Database
from typing import Any
from .misc import Message
# pylint: disable=unsubscriptable-object
class Application(web.Application):
DEFAULT: Application = None
def __init__(self, cfgpath: str):
web.Application.__init__(self)
Application.DEFAULT = self
self['signer'] = None
self['config'] = Config(cfgpath, load = True)
self['database'] = get_database(self.config)
self['client'] = HttpClient()
self['workers'] = []
self['last_worker'] = 0
self['start_time'] = None
self['running'] = False
self['config'] = RelayConfig(cfgpath)
if not self.config.load():
self.config.save()
if self.config.is_docker:
self.config.update({
'db': '/data/relay.jsonld',
'listen': '0.0.0.0',
'port': 8080
})
self['database'] = RelayDatabase(self.config)
self.database.load()
self['client'] = HttpClient(
database = self.database,
limit = self.config.push_limit,
timeout = self.config.timeout,
cache_size = self.config.json_cache
)
for path, view in VIEWS:
self.router.add_view(path, view)
@ -65,15 +54,29 @@ class Application(web.Application):
@property
def config(self) -> RelayConfig:
def config(self) -> Config:
return self['config']
@property
def database(self) -> RelayDatabase:
def database(self) -> Database:
return self['database']
@property
def signer(self) -> Signer:
return self['signer']
@signer.setter
def signer(self, value: Signer | str) -> None:
if isinstance(value, Signer):
self['signer'] = value
return
self['signer'] = Signer(value, self.config.keyid)
@property
def uptime(self) -> timedelta:
if not self['start_time']:
@ -118,7 +121,7 @@ class Application(web.Application):
logging.info(
'Starting webserver at %s (%s:%i)',
self.config.host,
self.config.domain,
self.config.listen,
self.config.port
)
@ -179,12 +182,7 @@ class PushWorker(threading.Thread):
async def handle_queue(self) -> None:
self.client = HttpClient(
database = self.app.database,
limit = self.app.config.push_limit,
timeout = self.app.config.timeout,
cache_size = self.app.config.json_cache
)
self.client = HttpClient()
while self.app['running']:
try:

View file

@ -1,17 +1,128 @@
from __future__ import annotations
import json
import os
import typing
import yaml
from aputils.signer import Signer
from functools import cached_property
from pathlib import Path
from urllib.parse import urlparse
from . import logger as logging
from .misc import Message, boolean
if typing.TYPE_CHECKING:
from typing import Iterator, Optional
from .config import RelayConfig
from .misc import Message
from typing import Any, Iterator, Optional
# pylint: disable=duplicate-code
class RelayConfig(dict):
def __init__(self, path: str):
dict.__init__(self, {})
if self.is_docker:
path = '/data/config.yaml'
self._path = Path(path).expanduser().resolve()
self.reset()
def __setitem__(self, key: str, value: Any) -> None:
if key in ['blocked_instances', 'blocked_software', 'whitelist']:
assert isinstance(value, (list, set, tuple))
elif key in ['port', 'workers', 'json_cache', 'timeout']:
if not isinstance(value, int):
value = int(value)
elif key == 'whitelist_enabled':
if not isinstance(value, bool):
value = boolean(value)
super().__setitem__(key, value)
@property
def db(self) -> RelayDatabase:
return Path(self['db']).expanduser().resolve()
@property
def actor(self) -> str:
return f'https://{self["host"]}/actor'
@property
def inbox(self) -> str:
return f'https://{self["host"]}/inbox'
@property
def keyid(self) -> str:
return f'{self.actor}#main-key'
@cached_property
def is_docker(self) -> bool:
return bool(os.environ.get('DOCKER_RUNNING'))
def reset(self) -> None:
self.clear()
self.update({
'db': str(self._path.parent.joinpath(f'{self._path.stem}.jsonld')),
'listen': '0.0.0.0',
'port': 8080,
'note': 'Make a note about your instance here.',
'push_limit': 512,
'json_cache': 1024,
'timeout': 10,
'workers': 0,
'host': 'relay.example.com',
'whitelist_enabled': False,
'blocked_software': [],
'blocked_instances': [],
'whitelist': []
})
def load(self) -> None:
self.reset()
options = {}
try:
options['Loader'] = yaml.FullLoader
except AttributeError:
pass
try:
with self._path.open('r', encoding = 'UTF-8') as fd:
config = yaml.load(fd, **options)
except FileNotFoundError:
return
if not config:
return
for key, value in config.items():
if key in ['ap']:
for k, v in value.items():
if k not in self:
continue
self[k] = v
continue
if key not in self:
continue
self[key] = value
class RelayDatabase(dict):
@ -37,9 +148,7 @@ class RelayDatabase(dict):
return tuple(data['inbox'] for data in self['relay-list'].values())
def load(self) -> bool:
new_db = True
def load(self) -> None:
try:
with self.config.db.open() as fd:
data = json.load(fd)
@ -65,17 +174,9 @@ class RelayDatabase(dict):
self['relay-list'] = data.get('relay-list', {})
for domain, instance in self['relay-list'].items():
if self.config.is_banned(domain) or \
(self.config.whitelist_enabled and not self.config.is_whitelisted(domain)):
self.del_inbox(domain)
continue
if not instance.get('domain'):
instance['domain'] = domain
new_db = False
except FileNotFoundError:
pass
@ -83,17 +184,6 @@ class RelayDatabase(dict):
if self.config.db.stat().st_size > 0:
raise e from None
if not self['private-key']:
logging.info('No actor keys present, generating 4096-bit RSA keypair.')
self.signer = Signer.new(self.config.keyid, size=4096)
self['private-key'] = self.signer.export()
else:
self.signer = Signer(self['private-key'], self.config.keyid)
self.save()
return not new_db
def save(self) -> None:
with self.config.db.open('w', encoding = 'UTF-8') as fd:

View file

@ -1,76 +1,73 @@
from __future__ import annotations
import getpass
import os
import typing
import yaml
from functools import cached_property
from pathlib import Path
from urllib.parse import urlparse
from .misc import DotDict, boolean
from .misc import IS_DOCKER
if typing.TYPE_CHECKING:
from typing import Any
from .database import RelayDatabase
from typing import Any, Optional
RELAY_SOFTWARE = [
'activityrelay', # https://git.pleroma.social/pleroma/relay
'aoderelay', # https://git.asonix.dog/asonix/relay
'feditools-relay' # https://git.ptzo.gdn/feditools/relay
]
DEFAULTS: dict[str, Any] = {
'listen': '0.0.0.0',
'port': 8080,
'domain': 'relay.example.com',
'workers': len(os.sched_getaffinity(0)),
'db_type': 'sqlite',
'sq_path': 'relay.sqlite3',
'pg_host': '/var/run/postgresql',
'pg_port': 5432,
'pg_user': getpass.getuser(),
'pg_pass': None,
'pg_name': 'activityrelay'
}
APKEYS = [
'host',
'whitelist_enabled',
'blocked_software',
'blocked_instances',
'whitelist'
]
if IS_DOCKER:
DEFAULTS['sq_path'] = '/data/relay.jsonld'
class RelayConfig(DotDict):
__slots__ = ('path', )
class Config:
def __init__(self, path: str, load: Optional[bool] = False):
self.path = Path(path).expanduser().resolve()
def __init__(self, path: str | Path):
DotDict.__init__(self, {})
self.listen = None
self.port = None
self.domain = None
self.workers = None
self.db_type = None
self.sq_path = None
self.pg_host = None
self.pg_port = None
self.pg_user = None
self.pg_pass = None
self.pg_name = None
if self.is_docker:
path = '/data/config.yaml'
if load:
try:
self.load()
self._path = Path(path).expanduser().resolve()
self.reset()
def __setitem__(self, key: str, value: Any) -> None:
if key in ['blocked_instances', 'blocked_software', 'whitelist']:
assert isinstance(value, (list, set, tuple))
elif key in ['port', 'workers', 'json_cache', 'timeout']:
if not isinstance(value, int):
value = int(value)
elif key == 'whitelist_enabled':
if not isinstance(value, bool):
value = boolean(value)
super().__setitem__(key, value)
except FileNotFoundError:
self.save()
@property
def db(self) -> RelayDatabase:
return Path(self['db']).expanduser().resolve()
def sqlite_path(self) -> Path:
return Path(self.sq_path).expanduser().resolve()
@property
def actor(self) -> str:
return f'https://{self.host}/actor'
return f'https://{self.domain}/actor'
@property
def inbox(self) -> str:
return f'https://{self.host}/inbox'
return f'https://{self.domain}/inbox'
@property
@ -78,115 +75,7 @@ class RelayConfig(DotDict):
return f'{self.actor}#main-key'
@cached_property
def is_docker(self) -> bool:
return bool(os.environ.get('DOCKER_RUNNING'))
def reset(self) -> None:
self.clear()
self.update({
'db': str(self._path.parent.joinpath(f'{self._path.stem}.jsonld')),
'listen': '0.0.0.0',
'port': 8080,
'note': 'Make a note about your instance here.',
'push_limit': 512,
'json_cache': 1024,
'timeout': 10,
'workers': 0,
'host': 'relay.example.com',
'whitelist_enabled': False,
'blocked_software': [],
'blocked_instances': [],
'whitelist': []
})
def ban_instance(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
if self.is_banned(instance):
return False
self.blocked_instances.append(instance)
return True
def unban_instance(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
try:
self.blocked_instances.remove(instance)
return True
except ValueError:
return False
def ban_software(self, software: str) -> bool:
if self.is_banned_software(software):
return False
self.blocked_software.append(software)
return True
def unban_software(self, software: str) -> bool:
try:
self.blocked_software.remove(software)
return True
except ValueError:
return False
def add_whitelist(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
if self.is_whitelisted(instance):
return False
self.whitelist.append(instance)
return True
def del_whitelist(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
try:
self.whitelist.remove(instance)
return True
except ValueError:
return False
def is_banned(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
return instance in self.blocked_instances
def is_banned_software(self, software: str) -> bool:
if not software:
return False
return software.lower() in self.blocked_software
def is_whitelisted(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
return instance in self.whitelist
def load(self) -> bool:
def load(self) -> None:
self.reset()
options = {}
@ -197,50 +86,69 @@ class RelayConfig(DotDict):
except AttributeError:
pass
try:
with self._path.open('r', encoding = 'UTF-8') as fd:
config = yaml.load(fd, **options)
except FileNotFoundError:
return False
with self.path.open('r', encoding = 'UTF-8') as fd:
config = yaml.load(fd, **options)
pgcfg = config.get('postgresql', {})
if not config:
return False
raise ValueError('Config is empty')
for key, value in config.items():
if key in ['ap']:
for k, v in value.items():
if k not in self:
continue
if IS_DOCKER:
self.listen = '0.0.0.0'
self.port = 8080
self.sq_path = '/data/relay.jsonld'
self[k] = v
else:
self.set('listen', config.get('listen', DEFAULTS['listen']))
self.set('port', config.get('port', DEFAULTS['port']))
self.set('sq_path', config.get('sqlite_path', DEFAULTS['sq_path']))
self.set('domain', config.get('domain', DEFAULTS['domain']))
self.set('db_type', config.get('database_type', DEFAULTS['db_type']))
for key in DEFAULTS:
if not key.startswith('pg'):
continue
if key not in self:
try:
self.set(key, pgcfg[key[3:]])
except KeyError:
continue
self[key] = value
if self.host.endswith('example.com'):
return False
return True
def reset(self) -> None:
for key, value in DEFAULTS.items():
setattr(self, key, value)
def save(self) -> None:
self.path.parent.mkdir(exist_ok = True, parents = True)
config = {
# just turning config.db into a string is good enough for now
'db': str(self.db),
'listen': self.listen,
'port': self.port,
'note': self.note,
'push_limit': self.push_limit,
'workers': self.workers,
'json_cache': self.json_cache,
'timeout': self.timeout,
'ap': {key: self[key] for key in APKEYS}
'domain': self.domain,
'database_type': self.db_type,
'sqlite_path': self.sq_path,
'postgres': {
'host': self.pg_host,
'port': self.pg_port,
'user': self.pg_user,
'pass': self.pg_pass,
'name': self.pg_name
}
}
with self._path.open('w', encoding = 'utf-8') as fd:
yaml.dump(config, fd, sort_keys=False)
with self.path.open('w', encoding = 'utf-8') as fd:
yaml.dump(config, fd, sort_keys = False)
def set(self, key: str, value: Any) -> None:
if key not in DEFAULTS:
raise KeyError(key)
if key in ('port', 'pg_port', 'workers') and not isinstance(value, int):
value = int(value)
setattr(self, key, value)

View file

@ -0,0 +1,63 @@
from __future__ import annotations
import tinysql
import typing
from importlib.resources import files as pkgfiles
from .config import get_default_value
from .connection import Connection
from .schema import VERSIONS, migrate_0
from .. import logger as logging
if typing.TYPE_CHECKING:
from typing import Optional
from .config import Config
def get_database(config: Config, migrate: Optional[bool] = True) -> tinysql.Database:
if config.db_type == "sqlite":
db = tinysql.Database.sqlite(config.sq_path, connection_class = Connection)
elif config.db_type == "postgres":
db = tinysql.Database.postgres(
config.pg_name,
config.pg_host,
config.pg_port,
config.pg_user,
config.pg_pass,
connection_class = Connection
)
db.load_prepared_statements(pkgfiles("relay").joinpath("database", "statements.sql"))
if not migrate:
return db
with db.connection() as conn:
if 'config' not in conn.get_tables():
logging.info("Creating database tables")
migrate_0(conn)
return db
schema_ver = conn.get_config('schema-version')
if schema_ver < get_default_value('schema-version'):
logging.info("Migrating database from version '%i'", schema_ver)
for ver, func in VERSIONS:
if schema_ver < ver:
conn.begin()
func(conn)
conn.put_config('schema-version', ver)
conn.commit()
if (privkey := conn.get_config('private-key')):
conn.app.signer = privkey
logging.set_level(conn.get_config('log-level'))
return db

44
relay/database/config.py Normal file
View file

@ -0,0 +1,44 @@
from __future__ import annotations
import typing
from .. import logger as logging
from ..misc import boolean
if typing.TYPE_CHECKING:
from typing import Any, Callable
CONFIG_DEFAULTS: dict[str, tuple[str, Any]] = {
'schema-version': ('int', 20240119),
'log-level': ('loglevel', logging.LogLevel.INFO),
'note': ('str', 'Make a note about your instance here.'),
'private-key': ('str', None),
'whitelist-enabled': ('bool', False)
}
# serializer | deserializer
CONFIG_CONVERT: dict[str, tuple[Callable, Callable]] = {
'str': (str, str),
'int': (str, int),
'bool': (str, boolean),
'loglevel': (lambda x: x.name, logging.LogLevel.parse)
}
def get_default_value(key: str) -> Any:
return CONFIG_DEFAULTS[key][1]
def get_default_type(key: str) -> str:
return CONFIG_DEFAULTS[key][0]
def serialize(key: str, value: Any) -> str:
type_name = get_default_type(key)
return CONFIG_CONVERT[type_name][0](value)
def deserialize(key: str, value: str) -> Any:
type_name = get_default_type(key)
return CONFIG_CONVERT[type_name][1](value)

View file

@ -0,0 +1,295 @@
from __future__ import annotations
import tinysql
import typing
from datetime import datetime, timezone
from urllib.parse import urlparse
from .config import CONFIG_DEFAULTS, get_default_type, get_default_value, serialize, deserialize
from .. import logger as logging
from ..misc import get_app
if typing.TYPE_CHECKING:
from tinysql import Cursor, Row
from typing import Any, Iterator, Optional
from .application import Application
from ..misc import Message
RELAY_SOFTWARE = [
'activityrelay', # https://git.pleroma.social/pleroma/relay
'activity-relay', # https://github.com/yukimochi/Activity-Relay
'aoderelay', # https://git.asonix.dog/asonix/relay
'feditools-relay' # https://git.ptzo.gdn/feditools/relay
]
class Connection(tinysql.Connection):
@property
def app(self) -> Application:
return get_app()
def distill_inboxes(self, message: Message) -> Iterator[str]:
src_domains = {
message.domain,
urlparse(message.object_id).netloc
}
for inbox in self.execute('SELECT * FROM inboxes'):
if inbox['domain'] not in src_domains:
yield inbox['inbox']
def exec_statement(self, name: str, params: Optional[dict[str, Any]] = None) -> Cursor:
return self.execute(self.database.prepared_statements[name], params)
def get_config(self, key: str) -> Any:
if key not in CONFIG_DEFAULTS:
raise KeyError(key)
with self.exec_statement('get-config', {'key': key}) as cur:
if not (row := cur.one()):
return get_default_value(key)
if row['value']:
return deserialize(row['key'], row['value'])
return None
def get_config_all(self) -> dict[str, Any]:
with self.exec_statement('get-config-all') as cur:
db_config = {row['key']: row['value'] for row in cur}
config = {}
for key, data in CONFIG_DEFAULTS.items():
try:
config[key] = deserialize(key, db_config[key])
except KeyError:
if key == 'schema-version':
config[key] = 0
else:
config[key] = data[1]
return config
def put_config(self, key: str, value: Any) -> Any:
if key not in CONFIG_DEFAULTS:
raise KeyError(key)
if key == 'private-key':
self.app.signer = value
elif key == 'log-level':
value = logging.LogLevel.parse(value)
logging.set_level(value)
params = {
'key': key,
'value': serialize(key, value) if value is not None else None,
'type': get_default_type(key)
}
with self.exec_statement('put-config', params):
return value
def get_inbox(self, value: str) -> Row:
with self.exec_statement('get-inbox', {'value': value}) as cur:
return cur.one()
def put_inbox(self,
domain: str,
inbox: str,
actor: Optional[str] = None,
followid: Optional[str] = None,
software: Optional[str] = None) -> Row:
params = {
'domain': domain,
'inbox': inbox,
'actor': actor,
'followid': followid,
'software': software,
'created': datetime.now(tz = timezone.utc)
}
with self.exec_statement('put-inbox', params) as cur:
return cur.one()
def update_inbox(self,
inbox: str,
actor: Optional[str] = None,
followid: Optional[str] = None,
software: Optional[str] = None) -> Row:
if not (actor or followid or software):
raise ValueError('Missing "actor", "followid", and/or "software"')
data = {}
if actor:
data['actor'] = actor
if followid:
data['followid'] = followid
if software:
data['software'] = software
statement = tinysql.Update('inboxes', data, inbox = inbox)
with self.query(statement):
return self.get_inbox(inbox)
def del_inbox(self, value: str) -> bool:
with self.exec_statement('del-inbox', {'value': value}) as cur:
if cur.modified_row_count > 1:
raise ValueError('More than one row was modified')
return cur.modified_row_count == 1
def get_domain_ban(self, domain: str) -> Row:
if domain.startswith('http'):
domain = urlparse(domain).netloc
with self.exec_statement('get-domain-ban', {'domain': domain}) as cur:
return cur.one()
def put_domain_ban(self,
domain: str,
reason: Optional[str] = None,
note: Optional[str] = None) -> Row:
params = {
'domain': domain,
'reason': reason,
'note': note,
'created': datetime.now(tz = timezone.utc)
}
with self.exec_statement('put-domain-ban', params) as cur:
return cur.one()
def update_domain_ban(self,
domain: str,
reason: Optional[str] = None,
note: Optional[str] = None) -> tinysql.Row:
if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified')
params = {}
if reason:
params['reason'] = reason
if note:
params['note'] = note
statement = tinysql.Update('domain_bans', params, domain = domain)
with self.query(statement) as cur:
if cur.modified_row_count > 1:
raise ValueError('More than one row was modified')
return self.get_domain_ban(domain)
def del_domain_ban(self, domain: str) -> bool:
with self.exec_statement('del-domain-ban', {'domain': domain}) as cur:
if cur.modified_row_count > 1:
raise ValueError('More than one row was modified')
return cur.modified_row_count == 1
def get_software_ban(self, name: str) -> Row:
with self.exec_statement('get-software-ban', {'name': name}) as cur:
return cur.one()
def put_software_ban(self,
name: str,
reason: Optional[str] = None,
note: Optional[str] = None) -> Row:
params = {
'name': name,
'reason': reason,
'note': note,
'created': datetime.now(tz = timezone.utc)
}
with self.exec_statement('put-software-ban', params) as cur:
return cur.one()
def update_software_ban(self,
name: str,
reason: Optional[str] = None,
note: Optional[str] = None) -> tinysql.Row:
if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified')
params = {}
if reason:
params['reason'] = reason
if note:
params['note'] = note
statement = tinysql.Update('software_bans', params, name = name)
with self.query(statement) as cur:
if cur.modified_row_count > 1:
raise ValueError('More than one row was modified')
return self.get_software_ban(name)
def del_software_ban(self, name: str) -> bool:
with self.exec_statement('del-software-ban', {'name': name}) as cur:
if cur.modified_row_count > 1:
raise ValueError('More than one row was modified')
return cur.modified_row_count == 1
def get_domain_whitelist(self, domain: str) -> Row:
with self.exec_statement('get-domain-whitelist', {'domain': domain}) as cur:
return cur.one()
def put_domain_whitelist(self, domain: str) -> Row:
params = {
'domain': domain,
'created': datetime.now(tz = timezone.utc)
}
with self.exec_statement('put-domain-whitelist', params) as cur:
return cur.one()
def del_domain_whitelist(self, domain: str) -> bool:
with self.exec_statement('del-domain-whitelist', {'domain': domain}) as cur:
if cur.modified_row_count > 1:
raise ValueError('More than one row was modified')
return cur.modified_row_count == 1

60
relay/database/schema.py Normal file
View file

@ -0,0 +1,60 @@
from __future__ import annotations
import typing
from tinysql import Column, Connection, Table
from .config import get_default_value
if typing.TYPE_CHECKING:
from typing import Callable
VERSIONS: list[Callable] = []
TABLES: list[Table] = [
Table(
'config',
Column('key', 'text', primary_key = True, unique = True, nullable = False),
Column('value', 'text'),
Column('type', 'text', default = 'str')
),
Table(
'inboxes',
Column('domain', 'text', primary_key = True, unique = True, nullable = False),
Column('actor', 'text', unique = True),
Column('inbox', 'text', unique = True, nullable = False),
Column('followid', 'text'),
Column('software', 'text'),
Column('created', 'timestamp', nullable = False)
),
Table(
'whitelist',
Column('domain', 'text', primary_key = True, unique = True, nullable = True),
Column('created', 'timestamp')
),
Table(
'instance_bans',
Column('domain', 'text', primary_key = True, unique = True, nullable = True),
Column('reason', 'text'),
Column('note', 'text'),
Column('created', 'timestamp', nullable = False)
),
Table(
'software_bans',
Column('name', 'text', primary_key = True, unique = True, nullable = True),
Column('reason', 'text'),
Column('note', 'text'),
Column('created', 'timestamp', nullable = False)
)
]
def version(func: Callable) -> Callable:
ver = int(func.replace('migrate_', ''))
VERSIONS[ver] = func
return func
def migrate_0(conn: Connection) -> None:
conn.create_tables(TABLES)
conn.put_config('schema-version', get_default_value('schema-version'))

View file

@ -0,0 +1,79 @@
-- name: get-config
SELECT * FROM config WHERE key = :key
-- name: get-config-all
SELECT * FROM config
-- name: put-config
INSERT INTO config (key, value, type)
VALUES (:key, :value, :type)
ON CONFLICT (key) DO UPDATE SET value = :value
RETURNING *
-- name: del-config
DELETE FROM config
WHERE key = :key
-- name: get-inbox
SELECT * FROM inboxes WHERE domain = :value or inbox = :value or actor = :value
-- name: put-inbox
INSERT INTO inboxes (domain, actor, inbox, followid, software, created)
VALUES (:domain, :actor, :inbox, :followid, :software, :created)
ON CONFLICT (domain) DO UPDATE SET followid = :followid
RETURNING *
-- name: del-inbox
DELETE FROM inboxes
WHERE domain = :value or inbox = :value or actor = :value
-- name: get-software-ban
SELECT * FROM software_bans WHERE name = :name
-- name: put-software-ban
INSERT INTO software_bans (name, reason, note, created)
VALUES (:name, :reason, :note, :created)
RETURNING *
-- name: del-software-ban
DELETE FROM software_bans
WHERE name = :name
-- name: get-domain-ban
SELECT * FROM domain_bans WHERE domain = :domain
-- name: put-domain-ban
INSERT INTO domain_bans (domain, reason, note, created)
VALUES (:domain, :reason, :note, :created)
RETURNING *
-- name: del-domain-ban
DELETE FROM domain_bans
WHERE domain = :domain
-- name: get-domain-whitelist
SELECT * FROM whitelist WHERE domain = :domain
-- name: put-domain-whitelist
INSERT INTO whitelist (domain, created)
VALUES (:domain, :created)
RETURNING *
-- name: del-domain-whitelist
DELETE FROM whitelist
WHERE domain = :domain

View file

@ -13,11 +13,10 @@ from urllib.parse import urlparse
from . import __version__
from . import logger as logging
from .misc import MIMETYPES, Message
from .misc import MIMETYPES, Message, get_app
if typing.TYPE_CHECKING:
from typing import Any, Callable, Optional
from .database import RelayDatabase
HEADERS = {
@ -28,12 +27,10 @@ HEADERS = {
class HttpClient:
def __init__(self,
database: RelayDatabase,
limit: Optional[int] = 100,
timeout: Optional[int] = 10,
cache_size: Optional[int] = 1024):
self.database = database
self.cache = LRUCache(cache_size)
self.limit = limit
self.timeout = timeout
@ -98,7 +95,7 @@ class HttpClient:
headers = {}
if sign_headers:
headers.update(self.database.signer.sign_headers('GET', url, algorithm='original'))
get_app().signer.sign_headers('GET', url, algorithm = 'original')
try:
logging.debug('Fetching resource: %s', url)
@ -150,23 +147,24 @@ class HttpClient:
async def post(self, url: str, message: Message) -> None:
await self.open()
instance = self.database.get_inbox(url)
with get_app().database.connection() as conn:
instance = conn.get_inbox(url)
## Using the old algo by default is probably a better idea right now
if instance and instance.get('software') in {'mastodon'}:
if instance and instance['software'] in {'mastodon'}:
algorithm = 'hs2019'
else:
algorithm = 'original'
headers = {'Content-Type': 'application/activity+json'}
headers.update(self.database.signer.sign_headers('POST', url, message, algorithm=algorithm))
headers.update(get_app().signer.sign_headers('POST', url, message, algorithm=algorithm))
try:
logging.verbose('Sending "%s" to %s', message.type, url)
async with self._session.post(url, headers=headers, data=message.to_json()) as resp:
## Not expecting a response, so just return
# Not expecting a response, so just return
if resp.status in {200, 202}:
logging.verbose('Successfully sent "%s" to %s', message.type, url)
return
@ -181,7 +179,7 @@ class HttpClient:
except (AsyncTimeoutError, ClientConnectionError):
logging.warning('Failed to connect to %s for message push', urlparse(url).netloc)
## prevent workers from being brought down
# prevent workers from being brought down
except Exception:
traceback.print_exc()
@ -211,16 +209,16 @@ class HttpClient:
return await self.get(nodeinfo_url, loads = Nodeinfo.parse) or None
async def get(database: RelayDatabase, *args: Any, **kwargs: Any) -> Message | dict | None:
async with HttpClient(database) as client:
async def get(*args: Any, **kwargs: Any) -> Message | dict | None:
async with HttpClient() as client:
return await client.get(*args, **kwargs)
async def post(database: RelayDatabase, *args: Any, **kwargs: Any) -> None:
async with HttpClient(database) as client:
async def post(*args: Any, **kwargs: Any) -> None:
async with HttpClient() as client:
return await client.post(*args, **kwargs)
async def fetch_nodeinfo(database: RelayDatabase, *args: Any, **kwargs: Any) -> Nodeinfo | None:
async with HttpClient(database) as client:
async def fetch_nodeinfo(*args: Any, **kwargs: Any) -> Nodeinfo | None:
async with HttpClient() as client:
return await client.fetch_nodeinfo(*args, **kwargs)

View file

@ -6,22 +6,49 @@ import click
import platform
import typing
from aputils.signer import Signer
from pathlib import Path
from shutil import copyfile
from urllib.parse import urlparse
from . import misc, __version__
from . import __version__
from . import http_client as http
from . import logger as logging
from .application import Application
from .config import RELAY_SOFTWARE
from .compat import RelayConfig, RelayDatabase
from .database import get_database
from .database.connection import RELAY_SOFTWARE
from .misc import IS_DOCKER, Message, check_open_port
if typing.TYPE_CHECKING:
from typing import Any
from tinysql import Row
from typing import Any, Optional
# pylint: disable=unsubscriptable-object,unsupported-assignment-operation
app = None
CONFIG_IGNORE = {'blocked_software', 'blocked_instances', 'whitelist'}
CONFIG_IGNORE = (
'schema-version',
'private-key'
)
ACTOR_FORMATS = {
'mastodon': 'https://{domain}/actor',
'akkoma': 'https://{domain}/relay',
'pleroma': 'https://{domain}/relay'
}
SOFTWARE = (
'mastodon',
'akkoma',
'pleroma',
'misskey',
'friendica',
'hubzilla',
'firefish',
'gotosocial'
)
@click.group('cli', context_settings={'show_default': True}, invoke_without_command=True)
@ -29,11 +56,10 @@ CONFIG_IGNORE = {'blocked_software', 'blocked_instances', 'whitelist'}
@click.version_option(version=__version__, prog_name='ActivityRelay')
@click.pass_context
def cli(ctx: click.Context, config: str) -> None:
global app
app = Application(config)
ctx.obj = Application(config)
if not ctx.invoked_subcommand:
if app.config.host.endswith('example.com'):
if ctx.obj.config.domain.endswith('example.com'):
cli_setup.callback()
else:
@ -41,46 +67,92 @@ def cli(ctx: click.Context, config: str) -> None:
@cli.command('setup')
def cli_setup() -> None:
@click.pass_context
def cli_setup(ctx: click.Context) -> None:
'Generate a new config'
while True:
app.config.host = click.prompt(
ctx.obj.config.domain = click.prompt(
'What domain will the relay be hosted on?',
default = app.config.host
default = ctx.obj.config.domain
)
if not app.config.host.endswith('example.com'):
if not ctx.obj.config.domain.endswith('example.com'):
break
click.echo('The domain must not be example.com')
click.echo('The domain must not end with "example.com"')
if not app.config.is_docker:
app.config.listen = click.prompt(
if not IS_DOCKER:
ctx.obj.config.listen = click.prompt(
'Which address should the relay listen on?',
default = app.config.listen
default = ctx.obj.config.listen
)
while True:
app.config.port = click.prompt(
'What TCP port should the relay listen on?',
default = app.config.port,
type = int
)
ctx.obj.config.port = click.prompt(
'What TCP port should the relay listen on?',
default = ctx.obj.config.port,
type = int
)
break
ctx.obj.config.db_type = click.prompt(
'Which database backend will be used?',
default = ctx.obj.config.db_type,
type = click.Choice(['postgres', 'sqlite'], case_sensitive = False)
)
app.config.save()
if ctx.obj.config.db_type == 'sqlite':
ctx.obj.config.sq_path = click.prompt(
'Where should the database be stored?',
default = ctx.obj.config.sq_path
)
if not app.config.is_docker and click.confirm('Relay all setup! Would you like to run it now?'):
elif ctx.obj.config.db_type == 'postgres':
ctx.obj.config.pg_name = click.prompt(
'What is the name of the database?',
default = ctx.obj.config.pg_name
)
ctx.obj.config.pg_host = click.prompt(
'What IP address or hostname does the server listen on?',
default = ctx.obj.config.pg_host
)
ctx.obj.config.pg_port = click.prompt(
'What port does the server listen on?',
default = ctx.obj.config.pg_port,
type = int
)
ctx.obj.config.pg_user = click.prompt(
'Which user will authenticate with the server?',
default = ctx.obj.config.pg_user
)
ctx.obj.config.pg_pass = click.prompt(
'User password: ',
hide_input = True
) or None
ctx.obj.config.save()
config = {
'private-key': Signer.new('n/a').export()
}
with ctx.obj.database.connection() as conn:
for key, value in config.items():
conn.put_config(key, value)
if not IS_DOCKER and click.confirm('Relay all setup! Would you like to run it now?'):
cli_run.callback()
@cli.command('run')
def cli_run() -> None:
@click.pass_context
def cli_run(ctx: click.Context) -> None:
'Run the relay'
if app.config.host.endswith('example.com'):
if ctx.obj.config.domain.endswith('example.com') or not ctx.obj.signer:
click.echo(
'Relay is not set up. Please edit your relay config or run "activityrelay setup".'
)
@ -104,40 +176,142 @@ def cli_run() -> None:
click.echo(pip_command)
return
if not misc.check_open_port(app.config.listen, app.config.port):
click.echo(f'Error: A server is already running on port {app.config.port}')
if not check_open_port(ctx.obj.config.listen, ctx.obj.config.port):
click.echo(f'Error: A server is already running on port {ctx.obj.config.port}')
return
app.run()
ctx.obj.run()
@cli.command('convert')
@click.option('--old-config', '-o', help = 'Path to the new config file')
@click.pass_context
def cli_convert(ctx: click.Context, old_config: str) -> None:
'Convert an old config and jsonld database to the new format.'
old_config = Path(old_config).expanduser().resolve()
backup = ctx.obj.config.path.parent.joinpath(f'{ctx.obj.config.path.stem}.backup.yaml')
if str(old_config) == str(ctx.obj.config.path) and not backup.exists():
logging.info('Created backup config @ %s', backup)
copyfile(ctx.obj.config.path, backup)
config = RelayConfig(old_config)
config.load()
database = RelayDatabase(config)
database.load()
ctx.obj.config.set('listen', config['listen'])
ctx.obj.config.set('port', config['port'])
ctx.obj.config.set('workers', config['workers'])
ctx.obj.config.set('sq_path', config['db'].replace('jsonld', 'sqlite3'))
with get_database(ctx.obj.config) as db:
with db.connection() as conn:
conn.put_config('private-key', database['private-key'])
conn.put_config('note', config['note'])
conn.put_config('whitelist-enabled', config['whitelist_enabled'])
with click.progressbar(
database['relay-list'].values(),
label = 'Inboxes'.ljust(15),
width = 0
) as inboxes:
for inbox in inboxes:
if inbox['software'] in ('akkoma', 'pleroma'):
actor = f'https://{inbox["domain"]}/relay'
elif inbox['software'] == 'mastodon':
actor = f'https://{inbox["domain"]}/actor'
else:
actor = None
conn.put_inbox(
inbox['domain'],
inbox['inbox'],
actor = actor,
followid = inbox['followid'],
software = inbox['software']
)
with click.progressbar(
config['blocked_software'],
label = 'Banned software'.ljust(15),
width = 0
) as banned_software:
for software in banned_software:
conn.put_software_ban(
software,
reason = 'relay' if software in RELAY_SOFTWARE else None
)
with click.progressbar(
config['blocked_instances'],
label = 'Banned domains'.ljust(15),
width = 0
) as banned_software:
for domain in banned_software:
conn.put_domain_ban(domain)
with click.progressbar(
config['whitelist'],
label = 'Whitelist'.ljust(15),
width = 0
) as whitelist:
for instance in whitelist:
conn.put_domain_whitelist(instance)
click.echo('Finished converting old config and database :3')
@cli.command('edit-config')
@click.option('--editor', '-e', help = 'Text editor to use')
@click.pass_context
def cli_editconfig(ctx: click.Context, editor: str) -> None:
'Edit the config file'
click.edit(
editor = editor,
filename = str(ctx.obj.config.path)
)
@cli.group('config')
def cli_config() -> None:
'Manage the relay config'
'Manage the relay settings stored in the database'
@cli_config.command('list')
def cli_config_list() -> None:
@click.pass_context
def cli_config_list(ctx: click.Context) -> None:
'List the current relay config'
click.echo('Relay Config:')
for key, value in app.config.items():
if key not in CONFIG_IGNORE:
key = f'{key}:'.ljust(20)
click.echo(f'- {key} {value}')
with ctx.obj.database.connection() as conn:
for key, value in conn.get_config_all().items():
if key not in CONFIG_IGNORE:
key = f'{key}:'.ljust(20)
click.echo(f'- {key} {value}')
@cli_config.command('set')
@click.argument('key')
@click.argument('value')
def cli_config_set(key: str, value: Any) -> None:
@click.pass_context
def cli_config_set(ctx: click.Context, key: str, value: Any) -> None:
'Set a config value'
app.config[key] = value
app.config.save()
with ctx.obj.database.connection() as conn:
new_value = conn.put_config(key, value)
print(f'{key}: {app.config[key]}')
print(f'{key}: {repr(new_value)}')
@cli.group('inbox')
@ -146,127 +320,145 @@ def cli_inbox() -> None:
@cli_inbox.command('list')
def cli_inbox_list() -> None:
@click.pass_context
def cli_inbox_list(ctx: click.Context) -> None:
'List the connected instances or relays'
click.echo('Connected to the following instances or relays:')
for inbox in app.database.inboxes:
click.echo(f'- {inbox}')
with ctx.obj.database.connection() as conn:
for inbox in conn.execute('SELECT * FROM inboxes'):
click.echo(f'- {inbox["inbox"]}')
@cli_inbox.command('follow')
@click.argument('actor')
def cli_inbox_follow(actor: str) -> None:
@click.pass_context
def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
'Follow an actor (Relay must be running)'
if app.config.is_banned(actor):
click.echo(f'Error: Refusing to follow banned actor: {actor}')
return
if not actor.startswith('http'):
domain = actor
actor = f'https://{actor}/actor'
else:
domain = urlparse(actor).hostname
try:
inbox_data = app.database['relay-list'][domain]
inbox = inbox_data['inbox']
except KeyError:
actor_data = asyncio.run(http.get(app.database, actor, sign_headers=True))
if not actor_data:
click.echo(f'Failed to fetch actor: {actor}')
with ctx.obj.database.connection() as conn:
if conn.get_domain_ban(actor):
click.echo(f'Error: Refusing to follow banned actor: {actor}')
return
inbox = actor_data.shared_inbox
if (inbox_data := conn.get_inbox(actor)):
inbox = inbox_data['inbox']
message = misc.Message.new_follow(
host = app.config.host,
else:
if not actor.startswith('http'):
actor = f'https://{actor}/actor'
actor_data = asyncio.run(http.get(actor, sign_headers = True))
if not actor_data:
click.echo(f'Failed to fetch actor: {actor}')
return
inbox = actor_data.shared_inbox
message = Message.new_follow(
host = ctx.obj.config.domain,
actor = actor
)
asyncio.run(http.post(app.database, inbox, message))
asyncio.run(http.post(inbox, message))
click.echo(f'Sent follow message to actor: {actor}')
@cli_inbox.command('unfollow')
@click.argument('actor')
def cli_inbox_unfollow(actor: str) -> None:
@click.pass_context
def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
'Unfollow an actor (Relay must be running)'
if not actor.startswith('http'):
domain = actor
actor = f'https://{actor}/actor'
inbox_data: Row = None
else:
domain = urlparse(actor).hostname
with ctx.obj.database.connection() as conn:
if conn.get_domain_ban(actor):
click.echo(f'Error: Refusing to follow banned actor: {actor}')
return
try:
inbox_data = app.database['relay-list'][domain]
inbox = inbox_data['inbox']
message = misc.Message.new_unfollow(
host = app.config.host,
actor = actor,
follow = inbox_data['followid']
)
if (inbox_data := conn.get_inbox(actor)):
inbox = inbox_data['inbox']
message = Message.new_unfollow(
host = ctx.obj.config.domain,
actor = actor,
follow = inbox_data['followid']
)
except KeyError:
actor_data = asyncio.run(http.get(app.database, actor, sign_headers=True))
inbox = actor_data.shared_inbox
message = misc.Message.new_unfollow(
host = app.config.host,
actor = actor,
follow = {
'type': 'Follow',
'object': actor,
'actor': f'https://{app.config.host}/actor'
}
)
else:
if not actor.startswith('http'):
actor = f'https://{actor}/actor'
asyncio.run(http.post(app.database, inbox, message))
actor_data = asyncio.run(http.get(actor, sign_headers = True))
inbox = actor_data.shared_inbox
message = Message.new_unfollow(
host = ctx.obj.config.domain,
actor = actor,
follow = {
'type': 'Follow',
'object': actor,
'actor': f'https://{ctx.obj.config.domain}/actor'
}
)
asyncio.run(http.post(inbox, message))
click.echo(f'Sent unfollow message to: {actor}')
@cli_inbox.command('add')
@click.argument('inbox')
def cli_inbox_add(inbox: str) -> None:
@click.option('--actor', '-a', help = 'Actor url for the inbox')
@click.option('--followid', '-f', help = 'Url for the follow activity')
@click.option('--software', '-s', type = click.Choice(SOFTWARE))
@click.pass_context
def cli_inbox_add(
ctx: click.Context,
inbox: str,
actor: Optional[str] = None,
followid: Optional[str] = None,
software: Optional[str] = None) -> None:
'Add an inbox to the database'
if not inbox.startswith('http'):
domain = inbox
inbox = f'https://{inbox}/inbox'
if app.config.is_banned(inbox):
click.echo(f'Error: Refusing to add banned inbox: {inbox}')
return
else:
domain = urlparse(inbox).netloc
if app.database.get_inbox(inbox):
click.echo(f'Error: Inbox already in database: {inbox}')
return
if not actor and software:
try:
actor = ACTOR_FORMATS[software].format(domain = domain)
app.database.add_inbox(inbox)
app.database.save()
except KeyError:
pass
with ctx.obj.database.connection() as conn:
if conn.get_domain_ban(domain):
click.echo(f'Refusing to add banned inbox: {inbox}')
return
if conn.get_inbox(inbox):
click.echo(f'Error: Inbox already in database: {inbox}')
return
conn.put_inbox(domain, inbox, actor, followid, software)
click.echo(f'Added inbox to the database: {inbox}')
@cli_inbox.command('remove')
@click.argument('inbox')
def cli_inbox_remove(inbox: str) -> None:
@click.pass_context
def cli_inbox_remove(ctx: click.Context, inbox: str) -> None:
'Remove an inbox from the database'
try:
dbinbox = app.database.get_inbox(inbox, fail=True)
except KeyError:
click.echo(f'Error: Inbox does not exist: {inbox}')
return
app.database.del_inbox(dbinbox['domain'])
app.database.save()
with ctx.obj.database.connection() as conn:
if not conn.del_inbox(inbox):
click.echo(f'Inbox not in database: {inbox}')
return
click.echo(f'Removed inbox from the database: {inbox}')
@ -277,47 +469,76 @@ def cli_instance() -> None:
@cli_instance.command('list')
def cli_instance_list() -> None:
@click.pass_context
def cli_instance_list(ctx: click.Context) -> None:
'List all banned instances'
click.echo('Banned instances or relays:')
click.echo('Banned domains:')
for domain in app.config.blocked_instances:
click.echo(f'- {domain}')
with ctx.obj.database.connection() as conn:
for instance in conn.execute('SELECT * FROM domain_bans'):
if instance['reason']:
click.echo(f'- {instance["domain"]} ({instance["reason"]})')
else:
click.echo(f'- {instance["domain"]}')
@cli_instance.command('ban')
@click.argument('target')
def cli_instance_ban(target: str) -> None:
@click.argument('domain')
@click.option('--reason', '-r', help = 'Public note about why the domain is banned')
@click.option('--note', '-n', help = 'Internal note that will only be seen by admins and mods')
@click.pass_context
def cli_instance_ban(ctx: click.Context, domain: str, reason: str, note: str) -> None:
'Ban an instance and remove the associated inbox if it exists'
if target.startswith('http'):
target = urlparse(target).hostname
with ctx.obj.database.connection() as conn:
if conn.get_domain_ban(domain):
click.echo(f'Domain already banned: {domain}')
return
if app.config.ban_instance(target):
app.config.save()
if app.database.del_inbox(target):
app.database.save()
click.echo(f'Banned instance: {target}')
return
click.echo(f'Instance already banned: {target}')
conn.put_domain_ban(domain, reason, note)
conn.del_inbox(domain)
click.echo(f'Banned instance: {domain}')
@cli_instance.command('unban')
@click.argument('target')
def cli_instance_unban(target: str) -> None:
@click.argument('domain')
@click.pass_context
def cli_instance_unban(ctx: click.Context, domain: str) -> None:
'Unban an instance'
if app.config.unban_instance(target):
app.config.save()
with ctx.obj.database.connection() as conn:
if not conn.del_domain_ban(domain):
click.echo(f'Instance wasn\'t banned: {domain}')
return
click.echo(f'Unbanned instance: {target}')
return
click.echo(f'Unbanned instance: {domain}')
click.echo(f'Instance wasn\'t banned: {target}')
@cli_instance.command('update')
@click.argument('domain')
@click.option('--reason', '-r')
@click.option('--note', '-n')
@click.pass_context
def cli_instance_update(ctx: click.Context, domain: str, reason: str, note: str) -> None:
'Update the public reason or internal note for a domain ban'
if not (reason or note):
ctx.fail('Must pass --reason or --note')
with ctx.obj.database.connection() as conn:
if not (row := conn.update_domain_ban(domain, reason, note)):
click.echo(f'Failed to update domain ban: {domain}')
return
click.echo(f'Updated domain ban: {domain}')
if row['reason']:
click.echo(f'- {row["domain"]} ({row["reason"]})')
else:
click.echo(f'- {row["domain"]}')
@cli.group('software')
@ -326,79 +547,131 @@ def cli_software() -> None:
@cli_software.command('list')
def cli_software_list() -> None:
@click.pass_context
def cli_software_list(ctx: click.Context) -> None:
'List all banned software'
click.echo('Banned software:')
for software in app.config.blocked_software:
click.echo(f'- {software}')
with ctx.obj.database.connection() as conn:
for software in conn.execute('SELECT * FROM software_bans'):
if software['reason']:
click.echo(f'- {software["name"]} ({software["reason"]})')
else:
click.echo(f'- {software["name"]}')
@cli_software.command('ban')
@click.option(
'--fetch-nodeinfo/--ignore-nodeinfo', '-f', 'fetch_nodeinfo', default = False,
help = 'Treat NAME like a domain and try to fet the software name from nodeinfo'
)
@click.argument('name')
def cli_software_ban(name: str, fetch_nodeinfo: bool) -> None:
@click.option('--reason', '-r')
@click.option('--note', '-n')
@click.option(
'--fetch-nodeinfo', '-f',
is_flag = True,
help = 'Treat NAME like a domain and try to fetch the software name from nodeinfo'
)
@click.pass_context
def cli_software_ban(ctx: click.Context,
name: str,
reason: str,
note: str,
fetch_nodeinfo: bool) -> None:
'Ban software. Use RELAYS for NAME to ban relays'
if name == 'RELAYS':
for software in RELAY_SOFTWARE:
app.config.ban_software(software)
with ctx.obj.database.connection() as conn:
if name == 'RELAYS':
for software in RELAY_SOFTWARE:
if conn.get_software_ban(software):
click.echo(f'Relay already banned: {software}')
continue
app.config.save()
click.echo('Banned all relay software')
return
conn.put_software_ban(software, reason or 'relay', note)
if fetch_nodeinfo:
nodeinfo = asyncio.run(http.fetch_nodeinfo(app.database, name))
click.echo('Banned all relay software')
return
if not nodeinfo:
click.echo(f'Failed to fetch software name from domain: {name}')
if fetch_nodeinfo:
nodeinfo = asyncio.run(http.fetch_nodeinfo(name))
name = nodeinfo.sw_name
if not nodeinfo:
click.echo(f'Failed to fetch software name from domain: {name}')
return
name = nodeinfo.sw_name
if conn.get_software_ban(name):
click.echo(f'Software already banned: {name}')
return
if not conn.put_software_ban(name, reason, note):
click.echo(f'Failed to ban software: {name}')
return
if app.config.ban_software(name):
app.config.save()
click.echo(f'Banned software: {name}')
return
click.echo(f'Software already banned: {name}')
@cli_software.command('unban')
@click.option(
'--fetch-nodeinfo/--ignore-nodeinfo', '-f', 'fetch_nodeinfo', default = False,
help = 'Treat NAME like a domain and try to fet the software name from nodeinfo'
)
@click.argument('name')
def cli_software_unban(name: str, fetch_nodeinfo: bool) -> None:
@click.option('--reason', '-r')
@click.option('--note', '-n')
@click.option(
'--fetch-nodeinfo', '-f',
is_flag = True,
help = 'Treat NAME like a domain and try to fetch the software name from nodeinfo'
)
@click.pass_context
def cli_software_unban(ctx: click.Context, name: str, fetch_nodeinfo: bool) -> None:
'Ban software. Use RELAYS for NAME to unban relays'
if name == 'RELAYS':
for software in RELAY_SOFTWARE:
app.config.unban_software(software)
with ctx.obj.database.connection() as conn:
if name == 'RELAYS':
for software in RELAY_SOFTWARE:
if not conn.del_software_ban(software):
click.echo(f'Relay was not banned: {software}')
app.config.save()
click.echo('Unbanned all relay software')
return
click.echo('Unbanned all relay software')
return
if fetch_nodeinfo:
nodeinfo = asyncio.run(http.fetch_nodeinfo(app.database, name))
if fetch_nodeinfo:
nodeinfo = asyncio.run(http.fetch_nodeinfo(name))
if not nodeinfo:
click.echo(f'Failed to fetch software name from domain: {name}')
if not nodeinfo:
click.echo(f'Failed to fetch software name from domain: {name}')
return
name = nodeinfo.sw_name
name = nodeinfo.sw_name
if not conn.del_software_ban(name):
click.echo(f'Software was not banned: {name}')
return
if app.config.unban_software(name):
app.config.save()
click.echo(f'Unbanned software: {name}')
return
click.echo(f'Software wasn\'t banned: {name}')
@cli_software.command('update')
@click.argument('name')
@click.option('--reason', '-r')
@click.option('--note', '-n')
@click.pass_context
def cli_software_update(ctx: click.Context, name: str, reason: str, note: str) -> None:
'Update the public reason or internal note for a software ban'
if not (reason or note):
ctx.fail('Must pass --reason or --note')
with ctx.obj.database.connection() as conn:
if not (row := conn.update_software_ban(name, reason, note)):
click.echo(f'Failed to update software ban: {name}')
return
click.echo(f'Updated software ban: {name}')
if row['reason']:
click.echo(f'- {row["name"]} ({row["reason"]})')
else:
click.echo(f'- {row["name"]}')
@cli.group('whitelist')
@ -407,52 +680,64 @@ def cli_whitelist() -> None:
@cli_whitelist.command('list')
def cli_whitelist_list() -> None:
@click.pass_context
def cli_whitelist_list(ctx: click.Context) -> None:
'List all the instances in the whitelist'
click.echo('Current whitelisted domains')
click.echo('Current whitelisted domains:')
for domain in app.config.whitelist:
click.echo(f'- {domain}')
with ctx.obj.database.connection() as conn:
for domain in conn.execute('SELECT * FROM whitelist'):
click.echo(f'- {domain["domain"]}')
@cli_whitelist.command('add')
@click.argument('instance')
def cli_whitelist_add(instance: str) -> None:
'Add an instance to the whitelist'
@click.argument('domain')
@click.pass_context
def cli_whitelist_add(ctx: click.Context, domain: str) -> None:
'Add a domain to the whitelist'
if not app.config.add_whitelist(instance):
click.echo(f'Instance already in the whitelist: {instance}')
return
with ctx.obj.database.connection() as conn:
if conn.get_domain_whitelist(domain):
click.echo(f'Instance already in the whitelist: {domain}')
return
app.config.save()
click.echo(f'Instance added to the whitelist: {instance}')
conn.put_domain_whitelist(domain)
click.echo(f'Instance added to the whitelist: {domain}')
@cli_whitelist.command('remove')
@click.argument('instance')
def cli_whitelist_remove(instance: str) -> None:
@click.argument('domain')
@click.pass_context
def cli_whitelist_remove(ctx: click.Context, domain: str) -> None:
'Remove an instance from the whitelist'
if not app.config.del_whitelist(instance):
click.echo(f'Instance not in the whitelist: {instance}')
return
with ctx.obj.database.connection() as conn:
if not conn.del_domain_whitelist(domain):
click.echo(f'Domain not in the whitelist: {domain}')
return
app.config.save()
if conn.get_config('whitelist-enabled'):
if conn.del_inbox(domain):
click.echo(f'Removed inbox for domain: {domain}')
if app.config.whitelist_enabled:
if app.database.del_inbox(instance):
app.database.save()
click.echo(f'Removed instance from the whitelist: {instance}')
click.echo(f'Removed domain from the whitelist: {domain}')
@cli_whitelist.command('import')
def cli_whitelist_import() -> None:
@click.pass_context
def cli_whitelist_import(ctx: click.Context) -> None:
'Add all current inboxes to the whitelist'
for domain in app.database.hostnames:
cli_whitelist_add.callback(domain)
with ctx.obj.database.connection() as conn:
for inbox in conn.execute('SELECT * FROM inboxes').all():
if conn.get_domain_whitelist(inbox['domain']):
click.echo(f'Domain already in whitelist: {inbox["domain"]}')
continue
conn.put_domain_whitelist(inbox['domain'])
click.echo('Imported whitelist from inboxes')
def main() -> None:

View file

@ -1,32 +1,27 @@
from __future__ import annotations
import json
import os
import socket
import traceback
import typing
from aiohttp.abc import AbstractView
from aiohttp.hdrs import METH_ALL as METHODS
from aiohttp.web import Request as AiohttpRequest, Response as AiohttpResponse
from aiohttp.web import Response as AiohttpResponse
from aiohttp.web_exceptions import HTTPMethodNotAllowed
from aputils.errors import SignatureFailureError
from aputils.misc import Digest, HttpDate, Signature
from aputils.message import Message as ApMessage
from functools import cached_property
from json.decoder import JSONDecodeError
from uuid import uuid4
from . import logger as logging
if typing.TYPE_CHECKING:
from typing import Any, Coroutine, Generator, Optional, Type
from aputils.signer import Signer
from .application import Application
from .config import RelayConfig
from .database import RelayDatabase
from .config import Config
from .database import Database
from .http_client import HttpClient
IS_DOCKER = bool(os.environ.get('DOCKER_RUNNING'))
MIMETYPES = {
'activity': 'application/activity+json',
'html': 'text/html',
@ -77,91 +72,13 @@ def check_open_port(host: str, port: int) -> bool:
return False
class DotDict(dict):
def __init__(self, _data: dict[str, Any], **kwargs: Any):
dict.__init__(self)
def get_app() -> Application:
from .application import Application # pylint: disable=import-outside-toplevel
self.update(_data, **kwargs)
if not Application.DEFAULT:
raise ValueError('No default application set')
def __getattr__(self, key: str) -> str:
try:
return self[key]
except KeyError:
raise AttributeError(
f'{self.__class__.__name__} object has no attribute {key}'
) from None
def __setattr__(self, key: str, value: Any) -> None:
if key.startswith('_'):
super().__setattr__(key, value)
else:
self[key] = value
def __setitem__(self, key: str, value: Any) -> None:
if type(value) is dict: # pylint: disable=unidiomatic-typecheck
value = DotDict(value)
super().__setitem__(key, value)
def __delattr__(self, key: str) -> None:
try:
dict.__delitem__(self, key)
except KeyError:
raise AttributeError(
f'{self.__class__.__name__} object has no attribute {key}'
) from None
@classmethod
def new_from_json(cls: Type[DotDict], data: dict[str, Any]) -> DotDict[str, Any]:
if not data:
raise JSONDecodeError('Empty body', data, 1)
try:
return cls(json.loads(data))
except ValueError:
raise JSONDecodeError('Invalid body', data, 1) from None
@classmethod
def new_from_signature(cls: Type[DotDict], sig: str) -> DotDict[str, Any]:
data = cls({})
for chunk in sig.strip().split(','):
key, value = chunk.split('=', 1)
value = value.strip('\"')
if key == 'headers':
value = value.split()
data[key.lower()] = value
return data
def to_json(self, indent: Optional[int | str] = None) -> str:
return json.dumps(self, indent=indent)
def update(self, _data: dict[str, Any], **kwargs: Any) -> None:
if isinstance(_data, dict):
for key, value in _data.items():
self[key] = value
elif isinstance(_data, (list, tuple, set)):
for key, value in _data:
self[key] = value
for key, value in kwargs.items():
self[key] = value
return Application.DEFAULT
class Message(ApMessage):
@ -181,7 +98,7 @@ class Message(ApMessage):
'followers': f'https://{host}/followers',
'following': f'https://{host}/following',
'inbox': f'https://{host}/inbox',
'url': f'https://{host}/inbox',
'url': f'https://{host}/',
'endpoints': {
'sharedInbox': f'https://{host}/inbox'
},
@ -310,16 +227,6 @@ class Response(AiohttpResponse):
class View(AbstractView):
def __init__(self, request: AiohttpRequest):
AbstractView.__init__(self, request)
self.signature: Signature = None
self.message: Message = None
self.actor: Message = None
self.instance: dict[str, str] = None
self.signer: Signer = None
def __await__(self) -> Generator[Response]:
method = self.request.method.upper()
@ -363,94 +270,10 @@ class View(AbstractView):
@property
def config(self) -> RelayConfig:
def config(self) -> Config:
return self.app.config
@property
def database(self) -> RelayDatabase:
def database(self) -> Database:
return self.app.database
# todo: move to views.ActorView
async def get_post_data(self) -> Response | None:
try:
self.signature = Signature.new_from_signature(self.request.headers['signature'])
except KeyError:
logging.verbose('Missing signature header')
return Response.new_error(400, 'missing signature header', 'json')
try:
self.message = await self.request.json(loads = Message.parse)
except Exception:
traceback.print_exc()
logging.verbose('Failed to parse inbox message')
return Response.new_error(400, 'failed to parse message', 'json')
if self.message is None:
logging.verbose('empty message')
return Response.new_error(400, 'missing message', 'json')
if 'actor' not in self.message:
logging.verbose('actor not in message')
return Response.new_error(400, 'no actor in message', 'json')
self.actor = await self.client.get(self.signature.keyid, sign_headers = True)
if self.actor is None:
# ld signatures aren't handled atm, so just ignore it
if self.message.type == 'Delete':
logging.verbose('Instance sent a delete which cannot be handled')
return Response.new(status=202)
logging.verbose(f'Failed to fetch actor: {self.signature.keyid}')
return Response.new_error(400, 'failed to fetch actor', 'json')
try:
self.signer = self.actor.signer
except KeyError:
logging.verbose('Actor missing public key: %s', self.signature.keyid)
return Response.new_error(400, 'actor missing public key', 'json')
try:
self.validate_signature(await self.request.read())
except SignatureFailureError as e:
logging.verbose('signature validation failed for "%s": %s', self.actor.id, e)
return Response.new_error(401, str(e), 'json')
self.instance = self.database.get_inbox(self.actor.inbox)
def validate_signature(self, body: bytes) -> None:
headers = {key.lower(): value for key, value in self.request.headers.items()}
headers["(request-target)"] = " ".join([self.request.method.lower(), self.request.path])
if (digest := Digest.new_from_digest(headers.get("digest"))):
if not body:
raise SignatureFailureError("Missing body for digest verification")
if not digest.validate(body):
raise SignatureFailureError("Body digest does not match")
if self.signature.algorithm_type == "hs2019":
if "(created)" not in self.signature.headers:
raise SignatureFailureError("'(created)' header not used")
current_timestamp = HttpDate.new_utc().timestamp()
if self.signature.created > current_timestamp:
raise SignatureFailureError("Creation date after current date")
if current_timestamp > self.signature.expires:
raise SignatureFailureError("Expiration date before current date")
headers["(created)"] = self.signature.created
headers["(expires)"] = self.signature.expires
# pylint: disable=protected-access
if not self.signer._validate_signature(headers, self.signature):
raise SignatureFailureError("Signature does not match")

View file

@ -1,5 +1,6 @@
from __future__ import annotations
import tinysql
import typing
from cachetools import LRUCache
@ -8,7 +9,7 @@ from . import logger as logging
from .misc import Message
if typing.TYPE_CHECKING:
from .misc import View
from .views import ActorView
cache = LRUCache(1024)
@ -16,128 +17,141 @@ cache = LRUCache(1024)
def person_check(actor: str, software: str) -> bool:
# pleroma and akkoma may use Person for the actor type for some reason
# akkoma changed this in a 3.6.0
if software in {'akkoma', 'pleroma'} and actor.id == f'https://{actor.domain}/relay':
return False
## make sure the actor is an application
# make sure the actor is an application
if actor.type != 'Application':
return True
return False
async def handle_relay(view: View) -> None:
async def handle_relay(view: ActorView) -> None:
if view.message.object_id in cache:
logging.verbose('already relayed %s', view.message.object_id)
return
message = Message.new_announce(view.config.host, view.message.object_id)
message = Message.new_announce(view.config.domain, view.message.object_id)
cache[view.message.object_id] = message.id
logging.debug('>> relay: %s', message)
inboxes = view.database.distill_inboxes(view.message)
for inbox in inboxes:
view.app.push_message(inbox, message)
with view.database.connection() as conn:
for inbox in conn.distill_inboxes(view.message):
view.app.push_message(inbox, message)
async def handle_forward(view: View) -> None:
async def handle_forward(view: ActorView) -> None:
if view.message.id in cache:
logging.verbose('already forwarded %s', view.message.id)
return
message = Message.new_announce(view.config.host, view.message)
message = Message.new_announce(view.config.domain, view.message)
cache[view.message.id] = message.id
logging.debug('>> forward: %s', message)
inboxes = view.database.distill_inboxes(view.message)
for inbox in inboxes:
view.app.push_message(inbox, message)
with view.database.connection() as conn:
for inbox in conn.distill_inboxes(view.message):
view.app.push_message(inbox, message)
async def handle_follow(view: View) -> None:
async def handle_follow(view: ActorView) -> None:
nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain)
software = nodeinfo.sw_name if nodeinfo else None
## reject if software used by actor is banned
if view.config.is_banned_software(software):
with view.database.connection() as conn:
# reject if software used by actor is banned
if view.config.is_banned_software(software):
view.app.push_message(
view.actor.shared_inbox,
Message.new_response(
host = view.config.domain,
actor = view.actor.id,
followid = view.message.id,
accept = False
)
)
logging.verbose(
'Rejected follow from actor for using specific software: actor=%s, software=%s',
view.actor.id,
software
)
return
## reject if the actor is not an instance actor
if person_check(view.actor, software):
view.app.push_message(
view.actor.shared_inbox,
Message.new_response(
host = view.config.domain,
actor = view.actor.id,
followid = view.message.id,
accept = False
)
)
logging.verbose('Non-application actor tried to follow: %s', view.actor.id)
return
if conn.get_inbox(view.actor.shared_inbox):
data = {'followid': view.message.id}
statement = tinysql.Update('inboxes', data, inbox = view.actor.shared_inbox)
with conn.query(statement):
pass
else:
conn.put_inbox(
view.actor.domain,
view.actor.shared_inbox,
view.actor.id,
view.message.id,
software
)
view.app.push_message(
view.actor.shared_inbox,
Message.new_response(
host = view.config.host,
host = view.config.domain,
actor = view.actor.id,
followid = view.message.id,
accept = False
accept = True
)
)
return logging.verbose(
'Rejected follow from actor for using specific software: actor=%s, software=%s',
view.actor.id,
software
)
## reject if the actor is not an instance actor
if person_check(view.actor, software):
view.app.push_message(
view.actor.shared_inbox,
Message.new_response(
host = view.config.host,
actor = view.actor.id,
followid = view.message.id,
accept = False
# Are Akkoma and Pleroma the only two that expect a follow back?
# Ignoring only Mastodon for now
if software != 'mastodon':
view.app.push_message(
view.actor.shared_inbox,
Message.new_follow(
host = view.config.domain,
actor = view.actor.id
)
)
)
logging.verbose('Non-application actor tried to follow: %s', view.actor.id)
return
view.database.add_inbox(view.actor.shared_inbox, view.message.id, software)
view.database.save()
view.app.push_message(
view.actor.shared_inbox,
Message.new_response(
host = view.config.host,
actor = view.actor.id,
followid = view.message.id,
accept = True
)
)
# Are Akkoma and Pleroma the only two that expect a follow back?
# Ignoring only Mastodon for now
if software != 'mastodon':
view.app.push_message(
view.actor.shared_inbox,
Message.new_follow(
host = view.config.host,
actor = view.actor.id
)
)
async def handle_undo(view: View) -> None:
async def handle_undo(view: ActorView) -> None:
## If the object is not a Follow, forward it
if view.message.object['type'] != 'Follow':
return await handle_forward(view)
if not view.database.del_inbox(view.actor.domain, view.message.object['id']):
logging.verbose(
'Failed to delete "%s" with follow ID "%s"',
view.actor.id,
view.message.object['id']
)
await handle_forward(view)
return
view.database.save()
with view.database.connection() as conn:
if not conn.del_inbox(view.actor.inbox):
logging.verbose(
'Failed to delete "%s" with follow ID "%s"',
view.actor.id,
view.message.object['id']
)
view.app.push_message(
view.actor.shared_inbox,
Message.new_unfollow(
host = view.config.host,
host = view.config.domain,
actor = view.actor.id,
follow = view.message
)
@ -154,7 +168,7 @@ processors = {
}
async def run_processor(view: View) -> None:
async def run_processor(view: ActorView) -> None:
if view.message.type not in processors:
logging.verbose(
'Message type "%s" from actor cannot be handled: %s',
@ -164,12 +178,21 @@ async def run_processor(view: View) -> None:
return
if view.instance and not view.instance.get('software'):
nodeinfo = await view.client.fetch_nodeinfo(view.instance['domain'])
if view.instance:
if not view.instance['software']:
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])):
with view.database.connection() as conn:
view.instance = conn.update_inbox(
view.instance['inbox'],
software = nodeinfo.sw_name
)
if nodeinfo:
view.instance['software'] = nodeinfo.sw_name
view.database.save()
if not view.instance['actor']:
with view.database.connection() as conn:
view.instance = conn.update_inbox(
view.instance['inbox'],
actor = view.actor.id
)
logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id)
await processors[view.message.type](view)

View file

@ -2,8 +2,11 @@ from __future__ import annotations
import asyncio
import subprocess
import traceback
import typing
from aputils.errors import SignatureFailureError
from aputils.misc import Digest, HttpDate, Signature
from aputils.objects import Nodeinfo, Webfinger, WellKnownNodeinfo
from pathlib import Path
@ -14,6 +17,7 @@ from .processors import run_processor
if typing.TYPE_CHECKING:
from aiohttp.web import Request
from aputils.signer import Signer
from typing import Callable
@ -71,12 +75,16 @@ def register_route(*paths: str) -> Callable:
@register_route('/')
class HomeView(View):
async def get(self, request: Request) -> Response:
text = HOME_TEMPLATE.format(
host = self.config.host,
note = self.config.note,
count = len(self.database.hostnames),
targets = '<br>'.join(self.database.hostnames)
)
with self.database.connection() as conn:
config = conn.get_config_all()
inboxes = conn.execute('SELECT * FROM inboxes').all()
text = HOME_TEMPLATE.format(
host = self.config.domain,
note = config['note'],
count = len(inboxes),
targets = '<br>'.join(inbox['domain'] for inbox in inboxes)
)
return Response.new(text, ctype='html')
@ -84,44 +92,137 @@ class HomeView(View):
@register_route('/actor', '/inbox')
class ActorView(View):
def __init__(self, request: Request):
View.__init__(self, request)
self.signature: Signature = None
self.message: Message = None
self.actor: Message = None
self.instance: dict[str, str] = None
self.signer: Signer = None
async def get(self, request: Request) -> Response:
data = Message.new_actor(
host = self.config.host,
pubkey = self.database.signer.pubkey
host = self.config.domain,
pubkey = self.app.signer.pubkey
)
return Response.new(data, ctype='activity')
async def post(self, request: Request) -> Response:
response = await self.get_post_data()
if response is not None:
if (response := await self.get_post_data()):
return response
## reject if the actor isn't whitelisted while the whiltelist is enabled
if self.config.whitelist_enabled and not self.config.is_whitelisted(self.actor.domain):
logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id)
return Response.new_error(403, 'access denied', 'json')
with self.database.connection() as conn:
self.instance = conn.get_inbox(self.actor.inbox)
config = conn.get_config_all()
## reject if actor is banned
if self.config.is_banned(self.actor.domain):
logging.verbose('Ignored request from banned actor: %s', self.actor.id)
return Response.new_error(403, 'access denied', 'json')
## reject if the actor isn't whitelisted while the whiltelist is enabled
if config['whitelist-enabled'] and not conn.get_domain_whitelist(self.actor.domain):
logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id)
return Response.new_error(403, 'access denied', 'json')
## reject if activity type isn't 'Follow' and the actor isn't following
if self.message.type != 'Follow' and not self.database.get_inbox(self.actor.domain):
logging.verbose(
'Rejected actor for trying to post while not following: %s',
self.actor.id
)
## reject if actor is banned
if conn.get_domain_ban(self.actor.domain):
logging.verbose('Ignored request from banned actor: %s', self.actor.id)
return Response.new_error(403, 'access denied', 'json')
return Response.new_error(401, 'access denied', 'json')
## reject if activity type isn't 'Follow' and the actor isn't following
if self.message.type != 'Follow' and not self.instance:
logging.verbose(
'Rejected actor for trying to post while not following: %s',
self.actor.id
)
logging.debug('>> payload %s', self.message.to_json(4))
return Response.new_error(401, 'access denied', 'json')
asyncio.ensure_future(run_processor(self))
return Response.new(status = 202)
logging.debug('>> payload %s', self.message.to_json(4))
asyncio.ensure_future(run_processor(self))
return Response.new(status = 202)
async def get_post_data(self) -> Response | None:
try:
self.signature = Signature.new_from_signature(self.request.headers['signature'])
except KeyError:
logging.verbose('Missing signature header')
return Response.new_error(400, 'missing signature header', 'json')
try:
self.message = await self.request.json(loads = Message.parse)
except Exception:
traceback.print_exc()
logging.verbose('Failed to parse inbox message')
return Response.new_error(400, 'failed to parse message', 'json')
if self.message is None:
logging.verbose('empty message')
return Response.new_error(400, 'missing message', 'json')
if 'actor' not in self.message:
logging.verbose('actor not in message')
return Response.new_error(400, 'no actor in message', 'json')
self.actor = await self.client.get(self.signature.keyid, sign_headers = True)
if self.actor is None:
# ld signatures aren't handled atm, so just ignore it
if self.message.type == 'Delete':
logging.verbose('Instance sent a delete which cannot be handled')
return Response.new(status=202)
logging.verbose(f'Failed to fetch actor: {self.signature.keyid}')
return Response.new_error(400, 'failed to fetch actor', 'json')
try:
self.signer = self.actor.signer
except KeyError:
logging.verbose('Actor missing public key: %s', self.signature.keyid)
return Response.new_error(400, 'actor missing public key', 'json')
try:
self.validate_signature(await self.request.read())
except SignatureFailureError as e:
logging.verbose('signature validation failed for "%s": %s', self.actor.id, e)
return Response.new_error(401, str(e), 'json')
def validate_signature(self, body: bytes) -> None:
headers = {key.lower(): value for key, value in self.request.headers.items()}
headers["(request-target)"] = " ".join([self.request.method.lower(), self.request.path])
if (digest := Digest.new_from_digest(headers.get("digest"))):
if not body:
raise SignatureFailureError("Missing body for digest verification")
if not digest.validate(body):
raise SignatureFailureError("Body digest does not match")
if self.signature.algorithm_type == "hs2019":
if "(created)" not in self.signature.headers:
raise SignatureFailureError("'(created)' header not used")
current_timestamp = HttpDate.new_utc().timestamp()
if self.signature.created > current_timestamp:
raise SignatureFailureError("Creation date after current date")
if current_timestamp > self.signature.expires:
raise SignatureFailureError("Expiration date before current date")
headers["(created)"] = self.signature.created
headers["(expires)"] = self.signature.expires
# pylint: disable=protected-access
if not self.signer._validate_signature(headers, self.signature):
raise SignatureFailureError("Signature does not match")
@register_route('/.well-known/webfinger')
@ -133,12 +234,12 @@ class WebfingerView(View):
except KeyError:
return Response.new_error(400, 'missing "resource" query key', 'json')
if subject != f'acct:relay@{self.config.host}':
if subject != f'acct:relay@{self.config.domain}':
return Response.new_error(404, 'user not found', 'json')
data = Webfinger.new(
handle = 'relay',
domain = self.config.host,
domain = self.config.domain,
actor = self.config.actor
)
@ -148,14 +249,17 @@ class WebfingerView(View):
@register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}')
class NodeinfoView(View):
async def get(self, request: Request, niversion: str) -> Response:
data = {
'name': 'activityrelay',
'version': VERSION,
'protocols': ['activitypub'],
'open_regs': not self.config.whitelist_enabled,
'users': 1,
'metadata': {'peers': self.database.hostnames}
}
with self.database.connection() as conn:
inboxes = conn.execute('SELECT * FROM inboxes').all()
data = {
'name': 'activityrelay',
'version': VERSION,
'protocols': ['activitypub'],
'open_regs': not conn.get_config('whitelist-enabled'),
'users': 1,
'metadata': {'peers': [inbox['domain'] for inbox in inboxes]}
}
if niversion == '2.1':
data['repo'] = 'https://git.pleroma.social/pleroma/relay'
@ -166,5 +270,5 @@ class NodeinfoView(View):
@register_route('/.well-known/nodeinfo')
class WellknownNodeinfoView(View):
async def get(self, request: Request) -> Response:
data = WellKnownNodeinfo.new_template(self.config.host)
data = WellKnownNodeinfo.new_template(self.config.domain)
return Response.new(data, ctype = 'json')

View file

@ -3,3 +3,4 @@ aputils@https://git.barkshark.xyz/barkshark/aputils/archive/0.1.6a.tar.gz
cachetools>=5.2.0
click>=8.1.2
pyyaml>=6.0
tinysql@https://git.barkshark.xyz/barkshark/tinysql/archive/0.2.1.tar.gz