Compare commits

..

No commits in common. "81215a83a4a99d777ffc9d1e1899d863574656e5" and "965ac73c6d84caf0fff7d316b8ae67c4a1473140" have entirely different histories.

26 changed files with 993 additions and 1838 deletions

10
.gitignore vendored
View file

@ -94,7 +94,9 @@ ENV/
# Rope project settings # Rope project settings
.ropeproject .ropeproject
# config and database viera.yaml
*.yaml viera.jsonld
*.jsonld
*.sqlite3 # config file
relay.yaml
relay.jsonld

View file

@ -1 +0,0 @@
include data/statements.sql

View file

@ -1,3 +0,0 @@
flake8 == 7.0.0
pyinstaller == 6.3.0
pylint == 3.0

View file

@ -3,8 +3,11 @@
There are a number of commands to manage your relay's database and config. You can add `--help` to There are a number of commands to manage your relay's database and config. You can add `--help` to
any category or command to get help on that specific option (ex. `activityrelay inbox --help`). any category or command to get help on that specific option (ex. `activityrelay inbox --help`).
Note: `activityrelay` is only available via pip or pipx if `~/.local/bin` is in `$PATH`. If not, Note: Unless specified, it is recommended to run any commands while the relay is shutdown.
use `python3 -m relay` if installed via pip or `~/.local/bin/activityrelay` if installed via pipx.
Note 2: `activityrelay` is only available via pip or pipx if `~/.local/bin` is in `$PATH`. If it
isn't, use `python3 -m relay` if installed via pip or `~/.local/bin/activityrelay` if installed
via pipx
## Run ## Run
@ -21,22 +24,6 @@ Run the setup wizard to configure your relay.
activityrelay setup activityrelay setup
## Convert
Convert the old config and jsonld to the new config and SQL backend. If the old config filename is
not specified, the config will get backed up as `relay.backup.yaml` before converting.
activityrelay convert --old-config relaycfg.yaml
## Edit Config
Open the config file in a text editor. If an editor is not specified with `--editor`, the default
editor will be used.
activityrelay edit-config --editor micro
## Config ## Config
Manage the relay config Manage the relay config
@ -133,7 +120,7 @@ Remove a domain from the whitelist.
### Import ### Import
Add all current inboxes to the whitelist. Add all current inboxes to the whitelist
activityrelay whitelist import activityrelay whitelist import
@ -145,15 +132,15 @@ Manage the instance ban list.
### List ### List
List the currently banned instances. List the currently banned instances
activityrelay instance list activityrelay instance list
### Ban ### Ban
Add an instance to the ban list. If the instance is currently subscribed, it will be removed from Add an instance to the ban list. If the instance is currently subscribed, remove it from the
the inbox list. database.
activityrelay instance ban <domain> activityrelay instance ban <domain>
@ -165,17 +152,10 @@ Remove an instance from the ban list.
activityrelay instance unban <domain> activityrelay instance unban <domain>
### Update
Update the ban reason or note for an instance ban.
activityrelay instance update bad.example.com --reason "the baddest reason"
## Software ## Software
Manage the software ban list. To get the correct name, check the software's nodeinfo endpoint. Manage the software ban list. To get the correct name, check the software's nodeinfo endpoint.
You can find it at `nodeinfo['software']['name']`. You can find it at nodeinfo\['software']\['name'].
### List ### List
@ -206,12 +186,4 @@ name via nodeinfo.
If the name is `RELAYS` (case-sensitive), remove all known relay software names from the list. If the name is `RELAYS` (case-sensitive), remove all known relay software names from the list.
activityrelay software unban [-f/--fetch-nodeinfo] <name, domain, or RELAYS> activityrelay unban [-f/--fetch-nodeinfo] <name, domain, or RELAYS>
### Update
Update the ban reason or note for a software ban. Either `--reason` and/or `--note` must be
specified.
activityrelay software update relay.example.com --reason "begone relay"

View file

@ -2,23 +2,41 @@
## General ## General
### Domain ### DB
Hostname the relay will be hosted on. The path to the database. It contains the relay actor private key and all subscribed
instances. If the path is not absolute, it is relative to the working directory.
domain: relay.example.com db: relay.jsonld
### Listener ### Listener
The address and port the relay will listen on. If the reverse proxy (nginx, apache, caddy, etc) The address and port the relay will listen on. If the reverse proxy (nginx, apache, caddy, etc)
is running on the same host, it is recommended to change `listen` to `localhost` if the reverse is running on the same host, it is recommended to change `listen` to `localhost`
proxy is on the same host.
listen: 0.0.0.0 listen: 0.0.0.0
port: 8080 port: 8080
### Note
A small blurb to describe your relay instance. This will show up on the relay's home page.
note: "Make a note about your instance here."
### Post Limit
The maximum number of messages to send out at once. For each incoming message, a message will be
sent out to every subscribed instance minus the instance which sent the message. This limit
is to prevent too many outgoing connections from being made, so adjust if necessary.
Note: If the `workers` option is set to anything above 0, this limit will be per worker.
push_limit: 512
### Push Workers ### Push Workers
The relay can be configured to use threads to push messages out. For smaller relays, this isn't The relay can be configured to use threads to push messages out. For smaller relays, this isn't
@ -28,59 +46,60 @@ threads.
workers: 0 workers: 0
### Database type ### JSON GET cache limit
SQL database backend to use. Valid values are `sqlite` or `postgres`. JSON objects (actors, nodeinfo, etc) will get cached when fetched. This will set the max number of
objects to keep in the cache.
database_type: sqlite json_cache: 1024
### Sqlite File Path ## AP
Path to the sqlite database file. If the path is not absolute, it is relative to the config file. Various ActivityPub-related settings
directory.
sqlite_path: relay.jsonld
## Postgresql
In order to use the Postgresql backend, the user and database need to be created first.
sudo -u postgres psql -c "CREATE USER activityrelay"
sudo -u postgres psql -c "CREATE DATABASE activityrelay OWNER activityrelay"
### Database Name
Name of the database to use.
name: activityrelay
### Host ### Host
Hostname, IP address, or unix socket the server is hosted on. The domain your relay will use to identify itself.
host: /var/run/postgresql host: relay.example.com
### Port ### Whitelist Enabled
Port number the server is listening on. If set to `true`, only instances in the whitelist can follow the relay. Any subscribed instances
not in the whitelist will be removed from the inbox list on startup.
port: 5432 whitelist_enabled: false
### Username ### Whitelist
User to use when logging into the server. A list of domains of instances which are allowed to subscribe to your relay.
user: null whitelist:
- bad-instance.example.com
- another-bad-instance.example.com
### Password ### Blocked Instances
Password for the specified user. A list of instances which are unable to follow the instance. If a subscribed instance is added to
the block list, it will be removed from the inbox list on startup.
pass: null blocked_instances:
- bad-instance.example.com
- another-bad-instance.example.com
### Blocked Software
A list of ActivityPub software which cannot follow your relay. This list is empty by default, but
setting this to the below list will block all other relays and prevent relay chains
blocked_software:
- activityrelay
- aoderelay
- social.seattle.wa.us-relay
- unciarelay

View file

@ -3,7 +3,7 @@ Description=ActivityPub Relay
[Service] [Service]
WorkingDirectory=/home/relay/relay WorkingDirectory=/home/relay/relay
ExecStart=/usr/bin/python3 -m relay run ExecStart=/usr/bin/python3 -m relay
[Install] [Install]
WantedBy=multi-user.target WantedBy=multi-user.target

View file

@ -6,23 +6,6 @@ build-backend = 'setuptools.build_meta'
[tool.pylint.main] [tool.pylint.main]
jobs = 0 jobs = 0
persistent = true persistent = true
load-plugins = [
"pylint.extensions.code_style",
"pylint.extensions.comparison_placement",
"pylint.extensions.confusing_elif",
"pylint.extensions.for_any_all",
"pylint.extensions.consider_ternary_expression",
"pylint.extensions.bad_builtin",
"pylint.extensions.dict_init_mutate",
"pylint.extensions.check_elif",
"pylint.extensions.empty_comment",
"pylint.extensions.private_import",
"pylint.extensions.redefined_variable_type",
"pylint.extensions.no_self_use",
"pylint.extensions.overlapping_exceptions",
"pylint.extensions.set_membership",
"pylint.extensions.typing"
]
[tool.pylint.design] [tool.pylint.design]
@ -39,7 +22,6 @@ single-line-if-stmt = true
[tool.pylint.messages_control] [tool.pylint.messages_control]
disable = [ disable = [
"fixme",
"broad-exception-caught", "broad-exception-caught",
"cyclic-import", "cyclic-import",
"global-statement", "global-statement",
@ -49,8 +31,7 @@ disable = [
"too-many-public-methods", "too-many-public-methods",
"too-many-return-statements", "too-many-return-statements",
"wrong-import-order", "wrong-import-order",
"wrong-import-position",
"missing-function-docstring", "missing-function-docstring",
"missing-class-docstring", "missing-class-docstring"
"consider-using-namedtuple-or-dataclass",
"confusing-consecutive-elif"
] ]

View file

@ -5,44 +5,40 @@ block_cipher = None
a = Analysis( a = Analysis(
['relay/__main__.py'], ['relay/__main__.py'],
pathex=[], pathex=[],
binaries=[], binaries=[],
datas=[ datas=[],
('relay/data', 'relay/data') hiddenimports=[],
], hookspath=[],
hiddenimports=[], hooksconfig={},
hookspath=[], runtime_hooks=[],
hooksconfig={}, excludes=[],
runtime_hooks=[], win_no_prefer_redirects=False,
excludes=[], win_private_assemblies=False,
win_no_prefer_redirects=False, cipher=block_cipher,
win_private_assemblies=False, noarchive=False,
cipher=block_cipher,
noarchive=False,
) )
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher) pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE( exe = EXE(
pyz, pyz,
a.scripts, a.scripts,
a.binaries, a.binaries,
a.zipfiles, a.zipfiles,
a.datas, a.datas,
[], [],
name='activityrelay', name='activityrelay',
icon=None, debug=False,
debug=False, bootloader_ignore_signals=False,
bootloader_ignore_signals=False, strip=False,
strip=False, upx=True,
upx=True, upx_exclude=[],
upx_exclude=[], runtime_tmpdir=None,
runtime_tmpdir=None, console=True,
console=True, disable_windowed_traceback=False,
disable_windowed_traceback=False, argv_emulation=False,
argv_emulation=False, target_arch=None,
target_arch=None, codesign_identity=None,
codesign_identity=None, entitlements_file=None,
entitlements_file=None,
) )

View file

@ -1,35 +1,43 @@
# [string] Domain the relay will be hosted on # this is the path that the object graph will get dumped to (in JSON-LD format),
domain: relay.example.com # you probably shouldn't change it, but you can if you want.
db: relay.jsonld
# [string] Address the relay will listen on # Listener
listen: 0.0.0.0 listen: 0.0.0.0
# [integer] Port the relay will listen on
port: 8080 port: 8080
# [integer] Number of push workers to start (will get removed in a future update) # Note
workers: 8 note: "Make a note about your instance here."
# [string] Database backend to use. Valid values: sqlite, postgres # Number of worker threads to start. If 0, use asyncio futures instead of threads.
database_type: sqlite workers: 0
# [string] Path to the sqlite database file if the sqlite backend is in use # Maximum number of inbox posts to do at once
sqlite_path: relay.sqlite3 # If workers is set to 1 or above, this is the max for each worker
push_limit: 512
# settings for the postgresql backend # The amount of json objects to cache from GET requests
postgres: json_cache: 1024
# [string] hostname or unix socket to connect to ap:
host: /var/run/postgresql # This is used for generating activitypub messages, as well as instructions for
# linking AP identities. It should be an SSL-enabled domain reachable by https.
host: 'relay.example.com'
# [integer] port of the server blocked_instances:
port: 5432 - 'bad-instance.example.com'
- 'another-bad-instance.example.com'
# [string] username to use when logging into the server (default is the current system username) whitelist_enabled: false
user: null
# [string] password of the user whitelist:
pass: null - 'good-instance.example.com'
- 'another.good-instance.example.com'
# [string] name of the database to use # uncomment the lines below to prevent certain activitypub software from posting
name: activityrelay # to the relay (all known relays by default). this uses the software name in nodeinfo
#blocked_software:
#- 'activityrelay'
#- 'aoderelay'
#- 'social.seattle.wa.us-relay'
#- 'unciarelay'

View file

