Compare commits

..

No commits in common. "81215a83a4a99d777ffc9d1e1899d863574656e5" and "965ac73c6d84caf0fff7d316b8ae67c4a1473140" have entirely different histories.

26 changed files with 993 additions and 1838 deletions

10
.gitignore vendored
View file

@ -94,7 +94,9 @@ ENV/
# Rope project settings
.ropeproject
# config and database
*.yaml
*.jsonld
*.sqlite3
viera.yaml
viera.jsonld
# config file
relay.yaml
relay.jsonld

View file

@ -1 +0,0 @@
include data/statements.sql

View file

@ -1,3 +0,0 @@
flake8 == 7.0.0
pyinstaller == 6.3.0
pylint == 3.0

View file

@ -3,8 +3,11 @@
There are a number of commands to manage your relay's database and config. You can add `--help` to
any category or command to get help on that specific option (ex. `activityrelay inbox --help`).
Note: `activityrelay` is only available via pip or pipx if `~/.local/bin` is in `$PATH`. If not,
use `python3 -m relay` if installed via pip or `~/.local/bin/activityrelay` if installed via pipx.
Note: Unless specified, it is recommended to run any commands while the relay is shutdown.
Note 2: `activityrelay` is only available via pip or pipx if `~/.local/bin` is in `$PATH`. If it
isn't, use `python3 -m relay` if installed via pip or `~/.local/bin/activityrelay` if installed
via pipx
## Run
@ -21,22 +24,6 @@ Run the setup wizard to configure your relay.
activityrelay setup
## Convert
Convert the old config and jsonld to the new config and SQL backend. If the old config filename is
not specified, the config will get backed up as `relay.backup.yaml` before converting.
activityrelay convert --old-config relaycfg.yaml
## Edit Config
Open the config file in a text editor. If an editor is not specified with `--editor`, the default
editor will be used.
activityrelay edit-config --editor micro
## Config
Manage the relay config
@ -133,7 +120,7 @@ Remove a domain from the whitelist.
### Import
Add all current inboxes to the whitelist.
Add all current inboxes to the whitelist
activityrelay whitelist import
@ -145,15 +132,15 @@ Manage the instance ban list.
### List
List the currently banned instances.
List the currently banned instances
activityrelay instance list
### Ban
Add an instance to the ban list. If the instance is currently subscribed, it will be removed from
the inbox list.
Add an instance to the ban list. If the instance is currently subscribed, remove it from the
database.
activityrelay instance ban <domain>
@ -165,17 +152,10 @@ Remove an instance from the ban list.
activityrelay instance unban <domain>
### Update
Update the ban reason or note for an instance ban.
activityrelay instance update bad.example.com --reason "the baddest reason"
## Software
Manage the software ban list. To get the correct name, check the software's nodeinfo endpoint.
You can find it at `nodeinfo['software']['name']`.
You can find it at nodeinfo\['software']\['name'].
### List
@ -206,12 +186,4 @@ name via nodeinfo.
If the name is `RELAYS` (case-sensitive), remove all known relay software names from the list.
activityrelay software unban [-f/--fetch-nodeinfo] <name, domain, or RELAYS>
### Update
Update the ban reason or note for a software ban. Either `--reason` and/or `--note` must be
specified.
activityrelay software update relay.example.com --reason "begone relay"
activityrelay unban [-f/--fetch-nodeinfo] <name, domain, or RELAYS>

View file

@ -2,23 +2,41 @@
## General
### Domain
### DB
Hostname the relay will be hosted on.
The path to the database. It contains the relay actor private key and all subscribed
instances. If the path is not absolute, it is relative to the working directory.
domain: relay.example.com
db: relay.jsonld
### Listener
The address and port the relay will listen on. If the reverse proxy (nginx, apache, caddy, etc)
is running on the same host, it is recommended to change `listen` to `localhost` if the reverse
proxy is on the same host.
is running on the same host, it is recommended to change `listen` to `localhost`
listen: 0.0.0.0
port: 8080
### Note
A small blurb to describe your relay instance. This will show up on the relay's home page.
note: "Make a note about your instance here."
### Post Limit
The maximum number of messages to send out at once. For each incoming message, a message will be
sent out to every subscribed instance minus the instance which sent the message. This limit
is to prevent too many outgoing connections from being made, so adjust if necessary.
Note: If the `workers` option is set to anything above 0, this limit will be per worker.
push_limit: 512
### Push Workers
The relay can be configured to use threads to push messages out. For smaller relays, this isn't
@ -28,59 +46,60 @@ threads.
workers: 0
### Database type
### JSON GET cache limit
SQL database backend to use. Valid values are `sqlite` or `postgres`.
JSON objects (actors, nodeinfo, etc) will get cached when fetched. This will set the max number of
objects to keep in the cache.
database_type: sqlite
json_cache: 1024
### Sqlite File Path
## AP
Path to the sqlite database file. If the path is not absolute, it is relative to the config file.
directory.
sqlite_path: relay.jsonld
## Postgresql
In order to use the Postgresql backend, the user and database need to be created first.
sudo -u postgres psql -c "CREATE USER activityrelay"
sudo -u postgres psql -c "CREATE DATABASE activityrelay OWNER activityrelay"
### Database Name
Name of the database to use.
name: activityrelay
Various ActivityPub-related settings
### Host
Hostname, IP address, or unix socket the server is hosted on.
The domain your relay will use to identify itself.
host: /var/run/postgresql
host: relay.example.com
### Port
### Whitelist Enabled
Port number the server is listening on.
If set to `true`, only instances in the whitelist can follow the relay. Any subscribed instances
not in the whitelist will be removed from the inbox list on startup.
port: 5432
whitelist_enabled: false
### Username
### Whitelist
User to use when logging into the server.
A list of domains of instances which are allowed to subscribe to your relay.
user: null
whitelist:
- bad-instance.example.com
- another-bad-instance.example.com
### Password
### Blocked Instances
Password for the specified user.
A list of instances which are unable to follow the instance. If a subscribed instance is added to
the block list, it will be removed from the inbox list on startup.
pass: null
blocked_instances:
- bad-instance.example.com
- another-bad-instance.example.com
### Blocked Software
A list of ActivityPub software which cannot follow your relay. This list is empty by default, but
setting this to the below list will block all other relays and prevent relay chains
blocked_software:
- activityrelay
- aoderelay
- social.seattle.wa.us-relay
- unciarelay

View file

