Merge branch 'sql' into 'master'

Draft: switch database backend to sql

See merge request pleroma/relay!53
This commit is contained in:
Izalia Mae 2024-01-25 10:58:28 +00:00
commit d42b939db9
26 changed files with 1836 additions and 992 deletions

10
.gitignore vendored
View file

@ -94,9 +94,7 @@ ENV/
# Rope project settings
.ropeproject
viera.yaml
viera.jsonld
# config file
relay.yaml
relay.jsonld
# config and database
*.yaml
*.jsonld
*.sqlite3

1
MANIFEST.in Normal file
View file

@ -0,0 +1 @@
include data/statements.sql

3
dev-requirements.txt Normal file
View file

@ -0,0 +1,3 @@
flake8 == 7.0.0
pyinstaller == 6.3.0
pylint == 3.0

View file

@ -3,11 +3,8 @@
There are a number of commands to manage your relay's database and config. You can add `--help` to
any category or command to get help on that specific option (ex. `activityrelay inbox --help`).
Note: Unless specified, it is recommended to run any commands while the relay is shutdown.
Note 2: `activityrelay` is only available via pip or pipx if `~/.local/bin` is in `$PATH`. If it
isn't, use `python3 -m relay` if installed via pip or `~/.local/bin/activityrelay` if installed
via pipx
Note: `activityrelay` is only available via pip or pipx if `~/.local/bin` is in `$PATH`. If not,
use `python3 -m relay` if installed via pip or `~/.local/bin/activityrelay` if installed via pipx.
## Run
@ -24,6 +21,22 @@ Run the setup wizard to configure your relay.
activityrelay setup
## Convert
Convert the old config and jsonld to the new config and SQL backend. If the old config filename is
not specified, the config will get backed up as `relay.backup.yaml` before converting.
activityrelay convert --old-config relaycfg.yaml
## Edit Config
Open the config file in a text editor. If an editor is not specified with `--editor`, the default
editor will be used.
activityrelay edit-config --editor micro
## Config
Manage the relay config
@ -120,7 +133,7 @@ Remove a domain from the whitelist.
### Import
Add all current inboxes to the whitelist
Add all current inboxes to the whitelist.
activityrelay whitelist import
@ -132,15 +145,15 @@ Manage the instance ban list.
### List
List the currently banned instances
List the currently banned instances.
activityrelay instance list
### Ban
Add an instance to the ban list. If the instance is currently subscribed, remove it from the
database.
Add an instance to the ban list. If the instance is currently subscribed, it will be removed from
the inbox list.
activityrelay instance ban <domain>
@ -152,10 +165,17 @@ Remove an instance from the ban list.
activityrelay instance unban <domain>
### Update
Update the ban reason or note for an instance ban.
activityrelay instance update bad.example.com --reason "the baddest reason"
## Software
Manage the software ban list. To get the correct name, check the software's nodeinfo endpoint.
You can find it at nodeinfo\['software']\['name'].
You can find it at `nodeinfo['software']['name']`.
### List
@ -186,4 +206,12 @@ name via nodeinfo.
If the name is `RELAYS` (case-sensitive), remove all known relay software names from the list.
activityrelay unban [-f/--fetch-nodeinfo] <name, domain, or RELAYS>
activityrelay software unban [-f/--fetch-nodeinfo] <name, domain, or RELAYS>
### Update
Update the ban reason or note for a software ban. Either `--reason` and/or `--note` must be
specified.
activityrelay software update relay.example.com --reason "begone relay"

View file

@ -2,41 +2,23 @@
## General
### DB
### Domain
The path to the database. It contains the relay actor private key and all subscribed
instances. If the path is not absolute, it is relative to the working directory.
Hostname the relay will be hosted on.
db: relay.jsonld
domain: relay.example.com
### Listener
The address and port the relay will listen on. If the reverse proxy (nginx, apache, caddy, etc)
is running on the same host, it is recommended to change `listen` to `localhost`
is running on the same host, it is recommended to change `listen` to `localhost` if the reverse
proxy is on the same host.
listen: 0.0.0.0
port: 8080
### Note
A small blurb to describe your relay instance. This will show up on the relay's home page.
note: "Make a note about your instance here."
### Post Limit
The maximum number of messages to send out at once. For each incoming message, a message will be
sent out to every subscribed instance minus the instance which sent the message. This limit
is to prevent too many outgoing connections from being made, so adjust if necessary.
Note: If the `workers` option is set to anything above 0, this limit will be per worker.
push_limit: 512
### Push Workers
The relay can be configured to use threads to push messages out. For smaller relays, this isn't
@ -46,60 +28,59 @@ threads.
workers: 0
### JSON GET cache limit
### Database type
JSON objects (actors, nodeinfo, etc) will get cached when fetched. This will set the max number of
objects to keep in the cache.
SQL database backend to use. Valid values are `sqlite` or `postgres`.
json_cache: 1024
database_type: sqlite
## AP
### Sqlite File Path
Various ActivityPub-related settings
Path to the sqlite database file. If the path is not absolute, it is relative to the config file.
directory.
sqlite_path: relay.jsonld
## Postgresql
In order to use the Postgresql backend, the user and database need to be created first.
sudo -u postgres psql -c "CREATE USER activityrelay"
sudo -u postgres psql -c "CREATE DATABASE activityrelay OWNER activityrelay"
### Database Name
Name of the database to use.
name: activityrelay
### Host
The domain your relay will use to identify itself.
Hostname, IP address, or unix socket the server is hosted on.
host: relay.example.com
host: /var/run/postgresql
### Whitelist Enabled
### Port
If set to `true`, only instances in the whitelist can follow the relay. Any subscribed instances
not in the whitelist will be removed from the inbox list on startup.
Port number the server is listening on.
whitelist_enabled: false
port: 5432
### Whitelist
### Username
A list of domains of instances which are allowed to subscribe to your relay.
User to use when logging into the server.
whitelist:
- bad-instance.example.com
- another-bad-instance.example.com
user: null
### Blocked Instances
### Password
A list of instances which are unable to follow the instance. If a subscribed instance is added to
the block list, it will be removed from the inbox list on startup.
Password for the specified user.
blocked_instances:
- bad-instance.example.com
- another-bad-instance.example.com
### Blocked Software
A list of ActivityPub software which cannot follow your relay. This list is empty by default, but
setting this to the below list will block all other relays and prevent relay chains
blocked_software:
- activityrelay
- aoderelay
- social.seattle.wa.us-relay
- unciarelay
pass: null

View file

@ -3,7 +3,7 @@ Description=ActivityPub Relay
[Service]
WorkingDirectory=/home/relay/relay
ExecStart=/usr/bin/python3 -m relay
ExecStart=/usr/bin/python3 -m relay run
[Install]
WantedBy=multi-user.target

View file