@ -8,41 +8,52 @@ import traceback
import typing import typing
from aiohttp import web from aiohttp import web
from aputils.signer import Signer
from datetime import datetime, timedelta from datetime import datetime, timedelta
from . import logger as logging from . import logger as logging
from .config import Config from .config import RelayConfig
from .database import get_database from .database import RelayDatabase
from .http_client import HttpClient from .http_client import HttpClient
from .misc import check_open_port from .misc import check_open_port
from .views import VIEWS from .views import VIEWS
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from tinysql import Database
from typing import Any from typing import Any
from .misc import Message from .misc import Message
# pylint: disable=unsubscriptable-object # pylint: disable=unsubscriptable-object
class Application(web.Application):
DEFAULT: Application = None
class Application(web.Application):
def __init__(self, cfgpath: str): def __init__(self, cfgpath: str):
web.Application.__init__(self) web.Application.__init__(self)
Application.DEFAULT = self
self['signer'] = None
self['config'] = Config(cfgpath, load = True)
self['database'] = get_database(self.config)
self['client'] = HttpClient()
self['workers'] = [] self['workers'] = []
self['last_worker'] = 0 self['last_worker'] = 0
self['start_time'] = None self['start_time'] = None
self['running'] = False self['running'] = False
self['config'] = RelayConfig(cfgpath)
if not self.config.load():
self.config.save()
if self.config.is_docker:
self.config.update({
'db': '/data/relay.jsonld',
'listen': '0.0.0.0',
'port': 8080
})
self['database'] = RelayDatabase(self.config)
self.database.load()
self['client'] = HttpClient(
database = self.database,
limit = self.config.push_limit,
timeout = self.config.timeout,
cache_size = self.config.json_cache
)
for path, view in VIEWS: for path, view in VIEWS:
self.router.add_view(path, view) self.router.add_view(path, view)
@ -54,29 +65,15 @@ class Application(web.Application):
@property @property
def config(self) -> Config: def config(self) -> RelayConfig:
return self['config'] return self['config']
@property @property
def database(self) -> Database: def database(self) -> RelayDatabase:
return self['database'] return self['database']
@property
def signer(self) -> Signer:
return self['signer']
@signer.setter
def signer(self, value: Signer | str) -> None:
if isinstance(value, Signer):
self['signer'] = value
return
self['signer'] = Signer(value, self.config.keyid)
@property @property
def uptime(self) -> timedelta: def uptime(self) -> timedelta:
if not self['start_time']: if not self['start_time']:
@ -121,7 +118,7 @@ class Application(web.Application):
logging.info( logging.info(
'Starting webserver at %s (%s:%i)', 'Starting webserver at %s (%s:%i)',
self.config.domain, self.config.host,
self.config.listen, self.config.listen,
self.config.port self.config.port
) )
@ -182,7 +179,12 @@ class PushWorker(threading.Thread):
async def handle_queue(self) -> None: async def handle_queue(self) -> None:
self.client = HttpClient() self.client = HttpClient(
database = self.app.database,
limit = self.app.config.push_limit,
timeout = self.app.config.timeout,
cache_size = self.app.config.json_cache
)
while self.app['running']: while self.app['running']:
try: try:

View file