@ -28,14 +28,14 @@ server {
# logging, mostly for debug purposes. Disable if you wish.
access_log /srv/www/relay.<yourdomain>/logs/access.log;
error_log /srv/www/relay.<yourdomain>/logs/error.log;
ssl_protocols TLSv1.2;
ssl_ciphers EECDH+AESGCM:EECDH+AES;
ssl_ecdh_curve secp384r1;
ssl_prefer_server_ciphers on;
ssl_session_cache shared:SSL:10m;
# ssl certs.
# ssl certs.
ssl_certificate /usr/local/etc/letsencrypt/live/relay.<yourdomain>/fullchain.pem;
ssl_certificate_key /usr/local/etc/letsencrypt/live/relay.<yourdomain>/privkey.pem;
@ -48,7 +48,7 @@ server {
# sts, change if you care.
# add_header Strict-Transport-Security "max-age=31536000; includeSubDomains";
# uncomment this to use a static page in your webroot for your root page.
#location = / {
# index index.html;

View file

@ -3,7 +3,7 @@ Description=ActivityPub Relay
[Service]
WorkingDirectory=/home/relay/relay
ExecStart=/usr/bin/python3 -m relay run
ExecStart=/usr/bin/python3 -m relay
[Install]
WantedBy=multi-user.target

View file

@ -6,23 +6,6 @@ build-backend = 'setuptools.build_meta'
[tool.pylint.main]
jobs = 0
persistent = true
load-plugins = [
"pylint.extensions.code_style",
"pylint.extensions.comparison_placement",
"pylint.extensions.confusing_elif",
"pylint.extensions.for_any_all",
"pylint.extensions.consider_ternary_expression",
"pylint.extensions.bad_builtin",
"pylint.extensions.dict_init_mutate",
"pylint.extensions.check_elif",
"pylint.extensions.empty_comment",
"pylint.extensions.private_import",
"pylint.extensions.redefined_variable_type",
"pylint.extensions.no_self_use",
"pylint.extensions.overlapping_exceptions",
"pylint.extensions.set_membership",
"pylint.extensions.typing"
]
[tool.pylint.design]
@ -39,7 +22,6 @@ single-line-if-stmt = true
[tool.pylint.messages_control]
disable = [
"fixme",
"broad-exception-caught",
"cyclic-import",
"global-statement",
@ -49,8 +31,7 @@ disable = [
"too-many-public-methods",
"too-many-return-statements",
"wrong-import-order",
"wrong-import-position",
"missing-function-docstring",
"missing-class-docstring",
"consider-using-namedtuple-or-dataclass",
"confusing-consecutive-elif"
"missing-class-docstring"
]

View file

@ -5,44 +5,40 @@ block_cipher = None
a = Analysis(
['relay/__main__.py'],
pathex=[],
binaries=[],
datas=[
('relay/data', 'relay/data')
],
hiddenimports=[],
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
['relay/__main__.py'],
pathex=[],
binaries=[],
datas=[],
hiddenimports=[],
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
a.binaries,
a.zipfiles,
a.datas,
[],
name='activityrelay',
icon=None,
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
pyz,
a.scripts,
a.binaries,
a.zipfiles,
a.datas,
[],
name='activityrelay',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)

View file

@ -1,35 +1,43 @@
# [string] Domain the relay will be hosted on
domain: relay.example.com
# this is the path that the object graph will get dumped to (in JSON-LD format),
# you probably shouldn't change it, but you can if you want.
db: relay.jsonld
# [string] Address the relay will listen on
# Listener
listen: 0.0.0.0
# [integer] Port the relay will listen on
port: 8080
# [integer] Number of push workers to start (will get removed in a future update)
workers: 8
# Note
note: "Make a note about your instance here."
# [string] Database backend to use. Valid values: sqlite, postgres
database_type: sqlite
# Number of worker threads to start. If 0, use asyncio futures instead of threads.
workers: 0
# [string] Path to the sqlite database file if the sqlite backend is in use
sqlite_path: relay.sqlite3
# Maximum number of inbox posts to do at once
# If workers is set to 1 or above, this is the max for each worker
push_limit: 512
# settings for the postgresql backend
postgres:
# The amount of json objects to cache from GET requests
json_cache: 1024
# [string] hostname or unix socket to connect to
host: /var/run/postgresql
ap:
# This is used for generating activitypub messages, as well as instructions for
# linking AP identities. It should be an SSL-enabled domain reachable by https.
host: 'relay.example.com'
# [integer] port of the server
port: 5432
blocked_instances:
- 'bad-instance.example.com'
- 'another-bad-instance.example.com'
# [string] username to use when logging into the server (default is the current system username)
user: null
whitelist_enabled: false
# [string] password of the user
pass: null
whitelist:
- 'good-instance.example.com'
- 'another.good-instance.example.com'
# [string] name of the database to use
name: activityrelay
# uncomment the lines below to prevent certain activitypub software from posting
# to the relay (all known relays by default). this uses the software name in nodeinfo
#blocked_software:
#- 'activityrelay'
#- 'aoderelay'
#- 'social.seattle.wa.us-relay'
#- 'unciarelay'

View file

@ -8,41 +8,52 @@ import traceback
import typing
from aiohttp import web
from aputils.signer import Signer
from datetime import datetime, timedelta
from . import logger as logging
from .config import Config
from .database import get_database
from .config import RelayConfig
from .database import RelayDatabase
from .http_client import HttpClient
from .misc import check_open_port
from .views import VIEWS
if typing.TYPE_CHECKING:
from tinysql import Database
from typing import Any
from .misc import Message
# pylint: disable=unsubscriptable-object
class Application(web.Application):
DEFAULT: Application = None
class Application(web.Application):
def __init__(self, cfgpath: str):
web.Application.__init__(self)
Application.DEFAULT = self
self['signer'] = None
self['config'] = Config(cfgpath, load = True)
self['database'] = get_database(self.config)
self['client'] = HttpClient()
self['workers'] = []
self['last_worker'] = 0
self['start_time'] = None
self['running'] = False
self['config'] = RelayConfig(cfgpath)
if not self.config.load():
self.config.save()
if self.config.is_docker:
self.config.update({
'db': '/data/relay.jsonld',
'listen': '0.0.0.0',
'port': 8080
})
self['database'] = RelayDatabase(self.config)
self.database.load()
self['client'] = HttpClient(
database = self.database,
limit = self.config.push_limit,
timeout = self.config.timeout,
cache_size = self.config.json_cache
)
for path, view in VIEWS:
self.router.add_view(path, view)
@ -54,29 +65,15 @@ class Application(web.Application):
@property
def config(self) -> Config:
def config(self) -> RelayConfig:
return self['config']
@property
def database(self) -> Database:
def database(self) -> RelayDatabase:
return self['database']
@property
def signer(self) -> Signer:
return self['signer']
@signer.setter
def signer(self, value: Signer | str) -> None:
if isinstance(value, Signer):
self['signer'] = value
return
self['signer'] = Signer(value, self.config.keyid)
@property
def uptime(self) -> timedelta:
if not self['start_time']:
@ -121,7 +118,7 @@ class Application(web.Application):
logging.info(
'Starting webserver at %s (%s:%i)',
self.config.domain,
self.config.host,
self.config.listen,
self.config.port
)
@ -182,7 +179,12 @@ class PushWorker(threading.Thread):
async def handle_queue(self) -> None:
self.client = HttpClient()
self.client = HttpClient(
database = self.app.database,
limit = self.app.config.push_limit,
timeout = self.app.config.timeout,
cache_size = self.app.config.json_cache
)
while self.app['running']:
try:

View file

@ -1,76 +1,76 @@
from __future__ import annotations
import getpass
import os
import typing
import yaml
from functools import cached_property
from pathlib import Path
from urllib.parse import urlparse
from .misc import IS_DOCKER
from .misc import DotDict, boolean
if typing.TYPE_CHECKING:
from typing import Any
from .database import RelayDatabase
DEFAULTS: dict[str, Any] = {
'listen': '0.0.0.0',
'port': 8080,
'domain': 'relay.example.com',
'workers': len(os.sched_getaffinity(0)),
'db_type': 'sqlite',
'sq_path': 'relay.sqlite3',
'pg_host': '/var/run/postgresql',
'pg_port': 5432,
'pg_user': getpass.getuser(),
'pg_pass': None,
'pg_name': 'activityrelay'
}
RELAY_SOFTWARE = [
'activityrelay', # https://git.pleroma.social/pleroma/relay
'aoderelay', # https://git.asonix.dog/asonix/relay
'feditools-relay' # https://git.ptzo.gdn/feditools/relay
]
if IS_DOCKER:
DEFAULTS['sq_path'] = '/data/relay.jsonld'
APKEYS = [
'host',
'whitelist_enabled',
'blocked_software',
'blocked_instances',
'whitelist'
]
class Config:
def __init__(self, path: str, load: bool = False):
self.path = Path(path).expanduser().resolve()
class RelayConfig(DotDict):
__slots__ = ('path', )
self.listen = None
self.port = None
self.domain = None
self.workers = None
self.db_type = None
self.sq_path = None
self.pg_host = None
self.pg_port = None
self.pg_user = None
self.pg_pass = None
self.pg_name = None
def __init__(self, path: str | Path):
DotDict.__init__(self, {})
if load:
try:
self.load()
if self.is_docker:
path = '/data/config.yaml'
except FileNotFoundError:
self.save()
self._path = Path(path).expanduser().resolve()
self.reset()
def __setitem__(self, key: str, value: Any) -> None:
if key in ['blocked_instances', 'blocked_software', 'whitelist']:
assert isinstance(value, (list, set, tuple))
elif key in ['port', 'workers', 'json_cache', 'timeout']:
if not isinstance(value, int):
value = int(value)
elif key == 'whitelist_enabled':
if not isinstance(value, bool):
value = boolean(value)
super().__setitem__(key, value)
@property
def sqlite_path(self) -> Path:
if not os.path.isabs(self.sq_path):
return self.path.parent.joinpath(self.sq_path).resolve()
return Path(self.sq_path).expanduser().resolve()
def db(self) -> RelayDatabase:
return Path(self['db']).expanduser().resolve()
@property
def actor(self) -> str:
return f'https://{self.domain}/actor'
return f'https://{self.host}/actor'
@property
def inbox(self) -> str:
return f'https://{self.domain}/inbox'
return f'https://{self.host}/inbox'
@property
@ -78,7 +78,115 @@ class Config:
return f'{self.actor}#main-key'
def load(self) -> None:
@cached_property
def is_docker(self) -> bool:
return bool(os.environ.get('DOCKER_RUNNING'))
def reset(self) -> None:
self.clear()
self.update({
'db': str(self._path.parent.joinpath(f'{self._path.stem}.jsonld')),
'listen': '0.0.0.0',
'port': 8080,
'note': 'Make a note about your instance here.',
'push_limit': 512,
'json_cache': 1024,
'timeout': 10,
'workers': 0,
'host': 'relay.example.com',
'whitelist_enabled': False,
'blocked_software': [],
'blocked_instances': [],
'whitelist': []
})
def ban_instance(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
if self.is_banned(instance):
return False
self.blocked_instances.append(instance)
return True
def unban_instance(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
try:
self.blocked_instances.remove(instance)
return True
except ValueError:
return False
def ban_software(self, software: str) -> bool:
if self.is_banned_software(software):
return False
self.blocked_software.append(software)
return True
def unban_software(self, software: str) -> bool:
try:
self.blocked_software.remove(software)
return True
except ValueError:
return False
def add_whitelist(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
if self.is_whitelisted(instance):
return False
self.whitelist.append(instance)
return True
def del_whitelist(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
try:
self.whitelist.remove(instance)
return True
except ValueError:
return False
def is_banned(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
return instance in self.blocked_instances
def is_banned_software(self, software: str) -> bool:
if not software:
return False
return software.lower() in self.blocked_software
def is_whitelisted(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
return instance in self.whitelist
def load(self) -> bool:
self.reset()
options = {}
@ -89,69 +197,50 @@ class Config:
except AttributeError:
pass
with self.path.open('r', encoding = 'UTF-8') as fd:
config = yaml.load(fd, **options)
pgcfg = config.get('postgresql', {})
try:
with self._path.open('r', encoding = 'UTF-8') as fd:
config = yaml.load(fd, **options)
except FileNotFoundError:
return False
if not config:
raise ValueError('Config is empty')
return False
if IS_DOCKER:
self.listen = '0.0.0.0'
self.port = 8080
self.sq_path = '/data/relay.jsonld'
for key, value in config.items():
if key in ['ap']:
for k, v in value.items():
if k not in self:
continue
else:
self.set('listen', config.get('listen', DEFAULTS['listen']))
self.set('port', config.get('port', DEFAULTS['port']))
self.set('sq_path', config.get('sqlite_path', DEFAULTS['sq_path']))
self[k] = v
self.set('domain', config.get('domain', DEFAULTS['domain']))
self.set('db_type', config.get('database_type', DEFAULTS['db_type']))
for key in DEFAULTS:
if not key.startswith('pg'):
continue
try:
self.set(key, pgcfg[key[3:]])
except KeyError:
if key not in self:
continue
self[key] = value
def reset(self) -> None:
for key, value in DEFAULTS.items():
setattr(self, key, value)
if self.host.endswith('example.com'):
return False
return True
def save(self) -> None:
self.path.parent.mkdir(exist_ok = True, parents = True)
config = {
# just turning config.db into a string is good enough for now
'db': str(self.db),
'listen': self.listen,
'port': self.port,
'domain': self.domain,
'database_type': self.db_type,
'sqlite_path': self.sq_path,
'postgres': {
'host': self.pg_host,
'port': self.pg_port,
'user': self.pg_user,
'pass': self.pg_pass,
'name': self.pg_name
}
'note': self.note,
'push_limit': self.push_limit,
'workers': self.workers,
'json_cache': self.json_cache,
'timeout': self.timeout,
'ap': {key: self[key] for key in APKEYS}
}
with self.path.open('w', encoding = 'utf-8') as fd:
yaml.dump(config, fd, sort_keys = False)
def set(self, key: str, value: Any) -> None:
if key not in DEFAULTS:
raise KeyError(key)
if key in {'port', 'pg_port', 'workers'} and not isinstance(value, int):
value = int(value)
setattr(self, key, value)
with self._path.open('w', encoding = 'utf-8') as fd:
yaml.dump(config, fd, sort_keys=False)

View file

@ -1,79 +0,0 @@
-- name: get-config
SELECT * FROM config WHERE key = :key
-- name: get-config-all
SELECT * FROM config
-- name: put-config
INSERT INTO config (key, value, type)
VALUES (:key, :value, :type)
ON CONFLICT (key) DO UPDATE SET value = :value
RETURNING *
-- name: del-config
DELETE FROM config
WHERE key = :key
-- name: get-inbox
SELECT * FROM inboxes WHERE domain = :value or inbox = :value or actor = :value
-- name: put-inbox
INSERT INTO inboxes (domain, actor, inbox, followid, software, created)
VALUES (:domain, :actor, :inbox, :followid, :software, :created)
ON CONFLICT (domain) DO UPDATE SET followid = :followid
RETURNING *
-- name: del-inbox
DELETE FROM inboxes
WHERE domain = :value or inbox = :value or actor = :value
-- name: get-software-ban
SELECT * FROM software_bans WHERE name = :name
-- name: put-software-ban
INSERT INTO software_bans (name, reason, note, created)
VALUES (:name, :reason, :note, :created)
RETURNING *
-- name: del-software-ban
DELETE FROM software_bans
WHERE name = :name
-- name: get-domain-ban
SELECT * FROM domain_bans WHERE domain = :domain
-- name: put-domain-ban
INSERT INTO domain_bans (domain, reason, note, created)
VALUES (:domain, :reason, :note, :created)
RETURNING *
-- name: del-domain-ban
DELETE FROM domain_bans
WHERE domain = :domain
-- name: get-domain-whitelist
SELECT * FROM whitelist WHERE domain = :domain
-- name: put-domain-whitelist
INSERT INTO whitelist (domain, created)
VALUES (:domain, :created)
RETURNING *
-- name: del-domain-whitelist
DELETE FROM whitelist
WHERE domain = :domain

View file

@ -1,129 +1,17 @@
from __future__ import annotations
import json
import os
import typing
import yaml
from functools import cached_property
from pathlib import Path
from aputils.signer import Signer
from urllib.parse import urlparse
from . import logger as logging
from .misc import Message, boolean
if typing.TYPE_CHECKING:
from collections.abc import Iterator
from typing import Any
# pylint: disable=duplicate-code
class RelayConfig(dict):
def __init__(self, path: str):
dict.__init__(self, {})
if self.is_docker:
path = '/data/config.yaml'
self._path = Path(path).expanduser().resolve()
self.reset()
def __setitem__(self, key: str, value: Any) -> None:
if key in {'blocked_instances', 'blocked_software', 'whitelist'}:
assert isinstance(value, (list, set, tuple))
elif key in {'port', 'workers', 'json_cache', 'timeout'}:
if not isinstance(value, int):
value = int(value)
elif key == 'whitelist_enabled':
if not isinstance(value, bool):
value = boolean(value)
super().__setitem__(key, value)
@property
def db(self) -> RelayDatabase:
return Path(self['db']).expanduser().resolve()
@property
def actor(self) -> str:
return f'https://{self["host"]}/actor'
@property
def inbox(self) -> str:
return f'https://{self["host"]}/inbox'
@property
def keyid(self) -> str:
return f'{self.actor}#main-key'
@cached_property
def is_docker(self) -> bool:
return bool(os.environ.get('DOCKER_RUNNING'))
def reset(self) -> None:
self.clear()
self.update({
'db': str(self._path.parent.joinpath(f'{self._path.stem}.jsonld')),
'listen': '0.0.0.0',
'port': 8080,
'note': 'Make a note about your instance here.',
'push_limit': 512,
'json_cache': 1024,
'timeout': 10,
'workers': 0,
'host': 'relay.example.com',
'whitelist_enabled': False,
'blocked_software': [],
'blocked_instances': [],
'whitelist': []
})
def load(self) -> None:
self.reset()
options = {}
try:
options['Loader'] = yaml.FullLoader
except AttributeError:
pass
try:
with self._path.open('r', encoding = 'UTF-8') as fd:
config = yaml.load(fd, **options)
except FileNotFoundError:
return
if not config:
return
for key, value in config.items():
if key == 'ap':
for k, v in value.items():
if k not in self:
continue
self[k] = v
continue
if key not in self:
continue
self[key] = value
from typing import Iterator, Optional
from .config import RelayConfig
from .misc import Message
class RelayDatabase(dict):
@ -149,7 +37,9 @@ class RelayDatabase(dict):
return tuple(data['inbox'] for data in self['relay-list'].values())
def load(self) -> None:
def load(self) -> bool:
new_db = True
try:
with self.config.db.open() as fd:
data = json.load(fd)
@ -175,9 +65,17 @@ class RelayDatabase(dict):
self['relay-list'] = data.get('relay-list', {})
for domain, instance in self['relay-list'].items():
if self.config.is_banned(domain) or \
(self.config.whitelist_enabled and not self.config.is_whitelisted(domain)):
self.del_inbox(domain)
continue
if not instance.get('domain'):
instance['domain'] = domain
new_db = False
except FileNotFoundError:
pass
@ -185,13 +83,24 @@ class RelayDatabase(dict):
if self.config.db.stat().st_size > 0:
raise e from None
if not self['private-key']:
logging.info('No actor keys present, generating 4096-bit RSA keypair.')
self.signer = Signer.new(self.config.keyid, size=4096)
self['private-key'] = self.signer.export()
else:
self.signer = Signer(self['private-key'], self.config.keyid)
self.save()
return not new_db
def save(self) -> None:
with self.config.db.open('w', encoding = 'UTF-8') as fd:
json.dump(self, fd, indent=4)
def get_inbox(self, domain: str, fail: bool = False) -> dict[str, str] | None:
def get_inbox(self, domain: str, fail: Optional[bool] = False) -> dict[str, str] | None:
if domain.startswith('http'):
domain = urlparse(domain).hostname
@ -206,13 +115,14 @@ class RelayDatabase(dict):
def add_inbox(self,
inbox: str,
followid: str | None = None,
software: str | None = None) -> dict[str, str]:
followid: Optional[str] = None,
software: Optional[str] = None) -> dict[str, str]:
assert inbox.startswith('https'), 'Inbox must be a url'
domain = urlparse(inbox).hostname
instance = self.get_inbox(domain)
if (instance := self.get_inbox(domain)):
if instance:
if followid:
instance['followid'] = followid
@ -234,10 +144,12 @@ class RelayDatabase(dict):
def del_inbox(self,
domain: str,
followid: str = None,
fail: bool = False) -> bool:
followid: Optional[str] = None,
fail: Optional[bool] = False) -> bool:
if not (data := self.get_inbox(domain, fail=False)):
data = self.get_inbox(domain, fail=False)
if not data:
if fail:
raise KeyError(domain)

View file

@ -1,65 +0,0 @@
from __future__ import annotations
import tinysql
import typing
from importlib.resources import files as pkgfiles
from .config import get_default_value
from .connection import Connection
from .schema import VERSIONS, migrate_0
from .. import logger as logging
if typing.TYPE_CHECKING:
from .config import Config
def get_database(config: Config, migrate: bool = True) -> tinysql.Database:
if config.db_type == "sqlite":
db = tinysql.Database.sqlite(
config.sqlite_path,
connection_class = Connection,
min_connections = 2,
max_connections = 10
)
elif config.db_type == "postgres":
db = tinysql.Database.postgres(
config.pg_name,
config.pg_host,
config.pg_port,
config.pg_user,
config.pg_pass,
connection_class = Connection
)
db.load_prepared_statements(pkgfiles("relay").joinpath("data", "statements.sql"))
if not migrate:
return db
with db.connection() as conn:
if 'config' not in conn.get_tables():
logging.info("Creating database tables")
migrate_0(conn)
return db
if (schema_ver := conn.get_config('schema-version')) < get_default_value('schema-version'):
logging.info("Migrating database from version '%i'", schema_ver)
for ver, func in VERSIONS:
if schema_ver < ver:
conn.begin()
func(conn)
conn.put_config('schema-version', ver)
conn.commit()
if (privkey := conn.get_config('private-key')):
conn.app.signer = privkey
logging.set_level(conn.get_config('log-level'))
return db

View file

@ -1,45 +0,0 @@
from __future__ import annotations
import typing
from .. import logger as logging
from ..misc import boolean
if typing.TYPE_CHECKING:
from collections.abc import Callable
from typing import Any
CONFIG_DEFAULTS: dict[str, tuple[str, Any]] = {
'schema-version': ('int', 20240119),
'log-level': ('loglevel', logging.LogLevel.INFO),
'note': ('str', 'Make a note about your instance here.'),
'private-key': ('str', None),
'whitelist-enabled': ('bool', False)
}
# serializer | deserializer
CONFIG_CONVERT: dict[str, tuple[Callable, Callable]] = {
'str': (str, str),
'int': (str, int),
'bool': (str, boolean),
'loglevel': (lambda x: x.name, logging.LogLevel.parse)
}
def get_default_value(key: str) -> Any:
return CONFIG_DEFAULTS[key][1]
def get_default_type(key: str) -> str:
return CONFIG_DEFAULTS[key][0]
def serialize(key: str, value: Any) -> str:
type_name = get_default_type(key)
return CONFIG_CONVERT[type_name][0](value)
def deserialize(key: str, value: str) -> Any:
type_name = get_default_type(key)
return CONFIG_CONVERT[type_name][1](value)

View file

@ -1,296 +0,0 @@
from __future__ import annotations
import tinysql
import typing
from datetime import datetime, timezone
from urllib.parse import urlparse
from .config import CONFIG_DEFAULTS, get_default_type, get_default_value, serialize, deserialize
from .. import logger as logging
from ..misc import get_app
if typing.TYPE_CHECKING:
from collections.abc import Iterator
from tinysql import Cursor, Row
from typing import Any
from .application import Application
from ..misc import Message
RELAY_SOFTWARE = [
'activityrelay', # https://git.pleroma.social/pleroma/relay
'activity-relay', # https://github.com/yukimochi/Activity-Relay
'aoderelay', # https://git.asonix.dog/asonix/relay
'feditools-relay' # https://git.ptzo.gdn/feditools/relay
]
class Connection(tinysql.Connection):
@property
def app(self) -> Application:
return get_app()
def distill_inboxes(self, message: Message) -> Iterator[str]:
src_domains = {
message.domain,
urlparse(message.object_id).netloc
}
for inbox in self.execute('SELECT * FROM inboxes'):
if inbox['domain'] not in src_domains:
yield inbox['inbox']
def exec_statement(self, name: str, params: dict[str, Any] | None = None) -> Cursor:
return self.execute(self.database.prepared_statements[name], params)
def get_config(self, key: str) -> Any:
if key not in CONFIG_DEFAULTS:
raise KeyError(key)
with self.exec_statement('get-config', {'key': key}) as cur:
if not (row := cur.one()):
return get_default_value(key)
if row['value']:
return deserialize(row['key'], row['value'])
return None
def get_config_all(self) -> dict[str, Any]:
with self.exec_statement('get-config-all') as cur:
db_config = {row['key']: row['value'] for row in cur}
config = {}
for key, data in CONFIG_DEFAULTS.items():
try:
config[key] = deserialize(key, db_config[key])
except KeyError:
if key == 'schema-version':
config[key] = 0
else:
config[key] = data[1]
return config
def put_config(self, key: str, value: Any) -> Any:
if key not in CONFIG_DEFAULTS:
raise KeyError(key)
if key == 'private-key':
self.app.signer = value
elif key == 'log-level':
value = logging.LogLevel.parse(value)
logging.set_level(value)
params = {
'key': key,
'value': serialize(key, value) if value is not None else None,
'type': get_default_type(key)
}
with self.exec_statement('put-config', params):
return value
def get_inbox(self, value: str) -> Row:
with self.exec_statement('get-inbox', {'value': value}) as cur:
return cur.one()
def put_inbox(self,
domain: str,
inbox: str,
actor: str | None = None,
followid: str | None = None,
software: str | None = None) -> Row:
params = {
'domain': domain,
'inbox': inbox,
'actor': actor,
'followid': followid,
'software': software,
'created': datetime.now(tz = timezone.utc)
}
with self.exec_statement('put-inbox', params) as cur:
return cur.one()
def update_inbox(self,
inbox: str,
actor: str | None = None,
followid: str | None = None,
software: str | None = None) -> Row:
if not (actor or followid or software):
raise ValueError('Missing "actor", "followid", and/or "software"')
data = {}
if actor:
data['actor'] = actor
if followid:
data['followid'] = followid
if software:
data['software'] = software
statement = tinysql.Update('inboxes', data, inbox = inbox)
with self.query(statement):
return self.get_inbox(inbox)
def del_inbox(self, value: str) -> bool:
with self.exec_statement('del-inbox', {'value': value}) as cur:
if cur.modified_row_count > 1:
raise ValueError('More than one row was modified')
return cur.modified_row_count == 1
def get_domain_ban(self, domain: str) -> Row:
if domain.startswith('http'):
domain = urlparse(domain).netloc
with self.exec_statement('get-domain-ban', {'domain': domain}) as cur:
return cur.one()
def put_domain_ban(self,
domain: str,
reason: str | None = None,
note: str | None = None) -> Row:
params = {
'domain': domain,
'reason': reason,
'note': note,
'created': datetime.now(tz = timezone.utc)
}
with self.exec_statement('put-domain-ban', params) as cur:
return cur.one()
def update_domain_ban(self,
domain: str,
reason: str | None = None,
note: str | None = None) -> tinysql.Row:
if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified')
params = {}
if reason:
params['reason'] = reason
if note:
params['note'] = note
statement = tinysql.Update('domain_bans', params, domain = domain)
with self.query(statement) as cur:
if cur.modified_row_count > 1:
raise ValueError('More than one row was modified')
return self.get_domain_ban(domain)
def del_domain_ban(self, domain: str) -> bool:
with self.exec_statement('del-domain-ban', {'domain': domain}) as cur:
if cur.modified_row_count > 1:
raise ValueError('More than one row was modified')
return cur.modified_row_count == 1
def get_software_ban(self, name: str) -> Row:
with self.exec_statement('get-software-ban', {'name': name}) as cur:
return cur.one()
def put_software_ban(self,
name: str,
reason: str | None = None,
note: str | None = None) -> Row:
params = {
'name': name,
'reason': reason,
'note': note,
'created': datetime.now(tz = timezone.utc)
}
with self.exec_statement('put-software-ban', params) as cur:
return cur.one()
def update_software_ban(self,
name: str,
reason: str | None = None,
note: str | None = None) -> tinysql.Row:
if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified')
params = {}
if reason:
params['reason'] = reason
if note:
params['note'] = note
statement = tinysql.Update('software_bans', params, name = name)
with self.query(statement) as cur:
if cur.modified_row_count > 1:
raise ValueError('More than one row was modified')
return self.get_software_ban(name)
def del_software_ban(self, name: str) -> bool:
with self.exec_statement('del-software-ban', {'name': name}) as cur:
if cur.modified_row_count > 1:
raise ValueError('More than one row was modified')
return cur.modified_row_count == 1
def get_domain_whitelist(self, domain: str) -> Row:
with self.exec_statement('get-domain-whitelist', {'domain': domain}) as cur:
return cur.one()
def put_domain_whitelist(self, domain: str) -> Row:
params = {
'domain': domain,
'created': datetime.now(tz = timezone.utc)
}
with self.exec_statement('put-domain-whitelist', params) as cur:
return cur.one()
def del_domain_whitelist(self, domain: str) -> bool:
with self.exec_statement('del-domain-whitelist', {'domain': domain}) as cur:
if cur.modified_row_count > 1:
raise ValueError('More than one row was modified')
return cur.modified_row_count == 1

View file

@ -1,60 +0,0 @@
from __future__ import annotations
import typing
from tinysql import Column, Connection, Table
from .config import get_default_value
if typing.TYPE_CHECKING:
from collections.abc import Callable
VERSIONS: list[Callable] = []
TABLES: list[Table] = [
Table(
'config',
Column('key', 'text', primary_key = True, unique = True, nullable = False),
Column('value', 'text'),
Column('type', 'text', default = 'str')
),
Table(
'inboxes',
Column('domain', 'text', primary_key = True, unique = True, nullable = False),
Column('actor', 'text', unique = True),
Column('inbox', 'text', unique = True, nullable = False),
Column('followid', 'text'),
Column('software', 'text'),
Column('created', 'timestamp', nullable = False)
),
Table(
'whitelist',
Column('domain', 'text', primary_key = True, unique = True, nullable = True),
Column('created', 'timestamp')
),
Table(
'domain_bans',
Column('domain', 'text', primary_key = True, unique = True, nullable = True),
Column('reason', 'text'),
Column('note', 'text'),
Column('created', 'timestamp', nullable = False)
),
Table(
'software_bans',
Column('name', 'text', primary_key = True, unique = True, nullable = True),
Column('reason', 'text'),
Column('note', 'text'),
Column('created', 'timestamp', nullable = False)
)
]
def version(func: Callable) -> Callable:
ver = int(func.replace('migrate_', ''))
VERSIONS[ver] = func
return func
def migrate_0(conn: Connection) -> None:
conn.create_tables(TABLES)
conn.put_config('schema-version', get_default_value('schema-version'))

View file

@ -13,10 +13,11 @@ from urllib.parse import urlparse
from . import __version__
from . import logger as logging
from .misc import MIMETYPES, Message, get_app
from .misc import MIMETYPES, Message
if typing.TYPE_CHECKING:
from typing import Any
from typing import Any, Callable, Optional
from .database import RelayDatabase
HEADERS = {
@ -26,7 +27,13 @@ HEADERS = {
class HttpClient:
def __init__(self, limit: int = 100, timeout: int = 10, cache_size: int = 1024):
def __init__(self,
database: RelayDatabase,
limit: Optional[int] = 100,
timeout: Optional[int] = 10,
cache_size: Optional[int] = 1024):
self.database = database
self.cache = LRUCache(cache_size)
self.limit = limit
self.timeout = timeout
@ -73,9 +80,9 @@ class HttpClient:
async def get(self, # pylint: disable=too-many-branches
url: str,
sign_headers: bool = False,
loads: callable | None = None,
force: bool = False) -> Message | dict | None:
sign_headers: Optional[bool] = False,
loads: Optional[Callable] = None,
force: Optional[bool] = False) -> Message | dict | None:
await self.open()
@ -91,7 +98,7 @@ class HttpClient:
headers = {}
if sign_headers:
get_app().signer.sign_headers('GET', url, algorithm = 'original')
headers.update(self.database.signer.sign_headers('GET', url, algorithm='original'))
try:
logging.debug('Fetching resource: %s', url)
@ -143,27 +150,23 @@ class HttpClient:
async def post(self, url: str, message: Message) -> None:
await self.open()
# todo: cache inboxes to avoid opening a db connection
with get_app().database.connection() as conn:
instance = conn.get_inbox(url)
instance = self.database.get_inbox(url)
## Using the old algo by default is probably a better idea right now
# pylint: disable=consider-ternary-expression
if instance and instance['software'] in {'mastodon'}:
if instance and instance.get('software') in {'mastodon'}:
algorithm = 'hs2019'
else:
algorithm = 'original'
# pylint: enable=consider-ternary-expression
headers = {'Content-Type': 'application/activity+json'}
headers.update(get_app().signer.sign_headers('POST', url, message, algorithm=algorithm))
headers.update(self.database.signer.sign_headers('POST', url, message, algorithm=algorithm))
try:
logging.verbose('Sending "%s" to %s', message.type, url)
async with self._session.post(url, headers=headers, data=message.to_json()) as resp:
# Not expecting a response, so just return
## Not expecting a response, so just return
if resp.status in {200, 202}:
logging.verbose('Successfully sent "%s" to %s', message.type, url)
return
@ -178,7 +181,7 @@ class HttpClient:
except (AsyncTimeoutError, ClientConnectionError):
logging.warning('Failed to connect to %s for message push', urlparse(url).netloc)
# prevent workers from being brought down
## prevent workers from being brought down
except Exception:
traceback.print_exc()
@ -194,7 +197,7 @@ class HttpClient:
logging.verbose('Failed to fetch well-known nodeinfo url for %s', domain)
return None
for version in ('20', '21'):
for version in ['20', '21']:
try:
nodeinfo_url = wk_nodeinfo.get_url(version)
@ -208,16 +211,16 @@ class HttpClient:
return await self.get(nodeinfo_url, loads = Nodeinfo.parse) or None
async def get(*args: Any, **kwargs: Any) -> Message | dict | None:
async with HttpClient() as client:
async def get(database: RelayDatabase, *args: Any, **kwargs: Any) -> Message | dict | None:
async with HttpClient(database) as client:
return await client.get(*args, **kwargs)
async def post(*args: Any, **kwargs: Any) -> None:
async with HttpClient() as client:
async def post(database: RelayDatabase, *args: Any, **kwargs: Any) -> None:
async with HttpClient(database) as client:
return await client.post(*args, **kwargs)
async def fetch_nodeinfo(*args: Any, **kwargs: Any) -> Nodeinfo | None:
async with HttpClient() as client:
async def fetch_nodeinfo(database: RelayDatabase, *args: Any, **kwargs: Any) -> Nodeinfo | None:
async with HttpClient(database) as client:
return await client.fetch_nodeinfo(*args, **kwargs)

View file

@ -4,63 +4,20 @@ import logging
import os
import typing
from enum import IntEnum
from pathlib import Path
if typing.TYPE_CHECKING:
from collections.abc import Callable
from typing import Any
from typing import Any, Callable
class LogLevel(IntEnum):
DEBUG = logging.DEBUG
VERBOSE = 15
INFO = logging.INFO
WARNING = logging.WARNING
ERROR = logging.ERROR
CRITICAL = logging.CRITICAL
def __str__(self) -> str:
return self.name
@classmethod
def parse(cls: type[IntEnum], data: object) -> IntEnum:
if isinstance(data, cls):
return data
if isinstance(data, str):
data = data.upper()
try:
return cls[data]
except KeyError:
pass
try:
return cls(data)
except ValueError:
pass
raise AttributeError(f'Invalid enum property for {cls.__name__}: {data}')
def get_level() -> LogLevel:
return LogLevel.parse(logging.root.level)
def set_level(level: LogLevel | str) -> None:
logging.root.setLevel(LogLevel.parse(level))
def verbose(message: str, *args: Any, **kwargs: Any) -> None:
if not logging.root.isEnabledFor(LogLevel['VERBOSE']):
return
logging.log(LogLevel['VERBOSE'], message, *args, **kwargs)
LOG_LEVELS: dict[str, int] = {
'DEBUG': logging.DEBUG,
'VERBOSE': 15,
'INFO': logging.INFO,
'WARNING': logging.WARNING,
'ERROR': logging.ERROR,
'CRITICAL': logging.CRITICAL
}
debug: Callable = logging.debug
@ -70,7 +27,14 @@ error: Callable = logging.error
critical: Callable = logging.critical
logging.addLevelName(LogLevel['VERBOSE'], 'VERBOSE')
def verbose(message: str, *args: Any, **kwargs: Any) -> None:
if not logging.root.isEnabledFor(LOG_LEVELS['VERBOSE']):
return
logging.log(LOG_LEVELS['VERBOSE'], message, *args, **kwargs)
logging.addLevelName(LOG_LEVELS['VERBOSE'], 'VERBOSE')
env_log_level = os.environ.get('LOG_LEVEL', 'INFO').upper()
try:
@ -81,11 +45,11 @@ except KeyError:
try:
log_level = LogLevel[env_log_level]
log_level = LOG_LEVELS[env_log_level]
except KeyError:
print('Invalid log level:', env_log_level)
log_level = LogLevel['INFO']
logging.warning('Invalid log level: %s', env_log_level)
log_level = logging.INFO
handlers = [logging.StreamHandler()]

View file

@ -6,49 +6,22 @@ import click
import platform
import typing
from aputils.signer import Signer
from pathlib import Path
from shutil import copyfile
from urllib.parse import urlparse
from . import __version__
from . import misc, __version__
from . import http_client as http
from . import logger as logging
from .application import Application
from .compat import RelayConfig, RelayDatabase
from .database import get_database
from .database.connection import RELAY_SOFTWARE
from .misc import IS_DOCKER, Message, check_open_port
from .config import RELAY_SOFTWARE
if typing.TYPE_CHECKING:
from tinysql import Row
from typing import Any
# pylint: disable=unsubscriptable-object,unsupported-assignment-operation
CONFIG_IGNORE = (
'schema-version',
'private-key'
)
ACTOR_FORMATS = {
'mastodon': 'https://{domain}/actor',
'akkoma': 'https://{domain}/relay',
'pleroma': 'https://{domain}/relay'
}
SOFTWARE = (
'mastodon',
'akkoma',
'pleroma',
'misskey',
'friendica',
'hubzilla',
'firefish',
'gotosocial'
)
app = None
CONFIG_IGNORE = {'blocked_software', 'blocked_instances', 'whitelist'}
@click.group('cli', context_settings={'show_default': True}, invoke_without_command=True)
@ -56,10 +29,11 @@ SOFTWARE = (
@click.version_option(version=__version__, prog_name='ActivityRelay')
@click.pass_context
def cli(ctx: click.Context, config: str) -> None:
ctx.obj = Application(config)
global app
app = Application(config)
if not ctx.invoked_subcommand:
if ctx.obj.config.domain.endswith('example.com'):
if app.config.host.endswith('example.com'):
cli_setup.callback()
else:
@ -67,92 +41,46 @@ def cli(ctx: click.Context, config: str) -> None:
@cli.command('setup')
@click.pass_context
def cli_setup(ctx: click.Context) -> None:
'Generate a new config and create the database'
def cli_setup() -> None:
'Generate a new config'
while True:
ctx.obj.config.domain = click.prompt(
app.config.host = click.prompt(
'What domain will the relay be hosted on?',
default = ctx.obj.config.domain
default = app.config.host
)
if not ctx.obj.config.domain.endswith('example.com'):
if not app.config.host.endswith('example.com'):
break
click.echo('The domain must not end with "example.com"')
click.echo('The domain must not be example.com')
if not IS_DOCKER:
ctx.obj.config.listen = click.prompt(
if not app.config.is_docker:
app.config.listen = click.prompt(
'Which address should the relay listen on?',
default = ctx.obj.config.listen
default = app.config.listen
)
ctx.obj.config.port = click.prompt(
'What TCP port should the relay listen on?',
default = ctx.obj.config.port,
type = int
)
while True:
app.config.port = click.prompt(
'What TCP port should the relay listen on?',
default = app.config.port,
type = int
)
ctx.obj.config.db_type = click.prompt(
'Which database backend will be used?',
default = ctx.obj.config.db_type,
type = click.Choice(['postgres', 'sqlite'], case_sensitive = False)
)
break
if ctx.obj.config.db_type == 'sqlite':
ctx.obj.config.sq_path = click.prompt(
'Where should the database be stored?',
default = ctx.obj.config.sq_path
)
app.config.save()
elif ctx.obj.config.db_type == 'postgres':
ctx.obj.config.pg_name = click.prompt(
'What is the name of the database?',
default = ctx.obj.config.pg_name
)
ctx.obj.config.pg_host = click.prompt(
'What IP address or hostname does the server listen on?',
default = ctx.obj.config.pg_host
)
ctx.obj.config.pg_port = click.prompt(
'What port does the server listen on?',
default = ctx.obj.config.pg_port,
type = int
)
ctx.obj.config.pg_user = click.prompt(
'Which user will authenticate with the server?',
default = ctx.obj.config.pg_user
)
ctx.obj.config.pg_pass = click.prompt(
'User password: ',
hide_input = True
) or None
ctx.obj.config.save()
config = {
'private-key': Signer.new('n/a').export()
}
with ctx.obj.database.connection() as conn:
for key, value in config.items():
conn.put_config(key, value)
if not IS_DOCKER and click.confirm('Relay all setup! Would you like to run it now?'):
if not app.config.is_docker and click.confirm('Relay all setup! Would you like to run it now?'):
cli_run.callback()
@cli.command('run')
@click.pass_context
def cli_run(ctx: click.Context) -> None:
def cli_run() -> None:
'Run the relay'
if ctx.obj.config.domain.endswith('example.com') or not ctx.obj.signer:
if app.config.host.endswith('example.com'):
click.echo(
'Relay is not set up. Please edit your relay config or run "activityrelay setup".'
)
@ -176,144 +104,40 @@ def cli_run(ctx: click.Context) -> None:
click.echo(pip_command)
return
if not check_open_port(ctx.obj.config.listen, ctx.obj.config.port):
click.echo(f'Error: A server is already running on port {ctx.obj.config.port}')
if not misc.check_open_port(app.config.listen, app.config.port):
click.echo(f'Error: A server is already running on port {app.config.port}')
return
ctx.obj.run()
@cli.command('convert')
@click.option('--old-config', '-o', help = 'Path to the config file to convert from')
@click.pass_context
def cli_convert(ctx: click.Context, old_config: str) -> None:
'Convert an old config and jsonld database to the new format.'
old_config = Path(old_config).expanduser().resolve() if old_config else ctx.obj.config.path
backup = ctx.obj.config.path.parent.joinpath(f'{ctx.obj.config.path.stem}.backup.yaml')
if str(old_config) == str(ctx.obj.config.path) and not backup.exists():
logging.info('Created backup config @ %s', backup)
copyfile(ctx.obj.config.path, backup)
config = RelayConfig(old_config)
config.load()
database = RelayDatabase(config)
database.load()
ctx.obj.config.set('listen', config['listen'])
ctx.obj.config.set('port', config['port'])
ctx.obj.config.set('workers', config['workers'])
ctx.obj.config.set('sq_path', config['db'].replace('jsonld', 'sqlite3'))
ctx.obj.config.set('domain', config['host'])
ctx.obj.config.save()
with get_database(ctx.obj.config) as db:
with db.connection() as conn:
conn.put_config('private-key', database['private-key'])
conn.put_config('note', config['note'])
conn.put_config('whitelist-enabled', config['whitelist_enabled'])
with click.progressbar(
database['relay-list'].values(),
label = 'Inboxes'.ljust(15),
width = 0
) as inboxes:
for inbox in inboxes:
if inbox['software'] in {'akkoma', 'pleroma'}:
actor = f'https://{inbox["domain"]}/relay'
elif inbox['software'] == 'mastodon':
actor = f'https://{inbox["domain"]}/actor'
else:
actor = None
conn.put_inbox(
inbox['domain'],
inbox['inbox'],
actor = actor,
followid = inbox['followid'],
software = inbox['software']
)
with click.progressbar(
config['blocked_software'],
label = 'Banned software'.ljust(15),
width = 0
) as banned_software:
for software in banned_software:
conn.put_software_ban(
software,
reason = 'relay' if software in RELAY_SOFTWARE else None
)
with click.progressbar(
config['blocked_instances'],
label = 'Banned domains'.ljust(15),
width = 0
) as banned_software:
for domain in banned_software:
conn.put_domain_ban(domain)
with click.progressbar(
config['whitelist'],
label = 'Whitelist'.ljust(15),
width = 0
) as whitelist:
for instance in whitelist:
conn.put_domain_whitelist(instance)
click.echo('Finished converting old config and database :3')
@cli.command('edit-config')
@click.option('--editor', '-e', help = 'Text editor to use')
@click.pass_context
def cli_editconfig(ctx: click.Context, editor: str) -> None:
'Edit the config file'
click.edit(
editor = editor,
filename = str(ctx.obj.config.path)
)
app.run()
@cli.group('config')
def cli_config() -> None:
'Manage the relay settings stored in the database'
'Manage the relay config'
@cli_config.command('list')
@click.pass_context
def cli_config_list(ctx: click.Context) -> None:
def cli_config_list() -> None:
'List the current relay config'
click.echo('Relay Config:')
with ctx.obj.database.connection() as conn:
for key, value in conn.get_config_all().items():
if key not in CONFIG_IGNORE:
key = f'{key}:'.ljust(20)
click.echo(f'- {key} {value}')
for key, value in app.config.items():
if key not in CONFIG_IGNORE:
key = f'{key}:'.ljust(20)
click.echo(f'- {key} {value}')
@cli_config.command('set')
@click.argument('key')
@click.argument('value')
@click.pass_context
def cli_config_set(ctx: click.Context, key: str, value: Any) -> None:
def cli_config_set(key: str, value: Any) -> None:
'Set a config value'
with ctx.obj.database.connection() as conn:
new_value = conn.put_config(key, value)
app.config[key] = value
app.config.save()
print(f'{key}: {repr(new_value)}')
print(f'{key}: {app.config[key]}')
@cli.group('inbox')
@ -322,150 +146,127 @@ def cli_inbox() -> None:
@cli_inbox.command('list')
@click.pass_context
def cli_inbox_list(ctx: click.Context) -> None:
def cli_inbox_list() -> None:
'List the connected instances or relays'
click.echo('Connected to the following instances or relays:')
with ctx.obj.database.connection() as conn:
for inbox in conn.execute('SELECT * FROM inboxes'):
click.echo(f'- {inbox["inbox"]}')
for inbox in app.database.inboxes:
click.echo(f'- {inbox}')
@cli_inbox.command('follow')
@click.argument('actor')
@click.pass_context
def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
def cli_inbox_follow(actor: str) -> None:
'Follow an actor (Relay must be running)'
with ctx.obj.database.connection() as conn:
if conn.get_domain_ban(actor):
click.echo(f'Error: Refusing to follow banned actor: {actor}')
if app.config.is_banned(actor):
click.echo(f'Error: Refusing to follow banned actor: {actor}')
return
if not actor.startswith('http'):
domain = actor
actor = f'https://{actor}/actor'
else:
domain = urlparse(actor).hostname
try:
inbox_data = app.database['relay-list'][domain]
inbox = inbox_data['inbox']
except KeyError:
actor_data = asyncio.run(http.get(app.database, actor, sign_headers=True))
if not actor_data:
click.echo(f'Failed to fetch actor: {actor}')
return
if (inbox_data := conn.get_inbox(actor)):
inbox = inbox_data['inbox']
inbox = actor_data.shared_inbox
else:
if not actor.startswith('http'):
actor = f'https://{actor}/actor'
if not (actor_data := asyncio.run(http.get(actor, sign_headers = True))):
click.echo(f'Failed to fetch actor: {actor}')
return
inbox = actor_data.shared_inbox
message = Message.new_follow(
host = ctx.obj.config.domain,
message = misc.Message.new_follow(
host = app.config.host,
actor = actor
)
asyncio.run(http.post(inbox, message))
asyncio.run(http.post(app.database, inbox, message))
click.echo(f'Sent follow message to actor: {actor}')
@cli_inbox.command('unfollow')
@click.argument('actor')
@click.pass_context
def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
def cli_inbox_unfollow(actor: str) -> None:
'Unfollow an actor (Relay must be running)'
inbox_data: Row = None
if not actor.startswith('http'):
domain = actor
actor = f'https://{actor}/actor'
with ctx.obj.database.connection() as conn:
if conn.get_domain_ban(actor):
click.echo(f'Error: Refusing to follow banned actor: {actor}')
return
else:
domain = urlparse(actor).hostname
if (inbox_data := conn.get_inbox(actor)):
inbox = inbox_data['inbox']
message = Message.new_unfollow(
host = ctx.obj.config.domain,
actor = actor,
follow = inbox_data['followid']
)
try:
inbox_data = app.database['relay-list'][domain]
inbox = inbox_data['inbox']
message = misc.Message.new_unfollow(
host = app.config.host,
actor = actor,
follow = inbox_data['followid']
)
else:
if not actor.startswith('http'):
actor = f'https://{actor}/actor'
except KeyError:
actor_data = asyncio.run(http.get(app.database, actor, sign_headers=True))
inbox = actor_data.shared_inbox
message = misc.Message.new_unfollow(
host = app.config.host,
actor = actor,
follow = {
'type': 'Follow',
'object': actor,
'actor': f'https://{app.config.host}/actor'
}
)
actor_data = asyncio.run(http.get(actor, sign_headers = True))
inbox = actor_data.shared_inbox
message = Message.new_unfollow(
host = ctx.obj.config.domain,
actor = actor,
follow = {
'type': 'Follow',
'object': actor,
'actor': f'https://{ctx.obj.config.domain}/actor'
}
)
asyncio.run(http.post(inbox, message))
asyncio.run(http.post(app.database, inbox, message))
click.echo(f'Sent unfollow message to: {actor}')
@cli_inbox.command('add')
@click.argument('inbox')
@click.option('--actor', '-a', help = 'Actor url for the inbox')
@click.option('--followid', '-f', help = 'Url for the follow activity')
@click.option('--software', '-s',
type = click.Choice(SOFTWARE),
help = 'Nodeinfo software name of the instance'
) # noqa: E124
@click.pass_context
def cli_inbox_add(
ctx: click.Context,
inbox: str,
actor: str | None = None,
followid: str | None = None,
software: str | None = None) -> None:
def cli_inbox_add(inbox: str) -> None:
'Add an inbox to the database'
if not inbox.startswith('http'):
domain = inbox
inbox = f'https://{inbox}/inbox'
else:
domain = urlparse(inbox).netloc
if app.config.is_banned(inbox):
click.echo(f'Error: Refusing to add banned inbox: {inbox}')
return
if not software:
if (nodeinfo := asyncio.run(http.fetch_nodeinfo(domain))):
software = nodeinfo.sw_name
if app.database.get_inbox(inbox):
click.echo(f'Error: Inbox already in database: {inbox}')
return
if not actor and software:
try:
actor = ACTOR_FORMATS[software].format(domain = domain)
except KeyError:
pass
with ctx.obj.database.connection() as conn:
if conn.get_domain_ban(domain):
click.echo(f'Refusing to add banned inbox: {inbox}')
return
if conn.get_inbox(inbox):
click.echo(f'Error: Inbox already in database: {inbox}')
return
conn.put_inbox(domain, inbox, actor, followid, software)
app.database.add_inbox(inbox)
app.database.save()
click.echo(f'Added inbox to the database: {inbox}')
@cli_inbox.command('remove')
@click.argument('inbox')
@click.pass_context
def cli_inbox_remove(ctx: click.Context, inbox: str) -> None:
def cli_inbox_remove(inbox: str) -> None:
'Remove an inbox from the database'
with ctx.obj.database.connection() as conn:
if not conn.del_inbox(inbox):
click.echo(f'Inbox not in database: {inbox}')
return
try:
dbinbox = app.database.get_inbox(inbox, fail=True)
except KeyError:
click.echo(f'Error: Inbox does not exist: {inbox}')
return
app.database.del_inbox(dbinbox['domain'])
app.database.save()
click.echo(f'Removed inbox from the database: {inbox}')
@ -476,76 +277,47 @@ def cli_instance() -> None:
@cli_instance.command('list')
@click.pass_context
def cli_instance_list(ctx: click.Context) -> None:
def cli_instance_list() -> None:
'List all banned instances'
click.echo('Banned domains:')
click.echo('Banned instances or relays:')
with ctx.obj.database.connection() as conn:
for instance in conn.execute('SELECT * FROM domain_bans'):
if instance['reason']:
click.echo(f'- {instance["domain"]} ({instance["reason"]})')
else:
click.echo(f'- {instance["domain"]}')
for domain in app.config.blocked_instances:
click.echo(f'- {domain}')
@cli_instance.command('ban')
@click.argument('domain')
@click.option('--reason', '-r', help = 'Public note about why the domain is banned')
@click.option('--note', '-n', help = 'Internal note that will only be seen by admins and mods')
@click.pass_context
def cli_instance_ban(ctx: click.Context, domain: str, reason: str, note: str) -> None:
@click.argument('target')
def cli_instance_ban(target: str) -> None:
'Ban an instance and remove the associated inbox if it exists'
with ctx.obj.database.connection() as conn:
if conn.get_domain_ban(domain):
click.echo(f'Domain already banned: {domain}')
return
if target.startswith('http'):
target = urlparse(target).hostname
conn.put_domain_ban(domain, reason, note)
conn.del_inbox(domain)
click.echo(f'Banned instance: {domain}')
if app.config.ban_instance(target):
app.config.save()
if app.database.del_inbox(target):
app.database.save()
click.echo(f'Banned instance: {target}')
return
click.echo(f'Instance already banned: {target}')
@cli_instance.command('unban')
@click.argument('domain')
@click.pass_context
def cli_instance_unban(ctx: click.Context, domain: str) -> None:
@click.argument('target')
def cli_instance_unban(target: str) -> None:
'Unban an instance'
with ctx.obj.database.connection() as conn:
if not conn.del_domain_ban(domain):
click.echo(f'Instance wasn\'t banned: {domain}')
return
if app.config.unban_instance(target):
app.config.save()
click.echo(f'Unbanned instance: {domain}')
click.echo(f'Unbanned instance: {target}')
return
@cli_instance.command('update')
@click.argument('domain')
@click.option('--reason', '-r')
@click.option('--note', '-n')
@click.pass_context
def cli_instance_update(ctx: click.Context, domain: str, reason: str, note: str) -> None:
'Update the public reason or internal note for a domain ban'
if not (reason or note):
ctx.fail('Must pass --reason or --note')
with ctx.obj.database.connection() as conn:
if not (row := conn.update_domain_ban(domain, reason, note)):
click.echo(f'Failed to update domain ban: {domain}')
return
click.echo(f'Updated domain ban: {domain}')
if row['reason']:
click.echo(f'- {row["domain"]} ({row["reason"]})')
else:
click.echo(f'- {row["domain"]}')
click.echo(f'Instance wasn\'t banned: {target}')
@cli.group('software')
@ -554,127 +326,79 @@ def cli_software() -> None:
@cli_software.command('list')
@click.pass_context
def cli_software_list(ctx: click.Context) -> None:
def cli_software_list() -> None:
'List all banned software'
click.echo('Banned software:')
with ctx.obj.database.connection() as conn:
for software in conn.execute('SELECT * FROM software_bans'):
if software['reason']:
click.echo(f'- {software["name"]} ({software["reason"]})')
else:
click.echo(f'- {software["name"]}')
for software in app.config.blocked_software:
click.echo(f'- {software}')
@cli_software.command('ban')
@click.argument('name')
@click.option('--reason', '-r')
@click.option('--note', '-n')
@click.option(
'--fetch-nodeinfo', '-f',
is_flag = True,
help = 'Treat NAME like a domain and try to fetch the software name from nodeinfo'
'--fetch-nodeinfo/--ignore-nodeinfo', '-f', 'fetch_nodeinfo', default = False,
help = 'Treat NAME like a domain and try to fet the software name from nodeinfo'
)
@click.pass_context
def cli_software_ban(ctx: click.Context,
name: str,
reason: str,
note: str,
fetch_nodeinfo: bool) -> None:
@click.argument('name')
def cli_software_ban(name: str, fetch_nodeinfo: bool) -> None:
'Ban software. Use RELAYS for NAME to ban relays'
with ctx.obj.database.connection() as conn:
if name == 'RELAYS':
for software in RELAY_SOFTWARE:
if conn.get_software_ban(software):
click.echo(f'Relay already banned: {software}')
continue
if name == 'RELAYS':
for software in RELAY_SOFTWARE:
app.config.ban_software(software)
conn.put_software_ban(software, reason or 'relay', note)
app.config.save()
click.echo('Banned all relay software')
return
click.echo('Banned all relay software')
return
if fetch_nodeinfo:
nodeinfo = asyncio.run(http.fetch_nodeinfo(app.database, name))
if fetch_nodeinfo:
if not (nodeinfo := asyncio.run(http.fetch_nodeinfo(name))):
click.echo(f'Failed to fetch software name from domain: {name}')
return
if not nodeinfo:
click.echo(f'Failed to fetch software name from domain: {name}')
name = nodeinfo.sw_name
if conn.get_software_ban(name):
click.echo(f'Software already banned: {name}')
return
if not conn.put_software_ban(name, reason, note):
click.echo(f'Failed to ban software: {name}')
return
name = nodeinfo.sw_name
if app.config.ban_software(name):
app.config.save()
click.echo(f'Banned software: {name}')
return
click.echo(f'Software already banned: {name}')
@cli_software.command('unban')
@click.argument('name')
@click.option('--reason', '-r')
@click.option('--note', '-n')
@click.option(
'--fetch-nodeinfo', '-f',
is_flag = True,
help = 'Treat NAME like a domain and try to fetch the software name from nodeinfo'
'--fetch-nodeinfo/--ignore-nodeinfo', '-f', 'fetch_nodeinfo', default = False,
help = 'Treat NAME like a domain and try to fet the software name from nodeinfo'
)
@click.pass_context
def cli_software_unban(ctx: click.Context, name: str, fetch_nodeinfo: bool) -> None:
@click.argument('name')
def cli_software_unban(name: str, fetch_nodeinfo: bool) -> None:
'Ban software. Use RELAYS for NAME to unban relays'
with ctx.obj.database.connection() as conn:
if name == 'RELAYS':
for software in RELAY_SOFTWARE:
if not conn.del_software_ban(software):
click.echo(f'Relay was not banned: {software}')
if name == 'RELAYS':
for software in RELAY_SOFTWARE:
app.config.unban_software(software)
click.echo('Unbanned all relay software')
return
app.config.save()
click.echo('Unbanned all relay software')
return
if fetch_nodeinfo:
if not (nodeinfo := asyncio.run(http.fetch_nodeinfo(name))):
click.echo(f'Failed to fetch software name from domain: {name}')
return
if fetch_nodeinfo:
nodeinfo = asyncio.run(http.fetch_nodeinfo(app.database, name))
name = nodeinfo.sw_name
if not nodeinfo:
click.echo(f'Failed to fetch software name from domain: {name}')
if not conn.del_software_ban(name):
click.echo(f'Software was not banned: {name}')
return
name = nodeinfo.sw_name
if app.config.unban_software(name):
app.config.save()
click.echo(f'Unbanned software: {name}')
return
@cli_software.command('update')
@click.argument('name')
@click.option('--reason', '-r')
@click.option('--note', '-n')
@click.pass_context
def cli_software_update(ctx: click.Context, name: str, reason: str, note: str) -> None:
'Update the public reason or internal note for a software ban'
if not (reason or note):
ctx.fail('Must pass --reason or --note')
with ctx.obj.database.connection() as conn:
if not (row := conn.update_software_ban(name, reason, note)):
click.echo(f'Failed to update software ban: {name}')
return
click.echo(f'Updated software ban: {name}')
if row['reason']:
click.echo(f'- {row["name"]} ({row["reason"]})')
else:
click.echo(f'- {row["name"]}')
click.echo(f'Software wasn\'t banned: {name}')
@cli.group('whitelist')
@ -683,64 +407,52 @@ def cli_whitelist() -> None:
@cli_whitelist.command('list')
@click.pass_context
def cli_whitelist_list(ctx: click.Context) -> None:
def cli_whitelist_list() -> None:
'List all the instances in the whitelist'
click.echo('Current whitelisted domains:')
click.echo('Current whitelisted domains')
with ctx.obj.database.connection() as conn:
for domain in conn.execute('SELECT * FROM whitelist'):
click.echo(f'- {domain["domain"]}')
for domain in app.config.whitelist:
click.echo(f'- {domain}')
@cli_whitelist.command('add')
@click.argument('domain')
@click.pass_context
def cli_whitelist_add(ctx: click.Context, domain: str) -> None:
'Add a domain to the whitelist'
@click.argument('instance')
def cli_whitelist_add(instance: str) -> None:
'Add an instance to the whitelist'
with ctx.obj.database.connection() as conn:
if conn.get_domain_whitelist(domain):
click.echo(f'Instance already in the whitelist: {domain}')
return
if not app.config.add_whitelist(instance):
click.echo(f'Instance already in the whitelist: {instance}')
return
conn.put_domain_whitelist(domain)
click.echo(f'Instance added to the whitelist: {domain}')
app.config.save()
click.echo(f'Instance added to the whitelist: {instance}')
@cli_whitelist.command('remove')
@click.argument('domain')
@click.pass_context
def cli_whitelist_remove(ctx: click.Context, domain: str) -> None:
@click.argument('instance')
def cli_whitelist_remove(instance: str) -> None:
'Remove an instance from the whitelist'
with ctx.obj.database.connection() as conn:
if not conn.del_domain_whitelist(domain):
click.echo(f'Domain not in the whitelist: {domain}')
return
if not app.config.del_whitelist(instance):
click.echo(f'Instance not in the whitelist: {instance}')
return
if conn.get_config('whitelist-enabled'):
if conn.del_inbox(domain):
click.echo(f'Removed inbox for domain: {domain}')
app.config.save()
click.echo(f'Removed domain from the whitelist: {domain}')
if app.config.whitelist_enabled:
if app.database.del_inbox(instance):
app.database.save()
click.echo(f'Removed instance from the whitelist: {instance}')
@cli_whitelist.command('import')
@click.pass_context
def cli_whitelist_import(ctx: click.Context) -> None:
def cli_whitelist_import() -> None:
'Add all current inboxes to the whitelist'
with ctx.obj.database.connection() as conn:
for inbox in conn.execute('SELECT * FROM inboxes').all():
if conn.get_domain_whitelist(inbox['domain']):
click.echo(f'Domain already in whitelist: {inbox["domain"]}')
continue
conn.put_domain_whitelist(inbox['domain'])
click.echo('Imported whitelist from inboxes')
for domain in app.database.hostnames:
cli_whitelist_add.callback(domain)
def main() -> None:

View file

@ -1,28 +1,32 @@
from __future__ import annotations
import json
import os
import socket
import traceback
import typing
from aiohttp.abc import AbstractView
from aiohttp.hdrs import METH_ALL as METHODS
from aiohttp.web import Response as AiohttpResponse
from aiohttp.web import Request as AiohttpRequest, Response as AiohttpResponse
from aiohttp.web_exceptions import HTTPMethodNotAllowed
from aputils.errors import SignatureFailureError
from aputils.misc import Digest, HttpDate, Signature
from aputils.message import Message as ApMessage
from functools import cached_property
from json.decoder import JSONDecodeError
from uuid import uuid4
from . import logger as logging
if typing.TYPE_CHECKING:
from collections.abc import Coroutine, Generator
from typing import Any
from typing import Any, Coroutine, Generator, Optional, Type
from aputils.signer import Signer
from .application import Application
from .config import Config
from .database import Database
from .config import RelayConfig
from .database import RelayDatabase
from .http_client import HttpClient
IS_DOCKER = bool(os.environ.get('DOCKER_RUNNING'))
MIMETYPES = {
'activity': 'application/activity+json',
'html': 'text/html',
@ -38,10 +42,10 @@ NODEINFO_NS = {
def boolean(value: Any) -> bool:
if isinstance(value, str):
if value.lower() in {'on', 'y', 'yes', 'true', 'enable', 'enabled', '1'}:
if value.lower() in ['on', 'y', 'yes', 'true', 'enable', 'enabled', '1']:
return True
if value.lower() in {'off', 'n', 'no', 'false', 'disable', 'disabled', '0'}:
if value.lower() in ['off', 'n', 'no', 'false', 'disable', 'disable', '0']:
return False
raise TypeError(f'Cannot parse string "{value}" as a boolean')
@ -73,21 +77,99 @@ def check_open_port(host: str, port: int) -> bool:
return False
def get_app() -> Application:
from .application import Application # pylint: disable=import-outside-toplevel
class DotDict(dict):
def __init__(self, _data: dict[str, Any], **kwargs: Any):
dict.__init__(self)
if not Application.DEFAULT:
raise ValueError('No default application set')
self.update(_data, **kwargs)
return Application.DEFAULT
def __getattr__(self, key: str) -> str:
try:
return self[key]
except KeyError:
raise AttributeError(
f'{self.__class__.__name__} object has no attribute {key}'
) from None
def __setattr__(self, key: str, value: Any) -> None:
if key.startswith('_'):
super().__setattr__(key, value)
else:
self[key] = value
def __setitem__(self, key: str, value: Any) -> None:
if type(value) is dict: # pylint: disable=unidiomatic-typecheck
value = DotDict(value)
super().__setitem__(key, value)
def __delattr__(self, key: str) -> None:
try:
dict.__delitem__(self, key)
except KeyError:
raise AttributeError(
f'{self.__class__.__name__} object has no attribute {key}'
) from None
@classmethod
def new_from_json(cls: Type[DotDict], data: dict[str, Any]) -> DotDict[str, Any]:
if not data:
raise JSONDecodeError('Empty body', data, 1)
try:
return cls(json.loads(data))
except ValueError:
raise JSONDecodeError('Invalid body', data, 1) from None
@classmethod
def new_from_signature(cls: Type[DotDict], sig: str) -> DotDict[str, Any]:
data = cls({})
for chunk in sig.strip().split(','):
key, value = chunk.split('=', 1)
value = value.strip('\"')
if key == 'headers':
value = value.split()
data[key.lower()] = value
return data
def to_json(self, indent: Optional[int | str] = None) -> str:
return json.dumps(self, indent=indent)
def update(self, _data: dict[str, Any], **kwargs: Any) -> None:
if isinstance(_data, dict):
for key, value in _data.items():
self[key] = value
elif isinstance(_data, (list, tuple, set)):
for key, value in _data:
self[key] = value
for key, value in kwargs.items():
self[key] = value
class Message(ApMessage):
@classmethod
def new_actor(cls: type[Message], # pylint: disable=arguments-differ
def new_actor(cls: Type[Message], # pylint: disable=arguments-differ
host: str,
pubkey: str,
description: str | None = None) -> Message:
description: Optional[str] = None) -> Message:
return cls({
'@context': 'https://www.w3.org/ns/activitystreams',
@ -99,7 +181,7 @@ class Message(ApMessage):
'followers': f'https://{host}/followers',
'following': f'https://{host}/following',
'inbox': f'https://{host}/inbox',
'url': f'https://{host}/',
'url': f'https://{host}/inbox',
'endpoints': {
'sharedInbox': f'https://{host}/inbox'
},
@ -112,7 +194,7 @@ class Message(ApMessage):
@classmethod
def new_announce(cls: type[Message], host: str, obj: str) -> Message:
def new_announce(cls: Type[Message], host: str, obj: str) -> Message:
return cls({
'@context': 'https://www.w3.org/ns/activitystreams',
'id': f'https://{host}/activities/{uuid4()}',
@ -124,7 +206,7 @@ class Message(ApMessage):
@classmethod
def new_follow(cls: type[Message], host: str, actor: str) -> Message:
def new_follow(cls: Type[Message], host: str, actor: str) -> Message:
return cls({
'@context': 'https://www.w3.org/ns/activitystreams',
'type': 'Follow',
@ -136,7 +218,7 @@ class Message(ApMessage):
@classmethod
def new_unfollow(cls: type[Message], host: str, actor: str, follow: str) -> Message:
def new_unfollow(cls: Type[Message], host: str, actor: str, follow: str) -> Message:
return cls({
'@context': 'https://www.w3.org/ns/activitystreams',
'id': f'https://{host}/activities/{uuid4()}',
@ -148,7 +230,7 @@ class Message(ApMessage):
@classmethod
def new_response(cls: type[Message],
def new_response(cls: Type[Message],
host: str,
actor: str,
followid: str,
@ -180,17 +262,12 @@ class Message(ApMessage):
class Response(AiohttpResponse):
# AiohttpResponse.__len__ method returns 0, so bool(response) always returns False
def __bool__(self) -> bool:
return True
@classmethod
def new(cls: type[Response],
body: str | bytes | dict = '',
status: int = 200,
headers: dict[str, str] | None = None,
ctype: str = 'text') -> Response:
def new(cls: Type[Response],
body: Optional[str | bytes | dict] = '',
status: Optional[int] = 200,
headers: Optional[dict[str, str]] = None,
ctype: Optional[str] = 'text') -> Response:
kwargs = {
'status': status,
@ -211,7 +288,7 @@ class Response(AiohttpResponse):
@classmethod
def new_error(cls: type[Response],
def new_error(cls: Type[Response],
status: int,
body: str | bytes | dict,
ctype: str = 'text') -> Response:
@ -233,11 +310,23 @@ class Response(AiohttpResponse):
class View(AbstractView):
def __await__(self) -> Generator[Response]:
if (self.request.method) not in METHODS:
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
def __init__(self, request: AiohttpRequest):
AbstractView.__init__(self, request)
if not (handler := self.handlers.get(self.request.method)):
self.signature: Signature = None
self.message: Message = None
self.actor: Message = None
self.instance: dict[str, str] = None
self.signer: Signer = None
def __await__(self) -> Generator[Response]:
method = self.request.method.upper()
if method not in METHODS:
raise HTTPMethodNotAllowed(method, self.allowed_methods)
if not (handler := self.handlers.get(method)):
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) from None
return handler(self.request, **self.request.match_info).__await__()
@ -274,10 +363,94 @@ class View(AbstractView):
@property
def config(self) -> Config:
def config(self) -> RelayConfig:
return self.app.config
@property
def database(self) -> Database:
def database(self) -> RelayDatabase:
return self.app.database
# todo: move to views.ActorView
async def get_post_data(self) -> Response | None:
try:
self.signature = Signature.new_from_signature(self.request.headers['signature'])
except KeyError:
logging.verbose('Missing signature header')
return Response.new_error(400, 'missing signature header', 'json')
try:
self.message = await self.request.json(loads = Message.parse)
except Exception:
traceback.print_exc()
logging.verbose('Failed to parse inbox message')
return Response.new_error(400, 'failed to parse message', 'json')
if self.message is None:
logging.verbose('empty message')
return Response.new_error(400, 'missing message', 'json')
if 'actor' not in self.message:
logging.verbose('actor not in message')
return Response.new_error(400, 'no actor in message', 'json')
self.actor = await self.client.get(self.signature.keyid, sign_headers = True)
if self.actor is None:
# ld signatures aren't handled atm, so just ignore it
if self.message.type == 'Delete':
logging.verbose('Instance sent a delete which cannot be handled')
return Response.new(status=202)
logging.verbose(f'Failed to fetch actor: {self.signature.keyid}')
return Response.new_error(400, 'failed to fetch actor', 'json')
try:
self.signer = self.actor.signer
except KeyError:
logging.verbose('Actor missing public key: %s', self.signature.keyid)
return Response.new_error(400, 'actor missing public key', 'json')
try:
self.validate_signature(await self.request.read())
except SignatureFailureError as e:
logging.verbose('signature validation failed for "%s": %s', self.actor.id, e)
return Response.new_error(401, str(e), 'json')
self.instance = self.database.get_inbox(self.actor.inbox)
def validate_signature(self, body: bytes) -> None:
headers = {key.lower(): value for key, value in self.request.headers.items()}
headers["(request-target)"] = " ".join([self.request.method.lower(), self.request.path])
if (digest := Digest.new_from_digest(headers.get("digest"))):
if not body:
raise SignatureFailureError("Missing body for digest verification")
if not digest.validate(body):
raise SignatureFailureError("Body digest does not match")
if self.signature.algorithm_type == "hs2019":
if "(created)" not in self.signature.headers:
raise SignatureFailureError("'(created)' header not used")
current_timestamp = HttpDate.new_utc().timestamp()
if self.signature.created > current_timestamp:
raise SignatureFailureError("Creation date after current date")
if current_timestamp > self.signature.expires:
raise SignatureFailureError("Expiration date before current date")
headers["(created)"] = self.signature.created
headers["(expires)"] = self.signature.expires
# pylint: disable=protected-access
if not self.signer._validate_signature(headers, self.signature):
raise SignatureFailureError("Signature does not match")

View file

@ -1,6 +1,5 @@
from __future__ import annotations
import tinysql
import typing
from cachetools import LRUCache
@ -9,7 +8,7 @@ from . import logger as logging
from .misc import Message
if typing.TYPE_CHECKING:
from .views import ActorView
from .misc import View
cache = LRUCache(1024)
@ -17,141 +16,128 @@ cache = LRUCache(1024)
def person_check(actor: str, software: str) -> bool:
# pleroma and akkoma may use Person for the actor type for some reason
# akkoma changed this in 3.6.0
if software in {'akkoma', 'pleroma'} and actor.id == f'https://{actor.domain}/relay':
return False
# make sure the actor is an application
## make sure the actor is an application
if actor.type != 'Application':
return True
return False
async def handle_relay(view: ActorView) -> None:
async def handle_relay(view: View) -> None:
if view.message.object_id in cache:
logging.verbose('already relayed %s', view.message.object_id)
return
message = Message.new_announce(view.config.domain, view.message.object_id)
message = Message.new_announce(view.config.host, view.message.object_id)
cache[view.message.object_id] = message.id
logging.debug('>> relay: %s', message)
with view.database.connection() as conn:
for inbox in conn.distill_inboxes(view.message):
view.app.push_message(inbox, message)
inboxes = view.database.distill_inboxes(view.message)
for inbox in inboxes:
view.app.push_message(inbox, message)
async def handle_forward(view: ActorView) -> None:
async def handle_forward(view: View) -> None:
if view.message.id in cache:
logging.verbose('already forwarded %s', view.message.id)
return
message = Message.new_announce(view.config.domain, view.message)
message = Message.new_announce(view.config.host, view.message)
cache[view.message.id] = message.id
logging.debug('>> forward: %s', message)
with view.database.connection() as conn:
for inbox in conn.distill_inboxes(view.message):
view.app.push_message(inbox, message)
inboxes = view.database.distill_inboxes(view.message)
for inbox in inboxes:
view.app.push_message(inbox, message)
async def handle_follow(view: ActorView) -> None:
async def handle_follow(view: View) -> None:
nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain)
software = nodeinfo.sw_name if nodeinfo else None
with view.database.connection() as conn:
# reject if software used by actor is banned
if conn.get_software_ban(software):
view.app.push_message(
view.actor.shared_inbox,
Message.new_response(
host = view.config.domain,
actor = view.actor.id,
followid = view.message.id,
accept = False
)
)
logging.verbose(
'Rejected follow from actor for using specific software: actor=%s, software=%s',
view.actor.id,
software
)
return
## reject if the actor is not an instance actor
if person_check(view.actor, software):
view.app.push_message(
view.actor.shared_inbox,
Message.new_response(
host = view.config.domain,
actor = view.actor.id,
followid = view.message.id,
accept = False
)
)
logging.verbose('Non-application actor tried to follow: %s', view.actor.id)
return
if conn.get_inbox(view.actor.shared_inbox):
data = {'followid': view.message.id}
statement = tinysql.Update('inboxes', data, inbox = view.actor.shared_inbox)
with conn.query(statement):
pass
else:
conn.put_inbox(
view.actor.domain,
view.actor.shared_inbox,
view.actor.id,
view.message.id,
software
)
## reject if software used by actor is banned
if view.config.is_banned_software(software):
view.app.push_message(
view.actor.shared_inbox,
Message.new_response(
host = view.config.domain,
host = view.config.host,
actor = view.actor.id,
followid = view.message.id,
accept = True
accept = False
)
)
# Are Akkoma and Pleroma the only two that expect a follow back?
# Ignoring only Mastodon for now
if software != 'mastodon':
view.app.push_message(
view.actor.shared_inbox,
Message.new_follow(
host = view.config.domain,
actor = view.actor.id
)
return logging.verbose(
'Rejected follow from actor for using specific software: actor=%s, software=%s',
view.actor.id,
software
)
## reject if the actor is not an instance actor
if person_check(view.actor, software):
view.app.push_message(
view.actor.shared_inbox,
Message.new_response(
host = view.config.host,
actor = view.actor.id,
followid = view.message.id,
accept = False
)
)
async def handle_undo(view: ActorView) -> None:
## If the object is not a Follow, forward it
if view.message.object['type'] != 'Follow':
await handle_forward(view)
logging.verbose('Non-application actor tried to follow: %s', view.actor.id)
return
with view.database.connection() as conn:
if not conn.del_inbox(view.actor.id):
logging.verbose(
'Failed to delete "%s" with follow ID "%s"',
view.actor.id,
view.message.object['id']
view.database.add_inbox(view.actor.shared_inbox, view.message.id, software)
view.database.save()
view.app.push_message(
view.actor.shared_inbox,
Message.new_response(
host = view.config.host,
actor = view.actor.id,
followid = view.message.id,
accept = True
)
)
# Are Akkoma and Pleroma the only two that expect a follow back?
# Ignoring only Mastodon for now
if software != 'mastodon':
view.app.push_message(
view.actor.shared_inbox,
Message.new_follow(
host = view.config.host,
actor = view.actor.id
)
)
async def handle_undo(view: View) -> None:
## If the object is not a Follow, forward it
if view.message.object['type'] != 'Follow':
return await handle_forward(view)
if not view.database.del_inbox(view.actor.domain, view.message.object['id']):
logging.verbose(
'Failed to delete "%s" with follow ID "%s"',
view.actor.id,
view.message.object['id']
)
return
view.database.save()
view.app.push_message(
view.actor.shared_inbox,
Message.new_unfollow(
host = view.config.domain,
host = view.config.host,
actor = view.actor.id,
follow = view.message
)
@ -168,7 +154,7 @@ processors = {
}
async def run_processor(view: ActorView) -> None:
async def run_processor(view: View) -> None:
if view.message.type not in processors:
logging.verbose(
'Message type "%s" from actor cannot be handled: %s',
@ -178,21 +164,12 @@ async def run_processor(view: ActorView) -> None:
return
if view.instance:
if not view.instance['software']:
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])):
with view.database.connection() as conn:
view.instance = conn.update_inbox(
view.instance['inbox'],
software = nodeinfo.sw_name
)
if view.instance and not view.instance.get('software'):
nodeinfo = await view.client.fetch_nodeinfo(view.instance['domain'])
if not view.instance['actor']:
with view.database.connection() as conn:
view.instance = conn.update_inbox(
view.instance['inbox'],
actor = view.actor.id
)
if nodeinfo:
view.instance['software'] = nodeinfo.sw_name
view.database.save()
logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id)
await processors[view.message.type](view)

View file

@ -2,11 +2,8 @@ from __future__ import annotations
import asyncio
import subprocess
import traceback
import typing
from aputils.errors import SignatureFailureError
from aputils.misc import Digest, HttpDate, Signature
from aputils.objects import Nodeinfo, Webfinger, WellKnownNodeinfo
from pathlib import Path
@ -17,8 +14,7 @@ from .processors import run_processor
if typing.TYPE_CHECKING:
from aiohttp.web import Request
from aputils.signer import Signer
from collections.abc import Callable
from typing import Callable
VIEWS = []
@ -75,16 +71,12 @@ def register_route(*paths: str) -> Callable:
@register_route('/')
class HomeView(View):
async def get(self, request: Request) -> Response:
with self.database.connection() as conn:
config = conn.get_config_all()
inboxes = conn.execute('SELECT * FROM inboxes').all()
text = HOME_TEMPLATE.format(
host = self.config.domain,
note = config['note'],
count = len(inboxes),
targets = '<br>'.join(inbox['domain'] for inbox in inboxes)
)
text = HOME_TEMPLATE.format(
host = self.config.host,
note = self.config.note,
count = len(self.database.hostnames),
targets = '<br>'.join(self.database.hostnames)
)
return Response.new(text, ctype='html')
@ -92,137 +84,44 @@ class HomeView(View):
@register_route('/actor', '/inbox')
class ActorView(View):
def __init__(self, request: Request):
View.__init__(self, request)
self.signature: Signature = None
self.message: Message = None
self.actor: Message = None
self.instance: dict[str, str] = None
self.signer: Signer = None
async def get(self, request: Request) -> Response:
data = Message.new_actor(
host = self.config.domain,
pubkey = self.app.signer.pubkey
host = self.config.host,
pubkey = self.database.signer.pubkey
)
return Response.new(data, ctype='activity')
async def post(self, request: Request) -> Response:
if response := await self.get_post_data():
response = await self.get_post_data()
if response is not None:
return response
with self.database.connection() as conn:
self.instance = conn.get_inbox(self.actor.shared_inbox)
config = conn.get_config_all()
## reject if the actor isn't whitelisted while the whiltelist is enabled
if self.config.whitelist_enabled and not self.config.is_whitelisted(self.actor.domain):
logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id)
return Response.new_error(403, 'access denied', 'json')
## reject if the actor isn't whitelisted while the whiltelist is enabled
if config['whitelist-enabled'] and not conn.get_domain_whitelist(self.actor.domain):
logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id)
return Response.new_error(403, 'access denied', 'json')
## reject if actor is banned
if self.config.is_banned(self.actor.domain):
logging.verbose('Ignored request from banned actor: %s', self.actor.id)
return Response.new_error(403, 'access denied', 'json')
## reject if actor is banned
if conn.get_domain_ban(self.actor.domain):
logging.verbose('Ignored request from banned actor: %s', self.actor.id)
return Response.new_error(403, 'access denied', 'json')
## reject if activity type isn't 'Follow' and the actor isn't following
if self.message.type != 'Follow' and not self.database.get_inbox(self.actor.domain):
logging.verbose(
'Rejected actor for trying to post while not following: %s',
self.actor.id
)
## reject if activity type isn't 'Follow' and the actor isn't following
if self.message.type != 'Follow' and not self.instance:
logging.verbose(
'Rejected actor for trying to post while not following: %s',
self.actor.id
)
return Response.new_error(401, 'access denied', 'json')
return Response.new_error(401, 'access denied', 'json')
logging.debug('>> payload %s', self.message.to_json(4))
logging.debug('>> payload %s', self.message.to_json(4))
asyncio.ensure_future(run_processor(self))
return Response.new(status = 202)
async def get_post_data(self) -> Response | None:
try:
self.signature = Signature.new_from_signature(self.request.headers['signature'])
except KeyError:
logging.verbose('Missing signature header')
return Response.new_error(400, 'missing signature header', 'json')
try:
self.message = await self.request.json(loads = Message.parse)
except Exception:
traceback.print_exc()
logging.verbose('Failed to parse inbox message')
return Response.new_error(400, 'failed to parse message', 'json')
if self.message is None:
logging.verbose('empty message')
return Response.new_error(400, 'missing message', 'json')
if 'actor' not in self.message:
logging.verbose('actor not in message')
return Response.new_error(400, 'no actor in message', 'json')
self.actor = await self.client.get(self.signature.keyid, sign_headers = True)
if not self.actor:
# ld signatures aren't handled atm, so just ignore it
if self.message.type == 'Delete':
logging.verbose('Instance sent a delete which cannot be handled')
return Response.new(status=202)
logging.verbose(f'Failed to fetch actor: {self.signature.keyid}')
return Response.new_error(400, 'failed to fetch actor', 'json')
try:
self.signer = self.actor.signer
except KeyError:
logging.verbose('Actor missing public key: %s', self.signature.keyid)
return Response.new_error(400, 'actor missing public key', 'json')
try:
self.validate_signature(await self.request.read())
except SignatureFailureError as e:
logging.verbose('signature validation failed for "%s": %s', self.actor.id, e)
return Response.new_error(401, str(e), 'json')
def validate_signature(self, body: bytes) -> None:
headers = {key.lower(): value for key, value in self.request.headers.items()}
headers["(request-target)"] = " ".join([self.request.method.lower(), self.request.path])
if (digest := Digest.new_from_digest(headers.get("digest"))):
if not body:
raise SignatureFailureError("Missing body for digest verification")
if not digest.validate(body):
raise SignatureFailureError("Body digest does not match")
if self.signature.algorithm_type == "hs2019":
if "(created)" not in self.signature.headers:
raise SignatureFailureError("'(created)' header not used")
current_timestamp = HttpDate.new_utc().timestamp()
if self.signature.created > current_timestamp:
raise SignatureFailureError("Creation date after current date")
if current_timestamp > self.signature.expires:
raise SignatureFailureError("Expiration date before current date")
headers["(created)"] = self.signature.created
headers["(expires)"] = self.signature.expires
# pylint: disable=protected-access
if not self.signer._validate_signature(headers, self.signature):
raise SignatureFailureError("Signature does not match")
asyncio.ensure_future(run_processor(self))
return Response.new(status = 202)
@register_route('/.well-known/webfinger')
@ -234,12 +133,12 @@ class WebfingerView(View):
except KeyError:
return Response.new_error(400, 'missing "resource" query key', 'json')
if subject != f'acct:relay@{self.config.domain}':
if subject != f'acct:relay@{self.config.host}':
return Response.new_error(404, 'user not found', 'json')
data = Webfinger.new(
handle = 'relay',
domain = self.config.domain,
domain = self.config.host,
actor = self.config.actor
)
@ -249,17 +148,14 @@ class WebfingerView(View):
@register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}')
class NodeinfoView(View):
async def get(self, request: Request, niversion: str) -> Response:
with self.database.connection() as conn:
inboxes = conn.execute('SELECT * FROM inboxes').all()
data = {
'name': 'activityrelay',
'version': VERSION,
'protocols': ['activitypub'],
'open_regs': not conn.get_config('whitelist-enabled'),
'users': 1,
'metadata': {'peers': [inbox['domain'] for inbox in inboxes]}
}
data = {
'name': 'activityrelay',
'version': VERSION,
'protocols': ['activitypub'],
'open_regs': not self.config.whitelist_enabled,
'users': 1,
'metadata': {'peers': self.database.hostnames}
}
if niversion == '2.1':
data['repo'] = 'https://git.pleroma.social/pleroma/relay'
@ -270,5 +166,5 @@ class NodeinfoView(View):
@register_route('/.well-known/nodeinfo')
class WellknownNodeinfoView(View):
async def get(self, request: Request) -> Response:
data = WellKnownNodeinfo.new_template(self.config.domain)
data = WellKnownNodeinfo.new_template(self.config.host)
return Response.new(data, ctype = 'json')

View file

@ -3,4 +3,3 @@ aputils@https://git.barkshark.xyz/barkshark/aputils/archive/0.1.6a.tar.gz
cachetools>=5.2.0
click>=8.1.2
pyyaml>=6.0
tinysql@https://git.barkshark.xyz/barkshark/tinysql/archive/0.2.2a.tar.gz

View file

@ -10,30 +10,26 @@ license_file = LICENSE
classifiers =
Environment :: Console
License :: OSI Approved :: AGPLv3 License
Programming Language :: Python :: 3.6
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
Programming Language :: Python :: 3.12
project_urls =
Source = https://git.pleroma.social/pleroma/relay
Tracker = https://git.pleroma.social/pleroma/relay/-/issues
[options]
zip_safe = False
packages =
relay
relay.database
include_package_data = true
packages = find:
install_requires = file: requirements.txt
python_requires = >=3.8
[options.extras_require]
dev = file: dev-requirements.txt
[options.package_data]
relay =
data/statements.sql
dev =
flake8 = 3.1.0
pyinstaller = 6.3.0
pylint = 3.0
[options.entry_points]
console_scripts =
@ -41,4 +37,7 @@ console_scripts =
[flake8]
select = F401
extend-ignore = ANN101,ANN204,E128,E251,E261,E266,E301,E303,W191
extend-exclude = docs, test*.py
max-line-length = 100
indent-size = 4