@ -6,6 +6,23 @@ build-backend = 'setuptools.build_meta'
[tool.pylint.main]
jobs = 0
persistent = true
load-plugins = [
"pylint.extensions.code_style",
"pylint.extensions.comparison_placement",
"pylint.extensions.confusing_elif",
"pylint.extensions.for_any_all",
"pylint.extensions.consider_ternary_expression",
"pylint.extensions.bad_builtin",
"pylint.extensions.dict_init_mutate",
"pylint.extensions.check_elif",
"pylint.extensions.empty_comment",
"pylint.extensions.private_import",
"pylint.extensions.redefined_variable_type",
"pylint.extensions.no_self_use",
"pylint.extensions.overlapping_exceptions",
"pylint.extensions.set_membership",
"pylint.extensions.typing"
]
[tool.pylint.design]
@ -22,6 +39,7 @@ single-line-if-stmt = true
[tool.pylint.messages_control]
disable = [
"fixme",
"broad-exception-caught",
"cyclic-import",
"global-statement",
@ -31,7 +49,8 @@ disable = [
"too-many-public-methods",
"too-many-return-statements",
"wrong-import-order",
"wrong-import-position",
"missing-function-docstring",
"missing-class-docstring"
"missing-class-docstring",
"consider-using-namedtuple-or-dataclass",
"confusing-consecutive-elif"
]

View file

@ -8,7 +8,9 @@ a = Analysis(
['relay/__main__.py'],
pathex=[],
binaries=[],
datas=[],
datas=[
('relay/data', 'relay/data')
],
hiddenimports=[],
hookspath=[],
hooksconfig={},
@ -19,6 +21,7 @@ a = Analysis(
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(

View file

@ -1,43 +1,35 @@
# this is the path that the object graph will get dumped to (in JSON-LD format),
# you probably shouldn't change it, but you can if you want.
db: relay.jsonld
# [string] Domain the relay will be hosted on
domain: relay.example.com
# Listener
# [string] Address the relay will listen on
listen: 0.0.0.0
# [integer] Port the relay will listen on
port: 8080
# Note
note: "Make a note about your instance here."
# [integer] Number of push workers to start (will get removed in a future update)
workers: 8
# Number of worker threads to start. If 0, use asyncio futures instead of threads.
workers: 0
# [string] Database backend to use. Valid values: sqlite, postgres
database_type: sqlite
# Maximum number of inbox posts to do at once
# If workers is set to 1 or above, this is the max for each worker
push_limit: 512
# [string] Path to the sqlite database file if the sqlite backend is in use
sqlite_path: relay.sqlite3
# The amount of json objects to cache from GET requests
json_cache: 1024
# settings for the postgresql backend
postgres:
ap:
# This is used for generating activitypub messages, as well as instructions for
# linking AP identities. It should be an SSL-enabled domain reachable by https.
host: 'relay.example.com'
# [string] hostname or unix socket to connect to
host: /var/run/postgresql
blocked_instances:
- 'bad-instance.example.com'
- 'another-bad-instance.example.com'
# [integer] port of the server
port: 5432
whitelist_enabled: false
# [string] username to use when logging into the server (default is the current system username)
user: null
whitelist:
- 'good-instance.example.com'
- 'another.good-instance.example.com'
# [string] password of the user
pass: null
# uncomment the lines below to prevent certain activitypub software from posting
# to the relay (all known relays by default). this uses the software name in nodeinfo
#blocked_software:
#- 'activityrelay'
#- 'aoderelay'
#- 'social.seattle.wa.us-relay'
#- 'unciarelay'
# [string] name of the database to use
name: activityrelay

View file

@ -8,52 +8,41 @@ import traceback
import typing
from aiohttp import web
from aputils.signer import Signer
from datetime import datetime, timedelta
from . import logger as logging
from .config import RelayConfig
from .database import RelayDatabase
from .config import Config
from .database import get_database
from .http_client import HttpClient
from .misc import check_open_port
from .views import VIEWS
if typing.TYPE_CHECKING:
from tinysql import Database
from typing import Any
from .misc import Message
# pylint: disable=unsubscriptable-object
class Application(web.Application):
DEFAULT: Application = None
def __init__(self, cfgpath: str):
web.Application.__init__(self)
Application.DEFAULT = self
self['signer'] = None
self['config'] = Config(cfgpath, load = True)
self['database'] = get_database(self.config)
self['client'] = HttpClient()
self['workers'] = []
self['last_worker'] = 0
self['start_time'] = None
self['running'] = False
self['config'] = RelayConfig(cfgpath)
if not self.config.load():
self.config.save()
if self.config.is_docker:
self.config.update({
'db': '/data/relay.jsonld',
'listen': '0.0.0.0',
'port': 8080
})
self['database'] = RelayDatabase(self.config)
self.database.load()
self['client'] = HttpClient(
database = self.database,
limit = self.config.push_limit,
timeout = self.config.timeout,
cache_size = self.config.json_cache
)
for path, view in VIEWS:
self.router.add_view(path, view)
@ -65,15 +54,29 @@ class Application(web.Application):
@property
def config(self) -> RelayConfig:
def config(self) -> Config:
return self['config']
@property
def database(self) -> RelayDatabase:
def database(self) -> Database:
return self['database']
@property
def signer(self) -> Signer:
return self['signer']
@signer.setter
def signer(self, value: Signer | str) -> None:
if isinstance(value, Signer):
self['signer'] = value
return
self['signer'] = Signer(value, self.config.keyid)
@property
def uptime(self) -> timedelta:
if not self['start_time']:
@ -118,7 +121,7 @@ class Application(web.Application):
logging.info(
'Starting webserver at %s (%s:%i)',
self.config.host,
self.config.domain,
self.config.listen,
self.config.port
)
@ -179,12 +182,7 @@ class PushWorker(threading.Thread):
async def handle_queue(self) -> None:
self.client = HttpClient(
database = self.app.database,
limit = self.app.config.push_limit,
timeout = self.app.config.timeout,
cache_size = self.app.config.json_cache
)
self.client = HttpClient()
while self.app['running']:
try:

View file

@ -1,17 +1,129 @@
from __future__ import annotations
import json
import os
import typing
import yaml
from aputils.signer import Signer
from functools import cached_property
from pathlib import Path
from urllib.parse import urlparse
from . import logger as logging
from .misc import Message, boolean
if typing.TYPE_CHECKING:
from typing import Iterator, Optional
from .config import RelayConfig
from .misc import Message
from collections.abc import Iterator
from typing import Any
# pylint: disable=duplicate-code
class RelayConfig(dict):
def __init__(self, path: str):
dict.__init__(self, {})
if self.is_docker:
path = '/data/config.yaml'
self._path = Path(path).expanduser().resolve()
self.reset()
def __setitem__(self, key: str, value: Any) -> None:
if key in {'blocked_instances', 'blocked_software', 'whitelist'}:
assert isinstance(value, (list, set, tuple))
elif key in {'port', 'workers', 'json_cache', 'timeout'}:
if not isinstance(value, int):
value = int(value)
elif key == 'whitelist_enabled':
if not isinstance(value, bool):
value = boolean(value)
super().__setitem__(key, value)
@property
def db(self) -> RelayDatabase:
return Path(self['db']).expanduser().resolve()
@property
def actor(self) -> str:
return f'https://{self["host"]}/actor'
@property
def inbox(self) -> str:
return f'https://{self["host"]}/inbox'
@property
def keyid(self) -> str:
return f'{self.actor}#main-key'
@cached_property
def is_docker(self) -> bool:
return bool(os.environ.get('DOCKER_RUNNING'))
def reset(self) -> None:
self.clear()
self.update({
'db': str(self._path.parent.joinpath(f'{self._path.stem}.jsonld')),
'listen': '0.0.0.0',
'port': 8080,
'note': 'Make a note about your instance here.',
'push_limit': 512,
'json_cache': 1024,
'timeout': 10,
'workers': 0,
'host': 'relay.example.com',
'whitelist_enabled': False,
'blocked_software': [],
'blocked_instances': [],
'whitelist': []
})
def load(self) -> None:
self.reset()
options = {}
try:
options['Loader'] = yaml.FullLoader
except AttributeError:
pass
try:
with self._path.open('r', encoding = 'UTF-8') as fd:
config = yaml.load(fd, **options)
except FileNotFoundError:
return
if not config:
return
for key, value in config.items():
if key == 'ap':
for k, v in value.items():
if k not in self:
continue
self[k] = v
continue
if key not in self:
continue
self[key] = value
class RelayDatabase(dict):
@ -37,9 +149,7 @@ class RelayDatabase(dict):
return tuple(data['inbox'] for data in self['relay-list'].values())
def load(self) -> bool:
new_db = True
def load(self) -> None:
try:
with self.config.db.open() as fd:
data = json.load(fd)
@ -65,17 +175,9 @@ class RelayDatabase(dict):
self['relay-list'] = data.get('relay-list', {})
for domain, instance in self['relay-list'].items():
if self.config.is_banned(domain) or \
(self.config.whitelist_enabled and not self.config.is_whitelisted(domain)):
self.del_inbox(domain)
continue
if not instance.get('domain'):
instance['domain'] = domain
new_db = False
except FileNotFoundError:
pass
@ -83,24 +185,13 @@ class RelayDatabase(dict):
if self.config.db.stat().st_size > 0:
raise e from None
if not self['private-key']:
logging.info('No actor keys present, generating 4096-bit RSA keypair.')
self.signer = Signer.new(self.config.keyid, size=4096)
self['private-key'] = self.signer.export()
else:
self.signer = Signer(self['private-key'], self.config.keyid)
self.save()
return not new_db
def save(self) -> None:
with self.config.db.open('w', encoding = 'UTF-8') as fd:
json.dump(self, fd, indent=4)
def get_inbox(self, domain: str, fail: Optional[bool] = False) -> dict[str, str] | None:
def get_inbox(self, domain: str, fail: bool = False) -> dict[str, str] | None:
if domain.startswith('http'):
domain = urlparse(domain).hostname
@ -115,14 +206,13 @@ class RelayDatabase(dict):
def add_inbox(self,
inbox: str,
followid: Optional[str] = None,
software: Optional[str] = None) -> dict[str, str]:
followid: str | None = None,
software: str | None = None) -> dict[str, str]:
assert inbox.startswith('https'), 'Inbox must be a url'
domain = urlparse(inbox).hostname
instance = self.get_inbox(domain)
if instance:
if (instance := self.get_inbox(domain)):
if followid:
instance['followid'] = followid
@ -144,12 +234,10 @@ class RelayDatabase(dict):
def del_inbox(self,
domain: str,
followid: Optional[str] = None,
fail: Optional[bool] = False) -> bool:
followid: str = None,
fail: bool = False) -> bool:
data = self.get_inbox(domain, fail=False)
if not data:
if not (data := self.get_inbox(domain, fail=False)):
if fail:
raise KeyError(domain)

View file

@ -1,76 +1,76 @@
from __future__ import annotations
import getpass
import os
import typing
import yaml
from functools import cached_property
from pathlib import Path
from urllib.parse import urlparse
from .misc import DotDict, boolean
from .misc import IS_DOCKER
if typing.TYPE_CHECKING:
from typing import Any
from .database import RelayDatabase
RELAY_SOFTWARE = [
'activityrelay', # https://git.pleroma.social/pleroma/relay
'aoderelay', # https://git.asonix.dog/asonix/relay
'feditools-relay' # https://git.ptzo.gdn/feditools/relay
]
DEFAULTS: dict[str, Any] = {
'listen': '0.0.0.0',
'port': 8080,
'domain': 'relay.example.com',
'workers': len(os.sched_getaffinity(0)),
'db_type': 'sqlite',
'sq_path': 'relay.sqlite3',
'pg_host': '/var/run/postgresql',
'pg_port': 5432,
'pg_user': getpass.getuser(),
'pg_pass': None,
'pg_name': 'activityrelay'
}
APKEYS = [
'host',
'whitelist_enabled',
'blocked_software',
'blocked_instances',
'whitelist'
]
if IS_DOCKER:
DEFAULTS['sq_path'] = '/data/relay.jsonld'
class RelayConfig(DotDict):
__slots__ = ('path', )
class Config:
def __init__(self, path: str, load: bool = False):
self.path = Path(path).expanduser().resolve()
def __init__(self, path: str | Path):
DotDict.__init__(self, {})
self.listen = None
self.port = None
self.domain = None
self.workers = None
self.db_type = None
self.sq_path = None
self.pg_host = None
self.pg_port = None
self.pg_user = None
self.pg_pass = None
self.pg_name = None
if self.is_docker:
path = '/data/config.yaml'
if load:
try:
self.load()
self._path = Path(path).expanduser().resolve()
self.reset()
def __setitem__(self, key: str, value: Any) -> None:
if key in ['blocked_instances', 'blocked_software', 'whitelist']:
assert isinstance(value, (list, set, tuple))
elif key in ['port', 'workers', 'json_cache', 'timeout']:
if not isinstance(value, int):
value = int(value)
elif key == 'whitelist_enabled':
if not isinstance(value, bool):
value = boolean(value)
super().__setitem__(key, value)
except FileNotFoundError:
self.save()
@property
def db(self) -> RelayDatabase:
return Path(self['db']).expanduser().resolve()
def sqlite_path(self) -> Path:
if not os.path.isabs(self.sq_path):
return self.path.parent.joinpath(self.sq_path).resolve()
return Path(self.sq_path).expanduser().resolve()
@property
def actor(self) -> str:
return f'https://{self.host}/actor'
return f'https://{self.domain}/actor'
@property
def inbox(self) -> str:
return f'https://{self.host}/inbox'
return f'https://{self.domain}/inbox'
@property
@ -78,115 +78,7 @@ class RelayConfig(DotDict):
return f'{self.actor}#main-key'
@cached_property
def is_docker(self) -> bool:
return bool(os.environ.get('DOCKER_RUNNING'))
def reset(self) -> None:
self.clear()
self.update({
'db': str(self._path.parent.joinpath(f'{self._path.stem}.jsonld')),
'listen': '0.0.0.0',
'port': 8080,
'note': 'Make a note about your instance here.',
'push_limit': 512,
'json_cache': 1024,
'timeout': 10,
'workers': 0,
'host': 'relay.example.com',
'whitelist_enabled': False,
'blocked_software': [],
'blocked_instances': [],
'whitelist': []
})
def ban_instance(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
if self.is_banned(instance):
return False
self.blocked_instances.append(instance)
return True
def unban_instance(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
try:
self.blocked_instances.remove(instance)
return True
except ValueError:
return False
def ban_software(self, software: str) -> bool:
if self.is_banned_software(software):
return False
self.blocked_software.append(software)
return True
def unban_software(self, software: str) -> bool:
try:
self.blocked_software.remove(software)
return True
except ValueError:
return False
def add_whitelist(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
if self.is_whitelisted(instance):
return False
self.whitelist.append(instance)
return True
def del_whitelist(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
try:
self.whitelist.remove(instance)
return True
except ValueError:
return False
def is_banned(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
return instance in self.blocked_instances
def is_banned_software(self, software: str) -> bool:
if not software:
return False
return software.lower() in self.blocked_software
def is_whitelisted(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
return instance in self.whitelist
def load(self) -> bool:
def load(self) -> None:
self.reset()
options = {}
@ -197,50 +89,69 @@ class RelayConfig(DotDict):
except AttributeError:
pass
try:
with self._path.open('r', encoding = 'UTF-8') as fd:
with self.path.open('r', encoding = 'UTF-8') as fd:
config = yaml.load(fd, **options)
except FileNotFoundError:
return False
pgcfg = config.get('postgresql', {})
if not config:
return False
raise ValueError('Config is empty')
for key, value in config.items():
if key in ['ap']:
for k, v in value.items():
if k not in self:
if IS_DOCKER:
self.listen = '0.0.0.0'
self.port = 8080
self.sq_path = '/data/relay.jsonld'
else:
self.set('listen', config.get('listen', DEFAULTS['listen']))
self.set('port', config.get('port', DEFAULTS['port']))
self.set('sq_path', config.get('sqlite_path', DEFAULTS['sq_path']))
self.set('domain', config.get('domain', DEFAULTS['domain']))
self.set('db_type', config.get('database_type', DEFAULTS['db_type']))
for key in DEFAULTS:
if not key.startswith('pg'):
continue
self[k] = v
try:
self.set(key, pgcfg[key[3:]])
except KeyError:
continue
if key not in self:
continue
self[key] = value
if self.host.endswith('example.com'):
return False
return True
def reset(self) -> None:
for key, value in DEFAULTS.items():
setattr(self, key, value)
def save(self) -> None:
self.path.parent.mkdir(exist_ok = True, parents = True)
config = {
# just turning config.db into a string is good enough for now
'db': str(self.db),
'listen': self.listen,
'port': self.port,
'note': self.note,
'push_limit': self.push_limit,
'workers': self.workers,
'json_cache': self.json_cache,
'timeout': self.timeout,
'ap': {key: self[key] for key in APKEYS}
'domain': self.domain,
'database_type': self.db_type,
'sqlite_path': self.sq_path,
'postgres': {
'host': self.pg_host,
'port': self.pg_port,
'user': self.pg_user,
'pass': self.pg_pass,
'name': self.pg_name
}
}
with self._path.open('w', encoding = 'utf-8') as fd:
yaml.dump(config, fd, sort_keys=False)
with self.path.open('w', encoding = 'utf-8') as fd:
yaml.dump(config, fd, sort_keys = False)
def set(self, key: str, value: Any) -> None:
if key not in DEFAULTS:
raise KeyError(key)
if key in {'port', 'pg_port', 'workers'} and not isinstance(value, int):
value = int(value)
setattr(self, key, value)

79
relay/data/statements.sql Normal file
View file

@ -0,0 +1,79 @@
-- name: get-config
SELECT * FROM config WHERE key = :key
-- name: get-config-all
SELECT * FROM config
-- name: put-config
INSERT INTO config (key, value, type)
VALUES (:key, :value, :type)
ON CONFLICT (key) DO UPDATE SET value = :value
RETURNING *
-- name: del-config
DELETE FROM config
WHERE key = :key
-- name: get-inbox
SELECT * FROM inboxes WHERE domain = :value or inbox = :value or actor = :value
-- name: put-inbox
INSERT INTO inboxes (domain, actor, inbox, followid, software, created)
VALUES (:domain, :actor, :inbox, :followid, :software, :created)
ON CONFLICT (domain) DO UPDATE SET followid = :followid
RETURNING *
-- name: del-inbox
DELETE FROM inboxes
WHERE domain = :value or inbox = :value or actor = :value
-- name: get-software-ban
SELECT * FROM software_bans WHERE name = :name
-- name: put-software-ban
INSERT INTO software_bans (name, reason, note, created)
VALUES (:name, :reason, :note, :created)
RETURNING *
-- name: del-software-ban
DELETE FROM software_bans
WHERE name = :name
-- name: get-domain-ban
SELECT * FROM domain_bans WHERE domain = :domain
-- name: put-domain-ban
INSERT INTO domain_bans (domain, reason, note, created)
VALUES (:domain, :reason, :note, :created)
RETURNING *
-- name: del-domain-ban
DELETE FROM domain_bans
WHERE domain = :domain
-- name: get-domain-whitelist
SELECT * FROM whitelist WHERE domain = :domain
-- name: put-domain-whitelist
INSERT INTO whitelist (domain, created)
VALUES (:domain, :created)
RETURNING *
-- name: del-domain-whitelist
DELETE FROM whitelist
WHERE domain = :domain

View file

@ -0,0 +1,65 @@
from __future__ import annotations
import tinysql
import typing
from importlib.resources import files as pkgfiles
from .config import get_default_value
from .connection import Connection
from .schema import VERSIONS, migrate_0
from .. import logger as logging
if typing.TYPE_CHECKING:
from .config import Config
def get_database(config: Config, migrate: bool = True) -> tinysql.Database:
if config.db_type == "sqlite":
db = tinysql.Database.sqlite(
config.sqlite_path,
connection_class = Connection,
min_connections = 2,
max_connections = 10
)
elif config.db_type == "postgres":
db = tinysql.Database.postgres(
config.pg_name,
config.pg_host,
config.pg_port,
config.pg_user,
config.pg_pass,
connection_class = Connection
)
db.load_prepared_statements(pkgfiles("relay").joinpath("data", "statements.sql"))
if not migrate:
return db
with db.connection() as conn:
if 'config' not in conn.get_tables():
logging.info("Creating database tables")
migrate_0(conn)
return db
if (schema_ver := conn.get_config('schema-version')) < get_default_value('schema-version'):
logging.info("Migrating database from version '%i'", schema_ver)
for ver, func in VERSIONS:
if schema_ver < ver:
conn.begin()
func(conn)
conn.put_config('schema-version', ver)
conn.commit()
if (privkey := conn.get_config('private-key')):
conn.app.signer = privkey
logging.set_level(conn.get_config('log-level'))
return db

45
relay/database/config.py Normal file
View file

@ -0,0 +1,45 @@
from __future__ import annotations
import typing
from .. import logger as logging
from ..misc import boolean
if typing.TYPE_CHECKING:
from collections.abc import Callable
from typing import Any
CONFIG_DEFAULTS: dict[str, tuple[str, Any]] = {
'schema-version': ('int', 20240119),
'log-level': ('loglevel', logging.LogLevel.INFO),
'note': ('str', 'Make a note about your instance here.'),
'private-key': ('str', None),
'whitelist-enabled': ('bool', False)
}
# serializer | deserializer
CONFIG_CONVERT: dict[str, tuple[Callable, Callable]] = {
'str': (str, str),
'int': (str, int),
'bool': (str, boolean),
'loglevel': (lambda x: x.name, logging.LogLevel.parse)
}
def get_default_value(key: str) -> Any:
return CONFIG_DEFAULTS[key][1]
def get_default_type(key: str) -> str:
return CONFIG_DEFAULTS[key][0]
def serialize(key: str, value: Any) -> str:
type_name = get_default_type(key)
return CONFIG_CONVERT[type_name][0](value)
def deserialize(key: str, value: str) -> Any:
type_name = get_default_type(key)
return CONFIG_CONVERT[type_name][1](value)

View file

@ -0,0 +1,296 @@
from __future__ import annotations
import tinysql
import typing
from datetime import datetime, timezone
from urllib.parse import urlparse
from .config import CONFIG_DEFAULTS, get_default_type, get_default_value, serialize, deserialize
from .. import logger as logging
from ..misc import get_app
if typing.TYPE_CHECKING:
from collections.abc import Iterator
from tinysql import Cursor, Row
from typing import Any
from .application import Application
from ..misc import Message
RELAY_SOFTWARE = [
'activityrelay', # https://git.pleroma.social/pleroma/relay
'activity-relay', # https://github.com/yukimochi/Activity-Relay
'aoderelay', # https://git.asonix.dog/asonix/relay
'feditools-relay' # https://git.ptzo.gdn/feditools/relay
]
class Connection(tinysql.Connection):
@property
def app(self) -> Application:
return get_app()
def distill_inboxes(self, message: Message) -> Iterator[str]:
src_domains = {
message.domain,
urlparse(message.object_id).netloc
}
for inbox in self.execute('SELECT * FROM inboxes'):
if inbox['domain'] not in src_domains:
yield inbox['inbox']
def exec_statement(self, name: str, params: dict[str, Any] | None = None) -> Cursor:
return self.execute(self.database.prepared_statements[name], params)
def get_config(self, key: str) -> Any:
if key not in CONFIG_DEFAULTS:
raise KeyError(key)
with self.exec_statement('get-config', {'key': key}) as cur:
if not (row := cur.one()):
return get_default_value(key)
if row['value']:
return deserialize(row['key'], row['value'])
return None
def get_config_all(self) -> dict[str, Any]:
with self.exec_statement('get-config-all') as cur:
db_config = {row['key']: row['value'] for row in cur}
config = {}
for key, data in CONFIG_DEFAULTS.items():
try:
config[key] = deserialize(key, db_config[key])
except KeyError:
if key == 'schema-version':
config[key] = 0
else:
config[key] = data[1]
return config
def put_config(self, key: str, value: Any) -> Any:
if key not in CONFIG_DEFAULTS:
raise KeyError(key)
if key == 'private-key':
self.app.signer = value
elif key == 'log-level':
value = logging.LogLevel.parse(value)
logging.set_level(value)
params = {
'key': key,
'value': serialize(key, value) if value is not None else None,
'type': get_default_type(key)
}
with self.exec_statement('put-config', params):
return value
def get_inbox(self, value: str) -> Row:
with self.exec_statement('get-inbox', {'value': value}) as cur:
return cur.one()
def put_inbox(self,
domain: str,
inbox: str,
actor: str | None = None,
followid: str | None = None,
software: str | None = None) -> Row:
params = {
'domain': domain,
'inbox': inbox,
'actor': actor,
'followid': followid,
'software': software,
'created': datetime.now(tz = timezone.utc)
}
with self.exec_statement('put-inbox', params) as cur:
return cur.one()
def update_inbox(self,
inbox: str,
actor: str | None = None,
followid: str | None = None,
software: str | None = None) -> Row:
if not (actor or followid or software):
raise ValueError('Missing "actor", "followid", and/or "software"')
data = {}
if actor:
data['actor'] = actor
if followid:
data['followid'] = followid
if software:
data['software'] = software
statement = tinysql.Update('inboxes', data, inbox = inbox)
with self.query(statement):
return self.get_inbox(inbox)
def del_inbox(self, value: str) -> bool:
with self.exec_statement('del-inbox', {'value': value}) as cur:
if cur.modified_row_count > 1:
raise ValueError('More than one row was modified')
return cur.modified_row_count == 1
def get_domain_ban(self, domain: str) -> Row:
if domain.startswith('http'):
domain = urlparse(domain).netloc
with self.exec_statement('get-domain-ban', {'domain': domain}) as cur:
return cur.one()
def put_domain_ban(self,
domain: str,
reason: str | None = None,
note: str | None = None) -> Row:
params = {
'domain': domain,
'reason': reason,
'note': note,
'created': datetime.now(tz = timezone.utc)
}
with self.exec_statement('put-domain-ban', params) as cur:
return cur.one()
def update_domain_ban(self,
domain: str,
reason: str | None = None,
note: str | None = None) -> tinysql.Row:
if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified')
params = {}
if reason:
params['reason'] = reason
if note:
params['note'] = note
statement = tinysql.Update('domain_bans', params, domain = domain)
with self.query(statement) as cur:
if cur.modified_row_count > 1:
raise ValueError('More than one row was modified')
return self.get_domain_ban(domain)
def del_domain_ban(self, domain: str) -> bool:
with self.exec_statement('del-domain-ban', {'domain': domain}) as cur:
if cur.modified_row_count > 1:
raise ValueError('More than one row was modified')
return cur.modified_row_count == 1
def get_software_ban(self, name: str) -> Row:
with self.exec_statement('get-software-ban', {'name': name}) as cur:
return cur.one()
def put_software_ban(self,
name: str,
reason: str | None = None,
note: str | None = None) -> Row:
params = {
'name': name,
'reason': reason,
'note': note,
'created': datetime.now(tz = timezone.utc)
}
with self.exec_statement('put-software-ban', params) as cur:
return cur.one()
def update_software_ban(self,
name: str,
reason: str | None = None,
note: str | None = None) -> tinysql.Row:
if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified')
params = {}
if reason:
params['reason'] = reason
if note:
params['note'] = note
statement = tinysql.Update('software_bans', params, name = name)
with self.query(statement) as cur:
if cur.modified_row_count > 1:
raise ValueError('More than one row was modified')
return self.get_software_ban(name)
def del_software_ban(self, name: str) -> bool:
with self.exec_statement('del-software-ban', {'name': name}) as cur:
if cur.modified_row_count > 1:
raise ValueError('More than one row was modified')
return cur.modified_row_count == 1
def get_domain_whitelist(self, domain: str) -> Row:
with self.exec_statement('get-domain-whitelist', {'domain': domain}) as cur:
return cur.one()
def put_domain_whitelist(self, domain: str) -> Row:
params = {
'domain': domain,
'created': datetime.now(tz = timezone.utc)
}
with self.exec_statement('put-domain-whitelist', params) as cur:
return cur.one()
def del_domain_whitelist(self, domain: str) -> bool:
with self.exec_statement('del-domain-whitelist', {'domain': domain}) as cur:
if cur.modified_row_count > 1:
raise ValueError('More than one row was modified')
return cur.modified_row_count == 1

60
relay/database/schema.py Normal file
View file

@ -0,0 +1,60 @@
from __future__ import annotations
import typing
from tinysql import Column, Connection, Table
from .config import get_default_value
if typing.TYPE_CHECKING:
from collections.abc import Callable
VERSIONS: list[Callable] = []
TABLES: list[Table] = [
Table(
'config',
Column('key', 'text', primary_key = True, unique = True, nullable = False),
Column('value', 'text'),
Column('type', 'text', default = 'str')
),
Table(
'inboxes',
Column('domain', 'text', primary_key = True, unique = True, nullable = False),
Column('actor', 'text', unique = True),
Column('inbox', 'text', unique = True, nullable = False),
Column('followid', 'text'),
Column('software', 'text'),
Column('created', 'timestamp', nullable = False)
),
Table(
'whitelist',
Column('domain', 'text', primary_key = True, unique = True, nullable = True),
Column('created', 'timestamp')
),
Table(
'domain_bans',
Column('domain', 'text', primary_key = True, unique = True, nullable = True),
Column('reason', 'text'),
Column('note', 'text'),
Column('created', 'timestamp', nullable = False)
),
Table(
'software_bans',
Column('name', 'text', primary_key = True, unique = True, nullable = True),
Column('reason', 'text'),
Column('note', 'text'),
Column('created', 'timestamp', nullable = False)
)
]
def version(func: Callable) -> Callable:
ver = int(func.replace('migrate_', ''))
VERSIONS[ver] = func
return func
def migrate_0(conn: Connection) -> None:
conn.create_tables(TABLES)
conn.put_config('schema-version', get_default_value('schema-version'))

View file

@ -13,11 +13,10 @@ from urllib.parse import urlparse
from . import __version__
from . import logger as logging
from .misc import MIMETYPES, Message
from .misc import MIMETYPES, Message, get_app
if typing.TYPE_CHECKING:
from typing import Any, Callable, Optional
from .database import RelayDatabase
from typing import Any
HEADERS = {
@ -27,13 +26,7 @@ HEADERS = {
class HttpClient:
def __init__(self,
database: RelayDatabase,
limit: Optional[int] = 100,
timeout: Optional[int] = 10,
cache_size: Optional[int] = 1024):
self.database = database
def __init__(self, limit: int = 100, timeout: int = 10, cache_size: int = 1024):
self.cache = LRUCache(cache_size)
self.limit = limit
self.timeout = timeout
@ -80,9 +73,9 @@ class HttpClient:
async def get(self, # pylint: disable=too-many-branches
url: str,
sign_headers: Optional[bool] = False,
loads: Optional[Callable] = None,
force: Optional[bool] = False) -> Message | dict | None:
sign_headers: bool = False,
loads: callable | None = None,
force: bool = False) -> Message | dict | None:
await self.open()
@ -98,7 +91,7 @@ class HttpClient:
headers = {}
if sign_headers:
headers.update(self.database.signer.sign_headers('GET', url, algorithm='original'))
get_app().signer.sign_headers('GET', url, algorithm = 'original')
try:
logging.debug('Fetching resource: %s', url)
@ -150,23 +143,27 @@ class HttpClient:
async def post(self, url: str, message: Message) -> None:
await self.open()
instance = self.database.get_inbox(url)
# todo: cache inboxes to avoid opening a db connection
with get_app().database.connection() as conn:
instance = conn.get_inbox(url)
## Using the old algo by default is probably a better idea right now
if instance and instance.get('software') in {'mastodon'}:
# pylint: disable=consider-ternary-expression
if instance and instance['software'] in {'mastodon'}:
algorithm = 'hs2019'
else:
algorithm = 'original'
# pylint: enable=consider-ternary-expression
headers = {'Content-Type': 'application/activity+json'}
headers.update(self.database.signer.sign_headers('POST', url, message, algorithm=algorithm))
headers.update(get_app().signer.sign_headers('POST', url, message, algorithm=algorithm))
try:
logging.verbose('Sending "%s" to %s', message.type, url)
async with self._session.post(url, headers=headers, data=message.to_json()) as resp:
## Not expecting a response, so just return
# Not expecting a response, so just return
if resp.status in {200, 202}:
logging.verbose('Successfully sent "%s" to %s', message.type, url)
return
@ -181,7 +178,7 @@ class HttpClient:
except (AsyncTimeoutError, ClientConnectionError):
logging.warning('Failed to connect to %s for message push', urlparse(url).netloc)
## prevent workers from being brought down
# prevent workers from being brought down
except Exception:
traceback.print_exc()
@ -197,7 +194,7 @@ class HttpClient:
logging.verbose('Failed to fetch well-known nodeinfo url for %s', domain)
return None
for version in ['20', '21']:
for version in ('20', '21'):
try:
nodeinfo_url = wk_nodeinfo.get_url(version)
@ -211,16 +208,16 @@ class HttpClient:
return await self.get(nodeinfo_url, loads = Nodeinfo.parse) or None
async def get(database: RelayDatabase, *args: Any, **kwargs: Any) -> Message | dict | None:
async with HttpClient(database) as client:
async def get(*args: Any, **kwargs: Any) -> Message | dict | None:
async with HttpClient() as client:
return await client.get(*args, **kwargs)
async def post(database: RelayDatabase, *args: Any, **kwargs: Any) -> None:
async with HttpClient(database) as client:
async def post(*args: Any, **kwargs: Any) -> None:
async with HttpClient() as client:
return await client.post(*args, **kwargs)
async def fetch_nodeinfo(database: RelayDatabase, *args: Any, **kwargs: Any) -> Nodeinfo | None:
async with HttpClient(database) as client:
async def fetch_nodeinfo(*args: Any, **kwargs: Any) -> Nodeinfo | None:
async with HttpClient() as client:
return await client.fetch_nodeinfo(*args, **kwargs)

View file

@ -4,20 +4,63 @@ import logging
import os
import typing
from enum import IntEnum
from pathlib import Path
if typing.TYPE_CHECKING:
from typing import Any, Callable
from collections.abc import Callable
from typing import Any
LOG_LEVELS: dict[str, int] = {
'DEBUG': logging.DEBUG,
'VERBOSE': 15,
'INFO': logging.INFO,
'WARNING': logging.WARNING,
'ERROR': logging.ERROR,
'CRITICAL': logging.CRITICAL
}
class LogLevel(IntEnum):
DEBUG = logging.DEBUG
VERBOSE = 15
INFO = logging.INFO
WARNING = logging.WARNING
ERROR = logging.ERROR
CRITICAL = logging.CRITICAL
def __str__(self) -> str:
return self.name
@classmethod
def parse(cls: type[IntEnum], data: object) -> IntEnum:
if isinstance(data, cls):
return data
if isinstance(data, str):
data = data.upper()
try:
return cls[data]
except KeyError:
pass
try:
return cls(data)
except ValueError:
pass
raise AttributeError(f'Invalid enum property for {cls.__name__}: {data}')
def get_level() -> LogLevel:
return LogLevel.parse(logging.root.level)
def set_level(level: LogLevel | str) -> None:
logging.root.setLevel(LogLevel.parse(level))
def verbose(message: str, *args: Any, **kwargs: Any) -> None:
if not logging.root.isEnabledFor(LogLevel['VERBOSE']):
return
logging.log(LogLevel['VERBOSE'], message, *args, **kwargs)
debug: Callable = logging.debug
@ -27,14 +70,7 @@ error: Callable = logging.error
critical: Callable = logging.critical
def verbose(message: str, *args: Any, **kwargs: Any) -> None:
if not logging.root.isEnabledFor(LOG_LEVELS['VERBOSE']):
return
logging.log(LOG_LEVELS['VERBOSE'], message, *args, **kwargs)
logging.addLevelName(LOG_LEVELS['VERBOSE'], 'VERBOSE')
logging.addLevelName(LogLevel['VERBOSE'], 'VERBOSE')
env_log_level = os.environ.get('LOG_LEVEL', 'INFO').upper()
try:
@ -45,11 +81,11 @@ except KeyError:
try:
log_level = LOG_LEVELS[env_log_level]
log_level = LogLevel[env_log_level]
except KeyError:
logging.warning('Invalid log level: %s', env_log_level)
log_level = logging.INFO
print('Invalid log level:', env_log_level)
log_level = LogLevel['INFO']
handlers = [logging.StreamHandler()]

View file

@ -6,22 +6,49 @@ import click
import platform
import typing
from aputils.signer import Signer
from pathlib import Path
from shutil import copyfile
from urllib.parse import urlparse
from . import misc, __version__
from . import __version__
from . import http_client as http
from . import logger as logging
from .application import Application
from .config import RELAY_SOFTWARE
from .compat import RelayConfig, RelayDatabase
from .database import get_database
from .database.connection import RELAY_SOFTWARE
from .misc import IS_DOCKER, Message, check_open_port
if typing.TYPE_CHECKING:
from tinysql import Row
from typing import Any
# pylint: disable=unsubscriptable-object,unsupported-assignment-operation
app = None
CONFIG_IGNORE = {'blocked_software', 'blocked_instances', 'whitelist'}
CONFIG_IGNORE = (
'schema-version',
'private-key'
)
ACTOR_FORMATS = {
'mastodon': 'https://{domain}/actor',
'akkoma': 'https://{domain}/relay',
'pleroma': 'https://{domain}/relay'
}
SOFTWARE = (
'mastodon',
'akkoma',
'pleroma',
'misskey',
'friendica',
'hubzilla',
'firefish',
'gotosocial'
)
@click.group('cli', context_settings={'show_default': True}, invoke_without_command=True)
@ -29,11 +56,10 @@ CONFIG_IGNORE = {'blocked_software', 'blocked_instances', 'whitelist'}
@click.version_option(version=__version__, prog_name='ActivityRelay')
@click.pass_context
def cli(ctx: click.Context, config: str) -> None:
global app
app = Application(config)
ctx.obj = Application(config)
if not ctx.invoked_subcommand:
if app.config.host.endswith('example.com'):
if ctx.obj.config.domain.endswith('example.com'):
cli_setup.callback()
else:
@ -41,46 +67,92 @@ def cli(ctx: click.Context, config: str) -> None:
@cli.command('setup')
def cli_setup() -> None:
'Generate a new config'
@click.pass_context
def cli_setup(ctx: click.Context) -> None:
'Generate a new config and create the database'
while True:
app.config.host = click.prompt(
ctx.obj.config.domain = click.prompt(
'What domain will the relay be hosted on?',
default = app.config.host
default = ctx.obj.config.domain
)
if not app.config.host.endswith('example.com'):
if not ctx.obj.config.domain.endswith('example.com'):
break
click.echo('The domain must not be example.com')
click.echo('The domain must not end with "example.com"')
if not app.config.is_docker:
app.config.listen = click.prompt(
if not IS_DOCKER:
ctx.obj.config.listen = click.prompt(
'Which address should the relay listen on?',
default = app.config.listen
default = ctx.obj.config.listen
)
while True:
app.config.port = click.prompt(
ctx.obj.config.port = click.prompt(
'What TCP port should the relay listen on?',
default = app.config.port,
default = ctx.obj.config.port,
type = int
)
break
ctx.obj.config.db_type = click.prompt(
'Which database backend will be used?',
default = ctx.obj.config.db_type,
type = click.Choice(['postgres', 'sqlite'], case_sensitive = False)
)
app.config.save()
if ctx.obj.config.db_type == 'sqlite':
ctx.obj.config.sq_path = click.prompt(
'Where should the database be stored?',
default = ctx.obj.config.sq_path
)
if not app.config.is_docker and click.confirm('Relay all setup! Would you like to run it now?'):
elif ctx.obj.config.db_type == 'postgres':
ctx.obj.config.pg_name = click.prompt(
'What is the name of the database?',
default = ctx.obj.config.pg_name
)
ctx.obj.config.pg_host = click.prompt(
'What IP address or hostname does the server listen on?',
default = ctx.obj.config.pg_host
)
ctx.obj.config.pg_port = click.prompt(
'What port does the server listen on?',
default = ctx.obj.config.pg_port,
type = int
)
ctx.obj.config.pg_user = click.prompt(
'Which user will authenticate with the server?',
default = ctx.obj.config.pg_user
)
ctx.obj.config.pg_pass = click.prompt(
'User password: ',
hide_input = True
) or None
ctx.obj.config.save()
config = {
'private-key': Signer.new('n/a').export()
}
with ctx.obj.database.connection() as conn:
for key, value in config.items():
conn.put_config(key, value)
if not IS_DOCKER and click.confirm('Relay all setup! Would you like to run it now?'):
cli_run.callback()
@cli.command('run')
def cli_run() -> None:
@click.pass_context
def cli_run(ctx: click.Context) -> None:
'Run the relay'
if app.config.host.endswith('example.com'):
if ctx.obj.config.domain.endswith('example.com') or not ctx.obj.signer:
click.echo(
'Relay is not set up. Please edit your relay config or run "activityrelay setup".'
)
@ -104,25 +176,128 @@ def cli_run() -> None:
click.echo(pip_command)
return
if not misc.check_open_port(app.config.listen, app.config.port):
click.echo(f'Error: A server is already running on port {app.config.port}')
if not check_open_port(ctx.obj.config.listen, ctx.obj.config.port):
click.echo(f'Error: A server is already running on port {ctx.obj.config.port}')
return
app.run()
ctx.obj.run()
@cli.command('convert')
@click.option('--old-config', '-o', help = 'Path to the config file to convert from')
@click.pass_context
def cli_convert(ctx: click.Context, old_config: str) -> None:
'Convert an old config and jsonld database to the new format.'
old_config = Path(old_config).expanduser().resolve() if old_config else ctx.obj.config.path
backup = ctx.obj.config.path.parent.joinpath(f'{ctx.obj.config.path.stem}.backup.yaml')
if str(old_config) == str(ctx.obj.config.path) and not backup.exists():
logging.info('Created backup config @ %s', backup)
copyfile(ctx.obj.config.path, backup)
config = RelayConfig(old_config)
config.load()
database = RelayDatabase(config)
database.load()
ctx.obj.config.set('listen', config['listen'])
ctx.obj.config.set('port', config['port'])
ctx.obj.config.set('workers', config['workers'])
ctx.obj.config.set('sq_path', config['db'].replace('jsonld', 'sqlite3'))
ctx.obj.config.set('domain', config['host'])
ctx.obj.config.save()
with get_database(ctx.obj.config) as db:
with db.connection() as conn:
conn.put_config('private-key', database['private-key'])
conn.put_config('note', config['note'])
conn.put_config('whitelist-enabled', config['whitelist_enabled'])
with click.progressbar(
database['relay-list'].values(),
label = 'Inboxes'.ljust(15),
width = 0
) as inboxes:
for inbox in inboxes:
if inbox['software'] in {'akkoma', 'pleroma'}:
actor = f'https://{inbox["domain"]}/relay'
elif inbox['software'] == 'mastodon':
actor = f'https://{inbox["domain"]}/actor'
else:
actor = None
conn.put_inbox(
inbox['domain'],
inbox['inbox'],
actor = actor,
followid = inbox['followid'],
software = inbox['software']
)
with click.progressbar(
config['blocked_software'],
label = 'Banned software'.ljust(15),
width = 0
) as banned_software:
for software in banned_software:
conn.put_software_ban(
software,
reason = 'relay' if software in RELAY_SOFTWARE else None
)
with click.progressbar(
config['blocked_instances'],
label = 'Banned domains'.ljust(15),
width = 0
) as banned_software:
for domain in banned_software:
conn.put_domain_ban(domain)
with click.progressbar(
config['whitelist'],
label = 'Whitelist'.ljust(15),
width = 0
) as whitelist:
for instance in whitelist:
conn.put_domain_whitelist(instance)
click.echo('Finished converting old config and database :3')
@cli.command('edit-config')
@click.option('--editor', '-e', help = 'Text editor to use')
@click.pass_context
def cli_editconfig(ctx: click.Context, editor: str) -> None:
'Edit the config file'
click.edit(
editor = editor,
filename = str(ctx.obj.config.path)
)
@cli.group('config')
def cli_config() -> None:
'Manage the relay config'
'Manage the relay settings stored in the database'
@cli_config.command('list')
def cli_config_list() -> None:
@click.pass_context
def cli_config_list(ctx: click.Context) -> None:
'List the current relay config'
click.echo('Relay Config:')
for key, value in app.config.items():
with ctx.obj.database.connection() as conn:
for key, value in conn.get_config_all().items():
if key not in CONFIG_IGNORE:
key = f'{key}:'.ljust(20)
click.echo(f'- {key} {value}')
@ -131,13 +306,14 @@ def cli_config_list() -> None:
@cli_config.command('set')
@click.argument('key')
@click.argument('value')
def cli_config_set(key: str, value: Any) -> None:
@click.pass_context
def cli_config_set(ctx: click.Context, key: str, value: Any) -> None:
'Set a config value'
app.config[key] = value
app.config.save()
with ctx.obj.database.connection() as conn:
new_value = conn.put_config(key, value)
print(f'{key}: {app.config[key]}')
print(f'{key}: {repr(new_value)}')
@cli.group('inbox')
@ -146,128 +322,151 @@ def cli_inbox() -> None:
@cli_inbox.command('list')
def cli_inbox_list() -> None:
@click.pass_context
def cli_inbox_list(ctx: click.Context) -> None:
'List the connected instances or relays'
click.echo('Connected to the following instances or relays:')
for inbox in app.database.inboxes:
click.echo(f'- {inbox}')
with ctx.obj.database.connection() as conn:
for inbox in conn.execute('SELECT * FROM inboxes'):
click.echo(f'- {inbox["inbox"]}')
@cli_inbox.command('follow')
@click.argument('actor')
def cli_inbox_follow(actor: str) -> None:
@click.pass_context
def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
'Follow an actor (Relay must be running)'
if app.config.is_banned(actor):
with ctx.obj.database.connection() as conn:
if conn.get_domain_ban(actor):
click.echo(f'Error: Refusing to follow banned actor: {actor}')
return
if not actor.startswith('http'):
domain = actor
actor = f'https://{actor}/actor'
else:
domain = urlparse(actor).hostname
try:
inbox_data = app.database['relay-list'][domain]
if (inbox_data := conn.get_inbox(actor)):
inbox = inbox_data['inbox']
except KeyError:
actor_data = asyncio.run(http.get(app.database, actor, sign_headers=True))
else:
if not actor.startswith('http'):
actor = f'https://{actor}/actor'
if not actor_data:
if not (actor_data := asyncio.run(http.get(actor, sign_headers = True))):
click.echo(f'Failed to fetch actor: {actor}')
return
inbox = actor_data.shared_inbox
message = misc.Message.new_follow(
host = app.config.host,
message = Message.new_follow(
host = ctx.obj.config.domain,
actor = actor
)
asyncio.run(http.post(app.database, inbox, message))
asyncio.run(http.post(inbox, message))
click.echo(f'Sent follow message to actor: {actor}')
@cli_inbox.command('unfollow')
@click.argument('actor')
def cli_inbox_unfollow(actor: str) -> None:
@click.pass_context
def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
'Unfollow an actor (Relay must be running)'
if not actor.startswith('http'):
domain = actor
actor = f'https://{actor}/actor'
inbox_data: Row = None
else:
domain = urlparse(actor).hostname
with ctx.obj.database.connection() as conn:
if conn.get_domain_ban(actor):
click.echo(f'Error: Refusing to follow banned actor: {actor}')
return
try:
inbox_data = app.database['relay-list'][domain]
if (inbox_data := conn.get_inbox(actor)):
inbox = inbox_data['inbox']
message = misc.Message.new_unfollow(
host = app.config.host,
message = Message.new_unfollow(
host = ctx.obj.config.domain,
actor = actor,
follow = inbox_data['followid']
)
except KeyError:
actor_data = asyncio.run(http.get(app.database, actor, sign_headers=True))
else:
if not actor.startswith('http'):
actor = f'https://{actor}/actor'
actor_data = asyncio.run(http.get(actor, sign_headers = True))
inbox = actor_data.shared_inbox
message = misc.Message.new_unfollow(
host = app.config.host,
message = Message.new_unfollow(
host = ctx.obj.config.domain,
actor = actor,
follow = {
'type': 'Follow',
'object': actor,
'actor': f'https://{app.config.host}/actor'
'actor': f'https://{ctx.obj.config.domain}/actor'
}
)
asyncio.run(http.post(app.database, inbox, message))
asyncio.run(http.post(inbox, message))
click.echo(f'Sent unfollow message to: {actor}')
@cli_inbox.command('add')
@click.argument('inbox')
def cli_inbox_add(inbox: str) -> None:
@click.option('--actor', '-a', help = 'Actor url for the inbox')
@click.option('--followid', '-f', help = 'Url for the follow activity')
@click.option('--software', '-s',
type = click.Choice(SOFTWARE),
help = 'Nodeinfo software name of the instance'
) # noqa: E124
@click.pass_context
def cli_inbox_add(
ctx: click.Context,
inbox: str,
actor: str | None = None,
followid: str | None = None,
software: str | None = None) -> None:
'Add an inbox to the database'
if not inbox.startswith('http'):
domain = inbox
inbox = f'https://{inbox}/inbox'
if app.config.is_banned(inbox):
click.echo(f'Error: Refusing to add banned inbox: {inbox}')
else:
domain = urlparse(inbox).netloc
if not software:
if (nodeinfo := asyncio.run(http.fetch_nodeinfo(domain))):
software = nodeinfo.sw_name
if not actor and software:
try:
actor = ACTOR_FORMATS[software].format(domain = domain)
except KeyError:
pass
with ctx.obj.database.connection() as conn:
if conn.get_domain_ban(domain):
click.echo(f'Refusing to add banned inbox: {inbox}')
return
if app.database.get_inbox(inbox):
if conn.get_inbox(inbox):
click.echo(f'Error: Inbox already in database: {inbox}')
return
app.database.add_inbox(inbox)
app.database.save()
conn.put_inbox(domain, inbox, actor, followid, software)
click.echo(f'Added inbox to the database: {inbox}')
@cli_inbox.command('remove')
@click.argument('inbox')
def cli_inbox_remove(inbox: str) -> None:
@click.pass_context
def cli_inbox_remove(ctx: click.Context, inbox: str) -> None:
'Remove an inbox from the database'
try:
dbinbox = app.database.get_inbox(inbox, fail=True)
except KeyError:
click.echo(f'Error: Inbox does not exist: {inbox}')
with ctx.obj.database.connection() as conn:
if not conn.del_inbox(inbox):
click.echo(f'Inbox not in database: {inbox}')
return
app.database.del_inbox(dbinbox['domain'])
app.database.save()
click.echo(f'Removed inbox from the database: {inbox}')
@ -277,47 +476,76 @@ def cli_instance() -> None:
@cli_instance.command('list')
def cli_instance_list() -> None:
@click.pass_context
def cli_instance_list(ctx: click.Context) -> None:
'List all banned instances'
click.echo('Banned instances or relays:')
click.echo('Banned domains:')
for domain in app.config.blocked_instances:
click.echo(f'- {domain}')
with ctx.obj.database.connection() as conn:
for instance in conn.execute('SELECT * FROM domain_bans'):
if instance['reason']:
click.echo(f'- {instance["domain"]} ({instance["reason"]})')
else:
click.echo(f'- {instance["domain"]}')
@cli_instance.command('ban')
@click.argument('target')
def cli_instance_ban(target: str) -> None:
@click.argument('domain')
@click.option('--reason', '-r', help = 'Public note about why the domain is banned')
@click.option('--note', '-n', help = 'Internal note that will only be seen by admins and mods')
@click.pass_context
def cli_instance_ban(ctx: click.Context, domain: str, reason: str, note: str) -> None:
'Ban an instance and remove the associated inbox if it exists'
if target.startswith('http'):
target = urlparse(target).hostname
if app.config.ban_instance(target):
app.config.save()
if app.database.del_inbox(target):
app.database.save()
click.echo(f'Banned instance: {target}')
with ctx.obj.database.connection() as conn:
if conn.get_domain_ban(domain):
click.echo(f'Domain already banned: {domain}')
return
click.echo(f'Instance already banned: {target}')
conn.put_domain_ban(domain, reason, note)
conn.del_inbox(domain)
click.echo(f'Banned instance: {domain}')
@cli_instance.command('unban')
@click.argument('target')
def cli_instance_unban(target: str) -> None:
@click.argument('domain')
@click.pass_context
def cli_instance_unban(ctx: click.Context, domain: str) -> None:
'Unban an instance'
if app.config.unban_instance(target):
app.config.save()
click.echo(f'Unbanned instance: {target}')
with ctx.obj.database.connection() as conn:
if not conn.del_domain_ban(domain):
click.echo(f'Instance wasn\'t banned: {domain}')
return
click.echo(f'Instance wasn\'t banned: {target}')
click.echo(f'Unbanned instance: {domain}')
@cli_instance.command('update')
@click.argument('domain')
@click.option('--reason', '-r')
@click.option('--note', '-n')
@click.pass_context
def cli_instance_update(ctx: click.Context, domain: str, reason: str, note: str) -> None:
'Update the public reason or internal note for a domain ban'
if not (reason or note):
ctx.fail('Must pass --reason or --note')
with ctx.obj.database.connection() as conn:
if not (row := conn.update_domain_ban(domain, reason, note)):
click.echo(f'Failed to update domain ban: {domain}')
return
click.echo(f'Updated domain ban: {domain}')
if row['reason']:
click.echo(f'- {row["domain"]} ({row["reason"]})')
else:
click.echo(f'- {row["domain"]}')
@cli.group('software')
@ -326,79 +554,127 @@ def cli_software() -> None:
@cli_software.command('list')
def cli_software_list() -> None:
@click.pass_context
def cli_software_list(ctx: click.Context) -> None:
'List all banned software'
click.echo('Banned software:')
for software in app.config.blocked_software:
click.echo(f'- {software}')
with ctx.obj.database.connection() as conn:
for software in conn.execute('SELECT * FROM software_bans'):
if software['reason']:
click.echo(f'- {software["name"]} ({software["reason"]})')
else:
click.echo(f'- {software["name"]}')
@cli_software.command('ban')
@click.option(
'--fetch-nodeinfo/--ignore-nodeinfo', '-f', 'fetch_nodeinfo', default = False,
help = 'Treat NAME like a domain and try to fet the software name from nodeinfo'
)
@click.argument('name')
def cli_software_ban(name: str, fetch_nodeinfo: bool) -> None:
@click.option('--reason', '-r')
@click.option('--note', '-n')
@click.option(
'--fetch-nodeinfo', '-f',
is_flag = True,
help = 'Treat NAME like a domain and try to fetch the software name from nodeinfo'
)
@click.pass_context
def cli_software_ban(ctx: click.Context,
name: str,
reason: str,
note: str,
fetch_nodeinfo: bool) -> None:
'Ban software. Use RELAYS for NAME to ban relays'
with ctx.obj.database.connection() as conn:
if name == 'RELAYS':
for software in RELAY_SOFTWARE:
app.config.ban_software(software)
if conn.get_software_ban(software):
click.echo(f'Relay already banned: {software}')
continue
conn.put_software_ban(software, reason or 'relay', note)
app.config.save()
click.echo('Banned all relay software')
return
if fetch_nodeinfo:
nodeinfo = asyncio.run(http.fetch_nodeinfo(app.database, name))
if not nodeinfo:
if not (nodeinfo := asyncio.run(http.fetch_nodeinfo(name))):
click.echo(f'Failed to fetch software name from domain: {name}')
return
name = nodeinfo.sw_name
if app.config.ban_software(name):
app.config.save()
click.echo(f'Banned software: {name}')
if conn.get_software_ban(name):
click.echo(f'Software already banned: {name}')
return
click.echo(f'Software already banned: {name}')
if not conn.put_software_ban(name, reason, note):
click.echo(f'Failed to ban software: {name}')
return
click.echo(f'Banned software: {name}')
@cli_software.command('unban')
@click.option(
'--fetch-nodeinfo/--ignore-nodeinfo', '-f', 'fetch_nodeinfo', default = False,
help = 'Treat NAME like a domain and try to fet the software name from nodeinfo'
)
@click.argument('name')
def cli_software_unban(name: str, fetch_nodeinfo: bool) -> None:
@click.option('--reason', '-r')
@click.option('--note', '-n')
@click.option(
'--fetch-nodeinfo', '-f',
is_flag = True,
help = 'Treat NAME like a domain and try to fetch the software name from nodeinfo'
)
@click.pass_context
def cli_software_unban(ctx: click.Context, name: str, fetch_nodeinfo: bool) -> None:
'Ban software. Use RELAYS for NAME to unban relays'
with ctx.obj.database.connection() as conn:
if name == 'RELAYS':
for software in RELAY_SOFTWARE:
app.config.unban_software(software)
if not conn.del_software_ban(software):
click.echo(f'Relay was not banned: {software}')
app.config.save()
click.echo('Unbanned all relay software')
return
if fetch_nodeinfo:
nodeinfo = asyncio.run(http.fetch_nodeinfo(app.database, name))
if not nodeinfo:
if not (nodeinfo := asyncio.run(http.fetch_nodeinfo(name))):
click.echo(f'Failed to fetch software name from domain: {name}')
return
name = nodeinfo.sw_name
if app.config.unban_software(name):
app.config.save()
click.echo(f'Unbanned software: {name}')
if not conn.del_software_ban(name):
click.echo(f'Software was not banned: {name}')
return
click.echo(f'Software wasn\'t banned: {name}')
click.echo(f'Unbanned software: {name}')
@cli_software.command('update')
@click.argument('name')
@click.option('--reason', '-r')
@click.option('--note', '-n')
@click.pass_context
def cli_software_update(ctx: click.Context, name: str, reason: str, note: str) -> None:
'Update the public reason or internal note for a software ban'
if not (reason or note):
ctx.fail('Must pass --reason or --note')
with ctx.obj.database.connection() as conn:
if not (row := conn.update_software_ban(name, reason, note)):
click.echo(f'Failed to update software ban: {name}')
return
click.echo(f'Updated software ban: {name}')
if row['reason']:
click.echo(f'- {row["name"]} ({row["reason"]})')
else:
click.echo(f'- {row["name"]}')
@cli.group('whitelist')
@ -407,52 +683,64 @@ def cli_whitelist() -> None:
@cli_whitelist.command('list')
def cli_whitelist_list() -> None:
@click.pass_context
def cli_whitelist_list(ctx: click.Context) -> None:
'List all the instances in the whitelist'
click.echo('Current whitelisted domains')
click.echo('Current whitelisted domains:')
for domain in app.config.whitelist:
click.echo(f'- {domain}')
with ctx.obj.database.connection() as conn:
for domain in conn.execute('SELECT * FROM whitelist'):
click.echo(f'- {domain["domain"]}')
@cli_whitelist.command('add')
@click.argument('instance')
def cli_whitelist_add(instance: str) -> None:
'Add an instance to the whitelist'
@click.argument('domain')
@click.pass_context
def cli_whitelist_add(ctx: click.Context, domain: str) -> None:
'Add a domain to the whitelist'
if not app.config.add_whitelist(instance):
click.echo(f'Instance already in the whitelist: {instance}')
with ctx.obj.database.connection() as conn:
if conn.get_domain_whitelist(domain):
click.echo(f'Instance already in the whitelist: {domain}')
return
app.config.save()
click.echo(f'Instance added to the whitelist: {instance}')
conn.put_domain_whitelist(domain)
click.echo(f'Instance added to the whitelist: {domain}')
@cli_whitelist.command('remove')
@click.argument('instance')
def cli_whitelist_remove(instance: str) -> None:
@click.argument('domain')
@click.pass_context
def cli_whitelist_remove(ctx: click.Context, domain: str) -> None:
'Remove an instance from the whitelist'
if not app.config.del_whitelist(instance):
click.echo(f'Instance not in the whitelist: {instance}')
with ctx.obj.database.connection() as conn:
if not conn.del_domain_whitelist(domain):
click.echo(f'Domain not in the whitelist: {domain}')
return
app.config.save()
if conn.get_config('whitelist-enabled'):
if conn.del_inbox(domain):
click.echo(f'Removed inbox for domain: {domain}')
if app.config.whitelist_enabled:
if app.database.del_inbox(instance):
app.database.save()
click.echo(f'Removed instance from the whitelist: {instance}')
click.echo(f'Removed domain from the whitelist: {domain}')
@cli_whitelist.command('import')
def cli_whitelist_import() -> None:
@click.pass_context
def cli_whitelist_import(ctx: click.Context) -> None:
'Add all current inboxes to the whitelist'
for domain in app.database.hostnames:
cli_whitelist_add.callback(domain)
with ctx.obj.database.connection() as conn:
for inbox in conn.execute('SELECT * FROM inboxes').all():
if conn.get_domain_whitelist(inbox['domain']):
click.echo(f'Domain already in whitelist: {inbox["domain"]}')
continue
conn.put_domain_whitelist(inbox['domain'])
click.echo('Imported whitelist from inboxes')
def main() -> None:

View file

@ -1,32 +1,28 @@
from __future__ import annotations
import json
import os
import socket
import traceback
import typing
from aiohttp.abc import AbstractView
from aiohttp.hdrs import METH_ALL as METHODS
from aiohttp.web import Request as AiohttpRequest, Response as AiohttpResponse
from aiohttp.web import Response as AiohttpResponse
from aiohttp.web_exceptions import HTTPMethodNotAllowed
from aputils.errors import SignatureFailureError
from aputils.misc import Digest, HttpDate, Signature
from aputils.message import Message as ApMessage
from functools import cached_property
from json.decoder import JSONDecodeError
from uuid import uuid4
from . import logger as logging
if typing.TYPE_CHECKING:
from typing import Any, Coroutine, Generator, Optional, Type
from aputils.signer import Signer
from collections.abc import Coroutine, Generator
from typing import Any
from .application import Application
from .config import RelayConfig
from .database import RelayDatabase
from .config import Config
from .database import Database
from .http_client import HttpClient
IS_DOCKER = bool(os.environ.get('DOCKER_RUNNING'))
MIMETYPES = {
'activity': 'application/activity+json',
'html': 'text/html',
@ -42,10 +38,10 @@ NODEINFO_NS = {
def boolean(value: Any) -> bool:
if isinstance(value, str):
if value.lower() in ['on', 'y', 'yes', 'true', 'enable', 'enabled', '1']:
if value.lower() in {'on', 'y', 'yes', 'true', 'enable', 'enabled', '1'}:
return True
if value.lower() in ['off', 'n', 'no', 'false', 'disable', 'disable', '0']:
if value.lower() in {'off', 'n', 'no', 'false', 'disable', 'disabled', '0'}:
return False
raise TypeError(f'Cannot parse string "{value}" as a boolean')
@ -77,99 +73,21 @@ def check_open_port(host: str, port: int) -> bool:
return False
class DotDict(dict):
def __init__(self, _data: dict[str, Any], **kwargs: Any):
dict.__init__(self)
def get_app() -> Application:
from .application import Application # pylint: disable=import-outside-toplevel
self.update(_data, **kwargs)
if not Application.DEFAULT:
raise ValueError('No default application set')
def __getattr__(self, key: str) -> str:
try:
return self[key]
except KeyError:
raise AttributeError(
f'{self.__class__.__name__} object has no attribute {key}'
) from None
def __setattr__(self, key: str, value: Any) -> None:
if key.startswith('_'):
super().__setattr__(key, value)
else:
self[key] = value
def __setitem__(self, key: str, value: Any) -> None:
if type(value) is dict: # pylint: disable=unidiomatic-typecheck
value = DotDict(value)
super().__setitem__(key, value)
def __delattr__(self, key: str) -> None:
try:
dict.__delitem__(self, key)
except KeyError:
raise AttributeError(
f'{self.__class__.__name__} object has no attribute {key}'
) from None
@classmethod
def new_from_json(cls: Type[DotDict], data: dict[str, Any]) -> DotDict[str, Any]:
if not data:
raise JSONDecodeError('Empty body', data, 1)
try:
return cls(json.loads(data))
except ValueError:
raise JSONDecodeError('Invalid body', data, 1) from None
@classmethod
def new_from_signature(cls: Type[DotDict], sig: str) -> DotDict[str, Any]:
data = cls({})
for chunk in sig.strip().split(','):
key, value = chunk.split('=', 1)
value = value.strip('\"')
if key == 'headers':
value = value.split()
data[key.lower()] = value
return data
def to_json(self, indent: Optional[int | str] = None) -> str:
return json.dumps(self, indent=indent)
def update(self, _data: dict[str, Any], **kwargs: Any) -> None:
if isinstance(_data, dict):
for key, value in _data.items():
self[key] = value
elif isinstance(_data, (list, tuple, set)):
for key, value in _data:
self[key] = value
for key, value in kwargs.items():
self[key] = value
return Application.DEFAULT
class Message(ApMessage):
@classmethod
def new_actor(cls: Type[Message], # pylint: disable=arguments-differ
def new_actor(cls: type[Message], # pylint: disable=arguments-differ
host: str,
pubkey: str,
description: Optional[str] = None) -> Message:
description: str | None = None) -> Message:
return cls({
'@context': 'https://www.w3.org/ns/activitystreams',
@ -181,7 +99,7 @@ class Message(ApMessage):
'followers': f'https://{host}/followers',
'following': f'https://{host}/following',
'inbox': f'https://{host}/inbox',
'url': f'https://{host}/inbox',
'url': f'https://{host}/',
'endpoints': {
'sharedInbox': f'https://{host}/inbox'
},
@ -194,7 +112,7 @@ class Message(ApMessage):
@classmethod
def new_announce(cls: Type[Message], host: str, obj: str) -> Message:
def new_announce(cls: type[Message], host: str, obj: str) -> Message:
return cls({
'@context': 'https://www.w3.org/ns/activitystreams',
'id': f'https://{host}/activities/{uuid4()}',
@ -206,7 +124,7 @@ class Message(ApMessage):
@classmethod
def new_follow(cls: Type[Message], host: str, actor: str) -> Message:
def new_follow(cls: type[Message], host: str, actor: str) -> Message:
return cls({
'@context': 'https://www.w3.org/ns/activitystreams',
'type': 'Follow',
@ -218,7 +136,7 @@ class Message(ApMessage):
@classmethod
def new_unfollow(cls: Type[Message], host: str, actor: str, follow: str) -> Message:
def new_unfollow(cls: type[Message], host: str, actor: str, follow: str) -> Message:
return cls({
'@context': 'https://www.w3.org/ns/activitystreams',
'id': f'https://{host}/activities/{uuid4()}',
@ -230,7 +148,7 @@ class Message(ApMessage):
@classmethod
def new_response(cls: Type[Message],
def new_response(cls: type[Message],
host: str,
actor: str,
followid: str,
@ -262,12 +180,17 @@ class Message(ApMessage):
class Response(AiohttpResponse):
# AiohttpResponse.__len__ method returns 0, so bool(response) always returns False
def __bool__(self) -> bool:
return True
@classmethod
def new(cls: Type[Response],
body: Optional[str | bytes | dict] = '',
status: Optional[int] = 200,
headers: Optional[dict[str, str]] = None,
ctype: Optional[str] = 'text') -> Response:
def new(cls: type[Response],
body: str | bytes | dict = '',
status: int = 200,
headers: dict[str, str] | None = None,
ctype: str = 'text') -> Response:
kwargs = {
'status': status,
@ -288,7 +211,7 @@ class Response(AiohttpResponse):
@classmethod
def new_error(cls: Type[Response],
def new_error(cls: type[Response],
status: int,
body: str | bytes | dict,
ctype: str = 'text') -> Response:
@ -310,23 +233,11 @@ class Response(AiohttpResponse):
class View(AbstractView):
def __init__(self, request: AiohttpRequest):
AbstractView.__init__(self, request)
self.signature: Signature = None
self.message: Message = None
self.actor: Message = None
self.instance: dict[str, str] = None
self.signer: Signer = None
def __await__(self) -> Generator[Response]:
method = self.request.method.upper()
if (self.request.method) not in METHODS:
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
if method not in METHODS:
raise HTTPMethodNotAllowed(method, self.allowed_methods)
if not (handler := self.handlers.get(method)):
if not (handler := self.handlers.get(self.request.method)):
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) from None
return handler(self.request, **self.request.match_info).__await__()
@ -363,94 +274,10 @@ class View(AbstractView):
@property
def config(self) -> RelayConfig:
def config(self) -> Config:
return self.app.config
@property
def database(self) -> RelayDatabase:
def database(self) -> Database:
return self.app.database
# todo: move to views.ActorView
async def get_post_data(self) -> Response | None:
try:
self.signature = Signature.new_from_signature(self.request.headers['signature'])
except KeyError:
logging.verbose('Missing signature header')
return Response.new_error(400, 'missing signature header', 'json')
try:
self.message = await self.request.json(loads = Message.parse)
except Exception:
traceback.print_exc()
logging.verbose('Failed to parse inbox message')
return Response.new_error(400, 'failed to parse message', 'json')
if self.message is None:
logging.verbose('empty message')
return Response.new_error(400, 'missing message', 'json')
if 'actor' not in self.message:
logging.verbose('actor not in message')
return Response.new_error(400, 'no actor in message', 'json')
self.actor = await self.client.get(self.signature.keyid, sign_headers = True)
if self.actor is None:
# ld signatures aren't handled atm, so just ignore it
if self.message.type == 'Delete':
logging.verbose('Instance sent a delete which cannot be handled')
return Response.new(status=202)
logging.verbose(f'Failed to fetch actor: {self.signature.keyid}')
return Response.new_error(400, 'failed to fetch actor', 'json')
try:
self.signer = self.actor.signer
except KeyError:
logging.verbose('Actor missing public key: %s', self.signature.keyid)
return Response.new_error(400, 'actor missing public key', 'json')
try:
self.validate_signature(await self.request.read())
except SignatureFailureError as e:
logging.verbose('signature validation failed for "%s": %s', self.actor.id, e)
return Response.new_error(401, str(e), 'json')
self.instance = self.database.get_inbox(self.actor.inbox)
def validate_signature(self, body: bytes) -> None:
headers = {key.lower(): value for key, value in self.request.headers.items()}
headers["(request-target)"] = " ".join([self.request.method.lower(), self.request.path])
if (digest := Digest.new_from_digest(headers.get("digest"))):
if not body:
raise SignatureFailureError("Missing body for digest verification")
if not digest.validate(body):
raise SignatureFailureError("Body digest does not match")
if self.signature.algorithm_type == "hs2019":
if "(created)" not in self.signature.headers:
raise SignatureFailureError("'(created)' header not used")
current_timestamp = HttpDate.new_utc().timestamp()
if self.signature.created > current_timestamp:
raise SignatureFailureError("Creation date after current date")
if current_timestamp > self.signature.expires:
raise SignatureFailureError("Expiration date before current date")
headers["(created)"] = self.signature.created
headers["(expires)"] = self.signature.expires
# pylint: disable=protected-access
if not self.signer._validate_signature(headers, self.signature):
raise SignatureFailureError("Signature does not match")

View file

@ -1,5 +1,6 @@
from __future__ import annotations
import tinysql
import typing
from cachetools import LRUCache
@ -8,7 +9,7 @@ from . import logger as logging
from .misc import Message
if typing.TYPE_CHECKING:
from .misc import View
from .views import ActorView
cache = LRUCache(1024)
@ -16,74 +17,76 @@ cache = LRUCache(1024)
def person_check(actor: str, software: str) -> bool:
# pleroma and akkoma may use Person for the actor type for some reason
# akkoma changed this in 3.6.0
if software in {'akkoma', 'pleroma'} and actor.id == f'https://{actor.domain}/relay':
return False
## make sure the actor is an application
# make sure the actor is an application
if actor.type != 'Application':
return True
return False
async def handle_relay(view: View) -> None:
async def handle_relay(view: ActorView) -> None:
if view.message.object_id in cache:
logging.verbose('already relayed %s', view.message.object_id)
return
message = Message.new_announce(view.config.host, view.message.object_id)
message = Message.new_announce(view.config.domain, view.message.object_id)
cache[view.message.object_id] = message.id
logging.debug('>> relay: %s', message)
inboxes = view.database.distill_inboxes(view.message)
for inbox in inboxes:
with view.database.connection() as conn:
for inbox in conn.distill_inboxes(view.message):
view.app.push_message(inbox, message)
async def handle_forward(view: View) -> None:
async def handle_forward(view: ActorView) -> None:
if view.message.id in cache:
logging.verbose('already forwarded %s', view.message.id)
return
message = Message.new_announce(view.config.host, view.message)
message = Message.new_announce(view.config.domain, view.message)
cache[view.message.id] = message.id
logging.debug('>> forward: %s', message)
inboxes = view.database.distill_inboxes(view.message)
for inbox in inboxes:
with view.database.connection() as conn:
for inbox in conn.distill_inboxes(view.message):
view.app.push_message(inbox, message)
async def handle_follow(view: View) -> None:
async def handle_follow(view: ActorView) -> None:
nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain)
software = nodeinfo.sw_name if nodeinfo else None
## reject if software used by actor is banned
if view.config.is_banned_software(software):
with view.database.connection() as conn:
# reject if software used by actor is banned
if conn.get_software_ban(software):
view.app.push_message(
view.actor.shared_inbox,
Message.new_response(
host = view.config.host,
host = view.config.domain,
actor = view.actor.id,
followid = view.message.id,
accept = False
)
)
return logging.verbose(
logging.verbose(
'Rejected follow from actor for using specific software: actor=%s, software=%s',
view.actor.id,
software
)
return
## reject if the actor is not an instance actor
if person_check(view.actor, software):
view.app.push_message(
view.actor.shared_inbox,
Message.new_response(
host = view.config.host,
host = view.config.domain,
actor = view.actor.id,
followid = view.message.id,
accept = False
@ -93,13 +96,26 @@ async def handle_follow(view: View) -> None:
logging.verbose('Non-application actor tried to follow: %s', view.actor.id)
return
view.database.add_inbox(view.actor.shared_inbox, view.message.id, software)
view.database.save()
if conn.get_inbox(view.actor.shared_inbox):
data = {'followid': view.message.id}
statement = tinysql.Update('inboxes', data, inbox = view.actor.shared_inbox)
with conn.query(statement):
pass
else:
conn.put_inbox(
view.actor.domain,
view.actor.shared_inbox,
view.actor.id,
view.message.id,
software
)
view.app.push_message(
view.actor.shared_inbox,
Message.new_response(
host = view.config.host,
host = view.config.domain,
actor = view.actor.id,
followid = view.message.id,
accept = True
@ -112,32 +128,30 @@ async def handle_follow(view: View) -> None:
view.app.push_message(
view.actor.shared_inbox,
Message.new_follow(
host = view.config.host,
host = view.config.domain,
actor = view.actor.id
)
)
async def handle_undo(view: View) -> None:
async def handle_undo(view: ActorView) -> None:
## If the object is not a Follow, forward it
if view.message.object['type'] != 'Follow':
return await handle_forward(view)
await handle_forward(view)
return
if not view.database.del_inbox(view.actor.domain, view.message.object['id']):
with view.database.connection() as conn:
if not conn.del_inbox(view.actor.id):
logging.verbose(
'Failed to delete "%s" with follow ID "%s"',
view.actor.id,
view.message.object['id']
)
return
view.database.save()
view.app.push_message(
view.actor.shared_inbox,
Message.new_unfollow(
host = view.config.host,
host = view.config.domain,
actor = view.actor.id,
follow = view.message
)
@ -154,7 +168,7 @@ processors = {
}
async def run_processor(view: View) -> None:
async def run_processor(view: ActorView) -> None:
if view.message.type not in processors:
logging.verbose(
'Message type "%s" from actor cannot be handled: %s',
@ -164,12 +178,21 @@ async def run_processor(view: View) -> None:
return
if view.instance and not view.instance.get('software'):
nodeinfo = await view.client.fetch_nodeinfo(view.instance['domain'])
if view.instance:
if not view.instance['software']:
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])):
with view.database.connection() as conn:
view.instance = conn.update_inbox(
view.instance['inbox'],
software = nodeinfo.sw_name
)
if nodeinfo:
view.instance['software'] = nodeinfo.sw_name
view.database.save()
if not view.instance['actor']:
with view.database.connection() as conn:
view.instance = conn.update_inbox(
view.instance['inbox'],
actor = view.actor.id
)
logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id)
await processors[view.message.type](view)

View file

@ -2,8 +2,11 @@ from __future__ import annotations
import asyncio
import subprocess
import traceback
import typing
from aputils.errors import SignatureFailureError
from aputils.misc import Digest, HttpDate, Signature
from aputils.objects import Nodeinfo, Webfinger, WellKnownNodeinfo
from pathlib import Path
@ -14,7 +17,8 @@ from .processors import run_processor
if typing.TYPE_CHECKING:
from aiohttp.web import Request
from typing import Callable
from aputils.signer import Signer
from collections.abc import Callable
VIEWS = []
@ -71,11 +75,15 @@ def register_route(*paths: str) -> Callable:
@register_route('/')
class HomeView(View):
async def get(self, request: Request) -> Response:
with self.database.connection() as conn:
config = conn.get_config_all()
inboxes = conn.execute('SELECT * FROM inboxes').all()
text = HOME_TEMPLATE.format(
host = self.config.host,
note = self.config.note,
count = len(self.database.hostnames),
targets = '<br>'.join(self.database.hostnames)
host = self.config.domain,
note = config['note'],
count = len(inboxes),
targets = '<br>'.join(inbox['domain'] for inbox in inboxes)
)
return Response.new(text, ctype='html')
@ -84,33 +92,45 @@ class HomeView(View):
@register_route('/actor', '/inbox')
class ActorView(View):
def __init__(self, request: Request):
View.__init__(self, request)
self.signature: Signature = None
self.message: Message = None
self.actor: Message = None
self.instance: dict[str, str] = None
self.signer: Signer = None
async def get(self, request: Request) -> Response:
data = Message.new_actor(
host = self.config.host,
pubkey = self.database.signer.pubkey
host = self.config.domain,
pubkey = self.app.signer.pubkey
)
return Response.new(data, ctype='activity')
async def post(self, request: Request) -> Response:
response = await self.get_post_data()
if response is not None:
if response := await self.get_post_data():
return response
with self.database.connection() as conn:
self.instance = conn.get_inbox(self.actor.shared_inbox)
config = conn.get_config_all()
## reject if the actor isn't whitelisted while the whiltelist is enabled
if self.config.whitelist_enabled and not self.config.is_whitelisted(self.actor.domain):
if config['whitelist-enabled'] and not conn.get_domain_whitelist(self.actor.domain):
logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id)
return Response.new_error(403, 'access denied', 'json')
## reject if actor is banned
if self.config.is_banned(self.actor.domain):
if conn.get_domain_ban(self.actor.domain):
logging.verbose('Ignored request from banned actor: %s', self.actor.id)
return Response.new_error(403, 'access denied', 'json')
## reject if activity type isn't 'Follow' and the actor isn't following
if self.message.type != 'Follow' and not self.database.get_inbox(self.actor.domain):
if self.message.type != 'Follow' and not self.instance:
logging.verbose(
'Rejected actor for trying to post while not following: %s',
self.actor.id
@ -124,6 +144,87 @@ class ActorView(View):
return Response.new(status = 202)
async def get_post_data(self) -> Response | None:
try:
self.signature = Signature.new_from_signature(self.request.headers['signature'])
except KeyError:
logging.verbose('Missing signature header')
return Response.new_error(400, 'missing signature header', 'json')
try:
self.message = await self.request.json(loads = Message.parse)
except Exception:
traceback.print_exc()
logging.verbose('Failed to parse inbox message')
return Response.new_error(400, 'failed to parse message', 'json')
if self.message is None:
logging.verbose('empty message')
return Response.new_error(400, 'missing message', 'json')
if 'actor' not in self.message:
logging.verbose('actor not in message')
return Response.new_error(400, 'no actor in message', 'json')
self.actor = await self.client.get(self.signature.keyid, sign_headers = True)
if not self.actor:
# ld signatures aren't handled atm, so just ignore it
if self.message.type == 'Delete':
logging.verbose('Instance sent a delete which cannot be handled')
return Response.new(status=202)
logging.verbose(f'Failed to fetch actor: {self.signature.keyid}')
return Response.new_error(400, 'failed to fetch actor', 'json')
try:
self.signer = self.actor.signer
except KeyError:
logging.verbose('Actor missing public key: %s', self.signature.keyid)
return Response.new_error(400, 'actor missing public key', 'json')
try:
self.validate_signature(await self.request.read())
except SignatureFailureError as e:
logging.verbose('signature validation failed for "%s": %s', self.actor.id, e)
return Response.new_error(401, str(e), 'json')
def validate_signature(self, body: bytes) -> None:
headers = {key.lower(): value for key, value in self.request.headers.items()}
headers["(request-target)"] = " ".join([self.request.method.lower(), self.request.path])
if (digest := Digest.new_from_digest(headers.get("digest"))):
if not body:
raise SignatureFailureError("Missing body for digest verification")
if not digest.validate(body):
raise SignatureFailureError("Body digest does not match")
if self.signature.algorithm_type == "hs2019":
if "(created)" not in self.signature.headers:
raise SignatureFailureError("'(created)' header not used")
current_timestamp = HttpDate.new_utc().timestamp()
if self.signature.created > current_timestamp:
raise SignatureFailureError("Creation date after current date")
if current_timestamp > self.signature.expires:
raise SignatureFailureError("Expiration date before current date")
headers["(created)"] = self.signature.created
headers["(expires)"] = self.signature.expires
# pylint: disable=protected-access
if not self.signer._validate_signature(headers, self.signature):
raise SignatureFailureError("Signature does not match")
@register_route('/.well-known/webfinger')
class WebfingerView(View):
async def get(self, request: Request) -> Response:
@ -133,12 +234,12 @@ class WebfingerView(View):
except KeyError:
return Response.new_error(400, 'missing "resource" query key', 'json')
if subject != f'acct:relay@{self.config.host}':
if subject != f'acct:relay@{self.config.domain}':
return Response.new_error(404, 'user not found', 'json')
data = Webfinger.new(
handle = 'relay',
domain = self.config.host,
domain = self.config.domain,
actor = self.config.actor
)
@ -148,13 +249,16 @@ class WebfingerView(View):
@register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}')
class NodeinfoView(View):
async def get(self, request: Request, niversion: str) -> Response:
with self.database.connection() as conn:
inboxes = conn.execute('SELECT * FROM inboxes').all()
data = {
'name': 'activityrelay',
'version': VERSION,
'protocols': ['activitypub'],
'open_regs': not self.config.whitelist_enabled,
'open_regs': not conn.get_config('whitelist-enabled'),
'users': 1,
'metadata': {'peers': self.database.hostnames}
'metadata': {'peers': [inbox['domain'] for inbox in inboxes]}
}
if niversion == '2.1':
@ -166,5 +270,5 @@ class NodeinfoView(View):
@register_route('/.well-known/nodeinfo')
class WellknownNodeinfoView(View):
async def get(self, request: Request) -> Response:
data = WellKnownNodeinfo.new_template(self.config.host)
data = WellKnownNodeinfo.new_template(self.config.domain)
return Response.new(data, ctype = 'json')

View file

@ -3,3 +3,4 @@ aputils@https://git.barkshark.xyz/barkshark/aputils/archive/0.1.6a.tar.gz
cachetools>=5.2.0
click>=8.1.2
pyyaml>=6.0
tinysql@git+https://git.barkshark.xyz/barkshark/tinysql@3d94193ee4

View file

@ -10,26 +10,30 @@ license_file = LICENSE
classifiers =
Environment :: Console
License :: OSI Approved :: AGPLv3 License
Programming Language :: Python :: 3.6
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
Programming Language :: Python :: 3.12
project_urls =
Source = https://git.pleroma.social/pleroma/relay
Tracker = https://git.pleroma.social/pleroma/relay/-/issues
[options]
zip_safe = False
packages = find:
packages =
relay
relay.database
include_package_data = true
install_requires = file: requirements.txt
python_requires = >=3.8
[options.extras_require]
dev =
flake8 = 3.1.0
pyinstaller = 6.3.0
pylint = 3.0
dev = file: dev-requirements.txt
[options.package_data]
relay =
data/statements.sql
[options.entry_points]
console_scripts =
@ -37,7 +41,4 @@ console_scripts =
[flake8]
extend-ignore = ANN101,ANN204,E128,E251,E261,E266,E301,E303,W191
extend-exclude = docs, test*.py
max-line-length = 100
indent-size = 4
select = F401