@ -1,76 +1,76 @@
from __future__ import annotations from __future__ import annotations
import getpass
import os import os
import typing import typing
import yaml import yaml
from functools import cached_property
from pathlib import Path from pathlib import Path
from urllib.parse import urlparse
from .misc import IS_DOCKER from .misc import DotDict, boolean
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from typing import Any from typing import Any
from .database import RelayDatabase
DEFAULTS: dict[str, Any] = { RELAY_SOFTWARE = [
'listen': '0.0.0.0', 'activityrelay', # https://git.pleroma.social/pleroma/relay
'port': 8080, 'aoderelay', # https://git.asonix.dog/asonix/relay
'domain': 'relay.example.com', 'feditools-relay' # https://git.ptzo.gdn/feditools/relay
'workers': len(os.sched_getaffinity(0)), ]
'db_type': 'sqlite',
'sq_path': 'relay.sqlite3',
'pg_host': '/var/run/postgresql',
'pg_port': 5432,
'pg_user': getpass.getuser(),
'pg_pass': None,
'pg_name': 'activityrelay'
}
if IS_DOCKER: APKEYS = [
DEFAULTS['sq_path'] = '/data/relay.jsonld' 'host',
'whitelist_enabled',
'blocked_software',
'blocked_instances',
'whitelist'
]
class Config: class RelayConfig(DotDict):
def __init__(self, path: str, load: bool = False): __slots__ = ('path', )
self.path = Path(path).expanduser().resolve()
self.listen = None def __init__(self, path: str | Path):
self.port = None DotDict.__init__(self, {})
self.domain = None
self.workers = None
self.db_type = None
self.sq_path = None
self.pg_host = None
self.pg_port = None
self.pg_user = None
self.pg_pass = None
self.pg_name = None
if load: if self.is_docker:
try: path = '/data/config.yaml'
self.load()
except FileNotFoundError: self._path = Path(path).expanduser().resolve()
self.save() self.reset()
def __setitem__(self, key: str, value: Any) -> None:
if key in ['blocked_instances', 'blocked_software', 'whitelist']:
assert isinstance(value, (list, set, tuple))
elif key in ['port', 'workers', 'json_cache', 'timeout']:
if not isinstance(value, int):
value = int(value)
elif key == 'whitelist_enabled':
if not isinstance(value, bool):
value = boolean(value)
super().__setitem__(key, value)
@property @property
def sqlite_path(self) -> Path: def db(self) -> RelayDatabase:
if not os.path.isabs(self.sq_path): return Path(self['db']).expanduser().resolve()
return self.path.parent.joinpath(self.sq_path).resolve()
return Path(self.sq_path).expanduser().resolve()
@property @property
def actor(self) -> str: def actor(self) -> str:
return f'https://{self.domain}/actor' return f'https://{self.host}/actor'
@property @property
def inbox(self) -> str: def inbox(self) -> str:
return f'https://{self.domain}/inbox' return f'https://{self.host}/inbox'
@property @property
@ -78,7 +78,115 @@ class Config:
return f'{self.actor}#main-key' return f'{self.actor}#main-key'
def load(self) -> None: @cached_property
def is_docker(self) -> bool:
return bool(os.environ.get('DOCKER_RUNNING'))
def reset(self) -> None:
self.clear()
self.update({
'db': str(self._path.parent.joinpath(f'{self._path.stem}.jsonld')),
'listen': '0.0.0.0',
'port': 8080,
'note': 'Make a note about your instance here.',
'push_limit': 512,
'json_cache': 1024,
'timeout': 10,
'workers': 0,
'host': 'relay.example.com',
'whitelist_enabled': False,
'blocked_software': [],
'blocked_instances': [],
'whitelist': []
})
def ban_instance(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
if self.is_banned(instance):
return False
self.blocked_instances.append(instance)
return True
def unban_instance(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
try:
self.blocked_instances.remove(instance)
return True
except ValueError:
return False
def ban_software(self, software: str) -> bool:
if self.is_banned_software(software):
return False
self.blocked_software.append(software)
return True
def unban_software(self, software: str) -> bool:
try:
self.blocked_software.remove(software)
return True
except ValueError:
return False
def add_whitelist(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
if self.is_whitelisted(instance):
return False
self.whitelist.append(instance)
return True
def del_whitelist(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
try:
self.whitelist.remove(instance)
return True
except ValueError:
return False
def is_banned(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
return instance in self.blocked_instances
def is_banned_software(self, software: str) -> bool:
if not software:
return False
return software.lower() in self.blocked_software
def is_whitelisted(self, instance: str) -> bool:
if instance.startswith('http'):
instance = urlparse(instance).hostname
return instance in self.whitelist
def load(self) -> bool:
self.reset() self.reset()
options = {} options = {}
@ -89,69 +197,50 @@ class Config:
except AttributeError: except AttributeError:
pass pass
with self.path.open('r', encoding = 'UTF-8') as fd: try:
config = yaml.load(fd, **options) with self._path.open('r', encoding = 'UTF-8') as fd:
pgcfg = config.get('postgresql', {}) config = yaml.load(fd, **options)
except FileNotFoundError:
return False
if not config: if not config:
raise ValueError('Config is empty') return False
if IS_DOCKER: for key, value in config.items():
self.listen = '0.0.0.0' if key in ['ap']:
self.port = 8080 for k, v in value.items():
self.sq_path = '/data/relay.jsonld' if k not in self:
continue
else: self[k] = v
self.set('listen', config.get('listen', DEFAULTS['listen']))
self.set('port', config.get('port', DEFAULTS['port']))
self.set('sq_path', config.get('sqlite_path', DEFAULTS['sq_path']))
self.set('domain', config.get('domain', DEFAULTS['domain']))
self.set('db_type', config.get('database_type', DEFAULTS['db_type']))
for key in DEFAULTS:
if not key.startswith('pg'):
continue continue
try: if key not in self:
self.set(key, pgcfg[key[3:]])
except KeyError:
continue continue
self[key] = value
def reset(self) -> None: if self.host.endswith('example.com'):
for key, value in DEFAULTS.items(): return False
setattr(self, key, value)
return True
def save(self) -> None: def save(self) -> None:
self.path.parent.mkdir(exist_ok = True, parents = True)
config = { config = {
# just turning config.db into a string is good enough for now
'db': str(self.db),
'listen': self.listen, 'listen': self.listen,
'port': self.port, 'port': self.port,
'domain': self.domain, 'note': self.note,
'database_type': self.db_type, 'push_limit': self.push_limit,
'sqlite_path': self.sq_path, 'workers': self.workers,
'postgres': { 'json_cache': self.json_cache,
'host': self.pg_host, 'timeout': self.timeout,
'port': self.pg_port, 'ap': {key: self[key] for key in APKEYS}
'user': self.pg_user,
'pass': self.pg_pass,
'name': self.pg_name
}
} }
with self.path.open('w', encoding = 'utf-8') as fd: with self._path.open('w', encoding = 'utf-8') as fd:
yaml.dump(config, fd, sort_keys = False) yaml.dump(config, fd, sort_keys=False)
def set(self, key: str, value: Any) -> None:
if key not in DEFAULTS:
raise KeyError(key)
if key in {'port', 'pg_port', 'workers'} and not isinstance(value, int):
value = int(value)
setattr(self, key, value)

View file

@ -1,79 +0,0 @@
-- name: get-config
SELECT * FROM config WHERE key = :key
-- name: get-config-all
SELECT * FROM config
-- name: put-config
INSERT INTO config (key, value, type)
VALUES (:key, :value, :type)
ON CONFLICT (key) DO UPDATE SET value = :value
RETURNING *
-- name: del-config
DELETE FROM config
WHERE key = :key
-- name: get-inbox
SELECT * FROM inboxes WHERE domain = :value or inbox = :value or actor = :value
-- name: put-inbox
INSERT INTO inboxes (domain, actor, inbox, followid, software, created)
VALUES (:domain, :actor, :inbox, :followid, :software, :created)
ON CONFLICT (domain) DO UPDATE SET followid = :followid
RETURNING *
-- name: del-inbox
DELETE FROM inboxes
WHERE domain = :value or inbox = :value or actor = :value
-- name: get-software-ban
SELECT * FROM software_bans WHERE name = :name
-- name: put-software-ban
INSERT INTO software_bans (name, reason, note, created)
VALUES (:name, :reason, :note, :created)
RETURNING *
-- name: del-software-ban
DELETE FROM software_bans
WHERE name = :name
-- name: get-domain-ban
SELECT * FROM domain_bans WHERE domain = :domain
-- name: put-domain-ban
INSERT INTO domain_bans (domain, reason, note, created)
VALUES (:domain, :reason, :note, :created)
RETURNING *
-- name: del-domain-ban
DELETE FROM domain_bans
WHERE domain = :domain
-- name: get-domain-whitelist
SELECT * FROM whitelist WHERE domain = :domain
-- name: put-domain-whitelist
INSERT INTO whitelist (domain, created)
VALUES (:domain, :created)
RETURNING *
-- name: del-domain-whitelist
DELETE FROM whitelist
WHERE domain = :domain

View file

@ -1,129 +1,17 @@
from __future__ import annotations from __future__ import annotations
import json import json
import os
import typing import typing
import yaml
from functools import cached_property from aputils.signer import Signer
from pathlib import Path
from urllib.parse import urlparse from urllib.parse import urlparse
from . import logger as logging from . import logger as logging
from .misc import Message, boolean
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from collections.abc import Iterator from typing import Iterator, Optional
from typing import Any from .config import RelayConfig
from .misc import Message
# pylint: disable=duplicate-code
class RelayConfig(dict):
def __init__(self, path: str):
dict.__init__(self, {})
if self.is_docker:
path = '/data/config.yaml'
self._path = Path(path).expanduser().resolve()
self.reset()
def __setitem__(self, key: str, value: Any) -> None:
if key in {'blocked_instances', 'blocked_software', 'whitelist'}:
assert isinstance(value, (list, set, tuple))
elif key in {'port', 'workers', 'json_cache', 'timeout'}:
if not isinstance(value, int):
value = int(value)
elif key == 'whitelist_enabled':
if not isinstance(value, bool):
value = boolean(value)
super().__setitem__(key, value)
@property
def db(self) -> RelayDatabase:
return Path(self['db']).expanduser().resolve()
@property
def actor(self) -> str:
return f'https://{self["host"]}/actor'
@property
def inbox(self) -> str:
return f'https://{self["host"]}/inbox'
@property
def keyid(self) -> str:
return f'{self.actor}#main-key'
@cached_property
def is_docker(self) -> bool:
return bool(os.environ.get('DOCKER_RUNNING'))
def reset(self) -> None:
self.clear()
self.update({
'db': str(self._path.parent.joinpath(f'{self._path.stem}.jsonld')),
'listen': '0.0.0.0',
'port': 8080,
'note': 'Make a note about your instance here.',
'push_limit': 512,
'json_cache': 1024,
'timeout': 10,
'workers': 0,
'host': 'relay.example.com',
'whitelist_enabled': False,
'blocked_software': [],
'blocked_instances': [],
'whitelist': []
})
def load(self) -> None:
self.reset()
options = {}
try:
options['Loader'] = yaml.FullLoader
except AttributeError:
pass
try:
with self._path.open('r', encoding = 'UTF-8') as fd:
config = yaml.load(fd, **options)
except FileNotFoundError:
return
if not config:
return
for key, value in config.items():
if key == 'ap':
for k, v in value.items():
if k not in self:
continue
self[k] = v
continue
if key not in self:
continue
self[key] = value
class RelayDatabase(dict): class RelayDatabase(dict):
@ -149,7 +37,9 @@ class RelayDatabase(dict):
return tuple(data['inbox'] for data in self['relay-list'].values()) return tuple(data['inbox'] for data in self['relay-list'].values())
def load(self) -> None: def load(self) -> bool:
new_db = True
try: try:
with self.config.db.open() as fd: with self.config.db.open() as fd:
data = json.load(fd) data = json.load(fd)
@ -175,9 +65,17 @@ class RelayDatabase(dict):
self['relay-list'] = data.get('relay-list', {}) self['relay-list'] = data.get('relay-list', {})
for domain, instance in self['relay-list'].items(): for domain, instance in self['relay-list'].items():
if self.config.is_banned(domain) or \
(self.config.whitelist_enabled and not self.config.is_whitelisted(domain)):
self.del_inbox(domain)
continue
if not instance.get('domain'): if not instance.get('domain'):
instance['domain'] = domain instance['domain'] = domain
new_db = False
except FileNotFoundError: except FileNotFoundError:
pass pass
@ -185,13 +83,24 @@ class RelayDatabase(dict):
if self.config.db.stat().st_size > 0: if self.config.db.stat().st_size > 0:
raise e from None raise e from None
if not self['private-key']:
logging.info('No actor keys present, generating 4096-bit RSA keypair.')
self.signer = Signer.new(self.config.keyid, size=4096)
self['private-key'] = self.signer.export()
else:
self.signer = Signer(self['private-key'], self.config.keyid)
self.save()
return not new_db
def save(self) -> None: def save(self) -> None:
with self.config.db.open('w', encoding = 'UTF-8') as fd: with self.config.db.open('w', encoding = 'UTF-8') as fd:
json.dump(self, fd, indent=4) json.dump(self, fd, indent=4)
def get_inbox(self, domain: str, fail: bool = False) -> dict[str, str] | None: def get_inbox(self, domain: str, fail: Optional[bool] = False) -> dict[str, str] | None:
if domain.startswith('http'): if domain.startswith('http'):
domain = urlparse(domain).hostname domain = urlparse(domain).hostname
@ -206,13 +115,14 @@ class RelayDatabase(dict):
def add_inbox(self, def add_inbox(self,
inbox: str, inbox: str,
followid: str | None = None, followid: Optional[str] = None,
software: str | None = None) -> dict[str, str]: software: Optional[str] = None) -> dict[str, str]:
assert inbox.startswith('https'), 'Inbox must be a url' assert inbox.startswith('https'), 'Inbox must be a url'
domain = urlparse(inbox).hostname domain = urlparse(inbox).hostname
instance = self.get_inbox(domain)
if (instance := self.get_inbox(domain)): if instance:
if followid: if followid:
instance['followid'] = followid instance['followid'] = followid
@ -234,10 +144,12 @@ class RelayDatabase(dict):
def del_inbox(self, def del_inbox(self,
domain: str, domain: str,
followid: str = None, followid: Optional[str] = None,
fail: bool = False) -> bool: fail: Optional[bool] = False) -> bool:
if not (data := self.get_inbox(domain, fail=False)): data = self.get_inbox(domain, fail=False)
if not data:
if fail: if fail:
raise KeyError(domain) raise KeyError(domain)

View file

@ -1,65 +0,0 @@
from __future__ import annotations
import tinysql
import typing
from importlib.resources import files as pkgfiles
from .config import get_default_value
from .connection import Connection
from .schema import VERSIONS, migrate_0
from .. import logger as logging
if typing.TYPE_CHECKING:
from .config import Config
def get_database(config: Config, migrate: bool = True) -> tinysql.Database:
if config.db_type == "sqlite":
db = tinysql.Database.sqlite(
config.sqlite_path,
connection_class = Connection,
min_connections = 2,
max_connections = 10
)
elif config.db_type == "postgres":
db = tinysql.Database.postgres(
config.pg_name,
config.pg_host,
config.pg_port,
config.pg_user,
config.pg_pass,
connection_class = Connection
)
db.load_prepared_statements(pkgfiles("relay").joinpath("data", "statements.sql"))
if not migrate:
return db
with db.connection() as conn:
if 'config' not in conn.get_tables():
logging.info("Creating database tables")
migrate_0(conn)
return db
if (schema_ver := conn.get_config('schema-version')) < get_default_value('schema-version'):
logging.info("Migrating database from version '%i'", schema_ver)
for ver, func in VERSIONS:
if schema_ver < ver:
conn.begin()
func(conn)
conn.put_config('schema-version', ver)
conn.commit()
if (privkey := conn.get_config('private-key')):
conn.app.signer = privkey
logging.set_level(conn.get_config('log-level'))
return db

View file

@ -1,45 +0,0 @@
from __future__ import annotations
import typing
from .. import logger as logging
from ..misc import boolean
if typing.TYPE_CHECKING:
from collections.abc import Callable
from typing import Any
CONFIG_DEFAULTS: dict[str, tuple[str, Any]] = {
'schema-version': ('int', 20240119),
'log-level': ('loglevel', logging.LogLevel.INFO),
'note': ('str', 'Make a note about your instance here.'),
'private-key': ('str', None),
'whitelist-enabled': ('bool', False)
}
# serializer | deserializer
CONFIG_CONVERT: dict[str, tuple[Callable, Callable]] = {
'str': (str, str),
'int': (str, int),
'bool': (str, boolean),
'loglevel': (lambda x: x.name, logging.LogLevel.parse)
}
def get_default_value(key: str) -> Any:
return CONFIG_DEFAULTS[key][1]
def get_default_type(key: str) -> str:
return CONFIG_DEFAULTS[key][0]
def serialize(key: str, value: Any) -> str:
type_name = get_default_type(key)
return CONFIG_CONVERT[type_name][0](value)
def deserialize(key: str, value: str) -> Any:
type_name = get_default_type(key)
return CONFIG_CONVERT[type_name][1](value)

View file

@ -1,296 +0,0 @@
from __future__ import annotations
import tinysql
import typing
from datetime import datetime, timezone
from urllib.parse import urlparse
from .config import CONFIG_DEFAULTS, get_default_type, get_default_value, serialize, deserialize
from .. import logger as logging
from ..misc import get_app
if typing.TYPE_CHECKING:
from collections.abc import Iterator
from tinysql import Cursor, Row
from typing import Any
from .application import Application
from ..misc import Message
RELAY_SOFTWARE = [
'activityrelay', # https://git.pleroma.social/pleroma/relay
'activity-relay', # https://github.com/yukimochi/Activity-Relay
'aoderelay', # https://git.asonix.dog/asonix/relay
'feditools-relay' # https://git.ptzo.gdn/feditools/relay
]
class Connection(tinysql.Connection):
@property
def app(self) -> Application:
return get_app()
def distill_inboxes(self, message: Message) -> Iterator[str]:
src_domains = {
message.domain,
urlparse(message.object_id).netloc
}
for inbox in self.execute('SELECT * FROM inboxes'):
if inbox['domain'] not in src_domains:
yield inbox['inbox']
def exec_statement(self, name: str, params: dict[str, Any] | None = None) -> Cursor:
return self.execute(self.database.prepared_statements[name], params)
def get_config(self, key: str) -> Any:
if key not in CONFIG_DEFAULTS:
raise KeyError(key)
with self.exec_statement('get-config', {'key': key}) as cur:
if not (row := cur.one()):
return get_default_value(key)
if row['value']:
return deserialize(row['key'], row['value'])
return None
def get_config_all(self) -> dict[str, Any]:
with self.exec_statement('get-config-all') as cur:
db_config = {row['key']: row['value'] for row in cur}
config = {}
for key, data in CONFIG_DEFAULTS.items():
try:
config[key] = deserialize(key, db_config[key])
except KeyError:
if key == 'schema-version':
config[key] = 0
else:
config[key] = data[1]
return config
def put_config(self, key: str, value: Any) -> Any:
if key not in CONFIG_DEFAULTS:
raise KeyError(key)
if key == 'private-key':
self.app.signer = value
elif key == 'log-level':
value = logging.LogLevel.parse(value)
logging.set_level(value)
params = {
'key': key,
'value': serialize(key, value) if value is not None else None,
'type': get_default_type(key)
}
with self.exec_statement('put-config', params):
return value
def get_inbox(self, value: str) -> Row:
with self.exec_statement('get-inbox', {'value': value}) as cur:
return cur.one()
def put_inbox(self,
domain: str,
inbox: str,
actor: str | None = None,
followid: str | None = None,
software: str | None = None) -> Row:
params = {
'domain': domain,
'inbox': inbox,
'actor': actor,
'followid': followid,
'software': software,
'created': datetime.now(tz = timezone.utc)
}
with self.exec_statement('put-inbox', params) as cur:
return cur.one()
def update_inbox(self,
inbox: str,
actor: str | None = None,
followid: str | None = None,
software: str | None = None) -> Row:
if not (actor or followid or software):
raise ValueError('Missing "actor", "followid", and/or "software"')
data = {}
if actor:
data['actor'] = actor
if followid:
data['followid'] = followid
if software:
data['software'] = software
statement = tinysql.Update('inboxes', data, inbox = inbox)
with self.query(statement):
return self.get_inbox(inbox)
def del_inbox(self, value: str) -> bool:
with self.exec_statement('del-inbox', {'value': value}) as cur:
if cur.modified_row_count > 1:
raise ValueError('More than one row was modified')
return cur.modified_row_count == 1
def get_domain_ban(self, domain: str) -> Row:
if domain.startswith('http'):
domain = urlparse(domain).netloc
with self.exec_statement('get-domain-ban', {'domain': domain}) as cur:
return cur.one()
def put_domain_ban(self,
domain: str,
reason: str | None = None,
note: str | None = None) -> Row:
params = {
'domain': domain,
'reason': reason,
'note': note,
'created': datetime.now(tz = timezone.utc)
}
with self.exec_statement('put-domain-ban', params) as cur:
return cur.one()
def update_domain_ban(self,
domain: str,
reason: str | None = None,
note: str | None = None) -> tinysql.Row:
if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified')
params = {}
if reason:
params['reason'] = reason
if note:
params['note'] = note
statement = tinysql.Update('domain_bans', params, domain = domain)
with self.query(statement) as cur:
if cur.modified_row_count > 1:
raise ValueError('More than one row was modified')
return self.get_domain_ban(domain)
def del_domain_ban(self, domain: str) -> bool:
with self.exec_statement('del-domain-ban', {'domain': domain}) as cur:
if cur.modified_row_count > 1:
raise ValueError('More than one row was modified')
return cur.modified_row_count == 1
def get_software_ban(self, name: str) -> Row:
with self.exec_statement('get-software-ban', {'name': name}) as cur:
return cur.one()
def put_software_ban(self,
name: str,
reason: str | None = None,
note: str | None = None) -> Row:
params = {
'name': name,
'reason': reason,
'note': note,
'created': datetime.now(tz = timezone.utc)
}
with self.exec_statement('put-software-ban', params) as cur:
return cur.one()
def update_software_ban(self,
name: str,
reason: str | None = None,
note: str | None = None) -> tinysql.Row:
if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified')
params = {}
if reason:
params['reason'] = reason
if note:
params['note'] = note
statement = tinysql.Update('software_bans', params, name = name)
with self.query(statement) as cur:
if cur.modified_row_count > 1:
raise ValueError('More than one row was modified')
return self.get_software_ban(name)
def del_software_ban(self, name: str) -> bool:
with self.exec_statement('del-software-ban', {'name': name}) as cur:
if cur.modified_row_count > 1:
raise ValueError('More than one row was modified')
return cur.modified_row_count == 1
def get_domain_whitelist(self, domain: str) -> Row:
with self.exec_statement('get-domain-whitelist', {'domain': domain}) as cur:
return cur.one()
def put_domain_whitelist(self, domain: str) -> Row:
params = {
'domain': domain,
'created': datetime.now(tz = timezone.utc)
}
with self.exec_statement('put-domain-whitelist', params) as cur:
return cur.one()
def del_domain_whitelist(self, domain: str) -> bool:
with self.exec_statement('del-domain-whitelist', {'domain': domain}) as cur:
if cur.modified_row_count > 1:
raise ValueError('More than one row was modified')
return cur.modified_row_count == 1

View file

@ -1,60 +0,0 @@
from __future__ import annotations
import typing
from tinysql import Column, Connection, Table
from .config import get_default_value
if typing.TYPE_CHECKING:
from collections.abc import Callable
VERSIONS: list[Callable] = []
TABLES: list[Table] = [
Table(
'config',
Column('key', 'text', primary_key = True, unique = True, nullable = False),
Column('value', 'text'),
Column('type', 'text', default = 'str')
),
Table(
'inboxes',
Column('domain', 'text', primary_key = True, unique = True, nullable = False),
Column('actor', 'text', unique = True),
Column('inbox', 'text', unique = True, nullable = False),
Column('followid', 'text'),
Column('software', 'text'),
Column('created', 'timestamp', nullable = False)
),
Table(
'whitelist',
Column('domain', 'text', primary_key = True, unique = True, nullable = True),
Column('created', 'timestamp')
),
Table(
'domain_bans',
Column('domain', 'text', primary_key = True, unique = True, nullable = True),
Column('reason', 'text'),
Column('note', 'text'),
Column('created', 'timestamp', nullable = False)
),
Table(
'software_bans',
Column('name', 'text', primary_key = True, unique = True, nullable = True),
Column('reason', 'text'),
Column('note', 'text'),
Column('created', 'timestamp', nullable = False)
)
]
def version(func: Callable) -> Callable:
ver = int(func.replace('migrate_', ''))
VERSIONS[ver] = func
return func
def migrate_0(conn: Connection) -> None:
conn.create_tables(TABLES)
conn.put_config('schema-version', get_default_value('schema-version'))

View file

@ -13,10 +13,11 @@ from urllib.parse import urlparse
from . import __version__ from . import __version__
from . import logger as logging from . import logger as logging
from .misc import MIMETYPES, Message, get_app from .misc import MIMETYPES, Message
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from typing import Any from typing import Any, Callable, Optional
from .database import RelayDatabase
HEADERS = { HEADERS = {
@ -26,7 +27,13 @@ HEADERS = {
class HttpClient: class HttpClient:
def __init__(self, limit: int = 100, timeout: int = 10, cache_size: int = 1024): def __init__(self,
database: RelayDatabase,
limit: Optional[int] = 100,
timeout: Optional[int] = 10,
cache_size: Optional[int] = 1024):
self.database = database
self.cache = LRUCache(cache_size) self.cache = LRUCache(cache_size)
self.limit = limit self.limit = limit
self.timeout = timeout self.timeout = timeout
@ -73,9 +80,9 @@ class HttpClient:
async def get(self, # pylint: disable=too-many-branches async def get(self, # pylint: disable=too-many-branches
url: str, url: str,
sign_headers: bool = False, sign_headers: Optional[bool] = False,
loads: callable | None = None, loads: Optional[Callable] = None,
force: bool = False) -> Message | dict | None: force: Optional[bool] = False) -> Message | dict | None:
await self.open() await self.open()
@ -91,7 +98,7 @@ class HttpClient:
headers = {} headers = {}
if sign_headers: if sign_headers:
get_app().signer.sign_headers('GET', url, algorithm = 'original') headers.update(self.database.signer.sign_headers('GET', url, algorithm='original'))
try: try:
logging.debug('Fetching resource: %s', url) logging.debug('Fetching resource: %s', url)
@ -143,27 +150,23 @@ class HttpClient:
async def post(self, url: str, message: Message) -> None: async def post(self, url: str, message: Message) -> None:
await self.open() await self.open()
# todo: cache inboxes to avoid opening a db connection instance = self.database.get_inbox(url)
with get_app().database.connection() as conn:
instance = conn.get_inbox(url)
## Using the old algo by default is probably a better idea right now ## Using the old algo by default is probably a better idea right now
# pylint: disable=consider-ternary-expression if instance and instance.get('software') in {'mastodon'}:
if instance and instance['software'] in {'mastodon'}:
algorithm = 'hs2019' algorithm = 'hs2019'
else: else:
algorithm = 'original' algorithm = 'original'
# pylint: enable=consider-ternary-expression
headers = {'Content-Type': 'application/activity+json'} headers = {'Content-Type': 'application/activity+json'}
headers.update(get_app().signer.sign_headers('POST', url, message, algorithm=algorithm)) headers.update(self.database.signer.sign_headers('POST', url, message, algorithm=algorithm))
try: try:
logging.verbose('Sending "%s" to %s', message.type, url) logging.verbose('Sending "%s" to %s', message.type, url)
async with self._session.post(url, headers=headers, data=message.to_json()) as resp: async with self._session.post(url, headers=headers, data=message.to_json()) as resp:
# Not expecting a response, so just return ## Not expecting a response, so just return
if resp.status in {200, 202}: if resp.status in {200, 202}:
logging.verbose('Successfully sent "%s" to %s', message.type, url) logging.verbose('Successfully sent "%s" to %s', message.type, url)
return return
@ -178,7 +181,7 @@ class HttpClient:
except (AsyncTimeoutError, ClientConnectionError): except (AsyncTimeoutError, ClientConnectionError):
logging.warning('Failed to connect to %s for message push', urlparse(url).netloc) logging.warning('Failed to connect to %s for message push', urlparse(url).netloc)
# prevent workers from being brought down ## prevent workers from being brought down
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
@ -194,7 +197,7 @@ class HttpClient:
logging.verbose('Failed to fetch well-known nodeinfo url for %s', domain) logging.verbose('Failed to fetch well-known nodeinfo url for %s', domain)
return None return None
for version in ('20', '21'): for version in ['20', '21']:
try: try:
nodeinfo_url = wk_nodeinfo.get_url(version) nodeinfo_url = wk_nodeinfo.get_url(version)
@ -208,16 +211,16 @@ class HttpClient:
return await self.get(nodeinfo_url, loads = Nodeinfo.parse) or None return await self.get(nodeinfo_url, loads = Nodeinfo.parse) or None
async def get(*args: Any, **kwargs: Any) -> Message | dict | None: async def get(database: RelayDatabase, *args: Any, **kwargs: Any) -> Message | dict | None:
async with HttpClient() as client: async with HttpClient(database) as client:
return await client.get(*args, **kwargs) return await client.get(*args, **kwargs)
async def post(*args: Any, **kwargs: Any) -> None: async def post(database: RelayDatabase, *args: Any, **kwargs: Any) -> None:
async with HttpClient() as client: async with HttpClient(database) as client:
return await client.post(*args, **kwargs) return await client.post(*args, **kwargs)
async def fetch_nodeinfo(*args: Any, **kwargs: Any) -> Nodeinfo | None: async def fetch_nodeinfo(database: RelayDatabase, *args: Any, **kwargs: Any) -> Nodeinfo | None:
async with HttpClient() as client: async with HttpClient(database) as client:
return await client.fetch_nodeinfo(*args, **kwargs) return await client.fetch_nodeinfo(*args, **kwargs)

View file

@ -4,63 +4,20 @@ import logging
import os import os
import typing import typing
from enum import IntEnum
from pathlib import Path from pathlib import Path
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from collections.abc import Callable from typing import Any, Callable
from typing import Any
class LogLevel(IntEnum): LOG_LEVELS: dict[str, int] = {
DEBUG = logging.DEBUG 'DEBUG': logging.DEBUG,
VERBOSE = 15 'VERBOSE': 15,
INFO = logging.INFO 'INFO': logging.INFO,
WARNING = logging.WARNING 'WARNING': logging.WARNING,
ERROR = logging.ERROR 'ERROR': logging.ERROR,
CRITICAL = logging.CRITICAL 'CRITICAL': logging.CRITICAL
}
def __str__(self) -> str:
return self.name
@classmethod
def parse(cls: type[IntEnum], data: object) -> IntEnum:
if isinstance(data, cls):
return data
if isinstance(data, str):
data = data.upper()
try:
return cls[data]
except KeyError:
pass
try:
return cls(data)
except ValueError:
pass
raise AttributeError(f'Invalid enum property for {cls.__name__}: {data}')
def get_level() -> LogLevel:
return LogLevel.parse(logging.root.level)
def set_level(level: LogLevel | str) -> None:
logging.root.setLevel(LogLevel.parse(level))
def verbose(message: str, *args: Any, **kwargs: Any) -> None:
if not logging.root.isEnabledFor(LogLevel['VERBOSE']):
return
logging.log(LogLevel['VERBOSE'], message, *args, **kwargs)
debug: Callable = logging.debug debug: Callable = logging.debug
@ -70,7 +27,14 @@ error: Callable = logging.error
critical: Callable = logging.critical critical: Callable = logging.critical
logging.addLevelName(LogLevel['VERBOSE'], 'VERBOSE') def verbose(message: str, *args: Any, **kwargs: Any) -> None:
if not logging.root.isEnabledFor(LOG_LEVELS['VERBOSE']):
return
logging.log(LOG_LEVELS['VERBOSE'], message, *args, **kwargs)
logging.addLevelName(LOG_LEVELS['VERBOSE'], 'VERBOSE')
env_log_level = os.environ.get('LOG_LEVEL', 'INFO').upper() env_log_level = os.environ.get('LOG_LEVEL', 'INFO').upper()
try: try:
@ -81,11 +45,11 @@ except KeyError:
try: try:
log_level = LogLevel[env_log_level] log_level = LOG_LEVELS[env_log_level]
except KeyError: except KeyError:
print('Invalid log level:', env_log_level) logging.warning('Invalid log level: %s', env_log_level)
log_level = LogLevel['INFO'] log_level = logging.INFO
handlers = [logging.StreamHandler()] handlers = [logging.StreamHandler()]

View file

@ -6,49 +6,22 @@ import click
import platform import platform
import typing import typing
from aputils.signer import Signer
from pathlib import Path
from shutil import copyfile
from urllib.parse import urlparse from urllib.parse import urlparse
from . import __version__ from . import misc, __version__
from . import http_client as http from . import http_client as http
from . import logger as logging
from .application import Application from .application import Application
from .compat import RelayConfig, RelayDatabase from .config import RELAY_SOFTWARE
from .database import get_database
from .database.connection import RELAY_SOFTWARE
from .misc import IS_DOCKER, Message, check_open_port
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from tinysql import Row
from typing import Any from typing import Any
# pylint: disable=unsubscriptable-object,unsupported-assignment-operation # pylint: disable=unsubscriptable-object,unsupported-assignment-operation
CONFIG_IGNORE = ( app = None
'schema-version', CONFIG_IGNORE = {'blocked_software', 'blocked_instances', 'whitelist'}
'private-key'
)
ACTOR_FORMATS = {
'mastodon': 'https://{domain}/actor',
'akkoma': 'https://{domain}/relay',
'pleroma': 'https://{domain}/relay'
}
SOFTWARE = (
'mastodon',
'akkoma',
'pleroma',
'misskey',
'friendica',
'hubzilla',
'firefish',
'gotosocial'
)
@click.group('cli', context_settings={'show_default': True}, invoke_without_command=True) @click.group('cli', context_settings={'show_default': True}, invoke_without_command=True)
@ -56,10 +29,11 @@ SOFTWARE = (
@click.version_option(version=__version__, prog_name='ActivityRelay') @click.version_option(version=__version__, prog_name='ActivityRelay')
@click.pass_context @click.pass_context
def cli(ctx: click.Context, config: str) -> None: def cli(ctx: click.Context, config: str) -> None:
ctx.obj = Application(config) global app
app = Application(config)
if not ctx.invoked_subcommand: if not ctx.invoked_subcommand:
if ctx.obj.config.domain.endswith('example.com'): if app.config.host.endswith('example.com'):
cli_setup.callback() cli_setup.callback()
else: else:
@ -67,92 +41,46 @@ def cli(ctx: click.Context, config: str) -> None:
@cli.command('setup') @cli.command('setup')
@click.pass_context def cli_setup() -> None:
def cli_setup(ctx: click.Context) -> None: 'Generate a new config'
'Generate a new config and create the database'
while True: while True:
ctx.obj.config.domain = click.prompt( app.config.host = click.prompt(
'What domain will the relay be hosted on?', 'What domain will the relay be hosted on?',
default = ctx.obj.config.domain default = app.config.host
) )
if not ctx.obj.config.domain.endswith('example.com'): if not app.config.host.endswith('example.com'):
break break
click.echo('The domain must not end with "example.com"') click.echo('The domain must not be example.com')
if not IS_DOCKER: if not app.config.is_docker:
ctx.obj.config.listen = click.prompt( app.config.listen = click.prompt(
'Which address should the relay listen on?', 'Which address should the relay listen on?',
default = ctx.obj.config.listen default = app.config.listen
) )
ctx.obj.config.port = click.prompt( while True:
'What TCP port should the relay listen on?', app.config.port = click.prompt(
default = ctx.obj.config.port, 'What TCP port should the relay listen on?',
type = int default = app.config.port,
) type = int
)
ctx.obj.config.db_type = click.prompt( break
'Which database backend will be used?',
default = ctx.obj.config.db_type,
type = click.Choice(['postgres', 'sqlite'], case_sensitive = False)
)
if ctx.obj.config.db_type == 'sqlite': app.config.save()
ctx.obj.config.sq_path = click.prompt(
'Where should the database be stored?',
default = ctx.obj.config.sq_path
)
elif ctx.obj.config.db_type == 'postgres': if not app.config.is_docker and click.confirm('Relay all setup! Would you like to run it now?'):
ctx.obj.config.pg_name = click.prompt(
'What is the name of the database?',
default = ctx.obj.config.pg_name
)
ctx.obj.config.pg_host = click.prompt(
'What IP address or hostname does the server listen on?',
default = ctx.obj.config.pg_host
)
ctx.obj.config.pg_port = click.prompt(
'What port does the server listen on?',
default = ctx.obj.config.pg_port,
type = int
)
ctx.obj.config.pg_user = click.prompt(
'Which user will authenticate with the server?',
default = ctx.obj.config.pg_user
)
ctx.obj.config.pg_pass = click.prompt(
'User password: ',
hide_input = True
) or None
ctx.obj.config.save()
config = {
'private-key': Signer.new('n/a').export()
}
with ctx.obj.database.connection() as conn:
for key, value in config.items():
conn.put_config(key, value)
if not IS_DOCKER and click.confirm('Relay all setup! Would you like to run it now?'):
cli_run.callback() cli_run.callback()
@cli.command('run') @cli.command('run')
@click.pass_context def cli_run() -> None:
def cli_run(ctx: click.Context) -> None:
'Run the relay' 'Run the relay'
if ctx.obj.config.domain.endswith('example.com') or not ctx.obj.signer: if app.config.host.endswith('example.com'):
click.echo( click.echo(
'Relay is not set up. Please edit your relay config or run "activityrelay setup".' 'Relay is not set up. Please edit your relay config or run "activityrelay setup".'
) )
@ -176,144 +104,40 @@ def cli_run(ctx: click.Context) -> None:
click.echo(pip_command) click.echo(pip_command)
return return
if not check_open_port(ctx.obj.config.listen, ctx.obj.config.port): if not misc.check_open_port(app.config.listen, app.config.port):
click.echo(f'Error: A server is already running on port {ctx.obj.config.port}') click.echo(f'Error: A server is already running on port {app.config.port}')
return return
ctx.obj.run() app.run()
@cli.command('convert')
@click.option('--old-config', '-o', help = 'Path to the config file to convert from')
@click.pass_context
def cli_convert(ctx: click.Context, old_config: str) -> None:
'Convert an old config and jsonld database to the new format.'
old_config = Path(old_config).expanduser().resolve() if old_config else ctx.obj.config.path
backup = ctx.obj.config.path.parent.joinpath(f'{ctx.obj.config.path.stem}.backup.yaml')
if str(old_config) == str(ctx.obj.config.path) and not backup.exists():
logging.info('Created backup config @ %s', backup)
copyfile(ctx.obj.config.path, backup)
config = RelayConfig(old_config)
config.load()
database = RelayDatabase(config)
database.load()
ctx.obj.config.set('listen', config['listen'])
ctx.obj.config.set('port', config['port'])
ctx.obj.config.set('workers', config['workers'])
ctx.obj.config.set('sq_path', config['db'].replace('jsonld', 'sqlite3'))
ctx.obj.config.set('domain', config['host'])
ctx.obj.config.save()
with get_database(ctx.obj.config) as db:
with db.connection() as conn:
conn.put_config('private-key', database['private-key'])
conn.put_config('note', config['note'])
conn.put_config('whitelist-enabled', config['whitelist_enabled'])
with click.progressbar(
database['relay-list'].values(),
label = 'Inboxes'.ljust(15),
width = 0
) as inboxes:
for inbox in inboxes:
if inbox['software'] in {'akkoma', 'pleroma'}:
actor = f'https://{inbox["domain"]}/relay'
elif inbox['software'] == 'mastodon':
actor = f'https://{inbox["domain"]}/actor'
else:
actor = None
conn.put_inbox(
inbox['domain'],
inbox['inbox'],
actor = actor,
followid = inbox['followid'],
software = inbox['software']
)
with click.progressbar(
config['blocked_software'],
label = 'Banned software'.ljust(15),
width = 0
) as banned_software:
for software in banned_software:
conn.put_software_ban(
software,
reason = 'relay' if software in RELAY_SOFTWARE else None
)
with click.progressbar(
config['blocked_instances'],
label = 'Banned domains'.ljust(15),
width = 0
) as banned_software:
for domain in banned_software:
conn.put_domain_ban(domain)
with click.progressbar(
config['whitelist'],
label = 'Whitelist'.ljust(15),
width = 0
) as whitelist:
for instance in whitelist:
conn.put_domain_whitelist(instance)
click.echo('Finished converting old config and database :3')
@cli.command('edit-config')
@click.option('--editor', '-e', help = 'Text editor to use')
@click.pass_context
def cli_editconfig(ctx: click.Context, editor: str) -> None:
'Edit the config file'
click.edit(
editor = editor,
filename = str(ctx.obj.config.path)
)
@cli.group('config') @cli.group('config')
def cli_config() -> None: def cli_config() -> None:
'Manage the relay settings stored in the database' 'Manage the relay config'
@cli_config.command('list') @cli_config.command('list')
@click.pass_context def cli_config_list() -> None:
def cli_config_list(ctx: click.Context) -> None:
'List the current relay config' 'List the current relay config'
click.echo('Relay Config:') click.echo('Relay Config:')
with ctx.obj.database.connection() as conn: for key, value in app.config.items():
for key, value in conn.get_config_all().items(): if key not in CONFIG_IGNORE:
if key not in CONFIG_IGNORE: key = f'{key}:'.ljust(20)
key = f'{key}:'.ljust(20) click.echo(f'- {key} {value}')
click.echo(f'- {key} {value}')
@cli_config.command('set') @cli_config.command('set')
@click.argument('key') @click.argument('key')
@click.argument('value') @click.argument('value')
@click.pass_context def cli_config_set(key: str, value: Any) -> None:
def cli_config_set(ctx: click.Context, key: str, value: Any) -> None:
'Set a config value' 'Set a config value'
with ctx.obj.database.connection() as conn: app.config[key] = value
new_value = conn.put_config(key, value) app.config.save()
print(f'{key}: {repr(new_value)}') print(f'{key}: {app.config[key]}')
@cli.group('inbox') @cli.group('inbox')
@ -322,150 +146,127 @@ def cli_inbox() -> None:
@cli_inbox.command('list') @cli_inbox.command('list')
@click.pass_context def cli_inbox_list() -> None:
def cli_inbox_list(ctx: click.Context) -> None:
'List the connected instances or relays' 'List the connected instances or relays'
click.echo('Connected to the following instances or relays:') click.echo('Connected to the following instances or relays:')
with ctx.obj.database.connection() as conn: for inbox in app.database.inboxes:
for inbox in conn.execute('SELECT * FROM inboxes'): click.echo(f'- {inbox}')
click.echo(f'- {inbox["inbox"]}')
@cli_inbox.command('follow') @cli_inbox.command('follow')
@click.argument('actor') @click.argument('actor')
@click.pass_context def cli_inbox_follow(actor: str) -> None:
def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
'Follow an actor (Relay must be running)' 'Follow an actor (Relay must be running)'
with ctx.obj.database.connection() as conn: if app.config.is_banned(actor):
if conn.get_domain_ban(actor): click.echo(f'Error: Refusing to follow banned actor: {actor}')
click.echo(f'Error: Refusing to follow banned actor: {actor}') return
if not actor.startswith('http'):
domain = actor
actor = f'https://{actor}/actor'
else:
domain = urlparse(actor).hostname
try:
inbox_data = app.database['relay-list'][domain]
inbox = inbox_data['inbox']
except KeyError:
actor_data = asyncio.run(http.get(app.database, actor, sign_headers=True))
if not actor_data:
click.echo(f'Failed to fetch actor: {actor}')
return return
if (inbox_data := conn.get_inbox(actor)): inbox = actor_data.shared_inbox
inbox = inbox_data['inbox']
else: message = misc.Message.new_follow(
if not actor.startswith('http'): host = app.config.host,
actor = f'https://{actor}/actor'
if not (actor_data := asyncio.run(http.get(actor, sign_headers = True))):
click.echo(f'Failed to fetch actor: {actor}')
return
inbox = actor_data.shared_inbox
message = Message.new_follow(
host = ctx.obj.config.domain,
actor = actor actor = actor
) )
asyncio.run(http.post(inbox, message)) asyncio.run(http.post(app.database, inbox, message))
click.echo(f'Sent follow message to actor: {actor}') click.echo(f'Sent follow message to actor: {actor}')
@cli_inbox.command('unfollow') @cli_inbox.command('unfollow')
@click.argument('actor') @click.argument('actor')
@click.pass_context def cli_inbox_unfollow(actor: str) -> None:
def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
'Unfollow an actor (Relay must be running)' 'Unfollow an actor (Relay must be running)'
inbox_data: Row = None if not actor.startswith('http'):
domain = actor
actor = f'https://{actor}/actor'
with ctx.obj.database.connection() as conn: else:
if conn.get_domain_ban(actor): domain = urlparse(actor).hostname
click.echo(f'Error: Refusing to follow banned actor: {actor}')
return
if (inbox_data := conn.get_inbox(actor)): try:
inbox = inbox_data['inbox'] inbox_data = app.database['relay-list'][domain]
message = Message.new_unfollow( inbox = inbox_data['inbox']
host = ctx.obj.config.domain, message = misc.Message.new_unfollow(
actor = actor, host = app.config.host,
follow = inbox_data['followid'] actor = actor,
) follow = inbox_data['followid']
)
else: except KeyError:
if not actor.startswith('http'): actor_data = asyncio.run(http.get(app.database, actor, sign_headers=True))
actor = f'https://{actor}/actor' inbox = actor_data.shared_inbox
message = misc.Message.new_unfollow(
host = app.config.host,
actor = actor,
follow = {
'type': 'Follow',
'object': actor,
'actor': f'https://{app.config.host}/actor'
}
)
actor_data = asyncio.run(http.get(actor, sign_headers = True)) asyncio.run(http.post(app.database, inbox, message))
inbox = actor_data.shared_inbox
message = Message.new_unfollow(
host = ctx.obj.config.domain,
actor = actor,
follow = {
'type': 'Follow',
'object': actor,
'actor': f'https://{ctx.obj.config.domain}/actor'
}
)
asyncio.run(http.post(inbox, message))
click.echo(f'Sent unfollow message to: {actor}') click.echo(f'Sent unfollow message to: {actor}')
@cli_inbox.command('add') @cli_inbox.command('add')
@click.argument('inbox') @click.argument('inbox')
@click.option('--actor', '-a', help = 'Actor url for the inbox') def cli_inbox_add(inbox: str) -> None:
@click.option('--followid', '-f', help = 'Url for the follow activity')
@click.option('--software', '-s',
type = click.Choice(SOFTWARE),
help = 'Nodeinfo software name of the instance'
) # noqa: E124
@click.pass_context
def cli_inbox_add(
ctx: click.Context,
inbox: str,
actor: str | None = None,
followid: str | None = None,
software: str | None = None) -> None:
'Add an inbox to the database' 'Add an inbox to the database'
if not inbox.startswith('http'): if not inbox.startswith('http'):
domain = inbox
inbox = f'https://{inbox}/inbox' inbox = f'https://{inbox}/inbox'
else: if app.config.is_banned(inbox):
domain = urlparse(inbox).netloc click.echo(f'Error: Refusing to add banned inbox: {inbox}')
return
if not software: if app.database.get_inbox(inbox):
if (nodeinfo := asyncio.run(http.fetch_nodeinfo(domain))): click.echo(f'Error: Inbox already in database: {inbox}')
software = nodeinfo.sw_name return
if not actor and software: app.database.add_inbox(inbox)
try: app.database.save()
actor = ACTOR_FORMATS[software].format(domain = domain)
except KeyError:
pass
with ctx.obj.database.connection() as conn:
if conn.get_domain_ban(domain):
click.echo(f'Refusing to add banned inbox: {inbox}')
return
if conn.get_inbox(inbox):
click.echo(f'Error: Inbox already in database: {inbox}')
return
conn.put_inbox(domain, inbox, actor, followid, software)
click.echo(f'Added inbox to the database: {inbox}') click.echo(f'Added inbox to the database: {inbox}')
@cli_inbox.command('remove') @cli_inbox.command('remove')
@click.argument('inbox') @click.argument('inbox')
@click.pass_context def cli_inbox_remove(inbox: str) -> None:
def cli_inbox_remove(ctx: click.Context, inbox: str) -> None:
'Remove an inbox from the database' 'Remove an inbox from the database'
with ctx.obj.database.connection() as conn: try:
if not conn.del_inbox(inbox): dbinbox = app.database.get_inbox(inbox, fail=True)
click.echo(f'Inbox not in database: {inbox}')
return except KeyError:
click.echo(f'Error: Inbox does not exist: {inbox}')
return
app.database.del_inbox(dbinbox['domain'])
app.database.save()
click.echo(f'Removed inbox from the database: {inbox}') click.echo(f'Removed inbox from the database: {inbox}')
@ -476,76 +277,47 @@ def cli_instance() -> None:
@cli_instance.command('list') @cli_instance.command('list')
@click.pass_context def cli_instance_list() -> None:
def cli_instance_list(ctx: click.Context) -> None:
'List all banned instances' 'List all banned instances'
click.echo('Banned domains:') click.echo('Banned instances or relays:')
with ctx.obj.database.connection() as conn: for domain in app.config.blocked_instances:
for instance in conn.execute('SELECT * FROM domain_bans'): click.echo(f'- {domain}')
if instance['reason']:
click.echo(f'- {instance["domain"]} ({instance["reason"]})')
else:
click.echo(f'- {instance["domain"]}')
@cli_instance.command('ban') @cli_instance.command('ban')
@click.argument('domain') @click.argument('target')
@click.option('--reason', '-r', help = 'Public note about why the domain is banned') def cli_instance_ban(target: str) -> None:
@click.option('--note', '-n', help = 'Internal note that will only be seen by admins and mods')
@click.pass_context
def cli_instance_ban(ctx: click.Context, domain: str, reason: str, note: str) -> None:
'Ban an instance and remove the associated inbox if it exists' 'Ban an instance and remove the associated inbox if it exists'
with ctx.obj.database.connection() as conn: if target.startswith('http'):
if conn.get_domain_ban(domain): target = urlparse(target).hostname
click.echo(f'Domain already banned: {domain}')
return
conn.put_domain_ban(domain, reason, note) if app.config.ban_instance(target):
conn.del_inbox(domain) app.config.save()
click.echo(f'Banned instance: {domain}')
if app.database.del_inbox(target):
app.database.save()
click.echo(f'Banned instance: {target}')
return
click.echo(f'Instance already banned: {target}')
@cli_instance.command('unban') @cli_instance.command('unban')
@click.argument('domain') @click.argument('target')
@click.pass_context def cli_instance_unban(target: str) -> None:
def cli_instance_unban(ctx: click.Context, domain: str) -> None:
'Unban an instance' 'Unban an instance'
with ctx.obj.database.connection() as conn: if app.config.unban_instance(target):
if not conn.del_domain_ban(domain): app.config.save()
click.echo(f'Instance wasn\'t banned: {domain}')
return
click.echo(f'Unbanned instance: {domain}') click.echo(f'Unbanned instance: {target}')
return
click.echo(f'Instance wasn\'t banned: {target}')
@cli_instance.command('update')
@click.argument('domain')
@click.option('--reason', '-r')
@click.option('--note', '-n')
@click.pass_context
def cli_instance_update(ctx: click.Context, domain: str, reason: str, note: str) -> None:
'Update the public reason or internal note for a domain ban'
if not (reason or note):
ctx.fail('Must pass --reason or --note')
with ctx.obj.database.connection() as conn:
if not (row := conn.update_domain_ban(domain, reason, note)):
click.echo(f'Failed to update domain ban: {domain}')
return
click.echo(f'Updated domain ban: {domain}')
if row['reason']:
click.echo(f'- {row["domain"]} ({row["reason"]})')
else:
click.echo(f'- {row["domain"]}')
@cli.group('software') @cli.group('software')
@ -554,127 +326,79 @@ def cli_software() -> None:
@cli_software.command('list') @cli_software.command('list')
@click.pass_context def cli_software_list() -> None:
def cli_software_list(ctx: click.Context) -> None:
'List all banned software' 'List all banned software'
click.echo('Banned software:') click.echo('Banned software:')
with ctx.obj.database.connection() as conn: for software in app.config.blocked_software:
for software in conn.execute('SELECT * FROM software_bans'): click.echo(f'- {software}')
if software['reason']:
click.echo(f'- {software["name"]} ({software["reason"]})')
else:
click.echo(f'- {software["name"]}')
@cli_software.command('ban') @cli_software.command('ban')
@click.argument('name')
@click.option('--reason', '-r')
@click.option('--note', '-n')
@click.option( @click.option(
'--fetch-nodeinfo', '-f', '--fetch-nodeinfo/--ignore-nodeinfo', '-f', 'fetch_nodeinfo', default = False,
is_flag = True, help = 'Treat NAME like a domain and try to fet the software name from nodeinfo'
help = 'Treat NAME like a domain and try to fetch the software name from nodeinfo'
) )
@click.pass_context @click.argument('name')
def cli_software_ban(ctx: click.Context, def cli_software_ban(name: str, fetch_nodeinfo: bool) -> None:
name: str,
reason: str,
note: str,
fetch_nodeinfo: bool) -> None:
'Ban software. Use RELAYS for NAME to ban relays' 'Ban software. Use RELAYS for NAME to ban relays'
with ctx.obj.database.connection() as conn: if name == 'RELAYS':
if name == 'RELAYS': for software in RELAY_SOFTWARE:
for software in RELAY_SOFTWARE: app.config.ban_software(software)
if conn.get_software_ban(software):
click.echo(f'Relay already banned: {software}')
continue
conn.put_software_ban(software, reason or 'relay', note) app.config.save()
click.echo('Banned all relay software')
return
click.echo('Banned all relay software') if fetch_nodeinfo:
return nodeinfo = asyncio.run(http.fetch_nodeinfo(app.database, name))
if fetch_nodeinfo: if not nodeinfo:
if not (nodeinfo := asyncio.run(http.fetch_nodeinfo(name))): click.echo(f'Failed to fetch software name from domain: {name}')
click.echo(f'Failed to fetch software name from domain: {name}')
return
name = nodeinfo.sw_name name = nodeinfo.sw_name
if conn.get_software_ban(name):
click.echo(f'Software already banned: {name}')
return
if not conn.put_software_ban(name, reason, note):
click.echo(f'Failed to ban software: {name}')
return
if app.config.ban_software(name):
app.config.save()
click.echo(f'Banned software: {name}') click.echo(f'Banned software: {name}')
return
click.echo(f'Software already banned: {name}')
@cli_software.command('unban') @cli_software.command('unban')
@click.argument('name')
@click.option('--reason', '-r')
@click.option('--note', '-n')
@click.option( @click.option(
'--fetch-nodeinfo', '-f', '--fetch-nodeinfo/--ignore-nodeinfo', '-f', 'fetch_nodeinfo', default = False,
is_flag = True, help = 'Treat NAME like a domain and try to fet the software name from nodeinfo'
help = 'Treat NAME like a domain and try to fetch the software name from nodeinfo'
) )
@click.pass_context @click.argument('name')
def cli_software_unban(ctx: click.Context, name: str, fetch_nodeinfo: bool) -> None: def cli_software_unban(name: str, fetch_nodeinfo: bool) -> None:
'Ban software. Use RELAYS for NAME to unban relays' 'Ban software. Use RELAYS for NAME to unban relays'
with ctx.obj.database.connection() as conn: if name == 'RELAYS':
if name == 'RELAYS': for software in RELAY_SOFTWARE:
for software in RELAY_SOFTWARE: app.config.unban_software(software)
if not conn.del_software_ban(software):
click.echo(f'Relay was not banned: {software}')
click.echo('Unbanned all relay software') app.config.save()
return click.echo('Unbanned all relay software')
return
if fetch_nodeinfo: if fetch_nodeinfo:
if not (nodeinfo := asyncio.run(http.fetch_nodeinfo(name))): nodeinfo = asyncio.run(http.fetch_nodeinfo(app.database, name))
click.echo(f'Failed to fetch software name from domain: {name}')
return
name = nodeinfo.sw_name if not nodeinfo:
click.echo(f'Failed to fetch software name from domain: {name}')
if not conn.del_software_ban(name): name = nodeinfo.sw_name
click.echo(f'Software was not banned: {name}')
return
if app.config.unban_software(name):
app.config.save()
click.echo(f'Unbanned software: {name}') click.echo(f'Unbanned software: {name}')
return
click.echo(f'Software wasn\'t banned: {name}')
@cli_software.command('update')
@click.argument('name')
@click.option('--reason', '-r')
@click.option('--note', '-n')
@click.pass_context
def cli_software_update(ctx: click.Context, name: str, reason: str, note: str) -> None:
'Update the public reason or internal note for a software ban'
if not (reason or note):
ctx.fail('Must pass --reason or --note')
with ctx.obj.database.connection() as conn:
if not (row := conn.update_software_ban(name, reason, note)):
click.echo(f'Failed to update software ban: {name}')
return
click.echo(f'Updated software ban: {name}')
if row['reason']:
click.echo(f'- {row["name"]} ({row["reason"]})')
else:
click.echo(f'- {row["name"]}')
@cli.group('whitelist') @cli.group('whitelist')
@ -683,64 +407,52 @@ def cli_whitelist() -> None:
@cli_whitelist.command('list') @cli_whitelist.command('list')
@click.pass_context def cli_whitelist_list() -> None:
def cli_whitelist_list(ctx: click.Context) -> None:
'List all the instances in the whitelist' 'List all the instances in the whitelist'
click.echo('Current whitelisted domains:') click.echo('Current whitelisted domains')
with ctx.obj.database.connection() as conn: for domain in app.config.whitelist:
for domain in conn.execute('SELECT * FROM whitelist'): click.echo(f'- {domain}')
click.echo(f'- {domain["domain"]}')
@cli_whitelist.command('add') @cli_whitelist.command('add')
@click.argument('domain') @click.argument('instance')
@click.pass_context def cli_whitelist_add(instance: str) -> None:
def cli_whitelist_add(ctx: click.Context, domain: str) -> None: 'Add an instance to the whitelist'
'Add a domain to the whitelist'
with ctx.obj.database.connection() as conn: if not app.config.add_whitelist(instance):
if conn.get_domain_whitelist(domain): click.echo(f'Instance already in the whitelist: {instance}')
click.echo(f'Instance already in the whitelist: {domain}') return
return
conn.put_domain_whitelist(domain) app.config.save()
click.echo(f'Instance added to the whitelist: {domain}') click.echo(f'Instance added to the whitelist: {instance}')
@cli_whitelist.command('remove') @cli_whitelist.command('remove')
@click.argument('domain') @click.argument('instance')
@click.pass_context def cli_whitelist_remove(instance: str) -> None:
def cli_whitelist_remove(ctx: click.Context, domain: str) -> None:
'Remove an instance from the whitelist' 'Remove an instance from the whitelist'
with ctx.obj.database.connection() as conn: if not app.config.del_whitelist(instance):
if not conn.del_domain_whitelist(domain): click.echo(f'Instance not in the whitelist: {instance}')
click.echo(f'Domain not in the whitelist: {domain}') return
return
if conn.get_config('whitelist-enabled'): app.config.save()
if conn.del_inbox(domain):
click.echo(f'Removed inbox for domain: {domain}')
click.echo(f'Removed domain from the whitelist: {domain}') if app.config.whitelist_enabled:
if app.database.del_inbox(instance):
app.database.save()
click.echo(f'Removed instance from the whitelist: {instance}')
@cli_whitelist.command('import') @cli_whitelist.command('import')
@click.pass_context def cli_whitelist_import() -> None:
def cli_whitelist_import(ctx: click.Context) -> None:
'Add all current inboxes to the whitelist' 'Add all current inboxes to the whitelist'
with ctx.obj.database.connection() as conn: for domain in app.database.hostnames:
for inbox in conn.execute('SELECT * FROM inboxes').all(): cli_whitelist_add.callback(domain)
if conn.get_domain_whitelist(inbox['domain']):
click.echo(f'Domain already in whitelist: {inbox["domain"]}')
continue
conn.put_domain_whitelist(inbox['domain'])
click.echo('Imported whitelist from inboxes')
def main() -> None: def main() -> None:

View file

@ -1,28 +1,32 @@
from __future__ import annotations from __future__ import annotations
import json import json
import os
import socket import socket
import traceback
import typing import typing
from aiohttp.abc import AbstractView from aiohttp.abc import AbstractView
from aiohttp.hdrs import METH_ALL as METHODS from aiohttp.hdrs import METH_ALL as METHODS
from aiohttp.web import Response as AiohttpResponse from aiohttp.web import Request as AiohttpRequest, Response as AiohttpResponse
from aiohttp.web_exceptions import HTTPMethodNotAllowed from aiohttp.web_exceptions import HTTPMethodNotAllowed
from aputils.errors import SignatureFailureError
from aputils.misc import Digest, HttpDate, Signature
from aputils.message import Message as ApMessage from aputils.message import Message as ApMessage
from functools import cached_property from functools import cached_property
from json.decoder import JSONDecodeError
from uuid import uuid4 from uuid import uuid4
from . import logger as logging
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from collections.abc import Coroutine, Generator from typing import Any, Coroutine, Generator, Optional, Type
from typing import Any from aputils.signer import Signer
from .application import Application from .application import Application
from .config import Config from .config import RelayConfig
from .database import Database from .database import RelayDatabase
from .http_client import HttpClient from .http_client import HttpClient
IS_DOCKER = bool(os.environ.get('DOCKER_RUNNING'))
MIMETYPES = { MIMETYPES = {
'activity': 'application/activity+json', 'activity': 'application/activity+json',
'html': 'text/html', 'html': 'text/html',
@ -38,10 +42,10 @@ NODEINFO_NS = {
def boolean(value: Any) -> bool: def boolean(value: Any) -> bool:
if isinstance(value, str): if isinstance(value, str):
if value.lower() in {'on', 'y', 'yes', 'true', 'enable', 'enabled', '1'}: if value.lower() in ['on', 'y', 'yes', 'true', 'enable', 'enabled', '1']:
return True return True
if value.lower() in {'off', 'n', 'no', 'false', 'disable', 'disabled', '0'}: if value.lower() in ['off', 'n', 'no', 'false', 'disable', 'disable', '0']:
return False return False
raise TypeError(f'Cannot parse string "{value}" as a boolean') raise TypeError(f'Cannot parse string "{value}" as a boolean')
@ -73,21 +77,99 @@ def check_open_port(host: str, port: int) -> bool:
return False return False
def get_app() -> Application: class DotDict(dict):
from .application import Application # pylint: disable=import-outside-toplevel def __init__(self, _data: dict[str, Any], **kwargs: Any):
dict.__init__(self)
if not Application.DEFAULT: self.update(_data, **kwargs)
raise ValueError('No default application set')
return Application.DEFAULT
def __getattr__(self, key: str) -> str:
try:
return self[key]
except KeyError:
raise AttributeError(
f'{self.__class__.__name__} object has no attribute {key}'
) from None
def __setattr__(self, key: str, value: Any) -> None:
if key.startswith('_'):
super().__setattr__(key, value)
else:
self[key] = value
def __setitem__(self, key: str, value: Any) -> None:
if type(value) is dict: # pylint: disable=unidiomatic-typecheck
value = DotDict(value)
super().__setitem__(key, value)
def __delattr__(self, key: str) -> None:
try:
dict.__delitem__(self, key)
except KeyError:
raise AttributeError(
f'{self.__class__.__name__} object has no attribute {key}'
) from None
@classmethod
def new_from_json(cls: Type[DotDict], data: dict[str, Any]) -> DotDict[str, Any]:
if not data:
raise JSONDecodeError('Empty body', data, 1)
try:
return cls(json.loads(data))
except ValueError:
raise JSONDecodeError('Invalid body', data, 1) from None
@classmethod
def new_from_signature(cls: Type[DotDict], sig: str) -> DotDict[str, Any]:
data = cls({})
for chunk in sig.strip().split(','):
key, value = chunk.split('=', 1)
value = value.strip('\"')
if key == 'headers':
value = value.split()
data[key.lower()] = value
return data
def to_json(self, indent: Optional[int | str] = None) -> str:
return json.dumps(self, indent=indent)
def update(self, _data: dict[str, Any], **kwargs: Any) -> None:
if isinstance(_data, dict):
for key, value in _data.items():
self[key] = value
elif isinstance(_data, (list, tuple, set)):
for key, value in _data:
self[key] = value
for key, value in kwargs.items():
self[key] = value
class Message(ApMessage): class Message(ApMessage):
@classmethod @classmethod
def new_actor(cls: type[Message], # pylint: disable=arguments-differ def new_actor(cls: Type[Message], # pylint: disable=arguments-differ
host: str, host: str,
pubkey: str, pubkey: str,
description: str | None = None) -> Message: description: Optional[str] = None) -> Message:
return cls({ return cls({
'@context': 'https://www.w3.org/ns/activitystreams', '@context': 'https://www.w3.org/ns/activitystreams',
@ -99,7 +181,7 @@ class Message(ApMessage):
'followers': f'https://{host}/followers', 'followers': f'https://{host}/followers',
'following': f'https://{host}/following', 'following': f'https://{host}/following',
'inbox': f'https://{host}/inbox', 'inbox': f'https://{host}/inbox',
'url': f'https://{host}/', 'url': f'https://{host}/inbox',
'endpoints': { 'endpoints': {
'sharedInbox': f'https://{host}/inbox' 'sharedInbox': f'https://{host}/inbox'
}, },
@ -112,7 +194,7 @@ class Message(ApMessage):
@classmethod @classmethod
def new_announce(cls: type[Message], host: str, obj: str) -> Message: def new_announce(cls: Type[Message], host: str, obj: str) -> Message:
return cls({ return cls({
'@context': 'https://www.w3.org/ns/activitystreams', '@context': 'https://www.w3.org/ns/activitystreams',
'id': f'https://{host}/activities/{uuid4()}', 'id': f'https://{host}/activities/{uuid4()}',
@ -124,7 +206,7 @@ class Message(ApMessage):
@classmethod @classmethod
def new_follow(cls: type[Message], host: str, actor: str) -> Message: def new_follow(cls: Type[Message], host: str, actor: str) -> Message:
return cls({ return cls({
'@context': 'https://www.w3.org/ns/activitystreams', '@context': 'https://www.w3.org/ns/activitystreams',
'type': 'Follow', 'type': 'Follow',
@ -136,7 +218,7 @@ class Message(ApMessage):
@classmethod @classmethod
def new_unfollow(cls: type[Message], host: str, actor: str, follow: str) -> Message: def new_unfollow(cls: Type[Message], host: str, actor: str, follow: str) -> Message:
return cls({ return cls({
'@context': 'https://www.w3.org/ns/activitystreams', '@context': 'https://www.w3.org/ns/activitystreams',
'id': f'https://{host}/activities/{uuid4()}', 'id': f'https://{host}/activities/{uuid4()}',
@ -148,7 +230,7 @@ class Message(ApMessage):
@classmethod @classmethod
def new_response(cls: type[Message], def new_response(cls: Type[Message],
host: str, host: str,
actor: str, actor: str,
followid: str, followid: str,
@ -180,17 +262,12 @@ class Message(ApMessage):
class Response(AiohttpResponse): class Response(AiohttpResponse):
# AiohttpResponse.__len__ method returns 0, so bool(response) always returns False
def __bool__(self) -> bool:
return True
@classmethod @classmethod
def new(cls: type[Response], def new(cls: Type[Response],
body: str | bytes | dict = '', body: Optional[str | bytes | dict] = '',
status: int = 200, status: Optional[int] = 200,
headers: dict[str, str] | None = None, headers: Optional[dict[str, str]] = None,
ctype: str = 'text') -> Response: ctype: Optional[str] = 'text') -> Response:
kwargs = { kwargs = {
'status': status, 'status': status,
@ -211,7 +288,7 @@ class Response(AiohttpResponse):
@classmethod @classmethod
def new_error(cls: type[Response], def new_error(cls: Type[Response],
status: int, status: int,
body: str | bytes | dict, body: str | bytes | dict,
ctype: str = 'text') -> Response: ctype: str = 'text') -> Response:
@ -233,11 +310,23 @@ class Response(AiohttpResponse):
class View(AbstractView): class View(AbstractView):
def __await__(self) -> Generator[Response]: def __init__(self, request: AiohttpRequest):
if (self.request.method) not in METHODS: AbstractView.__init__(self, request)
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
if not (handler := self.handlers.get(self.request.method)): self.signature: Signature = None
self.message: Message = None
self.actor: Message = None
self.instance: dict[str, str] = None
self.signer: Signer = None
def __await__(self) -> Generator[Response]:
method = self.request.method.upper()
if method not in METHODS:
raise HTTPMethodNotAllowed(method, self.allowed_methods)
if not (handler := self.handlers.get(method)):
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) from None raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) from None
return handler(self.request, **self.request.match_info).__await__() return handler(self.request, **self.request.match_info).__await__()
@ -274,10 +363,94 @@ class View(AbstractView):
@property @property
def config(self) -> Config: def config(self) -> RelayConfig:
return self.app.config return self.app.config
@property @property
def database(self) -> Database: def database(self) -> RelayDatabase:
return self.app.database return self.app.database
# todo: move to views.ActorView
async def get_post_data(self) -> Response | None:
try:
self.signature = Signature.new_from_signature(self.request.headers['signature'])
except KeyError:
logging.verbose('Missing signature header')
return Response.new_error(400, 'missing signature header', 'json')
try:
self.message = await self.request.json(loads = Message.parse)
except Exception:
traceback.print_exc()
logging.verbose('Failed to parse inbox message')
return Response.new_error(400, 'failed to parse message', 'json')
if self.message is None:
logging.verbose('empty message')
return Response.new_error(400, 'missing message', 'json')
if 'actor' not in self.message:
logging.verbose('actor not in message')
return Response.new_error(400, 'no actor in message', 'json')
self.actor = await self.client.get(self.signature.keyid, sign_headers = True)
if self.actor is None:
# ld signatures aren't handled atm, so just ignore it
if self.message.type == 'Delete':
logging.verbose('Instance sent a delete which cannot be handled')
return Response.new(status=202)
logging.verbose(f'Failed to fetch actor: {self.signature.keyid}')
return Response.new_error(400, 'failed to fetch actor', 'json')
try:
self.signer = self.actor.signer
except KeyError:
logging.verbose('Actor missing public key: %s', self.signature.keyid)
return Response.new_error(400, 'actor missing public key', 'json')
try:
self.validate_signature(await self.request.read())
except SignatureFailureError as e:
logging.verbose('signature validation failed for "%s": %s', self.actor.id, e)
return Response.new_error(401, str(e), 'json')
self.instance = self.database.get_inbox(self.actor.inbox)
def validate_signature(self, body: bytes) -> None:
headers = {key.lower(): value for key, value in self.request.headers.items()}
headers["(request-target)"] = " ".join([self.request.method.lower(), self.request.path])
if (digest := Digest.new_from_digest(headers.get("digest"))):
if not body:
raise SignatureFailureError("Missing body for digest verification")
if not digest.validate(body):
raise SignatureFailureError("Body digest does not match")
if self.signature.algorithm_type == "hs2019":
if "(created)" not in self.signature.headers:
raise SignatureFailureError("'(created)' header not used")
current_timestamp = HttpDate.new_utc().timestamp()
if self.signature.created > current_timestamp:
raise SignatureFailureError("Creation date after current date")
if current_timestamp > self.signature.expires:
raise SignatureFailureError("Expiration date before current date")
headers["(created)"] = self.signature.created
headers["(expires)"] = self.signature.expires
# pylint: disable=protected-access
if not self.signer._validate_signature(headers, self.signature):
raise SignatureFailureError("Signature does not match")

View file

@ -1,6 +1,5 @@
from __future__ import annotations from __future__ import annotations
import tinysql
import typing import typing
from cachetools import LRUCache from cachetools import LRUCache
@ -9,7 +8,7 @@ from . import logger as logging
from .misc import Message from .misc import Message
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from .views import ActorView from .misc import View
cache = LRUCache(1024) cache = LRUCache(1024)
@ -17,141 +16,128 @@ cache = LRUCache(1024)
def person_check(actor: str, software: str) -> bool: def person_check(actor: str, software: str) -> bool:
# pleroma and akkoma may use Person for the actor type for some reason # pleroma and akkoma may use Person for the actor type for some reason
# akkoma changed this in 3.6.0
if software in {'akkoma', 'pleroma'} and actor.id == f'https://{actor.domain}/relay': if software in {'akkoma', 'pleroma'} and actor.id == f'https://{actor.domain}/relay':
return False return False
# make sure the actor is an application ## make sure the actor is an application
if actor.type != 'Application': if actor.type != 'Application':
return True return True
return False return False
async def handle_relay(view: ActorView) -> None: async def handle_relay(view: View) -> None:
if view.message.object_id in cache: if view.message.object_id in cache:
logging.verbose('already relayed %s', view.message.object_id) logging.verbose('already relayed %s', view.message.object_id)
return return
message = Message.new_announce(view.config.domain, view.message.object_id) message = Message.new_announce(view.config.host, view.message.object_id)
cache[view.message.object_id] = message.id cache[view.message.object_id] = message.id
logging.debug('>> relay: %s', message) logging.debug('>> relay: %s', message)
with view.database.connection() as conn: inboxes = view.database.distill_inboxes(view.message)
for inbox in conn.distill_inboxes(view.message):
view.app.push_message(inbox, message) for inbox in inboxes:
view.app.push_message(inbox, message)
async def handle_forward(view: ActorView) -> None: async def handle_forward(view: View) -> None:
if view.message.id in cache: if view.message.id in cache:
logging.verbose('already forwarded %s', view.message.id) logging.verbose('already forwarded %s', view.message.id)
return return
message = Message.new_announce(view.config.domain, view.message) message = Message.new_announce(view.config.host, view.message)
cache[view.message.id] = message.id cache[view.message.id] = message.id
logging.debug('>> forward: %s', message) logging.debug('>> forward: %s', message)
with view.database.connection() as conn: inboxes = view.database.distill_inboxes(view.message)
for inbox in conn.distill_inboxes(view.message):
view.app.push_message(inbox, message) for inbox in inboxes:
view.app.push_message(inbox, message)
async def handle_follow(view: ActorView) -> None: async def handle_follow(view: View) -> None:
nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain) nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain)
software = nodeinfo.sw_name if nodeinfo else None software = nodeinfo.sw_name if nodeinfo else None
with view.database.connection() as conn: ## reject if software used by actor is banned
# reject if software used by actor is banned if view.config.is_banned_software(software):
if conn.get_software_ban(software):
view.app.push_message(
view.actor.shared_inbox,
Message.new_response(
host = view.config.domain,
actor = view.actor.id,
followid = view.message.id,
accept = False
)
)
logging.verbose(
'Rejected follow from actor for using specific software: actor=%s, software=%s',
view.actor.id,
software
)
return
## reject if the actor is not an instance actor
if person_check(view.actor, software):
view.app.push_message(
view.actor.shared_inbox,
Message.new_response(
host = view.config.domain,
actor = view.actor.id,
followid = view.message.id,
accept = False
)
)
logging.verbose('Non-application actor tried to follow: %s', view.actor.id)
return
if conn.get_inbox(view.actor.shared_inbox):
data = {'followid': view.message.id}
statement = tinysql.Update('inboxes', data, inbox = view.actor.shared_inbox)
with conn.query(statement):
pass
else:
conn.put_inbox(
view.actor.domain,
view.actor.shared_inbox,
view.actor.id,
view.message.id,
software
)
view.app.push_message( view.app.push_message(
view.actor.shared_inbox, view.actor.shared_inbox,
Message.new_response( Message.new_response(
host = view.config.domain, host = view.config.host,
actor = view.actor.id, actor = view.actor.id,
followid = view.message.id, followid = view.message.id,
accept = True accept = False
) )
) )
# Are Akkoma and Pleroma the only two that expect a follow back? return logging.verbose(
# Ignoring only Mastodon for now 'Rejected follow from actor for using specific software: actor=%s, software=%s',
if software != 'mastodon': view.actor.id,
view.app.push_message( software
view.actor.shared_inbox, )
Message.new_follow(
host = view.config.domain, ## reject if the actor is not an instance actor
actor = view.actor.id if person_check(view.actor, software):
) view.app.push_message(
view.actor.shared_inbox,
Message.new_response(
host = view.config.host,
actor = view.actor.id,
followid = view.message.id,
accept = False
) )
)
logging.verbose('Non-application actor tried to follow: %s', view.actor.id)
async def handle_undo(view: ActorView) -> None:
## If the object is not a Follow, forward it
if view.message.object['type'] != 'Follow':
await handle_forward(view)
return return
with view.database.connection() as conn: view.database.add_inbox(view.actor.shared_inbox, view.message.id, software)
if not conn.del_inbox(view.actor.id): view.database.save()
logging.verbose(
'Failed to delete "%s" with follow ID "%s"', view.app.push_message(
view.actor.id, view.actor.shared_inbox,
view.message.object['id'] Message.new_response(
host = view.config.host,
actor = view.actor.id,
followid = view.message.id,
accept = True
)
)
# Are Akkoma and Pleroma the only two that expect a follow back?
# Ignoring only Mastodon for now
if software != 'mastodon':
view.app.push_message(
view.actor.shared_inbox,
Message.new_follow(
host = view.config.host,
actor = view.actor.id
) )
)
async def handle_undo(view: View) -> None:
## If the object is not a Follow, forward it
if view.message.object['type'] != 'Follow':
return await handle_forward(view)
if not view.database.del_inbox(view.actor.domain, view.message.object['id']):
logging.verbose(
'Failed to delete "%s" with follow ID "%s"',
view.actor.id,
view.message.object['id']
)
return
view.database.save()
view.app.push_message( view.app.push_message(
view.actor.shared_inbox, view.actor.shared_inbox,
Message.new_unfollow( Message.new_unfollow(
host = view.config.domain, host = view.config.host,
actor = view.actor.id, actor = view.actor.id,
follow = view.message follow = view.message
) )
@ -168,7 +154,7 @@ processors = {
} }
async def run_processor(view: ActorView) -> None: async def run_processor(view: View) -> None:
if view.message.type not in processors: if view.message.type not in processors:
logging.verbose( logging.verbose(
'Message type "%s" from actor cannot be handled: %s', 'Message type "%s" from actor cannot be handled: %s',
@ -178,21 +164,12 @@ async def run_processor(view: ActorView) -> None:
return return
if view.instance: if view.instance and not view.instance.get('software'):
if not view.instance['software']: nodeinfo = await view.client.fetch_nodeinfo(view.instance['domain'])
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])):
with view.database.connection() as conn:
view.instance = conn.update_inbox(
view.instance['inbox'],
software = nodeinfo.sw_name
)
if not view.instance['actor']: if nodeinfo:
with view.database.connection() as conn: view.instance['software'] = nodeinfo.sw_name
view.instance = conn.update_inbox( view.database.save()
view.instance['inbox'],
actor = view.actor.id
)
logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id) logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id)
await processors[view.message.type](view) await processors[view.message.type](view)

View file

@ -2,11 +2,8 @@ from __future__ import annotations
import asyncio import asyncio
import subprocess import subprocess
import traceback
import typing import typing
from aputils.errors import SignatureFailureError
from aputils.misc import Digest, HttpDate, Signature
from aputils.objects import Nodeinfo, Webfinger, WellKnownNodeinfo from aputils.objects import Nodeinfo, Webfinger, WellKnownNodeinfo
from pathlib import Path from pathlib import Path
@ -17,8 +14,7 @@ from .processors import run_processor
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from aiohttp.web import Request from aiohttp.web import Request
from aputils.signer import Signer from typing import Callable
from collections.abc import Callable
VIEWS = [] VIEWS = []
@ -75,16 +71,12 @@ def register_route(*paths: str) -> Callable:
@register_route('/') @register_route('/')
class HomeView(View): class HomeView(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.connection() as conn: text = HOME_TEMPLATE.format(
config = conn.get_config_all() host = self.config.host,
inboxes = conn.execute('SELECT * FROM inboxes').all() note = self.config.note,
count = len(self.database.hostnames),
text = HOME_TEMPLATE.format( targets = '<br>'.join(self.database.hostnames)
host = self.config.domain, )
note = config['note'],
count = len(inboxes),
targets = '<br>'.join(inbox['domain'] for inbox in inboxes)
)
return Response.new(text, ctype='html') return Response.new(text, ctype='html')
@ -92,137 +84,44 @@ class HomeView(View):
@register_route('/actor', '/inbox') @register_route('/actor', '/inbox')
class ActorView(View): class ActorView(View):
def __init__(self, request: Request):
View.__init__(self, request)
self.signature: Signature = None
self.message: Message = None
self.actor: Message = None
self.instance: dict[str, str] = None
self.signer: Signer = None
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
data = Message.new_actor( data = Message.new_actor(
host = self.config.domain, host = self.config.host,
pubkey = self.app.signer.pubkey pubkey = self.database.signer.pubkey
) )
return Response.new(data, ctype='activity') return Response.new(data, ctype='activity')
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
if response := await self.get_post_data(): response = await self.get_post_data()
if response is not None:
return response return response
with self.database.connection() as conn: ## reject if the actor isn't whitelisted while the whiltelist is enabled
self.instance = conn.get_inbox(self.actor.shared_inbox) if self.config.whitelist_enabled and not self.config.is_whitelisted(self.actor.domain):
config = conn.get_config_all() logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id)
return Response.new_error(403, 'access denied', 'json')
## reject if the actor isn't whitelisted while the whiltelist is enabled ## reject if actor is banned
if config['whitelist-enabled'] and not conn.get_domain_whitelist(self.actor.domain): if self.config.is_banned(self.actor.domain):
logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id) logging.verbose('Ignored request from banned actor: %s', self.actor.id)
return Response.new_error(403, 'access denied', 'json') return Response.new_error(403, 'access denied', 'json')
## reject if actor is banned ## reject if activity type isn't 'Follow' and the actor isn't following
if conn.get_domain_ban(self.actor.domain): if self.message.type != 'Follow' and not self.database.get_inbox(self.actor.domain):
logging.verbose('Ignored request from banned actor: %s', self.actor.id) logging.verbose(
return Response.new_error(403, 'access denied', 'json') 'Rejected actor for trying to post while not following: %s',
self.actor.id
)
## reject if activity type isn't 'Follow' and the actor isn't following return Response.new_error(401, 'access denied', 'json')
if self.message.type != 'Follow' and not self.instance:
logging.verbose(
'Rejected actor for trying to post while not following: %s',
self.actor.id
)
return Response.new_error(401, 'access denied', 'json') logging.debug('>> payload %s', self.message.to_json(4))
logging.debug('>> payload %s', self.message.to_json(4)) asyncio.ensure_future(run_processor(self))
return Response.new(status = 202)
asyncio.ensure_future(run_processor(self))
return Response.new(status = 202)
async def get_post_data(self) -> Response | None:
try:
self.signature = Signature.new_from_signature(self.request.headers['signature'])
except KeyError:
logging.verbose('Missing signature header')
return Response.new_error(400, 'missing signature header', 'json')
try:
self.message = await self.request.json(loads = Message.parse)
except Exception:
traceback.print_exc()
logging.verbose('Failed to parse inbox message')
return Response.new_error(400, 'failed to parse message', 'json')
if self.message is None:
logging.verbose('empty message')
return Response.new_error(400, 'missing message', 'json')
if 'actor' not in self.message:
logging.verbose('actor not in message')
return Response.new_error(400, 'no actor in message', 'json')
self.actor = await self.client.get(self.signature.keyid, sign_headers = True)
if not self.actor:
# ld signatures aren't handled atm, so just ignore it
if self.message.type == 'Delete':
logging.verbose('Instance sent a delete which cannot be handled')
return Response.new(status=202)
logging.verbose(f'Failed to fetch actor: {self.signature.keyid}')
return Response.new_error(400, 'failed to fetch actor', 'json')
try:
self.signer = self.actor.signer
except KeyError:
logging.verbose('Actor missing public key: %s', self.signature.keyid)
return Response.new_error(400, 'actor missing public key', 'json')
try:
self.validate_signature(await self.request.read())
except SignatureFailureError as e:
logging.verbose('signature validation failed for "%s": %s', self.actor.id, e)
return Response.new_error(401, str(e), 'json')
def validate_signature(self, body: bytes) -> None:
headers = {key.lower(): value for key, value in self.request.headers.items()}
headers["(request-target)"] = " ".join([self.request.method.lower(), self.request.path])
if (digest := Digest.new_from_digest(headers.get("digest"))):
if not body:
raise SignatureFailureError("Missing body for digest verification")
if not digest.validate(body):
raise SignatureFailureError("Body digest does not match")
if self.signature.algorithm_type == "hs2019":
if "(created)" not in self.signature.headers:
raise SignatureFailureError("'(created)' header not used")
current_timestamp = HttpDate.new_utc().timestamp()
if self.signature.created > current_timestamp:
raise SignatureFailureError("Creation date after current date")
if current_timestamp > self.signature.expires:
raise SignatureFailureError("Expiration date before current date")
headers["(created)"] = self.signature.created
headers["(expires)"] = self.signature.expires
# pylint: disable=protected-access
if not self.signer._validate_signature(headers, self.signature):
raise SignatureFailureError("Signature does not match")
@register_route('/.well-known/webfinger') @register_route('/.well-known/webfinger')
@ -234,12 +133,12 @@ class WebfingerView(View):
except KeyError: except KeyError:
return Response.new_error(400, 'missing "resource" query key', 'json') return Response.new_error(400, 'missing "resource" query key', 'json')
if subject != f'acct:relay@{self.config.domain}': if subject != f'acct:relay@{self.config.host}':
return Response.new_error(404, 'user not found', 'json') return Response.new_error(404, 'user not found', 'json')
data = Webfinger.new( data = Webfinger.new(
handle = 'relay', handle = 'relay',
domain = self.config.domain, domain = self.config.host,
actor = self.config.actor actor = self.config.actor
) )
@ -249,17 +148,14 @@ class WebfingerView(View):
@register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}') @register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}')
class NodeinfoView(View): class NodeinfoView(View):
async def get(self, request: Request, niversion: str) -> Response: async def get(self, request: Request, niversion: str) -> Response:
with self.database.connection() as conn: data = {
inboxes = conn.execute('SELECT * FROM inboxes').all() 'name': 'activityrelay',
'version': VERSION,
data = { 'protocols': ['activitypub'],
'name': 'activityrelay', 'open_regs': not self.config.whitelist_enabled,
'version': VERSION, 'users': 1,
'protocols': ['activitypub'], 'metadata': {'peers': self.database.hostnames}
'open_regs': not conn.get_config('whitelist-enabled'), }
'users': 1,
'metadata': {'peers': [inbox['domain'] for inbox in inboxes]}
}
if niversion == '2.1': if niversion == '2.1':
data['repo'] = 'https://git.pleroma.social/pleroma/relay' data['repo'] = 'https://git.pleroma.social/pleroma/relay'
@ -270,5 +166,5 @@ class NodeinfoView(View):
@register_route('/.well-known/nodeinfo') @register_route('/.well-known/nodeinfo')
class WellknownNodeinfoView(View): class WellknownNodeinfoView(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
data = WellKnownNodeinfo.new_template(self.config.domain) data = WellKnownNodeinfo.new_template(self.config.host)
return Response.new(data, ctype = 'json') return Response.new(data, ctype = 'json')

View file

@ -3,4 +3,3 @@ aputils@https://git.barkshark.xyz/barkshark/aputils/archive/0.1.6a.tar.gz
cachetools>=5.2.0 cachetools>=5.2.0
click>=8.1.2 click>=8.1.2
pyyaml>=6.0 pyyaml>=6.0
tinysql@https://git.barkshark.xyz/barkshark/tinysql/archive/0.2.2a.tar.gz

View file

@ -10,30 +10,26 @@ license_file = LICENSE
classifiers = classifiers =
Environment :: Console Environment :: Console
License :: OSI Approved :: AGPLv3 License License :: OSI Approved :: AGPLv3 License
Programming Language :: Python :: 3.6
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9 Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10 Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
Programming Language :: Python :: 3.12
project_urls = project_urls =
Source = https://git.pleroma.social/pleroma/relay Source = https://git.pleroma.social/pleroma/relay
Tracker = https://git.pleroma.social/pleroma/relay/-/issues Tracker = https://git.pleroma.social/pleroma/relay/-/issues
[options] [options]
zip_safe = False zip_safe = False
packages = packages = find:
relay
relay.database
include_package_data = true
install_requires = file: requirements.txt install_requires = file: requirements.txt
python_requires = >=3.8 python_requires = >=3.8
[options.extras_require] [options.extras_require]
dev = file: dev-requirements.txt dev =
flake8 = 3.1.0
[options.package_data] pyinstaller = 6.3.0
relay = pylint = 3.0
data/statements.sql
[options.entry_points] [options.entry_points]
console_scripts = console_scripts =
@ -41,4 +37,7 @@ console_scripts =
[flake8] [flake8]
select = F401 extend-ignore = ANN101,ANN204,E128,E251,E261,E266,E301,E303,W191
extend-exclude = docs, test*.py
max-line-length = 100
indent-size = 4