diff --git a/.gitignore b/.gitignore
index ecb6570..737b9a4 100644
--- a/.gitignore
+++ b/.gitignore
@@ -94,9 +94,7 @@ ENV/
# Rope project settings
.ropeproject
-viera.yaml
-viera.jsonld
-
-# config file
-relay.yaml
-relay.jsonld
+# config and database
+*.yaml
+*.jsonld
+*.sqlite3
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 0000000..83229d2
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1 @@
+include data/statements.sql
diff --git a/relay.spec b/relay.spec
index 57fedc7..535e5fe 100644
--- a/relay.spec
+++ b/relay.spec
@@ -5,40 +5,43 @@ block_cipher = None
a = Analysis(
- ['relay/__main__.py'],
- pathex=[],
- binaries=[],
- datas=[],
- hiddenimports=[],
- hookspath=[],
- hooksconfig={},
- runtime_hooks=[],
- excludes=[],
- win_no_prefer_redirects=False,
- win_private_assemblies=False,
- cipher=block_cipher,
- noarchive=False,
+ ['relay/__main__.py'],
+ pathex=[],
+ binaries=[],
+ datas=[
+ ('relay/data', 'relay/data')
+ ],
+ hiddenimports=[],
+ hookspath=[],
+ hooksconfig={},
+ runtime_hooks=[],
+ excludes=[],
+ win_no_prefer_redirects=False,
+ win_private_assemblies=False,
+ cipher=block_cipher,
+ noarchive=False,
)
+
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
- pyz,
- a.scripts,
- a.binaries,
- a.zipfiles,
- a.datas,
- [],
- name='activityrelay',
- debug=False,
- bootloader_ignore_signals=False,
- strip=False,
- upx=True,
- upx_exclude=[],
- runtime_tmpdir=None,
- console=True,
- disable_windowed_traceback=False,
- argv_emulation=False,
- target_arch=None,
- codesign_identity=None,
- entitlements_file=None,
+ pyz,
+ a.scripts,
+ a.binaries,
+ a.zipfiles,
+ a.datas,
+ [],
+ name='activityrelay',
+ debug=False,
+ bootloader_ignore_signals=False,
+ strip=False,
+ upx=True,
+ upx_exclude=[],
+ runtime_tmpdir=None,
+ console=True,
+ disable_windowed_traceback=False,
+ argv_emulation=False,
+ target_arch=None,
+ codesign_identity=None,
+ entitlements_file=None,
)
diff --git a/relay.yaml.example b/relay.yaml.example
index 4e35697..90b9e8f 100644
--- a/relay.yaml.example
+++ b/relay.yaml.example
@@ -1,43 +1,35 @@
-# this is the path that the object graph will get dumped to (in JSON-LD format),
-# you probably shouldn't change it, but you can if you want.
-db: relay.jsonld
+# [string] Domain the relay will be hosted on
+domain: relay.example.com
-# Listener
+# [string] Address the relay will listen on
listen: 0.0.0.0
+
+# [integer] Port the relay will listen on
port: 8080
-# Note
-note: "Make a note about your instance here."
+# [integer] Number of push workers to start (will get removed in a future update)
+workers: 8
-# Number of worker threads to start. If 0, use asyncio futures instead of threads.
-workers: 0
+# [string] Database backend to use. Valid values: sqlite, postgres
+database_type: sqlite
-# Maximum number of inbox posts to do at once
-# If workers is set to 1 or above, this is the max for each worker
-push_limit: 512
+# [string] Path to the sqlite database file if the sqlite backend is in use
+sqlite_path: relay.sqlite3
-# The amount of json objects to cache from GET requests
-json_cache: 1024
+# settings for the postgresql backend
+postgres:
-ap:
- # This is used for generating activitypub messages, as well as instructions for
- # linking AP identities. It should be an SSL-enabled domain reachable by https.
- host: 'relay.example.com'
+ # [string] hostname or unix socket to connect to
+ host: /var/run/postgresql
- blocked_instances:
- - 'bad-instance.example.com'
- - 'another-bad-instance.example.com'
+ # [integer] port of the server
+ port: 5432
- whitelist_enabled: false
+ # [string] username to use when logging into the server (default is the current system username)
+ user: null
- whitelist:
- - 'good-instance.example.com'
- - 'another.good-instance.example.com'
+ # [string] password of the user
+ pass: null
- # uncomment the lines below to prevent certain activitypub software from posting
- # to the relay (all known relays by default). this uses the software name in nodeinfo
- #blocked_software:
- #- 'activityrelay'
- #- 'aoderelay'
- #- 'social.seattle.wa.us-relay'
- #- 'unciarelay'
+ # [string] name of the database to use
+ name: activityrelay
diff --git a/relay/application.py b/relay/application.py
index a01aaec..9440098 100644
--- a/relay/application.py
+++ b/relay/application.py
@@ -8,52 +8,41 @@ import traceback
import typing
from aiohttp import web
+from aputils.signer import Signer
from datetime import datetime, timedelta
from . import logger as logging
-from .config import RelayConfig
-from .database import RelayDatabase
+from .config import Config
+from .database import get_database
from .http_client import HttpClient
from .misc import check_open_port
from .views import VIEWS
if typing.TYPE_CHECKING:
+ from tinysql import Database
from typing import Any
from .misc import Message
# pylint: disable=unsubscriptable-object
-
class Application(web.Application):
+ DEFAULT: Application = None
+
def __init__(self, cfgpath: str):
web.Application.__init__(self)
+ Application.DEFAULT = self
+
+ self['signer'] = None
+ self['config'] = Config(cfgpath, load = True)
+ self['database'] = get_database(self.config)
+ self['client'] = HttpClient()
+
self['workers'] = []
self['last_worker'] = 0
self['start_time'] = None
self['running'] = False
- self['config'] = RelayConfig(cfgpath)
-
- if not self.config.load():
- self.config.save()
-
- if self.config.is_docker:
- self.config.update({
- 'db': '/data/relay.jsonld',
- 'listen': '0.0.0.0',
- 'port': 8080
- })
-
- self['database'] = RelayDatabase(self.config)
- self.database.load()
-
- self['client'] = HttpClient(
- database = self.database,
- limit = self.config.push_limit,
- timeout = self.config.timeout,
- cache_size = self.config.json_cache
- )
for path, view in VIEWS:
self.router.add_view(path, view)
@@ -65,15 +54,29 @@ class Application(web.Application):
@property
- def config(self) -> RelayConfig:
+ def config(self) -> Config:
return self['config']
@property
- def database(self) -> RelayDatabase:
+ def database(self) -> Database:
return self['database']
+ @property
+ def signer(self) -> Signer:
+ return self['signer']
+
+
+ @signer.setter
+ def signer(self, value: Signer | str) -> None:
+ if isinstance(value, Signer):
+ self['signer'] = value
+ return
+
+ self['signer'] = Signer(value, self.config.keyid)
+
+
@property
def uptime(self) -> timedelta:
if not self['start_time']:
@@ -118,7 +121,7 @@ class Application(web.Application):
logging.info(
'Starting webserver at %s (%s:%i)',
- self.config.host,
+ self.config.domain,
self.config.listen,
self.config.port
)
@@ -179,12 +182,7 @@ class PushWorker(threading.Thread):
async def handle_queue(self) -> None:
- self.client = HttpClient(
- database = self.app.database,
- limit = self.app.config.push_limit,
- timeout = self.app.config.timeout,
- cache_size = self.app.config.json_cache
- )
+ self.client = HttpClient()
while self.app['running']:
try:
diff --git a/relay/database.py b/relay/compat.py
similarity index 64%
rename from relay/database.py
rename to relay/compat.py
index 5d059dd..16d6461 100644
--- a/relay/database.py
+++ b/relay/compat.py
@@ -1,17 +1,128 @@
from __future__ import annotations
import json
+import os
import typing
+import yaml
-from aputils.signer import Signer
+from functools import cached_property
+from pathlib import Path
from urllib.parse import urlparse
from . import logger as logging
+from .misc import Message, boolean
if typing.TYPE_CHECKING:
- from typing import Iterator, Optional
- from .config import RelayConfig
- from .misc import Message
+ from typing import Any, Iterator, Optional
+
+
+# pylint: disable=duplicate-code
+
+class RelayConfig(dict):
+ def __init__(self, path: str):
+ dict.__init__(self, {})
+
+ if self.is_docker:
+ path = '/data/config.yaml'
+
+ self._path = Path(path).expanduser().resolve()
+ self.reset()
+
+
+ def __setitem__(self, key: str, value: Any) -> None:
+ if key in ['blocked_instances', 'blocked_software', 'whitelist']:
+ assert isinstance(value, (list, set, tuple))
+
+ elif key in ['port', 'workers', 'json_cache', 'timeout']:
+ if not isinstance(value, int):
+ value = int(value)
+
+ elif key == 'whitelist_enabled':
+ if not isinstance(value, bool):
+ value = boolean(value)
+
+ super().__setitem__(key, value)
+
+
+ @property
+ def db(self) -> RelayDatabase:
+ return Path(self['db']).expanduser().resolve()
+
+
+ @property
+ def actor(self) -> str:
+ return f'https://{self["host"]}/actor'
+
+
+ @property
+ def inbox(self) -> str:
+ return f'https://{self["host"]}/inbox'
+
+
+ @property
+ def keyid(self) -> str:
+ return f'{self.actor}#main-key'
+
+
+ @cached_property
+ def is_docker(self) -> bool:
+ return bool(os.environ.get('DOCKER_RUNNING'))
+
+
+ def reset(self) -> None:
+ self.clear()
+ self.update({
+ 'db': str(self._path.parent.joinpath(f'{self._path.stem}.jsonld')),
+ 'listen': '0.0.0.0',
+ 'port': 8080,
+ 'note': 'Make a note about your instance here.',
+ 'push_limit': 512,
+ 'json_cache': 1024,
+ 'timeout': 10,
+ 'workers': 0,
+ 'host': 'relay.example.com',
+ 'whitelist_enabled': False,
+ 'blocked_software': [],
+ 'blocked_instances': [],
+ 'whitelist': []
+ })
+
+
+ def load(self) -> None:
+ self.reset()
+
+ options = {}
+
+ try:
+ options['Loader'] = yaml.FullLoader
+
+ except AttributeError:
+ pass
+
+ try:
+ with self._path.open('r', encoding = 'UTF-8') as fd:
+ config = yaml.load(fd, **options)
+
+ except FileNotFoundError:
+ return
+
+ if not config:
+ return
+
+ for key, value in config.items():
+ if key in ['ap']:
+ for k, v in value.items():
+ if k not in self:
+ continue
+
+ self[k] = v
+
+ continue
+
+ if key not in self:
+ continue
+
+ self[key] = value
class RelayDatabase(dict):
@@ -37,9 +148,7 @@ class RelayDatabase(dict):
return tuple(data['inbox'] for data in self['relay-list'].values())
- def load(self) -> bool:
- new_db = True
-
+ def load(self) -> None:
try:
with self.config.db.open() as fd:
data = json.load(fd)
@@ -65,17 +174,9 @@ class RelayDatabase(dict):
self['relay-list'] = data.get('relay-list', {})
for domain, instance in self['relay-list'].items():
- if self.config.is_banned(domain) or \
- (self.config.whitelist_enabled and not self.config.is_whitelisted(domain)):
-
- self.del_inbox(domain)
- continue
-
if not instance.get('domain'):
instance['domain'] = domain
- new_db = False
-
except FileNotFoundError:
pass
@@ -83,17 +184,6 @@ class RelayDatabase(dict):
if self.config.db.stat().st_size > 0:
raise e from None
- if not self['private-key']:
- logging.info('No actor keys present, generating 4096-bit RSA keypair.')
- self.signer = Signer.new(self.config.keyid, size=4096)
- self['private-key'] = self.signer.export()
-
- else:
- self.signer = Signer(self['private-key'], self.config.keyid)
-
- self.save()
- return not new_db
-
def save(self) -> None:
with self.config.db.open('w', encoding = 'UTF-8') as fd:
diff --git a/relay/config.py b/relay/config.py
index e684ead..937372f 100644
--- a/relay/config.py
+++ b/relay/config.py
@@ -1,76 +1,76 @@
from __future__ import annotations
+import getpass
import os
import typing
import yaml
-from functools import cached_property
from pathlib import Path
-from urllib.parse import urlparse
-from .misc import DotDict, boolean
+from .misc import IS_DOCKER
if typing.TYPE_CHECKING:
- from typing import Any
- from .database import RelayDatabase
+ from typing import Any, Optional
-RELAY_SOFTWARE = [
- 'activityrelay', # https://git.pleroma.social/pleroma/relay
- 'aoderelay', # https://git.asonix.dog/asonix/relay
- 'feditools-relay' # https://git.ptzo.gdn/feditools/relay
-]
+DEFAULTS: dict[str, Any] = {
+ 'listen': '0.0.0.0',
+ 'port': 8080,
+ 'domain': 'relay.example.com',
+ 'workers': len(os.sched_getaffinity(0)),
+ 'db_type': 'sqlite',
+ 'sq_path': 'relay.sqlite3',
+ 'pg_host': '/var/run/postgresql',
+ 'pg_port': 5432,
+ 'pg_user': getpass.getuser(),
+ 'pg_pass': None,
+ 'pg_name': 'activityrelay'
+}
-APKEYS = [
- 'host',
- 'whitelist_enabled',
- 'blocked_software',
- 'blocked_instances',
- 'whitelist'
-]
+if IS_DOCKER:
+ DEFAULTS['sq_path'] = '/data/relay.jsonld'
-class RelayConfig(DotDict):
- __slots__ = ('path', )
+class Config:
+ def __init__(self, path: str, load: Optional[bool] = False):
+ self.path = Path(path).expanduser().resolve()
- def __init__(self, path: str | Path):
- DotDict.__init__(self, {})
+ self.listen = None
+ self.port = None
+ self.domain = None
+ self.workers = None
+ self.db_type = None
+ self.sq_path = None
+ self.pg_host = None
+ self.pg_port = None
+ self.pg_user = None
+ self.pg_pass = None
+ self.pg_name = None
- if self.is_docker:
- path = '/data/config.yaml'
+ if load:
+ try:
+ self.load()
- self._path = Path(path).expanduser().resolve()
- self.reset()
-
-
- def __setitem__(self, key: str, value: Any) -> None:
- if key in ['blocked_instances', 'blocked_software', 'whitelist']:
- assert isinstance(value, (list, set, tuple))
-
- elif key in ['port', 'workers', 'json_cache', 'timeout']:
- if not isinstance(value, int):
- value = int(value)
-
- elif key == 'whitelist_enabled':
- if not isinstance(value, bool):
- value = boolean(value)
-
- super().__setitem__(key, value)
+ except FileNotFoundError:
+ self.save()
@property
- def db(self) -> RelayDatabase:
- return Path(self['db']).expanduser().resolve()
+ def sqlite_path(self) -> Path:
+ if not os.path.isabs(self.sq_path):
+ return self.path.parent.joinpath(self.sq_path).resolve()
+
+ return Path(self.sq_path).expanduser().resolve()
@property
def actor(self) -> str:
- return f'https://{self.host}/actor'
+ return f'https://{self.domain}/actor'
@property
def inbox(self) -> str:
- return f'https://{self.host}/inbox'
+ return f'https://{self.domain}/inbox'
@property
@@ -78,115 +78,7 @@ class RelayConfig(DotDict):
return f'{self.actor}#main-key'
- @cached_property
- def is_docker(self) -> bool:
- return bool(os.environ.get('DOCKER_RUNNING'))
-
-
- def reset(self) -> None:
- self.clear()
- self.update({
- 'db': str(self._path.parent.joinpath(f'{self._path.stem}.jsonld')),
- 'listen': '0.0.0.0',
- 'port': 8080,
- 'note': 'Make a note about your instance here.',
- 'push_limit': 512,
- 'json_cache': 1024,
- 'timeout': 10,
- 'workers': 0,
- 'host': 'relay.example.com',
- 'whitelist_enabled': False,
- 'blocked_software': [],
- 'blocked_instances': [],
- 'whitelist': []
- })
-
-
- def ban_instance(self, instance: str) -> bool:
- if instance.startswith('http'):
- instance = urlparse(instance).hostname
-
- if self.is_banned(instance):
- return False
-
- self.blocked_instances.append(instance)
- return True
-
-
- def unban_instance(self, instance: str) -> bool:
- if instance.startswith('http'):
- instance = urlparse(instance).hostname
-
- try:
- self.blocked_instances.remove(instance)
- return True
-
- except ValueError:
- return False
-
-
- def ban_software(self, software: str) -> bool:
- if self.is_banned_software(software):
- return False
-
- self.blocked_software.append(software)
- return True
-
-
- def unban_software(self, software: str) -> bool:
- try:
- self.blocked_software.remove(software)
- return True
-
- except ValueError:
- return False
-
-
- def add_whitelist(self, instance: str) -> bool:
- if instance.startswith('http'):
- instance = urlparse(instance).hostname
-
- if self.is_whitelisted(instance):
- return False
-
- self.whitelist.append(instance)
- return True
-
-
- def del_whitelist(self, instance: str) -> bool:
- if instance.startswith('http'):
- instance = urlparse(instance).hostname
-
- try:
- self.whitelist.remove(instance)
- return True
-
- except ValueError:
- return False
-
-
- def is_banned(self, instance: str) -> bool:
- if instance.startswith('http'):
- instance = urlparse(instance).hostname
-
- return instance in self.blocked_instances
-
-
- def is_banned_software(self, software: str) -> bool:
- if not software:
- return False
-
- return software.lower() in self.blocked_software
-
-
- def is_whitelisted(self, instance: str) -> bool:
- if instance.startswith('http'):
- instance = urlparse(instance).hostname
-
- return instance in self.whitelist
-
-
- def load(self) -> bool:
+ def load(self) -> None:
self.reset()
options = {}
@@ -197,50 +89,69 @@ class RelayConfig(DotDict):
except AttributeError:
pass
- try:
- with self._path.open('r', encoding = 'UTF-8') as fd:
- config = yaml.load(fd, **options)
-
- except FileNotFoundError:
- return False
+ with self.path.open('r', encoding = 'UTF-8') as fd:
+ config = yaml.load(fd, **options)
+ pgcfg = config.get('postgresql', {})
if not config:
- return False
+ raise ValueError('Config is empty')
- for key, value in config.items():
- if key in ['ap']:
- for k, v in value.items():
- if k not in self:
- continue
+ if IS_DOCKER:
+ self.listen = '0.0.0.0'
+ self.port = 8080
+ self.sq_path = '/data/relay.jsonld'
- self[k] = v
+ else:
+ self.set('listen', config.get('listen', DEFAULTS['listen']))
+ self.set('port', config.get('port', DEFAULTS['port']))
+ self.set('sq_path', config.get('sqlite_path', DEFAULTS['sq_path']))
+ self.set('domain', config.get('domain', DEFAULTS['domain']))
+ self.set('db_type', config.get('database_type', DEFAULTS['db_type']))
+
+ for key in DEFAULTS:
+ if not key.startswith('pg'):
continue
- if key not in self:
+ try:
+ self.set(key, pgcfg[key[3:]])
+
+ except KeyError:
continue
- self[key] = value
- if self.host.endswith('example.com'):
- return False
-
- return True
+ def reset(self) -> None:
+ for key, value in DEFAULTS.items():
+ setattr(self, key, value)
def save(self) -> None:
+ self.path.parent.mkdir(exist_ok = True, parents = True)
+
config = {
- # just turning config.db into a string is good enough for now
- 'db': str(self.db),
'listen': self.listen,
'port': self.port,
- 'note': self.note,
- 'push_limit': self.push_limit,
- 'workers': self.workers,
- 'json_cache': self.json_cache,
- 'timeout': self.timeout,
- 'ap': {key: self[key] for key in APKEYS}
+ 'domain': self.domain,
+ 'database_type': self.db_type,
+ 'sqlite_path': self.sq_path,
+ 'postgres': {
+ 'host': self.pg_host,
+ 'port': self.pg_port,
+ 'user': self.pg_user,
+ 'pass': self.pg_pass,
+ 'name': self.pg_name
+ }
}
- with self._path.open('w', encoding = 'utf-8') as fd:
- yaml.dump(config, fd, sort_keys=False)
+ with self.path.open('w', encoding = 'utf-8') as fd:
+ yaml.dump(config, fd, sort_keys = False)
+
+
+ def set(self, key: str, value: Any) -> None:
+ if key not in DEFAULTS:
+ raise KeyError(key)
+
+ if key in ('port', 'pg_port', 'workers') and not isinstance(value, int):
+ value = int(value)
+
+ setattr(self, key, value)
diff --git a/relay/data/statements.sql b/relay/data/statements.sql
new file mode 100644
index 0000000..a262feb
--- /dev/null
+++ b/relay/data/statements.sql
@@ -0,0 +1,79 @@
+-- name: get-config
+SELECT * FROM config WHERE key = :key
+
+
+-- name: get-config-all
+SELECT * FROM config
+
+
+-- name: put-config
+INSERT INTO config (key, value, type)
+VALUES (:key, :value, :type)
+ON CONFLICT (key) DO UPDATE SET value = :value
+RETURNING *
+
+
+-- name: del-config
+DELETE FROM config
+WHERE key = :key
+
+
+-- name: get-inbox
+SELECT * FROM inboxes WHERE domain = :value or inbox = :value or actor = :value
+
+
+-- name: put-inbox
+INSERT INTO inboxes (domain, actor, inbox, followid, software, created)
+VALUES (:domain, :actor, :inbox, :followid, :software, :created)
+ON CONFLICT (domain) DO UPDATE SET followid = :followid
+RETURNING *
+
+
+-- name: del-inbox
+DELETE FROM inboxes
+WHERE domain = :value or inbox = :value or actor = :value
+
+
+-- name: get-software-ban
+SELECT * FROM software_bans WHERE name = :name
+
+
+-- name: put-software-ban
+INSERT INTO software_bans (name, reason, note, created)
+VALUES (:name, :reason, :note, :created)
+RETURNING *
+
+
+-- name: del-software-ban
+DELETE FROM software_bans
+WHERE name = :name
+
+
+-- name: get-domain-ban
+SELECT * FROM domain_bans WHERE domain = :domain
+
+
+-- name: put-domain-ban
+INSERT INTO domain_bans (domain, reason, note, created)
+VALUES (:domain, :reason, :note, :created)
+RETURNING *
+
+
+-- name: del-domain-ban
+DELETE FROM domain_bans
+WHERE domain = :domain
+
+
+-- name: get-domain-whitelist
+SELECT * FROM whitelist WHERE domain = :domain
+
+
+-- name: put-domain-whitelist
+INSERT INTO whitelist (domain, created)
+VALUES (:domain, :created)
+RETURNING *
+
+
+-- name: del-domain-whitelist
+DELETE FROM whitelist
+WHERE domain = :domain
diff --git a/relay/database/__init__.py b/relay/database/__init__.py
new file mode 100644
index 0000000..925c5e0
--- /dev/null
+++ b/relay/database/__init__.py
@@ -0,0 +1,63 @@
+from __future__ import annotations
+
+import tinysql
+import typing
+
+from importlib.resources import files as pkgfiles
+
+from .config import get_default_value
+from .connection import Connection
+from .schema import VERSIONS, migrate_0
+
+from .. import logger as logging
+
+if typing.TYPE_CHECKING:
+ from typing import Optional
+ from .config import Config
+
+
+def get_database(config: Config, migrate: Optional[bool] = True) -> tinysql.Database:
+ if config.db_type == "sqlite":
+ db = tinysql.Database.sqlite(config.sqlite_path, connection_class = Connection)
+
+ elif config.db_type == "postgres":
+ db = tinysql.Database.postgres(
+ config.pg_name,
+ config.pg_host,
+ config.pg_port,
+ config.pg_user,
+ config.pg_pass,
+ connection_class = Connection
+ )
+
+ db.load_prepared_statements(pkgfiles("relay").joinpath("data", "statements.sql"))
+
+ if not migrate:
+ return db
+
+ with db.connection() as conn:
+ if 'config' not in conn.get_tables():
+ logging.info("Creating database tables")
+ migrate_0(conn)
+ return db
+
+ schema_ver = conn.get_config('schema-version')
+
+ if schema_ver < get_default_value('schema-version'):
+ logging.info("Migrating database from version '%i'", schema_ver)
+
+ for ver, func in VERSIONS:
+ if schema_ver < ver:
+ conn.begin()
+
+ func(conn)
+
+ conn.put_config('schema-version', ver)
+ conn.commit()
+
+ if (privkey := conn.get_config('private-key')):
+ conn.app.signer = privkey
+
+ logging.set_level(conn.get_config('log-level'))
+
+ return db
diff --git a/relay/database/config.py b/relay/database/config.py
new file mode 100644
index 0000000..e132647
--- /dev/null
+++ b/relay/database/config.py
@@ -0,0 +1,44 @@
+from __future__ import annotations
+
+import typing
+
+from .. import logger as logging
+from ..misc import boolean
+
+if typing.TYPE_CHECKING:
+ from typing import Any, Callable
+
+
+CONFIG_DEFAULTS: dict[str, tuple[str, Any]] = {
+ 'schema-version': ('int', 20240119),
+ 'log-level': ('loglevel', logging.LogLevel.INFO),
+ 'note': ('str', 'Make a note about your instance here.'),
+ 'private-key': ('str', None),
+ 'whitelist-enabled': ('bool', False)
+}
+
+# serializer | deserializer
+CONFIG_CONVERT: dict[str, tuple[Callable, Callable]] = {
+ 'str': (str, str),
+ 'int': (str, int),
+ 'bool': (str, boolean),
+ 'loglevel': (lambda x: x.name, logging.LogLevel.parse)
+}
+
+
+def get_default_value(key: str) -> Any:
+ return CONFIG_DEFAULTS[key][1]
+
+
+def get_default_type(key: str) -> str:
+ return CONFIG_DEFAULTS[key][0]
+
+
+def serialize(key: str, value: Any) -> str:
+ type_name = get_default_type(key)
+ return CONFIG_CONVERT[type_name][0](value)
+
+
+def deserialize(key: str, value: str) -> Any:
+ type_name = get_default_type(key)
+ return CONFIG_CONVERT[type_name][1](value)
diff --git a/relay/database/connection.py b/relay/database/connection.py
new file mode 100644
index 0000000..43bbb7e
--- /dev/null
+++ b/relay/database/connection.py
@@ -0,0 +1,295 @@
+from __future__ import annotations
+
+import tinysql
+import typing
+
+from datetime import datetime, timezone
+from urllib.parse import urlparse
+
+from .config import CONFIG_DEFAULTS, get_default_type, get_default_value, serialize, deserialize
+
+from .. import logger as logging
+from ..misc import get_app
+
+if typing.TYPE_CHECKING:
+ from tinysql import Cursor, Row
+ from typing import Any, Iterator, Optional
+ from .application import Application
+ from ..misc import Message
+
+
+RELAY_SOFTWARE = [
+ 'activityrelay', # https://git.pleroma.social/pleroma/relay
+ 'activity-relay', # https://github.com/yukimochi/Activity-Relay
+ 'aoderelay', # https://git.asonix.dog/asonix/relay
+ 'feditools-relay' # https://git.ptzo.gdn/feditools/relay
+]
+
+
+class Connection(tinysql.Connection):
+ @property
+ def app(self) -> Application:
+ return get_app()
+
+
+ def distill_inboxes(self, message: Message) -> Iterator[str]:
+ src_domains = {
+ message.domain,
+ urlparse(message.object_id).netloc
+ }
+
+ for inbox in self.execute('SELECT * FROM inboxes'):
+ if inbox['domain'] not in src_domains:
+ yield inbox['inbox']
+
+
+ def exec_statement(self, name: str, params: Optional[dict[str, Any]] = None) -> Cursor:
+ return self.execute(self.database.prepared_statements[name], params)
+
+
+ def get_config(self, key: str) -> Any:
+ if key not in CONFIG_DEFAULTS:
+ raise KeyError(key)
+
+ with self.exec_statement('get-config', {'key': key}) as cur:
+ if not (row := cur.one()):
+ return get_default_value(key)
+
+ if row['value']:
+ return deserialize(row['key'], row['value'])
+
+ return None
+
+
+ def get_config_all(self) -> dict[str, Any]:
+ with self.exec_statement('get-config-all') as cur:
+ db_config = {row['key']: row['value'] for row in cur}
+
+ config = {}
+
+ for key, data in CONFIG_DEFAULTS.items():
+ try:
+ config[key] = deserialize(key, db_config[key])
+
+ except KeyError:
+ if key == 'schema-version':
+ config[key] = 0
+
+ else:
+ config[key] = data[1]
+
+ return config
+
+
+ def put_config(self, key: str, value: Any) -> Any:
+ if key not in CONFIG_DEFAULTS:
+ raise KeyError(key)
+
+ if key == 'private-key':
+ self.app.signer = value
+
+ elif key == 'log-level':
+ value = logging.LogLevel.parse(value)
+ logging.set_level(value)
+
+ params = {
+ 'key': key,
+ 'value': serialize(key, value) if value is not None else None,
+ 'type': get_default_type(key)
+ }
+
+ with self.exec_statement('put-config', params):
+ return value
+
+
+ def get_inbox(self, value: str) -> Row:
+ with self.exec_statement('get-inbox', {'value': value}) as cur:
+ return cur.one()
+
+
+ def put_inbox(self,
+ domain: str,
+ inbox: str,
+ actor: Optional[str] = None,
+ followid: Optional[str] = None,
+ software: Optional[str] = None) -> Row:
+
+ params = {
+ 'domain': domain,
+ 'inbox': inbox,
+ 'actor': actor,
+ 'followid': followid,
+ 'software': software,
+ 'created': datetime.now(tz = timezone.utc)
+ }
+
+ with self.exec_statement('put-inbox', params) as cur:
+ return cur.one()
+
+
+ def update_inbox(self,
+ inbox: str,
+ actor: Optional[str] = None,
+ followid: Optional[str] = None,
+ software: Optional[str] = None) -> Row:
+
+ if not (actor or followid or software):
+ raise ValueError('Missing "actor", "followid", and/or "software"')
+
+ data = {}
+
+ if actor:
+ data['actor'] = actor
+
+ if followid:
+ data['followid'] = followid
+
+ if software:
+ data['software'] = software
+
+ statement = tinysql.Update('inboxes', data, inbox = inbox)
+
+ with self.query(statement):
+ return self.get_inbox(inbox)
+
+
+ def del_inbox(self, value: str) -> bool:
+ with self.exec_statement('del-inbox', {'value': value}) as cur:
+ if cur.modified_row_count > 1:
+ raise ValueError('More than one row was modified')
+
+ return cur.modified_row_count == 1
+
+
+ def get_domain_ban(self, domain: str) -> Row:
+ if domain.startswith('http'):
+ domain = urlparse(domain).netloc
+
+ with self.exec_statement('get-domain-ban', {'domain': domain}) as cur:
+ return cur.one()
+
+
+ def put_domain_ban(self,
+ domain: str,
+ reason: Optional[str] = None,
+ note: Optional[str] = None) -> Row:
+
+ params = {
+ 'domain': domain,
+ 'reason': reason,
+ 'note': note,
+ 'created': datetime.now(tz = timezone.utc)
+ }
+
+ with self.exec_statement('put-domain-ban', params) as cur:
+ return cur.one()
+
+
+ def update_domain_ban(self,
+ domain: str,
+ reason: Optional[str] = None,
+ note: Optional[str] = None) -> tinysql.Row:
+
+ if not (reason or note):
+ raise ValueError('"reason" and/or "note" must be specified')
+
+ params = {}
+
+ if reason:
+ params['reason'] = reason
+
+ if note:
+ params['note'] = note
+
+ statement = tinysql.Update('domain_bans', params, domain = domain)
+
+ with self.query(statement) as cur:
+ if cur.modified_row_count > 1:
+ raise ValueError('More than one row was modified')
+
+ return self.get_domain_ban(domain)
+
+
+ def del_domain_ban(self, domain: str) -> bool:
+ with self.exec_statement('del-domain-ban', {'domain': domain}) as cur:
+ if cur.modified_row_count > 1:
+ raise ValueError('More than one row was modified')
+
+ return cur.modified_row_count == 1
+
+
+ def get_software_ban(self, name: str) -> Row:
+ with self.exec_statement('get-software-ban', {'name': name}) as cur:
+ return cur.one()
+
+
+ def put_software_ban(self,
+ name: str,
+ reason: Optional[str] = None,
+ note: Optional[str] = None) -> Row:
+
+ params = {
+ 'name': name,
+ 'reason': reason,
+ 'note': note,
+ 'created': datetime.now(tz = timezone.utc)
+ }
+
+ with self.exec_statement('put-software-ban', params) as cur:
+ return cur.one()
+
+
+ def update_software_ban(self,
+ name: str,
+ reason: Optional[str] = None,
+ note: Optional[str] = None) -> tinysql.Row:
+
+ if not (reason or note):
+ raise ValueError('"reason" and/or "note" must be specified')
+
+ params = {}
+
+ if reason:
+ params['reason'] = reason
+
+ if note:
+ params['note'] = note
+
+ statement = tinysql.Update('software_bans', params, name = name)
+
+ with self.query(statement) as cur:
+ if cur.modified_row_count > 1:
+ raise ValueError('More than one row was modified')
+
+ return self.get_software_ban(name)
+
+
+ def del_software_ban(self, name: str) -> bool:
+ with self.exec_statement('del-software-ban', {'name': name}) as cur:
+ if cur.modified_row_count > 1:
+ raise ValueError('More than one row was modified')
+
+ return cur.modified_row_count == 1
+
+
+ def get_domain_whitelist(self, domain: str) -> Row:
+ with self.exec_statement('get-domain-whitelist', {'domain': domain}) as cur:
+ return cur.one()
+
+
+ def put_domain_whitelist(self, domain: str) -> Row:
+ params = {
+ 'domain': domain,
+ 'created': datetime.now(tz = timezone.utc)
+ }
+
+ with self.exec_statement('put-domain-whitelist', params) as cur:
+ return cur.one()
+
+
+ def del_domain_whitelist(self, domain: str) -> bool:
+ with self.exec_statement('del-domain-whitelist', {'domain': domain}) as cur:
+ if cur.modified_row_count > 1:
+ raise ValueError('More than one row was modified')
+
+ return cur.modified_row_count == 1
diff --git a/relay/database/schema.py b/relay/database/schema.py
new file mode 100644
index 0000000..15a1fae
--- /dev/null
+++ b/relay/database/schema.py
@@ -0,0 +1,60 @@
+from __future__ import annotations
+
+import typing
+
+from tinysql import Column, Connection, Table
+
+from .config import get_default_value
+
+if typing.TYPE_CHECKING:
+ from typing import Callable
+
+
+VERSIONS: list[Callable] = []
+TABLES: list[Table] = [
+ Table(
+ 'config',
+ Column('key', 'text', primary_key = True, unique = True, nullable = False),
+ Column('value', 'text'),
+ Column('type', 'text', default = 'str')
+ ),
+ Table(
+ 'inboxes',
+ Column('domain', 'text', primary_key = True, unique = True, nullable = False),
+ Column('actor', 'text', unique = True),
+ Column('inbox', 'text', unique = True, nullable = False),
+ Column('followid', 'text'),
+ Column('software', 'text'),
+ Column('created', 'timestamp', nullable = False)
+ ),
+ Table(
+ 'whitelist',
+ Column('domain', 'text', primary_key = True, unique = True, nullable = True),
+ Column('created', 'timestamp')
+ ),
+ Table(
+ 'instance_bans',
+ Column('domain', 'text', primary_key = True, unique = True, nullable = True),
+ Column('reason', 'text'),
+ Column('note', 'text'),
+ Column('created', 'timestamp', nullable = False)
+ ),
+ Table(
+ 'software_bans',
+ Column('name', 'text', primary_key = True, unique = True, nullable = True),
+ Column('reason', 'text'),
+ Column('note', 'text'),
+ Column('created', 'timestamp', nullable = False)
+ )
+]
+
+
+def version(func: Callable) -> Callable:
+ ver = int(func.replace('migrate_', ''))
+ VERSIONS[ver] = func
+ return func
+
+
+def migrate_0(conn: Connection) -> None:
+ conn.create_tables(TABLES)
+ conn.put_config('schema-version', get_default_value('schema-version'))
diff --git a/relay/http_client.py b/relay/http_client.py
index 6f2a044..52176b7 100644
--- a/relay/http_client.py
+++ b/relay/http_client.py
@@ -13,11 +13,10 @@ from urllib.parse import urlparse
from . import __version__
from . import logger as logging
-from .misc import MIMETYPES, Message
+from .misc import MIMETYPES, Message, get_app
if typing.TYPE_CHECKING:
from typing import Any, Callable, Optional
- from .database import RelayDatabase
HEADERS = {
@@ -28,12 +27,10 @@ HEADERS = {
class HttpClient:
def __init__(self,
- database: RelayDatabase,
limit: Optional[int] = 100,
timeout: Optional[int] = 10,
cache_size: Optional[int] = 1024):
- self.database = database
self.cache = LRUCache(cache_size)
self.limit = limit
self.timeout = timeout
@@ -98,7 +95,7 @@ class HttpClient:
headers = {}
if sign_headers:
- headers.update(self.database.signer.sign_headers('GET', url, algorithm='original'))
+ get_app().signer.sign_headers('GET', url, algorithm = 'original')
try:
logging.debug('Fetching resource: %s', url)
@@ -150,23 +147,24 @@ class HttpClient:
async def post(self, url: str, message: Message) -> None:
await self.open()
- instance = self.database.get_inbox(url)
+ with get_app().database.connection() as conn:
+ instance = conn.get_inbox(url)
## Using the old algo by default is probably a better idea right now
- if instance and instance.get('software') in {'mastodon'}:
+ if instance and instance['software'] in {'mastodon'}:
algorithm = 'hs2019'
else:
algorithm = 'original'
headers = {'Content-Type': 'application/activity+json'}
- headers.update(self.database.signer.sign_headers('POST', url, message, algorithm=algorithm))
+ headers.update(get_app().signer.sign_headers('POST', url, message, algorithm=algorithm))
try:
logging.verbose('Sending "%s" to %s', message.type, url)
async with self._session.post(url, headers=headers, data=message.to_json()) as resp:
- ## Not expecting a response, so just return
+ # Not expecting a response, so just return
if resp.status in {200, 202}:
logging.verbose('Successfully sent "%s" to %s', message.type, url)
return
@@ -181,7 +179,7 @@ class HttpClient:
except (AsyncTimeoutError, ClientConnectionError):
logging.warning('Failed to connect to %s for message push', urlparse(url).netloc)
- ## prevent workers from being brought down
+ # prevent workers from being brought down
except Exception:
traceback.print_exc()
@@ -211,16 +209,16 @@ class HttpClient:
return await self.get(nodeinfo_url, loads = Nodeinfo.parse) or None
-async def get(database: RelayDatabase, *args: Any, **kwargs: Any) -> Message | dict | None:
- async with HttpClient(database) as client:
+async def get(*args: Any, **kwargs: Any) -> Message | dict | None:
+ async with HttpClient() as client:
return await client.get(*args, **kwargs)
-async def post(database: RelayDatabase, *args: Any, **kwargs: Any) -> None:
- async with HttpClient(database) as client:
+async def post(*args: Any, **kwargs: Any) -> None:
+ async with HttpClient() as client:
return await client.post(*args, **kwargs)
-async def fetch_nodeinfo(database: RelayDatabase, *args: Any, **kwargs: Any) -> Nodeinfo | None:
- async with HttpClient(database) as client:
+async def fetch_nodeinfo(*args: Any, **kwargs: Any) -> Nodeinfo | None:
+ async with HttpClient() as client:
return await client.fetch_nodeinfo(*args, **kwargs)
diff --git a/relay/logger.py b/relay/logger.py
index 0d1d451..e822cb4 100644
--- a/relay/logger.py
+++ b/relay/logger.py
@@ -4,20 +4,62 @@ import logging
import os
import typing
+from enum import IntEnum
from pathlib import Path
if typing.TYPE_CHECKING:
- from typing import Any, Callable
+ from typing import Any, Callable, Type
-LOG_LEVELS: dict[str, int] = {
- 'DEBUG': logging.DEBUG,
- 'VERBOSE': 15,
- 'INFO': logging.INFO,
- 'WARNING': logging.WARNING,
- 'ERROR': logging.ERROR,
- 'CRITICAL': logging.CRITICAL
-}
+class LogLevel(IntEnum):
+ DEBUG = logging.DEBUG
+ VERBOSE = 15
+ INFO = logging.INFO
+ WARNING = logging.WARNING
+ ERROR = logging.ERROR
+ CRITICAL = logging.CRITICAL
+
+
+ def __str__(self) -> str:
+ return self.name
+
+
+ @classmethod
+ def parse(cls: Type[IntEnum], data: object) -> IntEnum:
+ if isinstance(data, cls):
+ return data
+
+ if isinstance(data, str):
+ data = data.upper()
+
+ try:
+ return cls[data]
+
+ except KeyError:
+ pass
+
+ try:
+ return cls(data)
+
+ except ValueError:
+ pass
+
+ raise AttributeError(f'Invalid enum property for {cls.__name__}: {data}')
+
+
+def get_level() -> LogLevel:
+ return LogLevel.parse(logging.root.level)
+
+
+def set_level(level: LogLevel | str) -> None:
+ logging.root.setLevel(LogLevel.parse(level))
+
+
+def verbose(message: str, *args: Any, **kwargs: Any) -> None:
+ if not logging.root.isEnabledFor(LogLevel['VERBOSE']):
+ return
+
+ logging.log(LogLevel['VERBOSE'], message, *args, **kwargs)
debug: Callable = logging.debug
@@ -27,14 +69,7 @@ error: Callable = logging.error
critical: Callable = logging.critical
-def verbose(message: str, *args: Any, **kwargs: Any) -> None:
- if not logging.root.isEnabledFor(LOG_LEVELS['VERBOSE']):
- return
-
- logging.log(LOG_LEVELS['VERBOSE'], message, *args, **kwargs)
-
-
-logging.addLevelName(LOG_LEVELS['VERBOSE'], 'VERBOSE')
+logging.addLevelName(LogLevel['VERBOSE'], 'VERBOSE')
env_log_level = os.environ.get('LOG_LEVEL', 'INFO').upper()
try:
@@ -45,11 +80,11 @@ except KeyError:
try:
- log_level = LOG_LEVELS[env_log_level]
+ log_level = LogLevel[env_log_level]
except KeyError:
- logging.warning('Invalid log level: %s', env_log_level)
- log_level = logging.INFO
+ print('Invalid log level:', env_log_level)
+ log_level = LogLevel['INFO']
handlers = [logging.StreamHandler()]
diff --git a/relay/manage.py b/relay/manage.py
index b0c5cb3..c04235f 100644
--- a/relay/manage.py
+++ b/relay/manage.py
@@ -6,22 +6,49 @@ import click
import platform
import typing
+from aputils.signer import Signer
+from pathlib import Path
+from shutil import copyfile
from urllib.parse import urlparse
-from . import misc, __version__
+from . import __version__
from . import http_client as http
+from . import logger as logging
from .application import Application
-from .config import RELAY_SOFTWARE
+from .compat import RelayConfig, RelayDatabase
+from .database import get_database
+from .database.connection import RELAY_SOFTWARE
+from .misc import IS_DOCKER, Message, check_open_port
if typing.TYPE_CHECKING:
- from typing import Any
+ from tinysql import Row
+ from typing import Any, Optional
# pylint: disable=unsubscriptable-object,unsupported-assignment-operation
-app = None
-CONFIG_IGNORE = {'blocked_software', 'blocked_instances', 'whitelist'}
+CONFIG_IGNORE = (
+ 'schema-version',
+ 'private-key'
+)
+
+ACTOR_FORMATS = {
+ 'mastodon': 'https://{domain}/actor',
+ 'akkoma': 'https://{domain}/relay',
+ 'pleroma': 'https://{domain}/relay'
+}
+
+SOFTWARE = (
+ 'mastodon',
+ 'akkoma',
+ 'pleroma',
+ 'misskey',
+ 'friendica',
+ 'hubzilla',
+ 'firefish',
+ 'gotosocial'
+)
@click.group('cli', context_settings={'show_default': True}, invoke_without_command=True)
@@ -29,11 +56,10 @@ CONFIG_IGNORE = {'blocked_software', 'blocked_instances', 'whitelist'}
@click.version_option(version=__version__, prog_name='ActivityRelay')
@click.pass_context
def cli(ctx: click.Context, config: str) -> None:
- global app
- app = Application(config)
+ ctx.obj = Application(config)
if not ctx.invoked_subcommand:
- if app.config.host.endswith('example.com'):
+ if ctx.obj.config.domain.endswith('example.com'):
cli_setup.callback()
else:
@@ -41,46 +67,92 @@ def cli(ctx: click.Context, config: str) -> None:
@cli.command('setup')
-def cli_setup() -> None:
+@click.pass_context
+def cli_setup(ctx: click.Context) -> None:
'Generate a new config'
while True:
- app.config.host = click.prompt(
+ ctx.obj.config.domain = click.prompt(
'What domain will the relay be hosted on?',
- default = app.config.host
+ default = ctx.obj.config.domain
)
- if not app.config.host.endswith('example.com'):
+ if not ctx.obj.config.domain.endswith('example.com'):
break
- click.echo('The domain must not be example.com')
+ click.echo('The domain must not end with "example.com"')
- if not app.config.is_docker:
- app.config.listen = click.prompt(
+ if not IS_DOCKER:
+ ctx.obj.config.listen = click.prompt(
'Which address should the relay listen on?',
- default = app.config.listen
+ default = ctx.obj.config.listen
)
- while True:
- app.config.port = click.prompt(
- 'What TCP port should the relay listen on?',
- default = app.config.port,
- type = int
- )
+ ctx.obj.config.port = click.prompt(
+ 'What TCP port should the relay listen on?',
+ default = ctx.obj.config.port,
+ type = int
+ )
- break
+ ctx.obj.config.db_type = click.prompt(
+ 'Which database backend will be used?',
+ default = ctx.obj.config.db_type,
+ type = click.Choice(['postgres', 'sqlite'], case_sensitive = False)
+ )
- app.config.save()
+ if ctx.obj.config.db_type == 'sqlite':
+ ctx.obj.config.sq_path = click.prompt(
+ 'Where should the database be stored?',
+ default = ctx.obj.config.sq_path
+ )
- if not app.config.is_docker and click.confirm('Relay all setup! Would you like to run it now?'):
+ elif ctx.obj.config.db_type == 'postgres':
+ ctx.obj.config.pg_name = click.prompt(
+ 'What is the name of the database?',
+ default = ctx.obj.config.pg_name
+ )
+
+ ctx.obj.config.pg_host = click.prompt(
+ 'What IP address or hostname does the server listen on?',
+ default = ctx.obj.config.pg_host
+ )
+
+ ctx.obj.config.pg_port = click.prompt(
+ 'What port does the server listen on?',
+ default = ctx.obj.config.pg_port,
+ type = int
+ )
+
+ ctx.obj.config.pg_user = click.prompt(
+ 'Which user will authenticate with the server?',
+ default = ctx.obj.config.pg_user
+ )
+
+ ctx.obj.config.pg_pass = click.prompt(
+ 'User password: ',
+ hide_input = True
+ ) or None
+
+ ctx.obj.config.save()
+
+ config = {
+ 'private-key': Signer.new('n/a').export()
+ }
+
+ with ctx.obj.database.connection() as conn:
+ for key, value in config.items():
+ conn.put_config(key, value)
+
+ if not IS_DOCKER and click.confirm('Relay all setup! Would you like to run it now?'):
cli_run.callback()
@cli.command('run')
-def cli_run() -> None:
+@click.pass_context
+def cli_run(ctx: click.Context) -> None:
'Run the relay'
- if app.config.host.endswith('example.com'):
+ if ctx.obj.config.domain.endswith('example.com') or not ctx.obj.signer:
click.echo(
'Relay is not set up. Please edit your relay config or run "activityrelay setup".'
)
@@ -104,40 +176,142 @@ def cli_run() -> None:
click.echo(pip_command)
return
- if not misc.check_open_port(app.config.listen, app.config.port):
- click.echo(f'Error: A server is already running on port {app.config.port}')
+ if not check_open_port(ctx.obj.config.listen, ctx.obj.config.port):
+ click.echo(f'Error: A server is already running on port {ctx.obj.config.port}')
return
- app.run()
+ ctx.obj.run()
+
+
+@cli.command('convert')
+@click.option('--old-config', '-o', help = 'Path to the new config file')
+@click.pass_context
+def cli_convert(ctx: click.Context, old_config: str) -> None:
+ 'Convert an old config and jsonld database to the new format.'
+
+ old_config = Path(old_config).expanduser().resolve()
+ backup = ctx.obj.config.path.parent.joinpath(f'{ctx.obj.config.path.stem}.backup.yaml')
+
+ if str(old_config) == str(ctx.obj.config.path) and not backup.exists():
+ logging.info('Created backup config @ %s', backup)
+ copyfile(ctx.obj.config.path, backup)
+
+ config = RelayConfig(old_config)
+ config.load()
+
+ database = RelayDatabase(config)
+ database.load()
+
+ ctx.obj.config.set('listen', config['listen'])
+ ctx.obj.config.set('port', config['port'])
+ ctx.obj.config.set('workers', config['workers'])
+ ctx.obj.config.set('sq_path', config['db'].replace('jsonld', 'sqlite3'))
+
+ with get_database(ctx.obj.config) as db:
+ with db.connection() as conn:
+ conn.put_config('private-key', database['private-key'])
+ conn.put_config('note', config['note'])
+ conn.put_config('whitelist-enabled', config['whitelist_enabled'])
+
+ with click.progressbar(
+ database['relay-list'].values(),
+ label = 'Inboxes'.ljust(15),
+ width = 0
+ ) as inboxes:
+
+ for inbox in inboxes:
+ if inbox['software'] in ('akkoma', 'pleroma'):
+ actor = f'https://{inbox["domain"]}/relay'
+
+ elif inbox['software'] == 'mastodon':
+ actor = f'https://{inbox["domain"]}/actor'
+
+ else:
+ actor = None
+
+ conn.put_inbox(
+ inbox['domain'],
+ inbox['inbox'],
+ actor = actor,
+ followid = inbox['followid'],
+ software = inbox['software']
+ )
+
+ with click.progressbar(
+ config['blocked_software'],
+ label = 'Banned software'.ljust(15),
+ width = 0
+ ) as banned_software:
+
+ for software in banned_software:
+ conn.put_software_ban(
+ software,
+ reason = 'relay' if software in RELAY_SOFTWARE else None
+ )
+
+ with click.progressbar(
+ config['blocked_instances'],
+ label = 'Banned domains'.ljust(15),
+ width = 0
+ ) as banned_software:
+
+ for domain in banned_software:
+ conn.put_domain_ban(domain)
+
+ with click.progressbar(
+ config['whitelist'],
+ label = 'Whitelist'.ljust(15),
+ width = 0
+ ) as whitelist:
+
+ for instance in whitelist:
+ conn.put_domain_whitelist(instance)
+
+ click.echo('Finished converting old config and database :3')
+
+
+@cli.command('edit-config')
+@click.option('--editor', '-e', help = 'Text editor to use')
+@click.pass_context
+def cli_editconfig(ctx: click.Context, editor: str) -> None:
+ 'Edit the config file'
+
+ click.edit(
+ editor = editor,
+ filename = str(ctx.obj.config.path)
+ )
@cli.group('config')
def cli_config() -> None:
- 'Manage the relay config'
+ 'Manage the relay settings stored in the database'
@cli_config.command('list')
-def cli_config_list() -> None:
+@click.pass_context
+def cli_config_list(ctx: click.Context) -> None:
'List the current relay config'
click.echo('Relay Config:')
- for key, value in app.config.items():
- if key not in CONFIG_IGNORE:
- key = f'{key}:'.ljust(20)
- click.echo(f'- {key} {value}')
+ with ctx.obj.database.connection() as conn:
+ for key, value in conn.get_config_all().items():
+ if key not in CONFIG_IGNORE:
+ key = f'{key}:'.ljust(20)
+ click.echo(f'- {key} {value}')
@cli_config.command('set')
@click.argument('key')
@click.argument('value')
-def cli_config_set(key: str, value: Any) -> None:
+@click.pass_context
+def cli_config_set(ctx: click.Context, key: str, value: Any) -> None:
'Set a config value'
- app.config[key] = value
- app.config.save()
+ with ctx.obj.database.connection() as conn:
+ new_value = conn.put_config(key, value)
- print(f'{key}: {app.config[key]}')
+ print(f'{key}: {repr(new_value)}')
@cli.group('inbox')
@@ -146,127 +320,145 @@ def cli_inbox() -> None:
@cli_inbox.command('list')
-def cli_inbox_list() -> None:
+@click.pass_context
+def cli_inbox_list(ctx: click.Context) -> None:
'List the connected instances or relays'
click.echo('Connected to the following instances or relays:')
- for inbox in app.database.inboxes:
- click.echo(f'- {inbox}')
+ with ctx.obj.database.connection() as conn:
+ for inbox in conn.execute('SELECT * FROM inboxes'):
+ click.echo(f'- {inbox["inbox"]}')
@cli_inbox.command('follow')
@click.argument('actor')
-def cli_inbox_follow(actor: str) -> None:
+@click.pass_context
+def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
'Follow an actor (Relay must be running)'
- if app.config.is_banned(actor):
- click.echo(f'Error: Refusing to follow banned actor: {actor}')
- return
-
- if not actor.startswith('http'):
- domain = actor
- actor = f'https://{actor}/actor'
-
- else:
- domain = urlparse(actor).hostname
-
- try:
- inbox_data = app.database['relay-list'][domain]
- inbox = inbox_data['inbox']
-
- except KeyError:
- actor_data = asyncio.run(http.get(app.database, actor, sign_headers=True))
-
- if not actor_data:
- click.echo(f'Failed to fetch actor: {actor}')
+ with ctx.obj.database.connection() as conn:
+ if conn.get_domain_ban(actor):
+ click.echo(f'Error: Refusing to follow banned actor: {actor}')
return
- inbox = actor_data.shared_inbox
+ if (inbox_data := conn.get_inbox(actor)):
+ inbox = inbox_data['inbox']
- message = misc.Message.new_follow(
- host = app.config.host,
+ else:
+ if not actor.startswith('http'):
+ actor = f'https://{actor}/actor'
+
+ actor_data = asyncio.run(http.get(actor, sign_headers = True))
+
+ if not actor_data:
+ click.echo(f'Failed to fetch actor: {actor}')
+ return
+
+ inbox = actor_data.shared_inbox
+
+ message = Message.new_follow(
+ host = ctx.obj.config.domain,
actor = actor
)
- asyncio.run(http.post(app.database, inbox, message))
+ asyncio.run(http.post(inbox, message))
click.echo(f'Sent follow message to actor: {actor}')
@cli_inbox.command('unfollow')
@click.argument('actor')
-def cli_inbox_unfollow(actor: str) -> None:
+@click.pass_context
+def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
'Unfollow an actor (Relay must be running)'
- if not actor.startswith('http'):
- domain = actor
- actor = f'https://{actor}/actor'
+ inbox_data: Row = None
- else:
- domain = urlparse(actor).hostname
+ with ctx.obj.database.connection() as conn:
+ if conn.get_domain_ban(actor):
+ click.echo(f'Error: Refusing to follow banned actor: {actor}')
+ return
- try:
- inbox_data = app.database['relay-list'][domain]
- inbox = inbox_data['inbox']
- message = misc.Message.new_unfollow(
- host = app.config.host,
- actor = actor,
- follow = inbox_data['followid']
- )
+ if (inbox_data := conn.get_inbox(actor)):
+ inbox = inbox_data['inbox']
+ message = Message.new_unfollow(
+ host = ctx.obj.config.domain,
+ actor = actor,
+ follow = inbox_data['followid']
+ )
- except KeyError:
- actor_data = asyncio.run(http.get(app.database, actor, sign_headers=True))
- inbox = actor_data.shared_inbox
- message = misc.Message.new_unfollow(
- host = app.config.host,
- actor = actor,
- follow = {
- 'type': 'Follow',
- 'object': actor,
- 'actor': f'https://{app.config.host}/actor'
- }
- )
+ else:
+ if not actor.startswith('http'):
+ actor = f'https://{actor}/actor'
- asyncio.run(http.post(app.database, inbox, message))
+ actor_data = asyncio.run(http.get(actor, sign_headers = True))
+ inbox = actor_data.shared_inbox
+ message = Message.new_unfollow(
+ host = ctx.obj.config.domain,
+ actor = actor,
+ follow = {
+ 'type': 'Follow',
+ 'object': actor,
+ 'actor': f'https://{ctx.obj.config.domain}/actor'
+ }
+ )
+
+ asyncio.run(http.post(inbox, message))
click.echo(f'Sent unfollow message to: {actor}')
@cli_inbox.command('add')
@click.argument('inbox')
-def cli_inbox_add(inbox: str) -> None:
+@click.option('--actor', '-a', help = 'Actor url for the inbox')
+@click.option('--followid', '-f', help = 'Url for the follow activity')
+@click.option('--software', '-s', type = click.Choice(SOFTWARE))
+@click.pass_context
+def cli_inbox_add(
+ ctx: click.Context,
+ inbox: str,
+ actor: Optional[str] = None,
+ followid: Optional[str] = None,
+ software: Optional[str] = None) -> None:
'Add an inbox to the database'
if not inbox.startswith('http'):
+ domain = inbox
inbox = f'https://{inbox}/inbox'
- if app.config.is_banned(inbox):
- click.echo(f'Error: Refusing to add banned inbox: {inbox}')
- return
+ else:
+ domain = urlparse(inbox).netloc
- if app.database.get_inbox(inbox):
- click.echo(f'Error: Inbox already in database: {inbox}')
- return
+ if not actor and software:
+ try:
+ actor = ACTOR_FORMATS[software].format(domain = domain)
- app.database.add_inbox(inbox)
- app.database.save()
+ except KeyError:
+ pass
+
+ with ctx.obj.database.connection() as conn:
+ if conn.get_domain_ban(domain):
+ click.echo(f'Refusing to add banned inbox: {inbox}')
+ return
+
+ if conn.get_inbox(inbox):
+ click.echo(f'Error: Inbox already in database: {inbox}')
+ return
+
+ conn.put_inbox(domain, inbox, actor, followid, software)
click.echo(f'Added inbox to the database: {inbox}')
@cli_inbox.command('remove')
@click.argument('inbox')
-def cli_inbox_remove(inbox: str) -> None:
+@click.pass_context
+def cli_inbox_remove(ctx: click.Context, inbox: str) -> None:
'Remove an inbox from the database'
- try:
- dbinbox = app.database.get_inbox(inbox, fail=True)
-
- except KeyError:
- click.echo(f'Error: Inbox does not exist: {inbox}')
- return
-
- app.database.del_inbox(dbinbox['domain'])
- app.database.save()
+ with ctx.obj.database.connection() as conn:
+ if not conn.del_inbox(inbox):
+ click.echo(f'Inbox not in database: {inbox}')
+ return
click.echo(f'Removed inbox from the database: {inbox}')
@@ -277,47 +469,76 @@ def cli_instance() -> None:
@cli_instance.command('list')
-def cli_instance_list() -> None:
+@click.pass_context
+def cli_instance_list(ctx: click.Context) -> None:
'List all banned instances'
- click.echo('Banned instances or relays:')
+ click.echo('Banned domains:')
- for domain in app.config.blocked_instances:
- click.echo(f'- {domain}')
+ with ctx.obj.database.connection() as conn:
+ for instance in conn.execute('SELECT * FROM domain_bans'):
+ if instance['reason']:
+ click.echo(f'- {instance["domain"]} ({instance["reason"]})')
+
+ else:
+ click.echo(f'- {instance["domain"]}')
@cli_instance.command('ban')
-@click.argument('target')
-def cli_instance_ban(target: str) -> None:
+@click.argument('domain')
+@click.option('--reason', '-r', help = 'Public note about why the domain is banned')
+@click.option('--note', '-n', help = 'Internal note that will only be seen by admins and mods')
+@click.pass_context
+def cli_instance_ban(ctx: click.Context, domain: str, reason: str, note: str) -> None:
'Ban an instance and remove the associated inbox if it exists'
- if target.startswith('http'):
- target = urlparse(target).hostname
+ with ctx.obj.database.connection() as conn:
+ if conn.get_domain_ban(domain):
+ click.echo(f'Domain already banned: {domain}')
+ return
- if app.config.ban_instance(target):
- app.config.save()
-
- if app.database.del_inbox(target):
- app.database.save()
-
- click.echo(f'Banned instance: {target}')
- return
-
- click.echo(f'Instance already banned: {target}')
+ conn.put_domain_ban(domain, reason, note)
+ conn.del_inbox(domain)
+ click.echo(f'Banned instance: {domain}')
@cli_instance.command('unban')
-@click.argument('target')
-def cli_instance_unban(target: str) -> None:
+@click.argument('domain')
+@click.pass_context
+def cli_instance_unban(ctx: click.Context, domain: str) -> None:
'Unban an instance'
- if app.config.unban_instance(target):
- app.config.save()
+ with ctx.obj.database.connection() as conn:
+ if not conn.del_domain_ban(domain):
+ click.echo(f'Instance wasn\'t banned: {domain}')
+ return
- click.echo(f'Unbanned instance: {target}')
- return
+ click.echo(f'Unbanned instance: {domain}')
- click.echo(f'Instance wasn\'t banned: {target}')
+
+@cli_instance.command('update')
+@click.argument('domain')
+@click.option('--reason', '-r')
+@click.option('--note', '-n')
+@click.pass_context
+def cli_instance_update(ctx: click.Context, domain: str, reason: str, note: str) -> None:
+ 'Update the public reason or internal note for a domain ban'
+
+ if not (reason or note):
+ ctx.fail('Must pass --reason or --note')
+
+ with ctx.obj.database.connection() as conn:
+ if not (row := conn.update_domain_ban(domain, reason, note)):
+ click.echo(f'Failed to update domain ban: {domain}')
+ return
+
+ click.echo(f'Updated domain ban: {domain}')
+
+ if row['reason']:
+ click.echo(f'- {row["domain"]} ({row["reason"]})')
+
+ else:
+ click.echo(f'- {row["domain"]}')
@cli.group('software')
@@ -326,79 +547,131 @@ def cli_software() -> None:
@cli_software.command('list')
-def cli_software_list() -> None:
+@click.pass_context
+def cli_software_list(ctx: click.Context) -> None:
'List all banned software'
click.echo('Banned software:')
- for software in app.config.blocked_software:
- click.echo(f'- {software}')
+ with ctx.obj.database.connection() as conn:
+ for software in conn.execute('SELECT * FROM software_bans'):
+ if software['reason']:
+ click.echo(f'- {software["name"]} ({software["reason"]})')
+
+ else:
+ click.echo(f'- {software["name"]}')
@cli_software.command('ban')
-@click.option(
- '--fetch-nodeinfo/--ignore-nodeinfo', '-f', 'fetch_nodeinfo', default = False,
- help = 'Treat NAME like a domain and try to fet the software name from nodeinfo'
-)
@click.argument('name')
-def cli_software_ban(name: str, fetch_nodeinfo: bool) -> None:
+@click.option('--reason', '-r')
+@click.option('--note', '-n')
+@click.option(
+ '--fetch-nodeinfo', '-f',
+ is_flag = True,
+ help = 'Treat NAME like a domain and try to fetch the software name from nodeinfo'
+)
+@click.pass_context
+def cli_software_ban(ctx: click.Context,
+ name: str,
+ reason: str,
+ note: str,
+ fetch_nodeinfo: bool) -> None:
'Ban software. Use RELAYS for NAME to ban relays'
- if name == 'RELAYS':
- for software in RELAY_SOFTWARE:
- app.config.ban_software(software)
+ with ctx.obj.database.connection() as conn:
+ if name == 'RELAYS':
+ for software in RELAY_SOFTWARE:
+ if conn.get_software_ban(software):
+ click.echo(f'Relay already banned: {software}')
+ continue
- app.config.save()
- click.echo('Banned all relay software')
- return
+ conn.put_software_ban(software, reason or 'relay', note)
- if fetch_nodeinfo:
- nodeinfo = asyncio.run(http.fetch_nodeinfo(app.database, name))
+ click.echo('Banned all relay software')
+ return
- if not nodeinfo:
- click.echo(f'Failed to fetch software name from domain: {name}')
+ if fetch_nodeinfo:
+ nodeinfo = asyncio.run(http.fetch_nodeinfo(name))
- name = nodeinfo.sw_name
+ if not nodeinfo:
+ click.echo(f'Failed to fetch software name from domain: {name}')
+ return
+
+ name = nodeinfo.sw_name
+
+ if conn.get_software_ban(name):
+ click.echo(f'Software already banned: {name}')
+ return
+
+ if not conn.put_software_ban(name, reason, note):
+ click.echo(f'Failed to ban software: {name}')
+ return
- if app.config.ban_software(name):
- app.config.save()
click.echo(f'Banned software: {name}')
- return
-
- click.echo(f'Software already banned: {name}')
@cli_software.command('unban')
-@click.option(
- '--fetch-nodeinfo/--ignore-nodeinfo', '-f', 'fetch_nodeinfo', default = False,
- help = 'Treat NAME like a domain and try to fet the software name from nodeinfo'
-)
@click.argument('name')
-def cli_software_unban(name: str, fetch_nodeinfo: bool) -> None:
+@click.option('--reason', '-r')
+@click.option('--note', '-n')
+@click.option(
+ '--fetch-nodeinfo', '-f',
+ is_flag = True,
+ help = 'Treat NAME like a domain and try to fetch the software name from nodeinfo'
+)
+@click.pass_context
+def cli_software_unban(ctx: click.Context, name: str, fetch_nodeinfo: bool) -> None:
'Ban software. Use RELAYS for NAME to unban relays'
- if name == 'RELAYS':
- for software in RELAY_SOFTWARE:
- app.config.unban_software(software)
+ with ctx.obj.database.connection() as conn:
+ if name == 'RELAYS':
+ for software in RELAY_SOFTWARE:
+ if not conn.del_software_ban(software):
+ click.echo(f'Relay was not banned: {software}')
- app.config.save()
- click.echo('Unbanned all relay software')
- return
+ click.echo('Unbanned all relay software')
+ return
- if fetch_nodeinfo:
- nodeinfo = asyncio.run(http.fetch_nodeinfo(app.database, name))
+ if fetch_nodeinfo:
+ nodeinfo = asyncio.run(http.fetch_nodeinfo(name))
- if not nodeinfo:
- click.echo(f'Failed to fetch software name from domain: {name}')
+ if not nodeinfo:
+ click.echo(f'Failed to fetch software name from domain: {name}')
+ return
- name = nodeinfo.sw_name
+ name = nodeinfo.sw_name
+
+ if not conn.del_software_ban(name):
+ click.echo(f'Software was not banned: {name}')
+ return
- if app.config.unban_software(name):
- app.config.save()
click.echo(f'Unbanned software: {name}')
- return
- click.echo(f'Software wasn\'t banned: {name}')
+
+@cli_software.command('update')
+@click.argument('name')
+@click.option('--reason', '-r')
+@click.option('--note', '-n')
+@click.pass_context
+def cli_software_update(ctx: click.Context, name: str, reason: str, note: str) -> None:
+ 'Update the public reason or internal note for a software ban'
+
+ if not (reason or note):
+ ctx.fail('Must pass --reason or --note')
+
+ with ctx.obj.database.connection() as conn:
+ if not (row := conn.update_software_ban(name, reason, note)):
+ click.echo(f'Failed to update software ban: {name}')
+ return
+
+ click.echo(f'Updated software ban: {name}')
+
+ if row['reason']:
+ click.echo(f'- {row["name"]} ({row["reason"]})')
+
+ else:
+ click.echo(f'- {row["name"]}')
@cli.group('whitelist')
@@ -407,52 +680,64 @@ def cli_whitelist() -> None:
@cli_whitelist.command('list')
-def cli_whitelist_list() -> None:
+@click.pass_context
+def cli_whitelist_list(ctx: click.Context) -> None:
'List all the instances in the whitelist'
- click.echo('Current whitelisted domains')
+ click.echo('Current whitelisted domains:')
- for domain in app.config.whitelist:
- click.echo(f'- {domain}')
+ with ctx.obj.database.connection() as conn:
+ for domain in conn.execute('SELECT * FROM whitelist'):
+ click.echo(f'- {domain["domain"]}')
@cli_whitelist.command('add')
-@click.argument('instance')
-def cli_whitelist_add(instance: str) -> None:
- 'Add an instance to the whitelist'
+@click.argument('domain')
+@click.pass_context
+def cli_whitelist_add(ctx: click.Context, domain: str) -> None:
+ 'Add a domain to the whitelist'
- if not app.config.add_whitelist(instance):
- click.echo(f'Instance already in the whitelist: {instance}')
- return
+ with ctx.obj.database.connection() as conn:
+ if conn.get_domain_whitelist(domain):
+ click.echo(f'Instance already in the whitelist: {domain}')
+ return
- app.config.save()
- click.echo(f'Instance added to the whitelist: {instance}')
+ conn.put_domain_whitelist(domain)
+ click.echo(f'Instance added to the whitelist: {domain}')
@cli_whitelist.command('remove')
-@click.argument('instance')
-def cli_whitelist_remove(instance: str) -> None:
+@click.argument('domain')
+@click.pass_context
+def cli_whitelist_remove(ctx: click.Context, domain: str) -> None:
'Remove an instance from the whitelist'
- if not app.config.del_whitelist(instance):
- click.echo(f'Instance not in the whitelist: {instance}')
- return
+ with ctx.obj.database.connection() as conn:
+ if not conn.del_domain_whitelist(domain):
+ click.echo(f'Domain not in the whitelist: {domain}')
+ return
- app.config.save()
+ if conn.get_config('whitelist-enabled'):
+ if conn.del_inbox(domain):
+ click.echo(f'Removed inbox for domain: {domain}')
- if app.config.whitelist_enabled:
- if app.database.del_inbox(instance):
- app.database.save()
-
- click.echo(f'Removed instance from the whitelist: {instance}')
+ click.echo(f'Removed domain from the whitelist: {domain}')
@cli_whitelist.command('import')
-def cli_whitelist_import() -> None:
+@click.pass_context
+def cli_whitelist_import(ctx: click.Context) -> None:
'Add all current inboxes to the whitelist'
- for domain in app.database.hostnames:
- cli_whitelist_add.callback(domain)
+ with ctx.obj.database.connection() as conn:
+ for inbox in conn.execute('SELECT * FROM inboxes').all():
+ if conn.get_domain_whitelist(inbox['domain']):
+ click.echo(f'Domain already in whitelist: {inbox["domain"]}')
+ continue
+
+ conn.put_domain_whitelist(inbox['domain'])
+
+ click.echo('Imported whitelist from inboxes')
def main() -> None:
diff --git a/relay/misc.py b/relay/misc.py
index 7244eaa..2d7117d 100644
--- a/relay/misc.py
+++ b/relay/misc.py
@@ -1,32 +1,27 @@
from __future__ import annotations
import json
+import os
import socket
-import traceback
import typing
from aiohttp.abc import AbstractView
from aiohttp.hdrs import METH_ALL as METHODS
-from aiohttp.web import Request as AiohttpRequest, Response as AiohttpResponse
+from aiohttp.web import Response as AiohttpResponse
from aiohttp.web_exceptions import HTTPMethodNotAllowed
-from aputils.errors import SignatureFailureError
-from aputils.misc import Digest, HttpDate, Signature
from aputils.message import Message as ApMessage
from functools import cached_property
-from json.decoder import JSONDecodeError
from uuid import uuid4
-from . import logger as logging
-
if typing.TYPE_CHECKING:
from typing import Any, Coroutine, Generator, Optional, Type
- from aputils.signer import Signer
from .application import Application
- from .config import RelayConfig
- from .database import RelayDatabase
+ from .config import Config
+ from .database import Database
from .http_client import HttpClient
+IS_DOCKER = bool(os.environ.get('DOCKER_RUNNING'))
MIMETYPES = {
'activity': 'application/activity+json',
'html': 'text/html',
@@ -77,91 +72,13 @@ def check_open_port(host: str, port: int) -> bool:
return False
-class DotDict(dict):
- def __init__(self, _data: dict[str, Any], **kwargs: Any):
- dict.__init__(self)
+def get_app() -> Application:
+ from .application import Application # pylint: disable=import-outside-toplevel
- self.update(_data, **kwargs)
+ if not Application.DEFAULT:
+ raise ValueError('No default application set')
-
- def __getattr__(self, key: str) -> str:
- try:
- return self[key]
-
- except KeyError:
- raise AttributeError(
- f'{self.__class__.__name__} object has no attribute {key}'
- ) from None
-
-
- def __setattr__(self, key: str, value: Any) -> None:
- if key.startswith('_'):
- super().__setattr__(key, value)
-
- else:
- self[key] = value
-
-
- def __setitem__(self, key: str, value: Any) -> None:
- if type(value) is dict: # pylint: disable=unidiomatic-typecheck
- value = DotDict(value)
-
- super().__setitem__(key, value)
-
-
- def __delattr__(self, key: str) -> None:
- try:
- dict.__delitem__(self, key)
-
- except KeyError:
- raise AttributeError(
- f'{self.__class__.__name__} object has no attribute {key}'
- ) from None
-
-
- @classmethod
- def new_from_json(cls: Type[DotDict], data: dict[str, Any]) -> DotDict[str, Any]:
- if not data:
- raise JSONDecodeError('Empty body', data, 1)
-
- try:
- return cls(json.loads(data))
-
- except ValueError:
- raise JSONDecodeError('Invalid body', data, 1) from None
-
-
- @classmethod
- def new_from_signature(cls: Type[DotDict], sig: str) -> DotDict[str, Any]:
- data = cls({})
-
- for chunk in sig.strip().split(','):
- key, value = chunk.split('=', 1)
- value = value.strip('\"')
-
- if key == 'headers':
- value = value.split()
-
- data[key.lower()] = value
-
- return data
-
-
- def to_json(self, indent: Optional[int | str] = None) -> str:
- return json.dumps(self, indent=indent)
-
-
- def update(self, _data: dict[str, Any], **kwargs: Any) -> None:
- if isinstance(_data, dict):
- for key, value in _data.items():
- self[key] = value
-
- elif isinstance(_data, (list, tuple, set)):
- for key, value in _data:
- self[key] = value
-
- for key, value in kwargs.items():
- self[key] = value
+ return Application.DEFAULT
class Message(ApMessage):
@@ -181,7 +98,7 @@ class Message(ApMessage):
'followers': f'https://{host}/followers',
'following': f'https://{host}/following',
'inbox': f'https://{host}/inbox',
- 'url': f'https://{host}/inbox',
+ 'url': f'https://{host}/',
'endpoints': {
'sharedInbox': f'https://{host}/inbox'
},
@@ -310,16 +227,6 @@ class Response(AiohttpResponse):
class View(AbstractView):
- def __init__(self, request: AiohttpRequest):
- AbstractView.__init__(self, request)
-
- self.signature: Signature = None
- self.message: Message = None
- self.actor: Message = None
- self.instance: dict[str, str] = None
- self.signer: Signer = None
-
-
def __await__(self) -> Generator[Response]:
method = self.request.method.upper()
@@ -363,94 +270,10 @@ class View(AbstractView):
@property
- def config(self) -> RelayConfig:
+ def config(self) -> Config:
return self.app.config
@property
- def database(self) -> RelayDatabase:
+ def database(self) -> Database:
return self.app.database
-
-
- # todo: move to views.ActorView
- async def get_post_data(self) -> Response | None:
- try:
- self.signature = Signature.new_from_signature(self.request.headers['signature'])
-
- except KeyError:
- logging.verbose('Missing signature header')
- return Response.new_error(400, 'missing signature header', 'json')
-
- try:
- self.message = await self.request.json(loads = Message.parse)
-
- except Exception:
- traceback.print_exc()
- logging.verbose('Failed to parse inbox message')
- return Response.new_error(400, 'failed to parse message', 'json')
-
- if self.message is None:
- logging.verbose('empty message')
- return Response.new_error(400, 'missing message', 'json')
-
- if 'actor' not in self.message:
- logging.verbose('actor not in message')
- return Response.new_error(400, 'no actor in message', 'json')
-
- self.actor = await self.client.get(self.signature.keyid, sign_headers = True)
-
- if self.actor is None:
- # ld signatures aren't handled atm, so just ignore it
- if self.message.type == 'Delete':
- logging.verbose('Instance sent a delete which cannot be handled')
- return Response.new(status=202)
-
- logging.verbose(f'Failed to fetch actor: {self.signature.keyid}')
- return Response.new_error(400, 'failed to fetch actor', 'json')
-
- try:
- self.signer = self.actor.signer
-
- except KeyError:
- logging.verbose('Actor missing public key: %s', self.signature.keyid)
- return Response.new_error(400, 'actor missing public key', 'json')
-
- try:
- self.validate_signature(await self.request.read())
-
- except SignatureFailureError as e:
- logging.verbose('signature validation failed for "%s": %s', self.actor.id, e)
- return Response.new_error(401, str(e), 'json')
-
- self.instance = self.database.get_inbox(self.actor.inbox)
-
-
- def validate_signature(self, body: bytes) -> None:
- headers = {key.lower(): value for key, value in self.request.headers.items()}
- headers["(request-target)"] = " ".join([self.request.method.lower(), self.request.path])
-
- if (digest := Digest.new_from_digest(headers.get("digest"))):
- if not body:
- raise SignatureFailureError("Missing body for digest verification")
-
- if not digest.validate(body):
- raise SignatureFailureError("Body digest does not match")
-
- if self.signature.algorithm_type == "hs2019":
- if "(created)" not in self.signature.headers:
- raise SignatureFailureError("'(created)' header not used")
-
- current_timestamp = HttpDate.new_utc().timestamp()
-
- if self.signature.created > current_timestamp:
- raise SignatureFailureError("Creation date after current date")
-
- if current_timestamp > self.signature.expires:
- raise SignatureFailureError("Expiration date before current date")
-
- headers["(created)"] = self.signature.created
- headers["(expires)"] = self.signature.expires
-
- # pylint: disable=protected-access
- if not self.signer._validate_signature(headers, self.signature):
- raise SignatureFailureError("Signature does not match")
diff --git a/relay/processors.py b/relay/processors.py
index b9b32bc..4d85607 100644
--- a/relay/processors.py
+++ b/relay/processors.py
@@ -1,5 +1,6 @@
from __future__ import annotations
+import tinysql
import typing
from cachetools import LRUCache
@@ -8,7 +9,7 @@ from . import logger as logging
from .misc import Message
if typing.TYPE_CHECKING:
- from .misc import View
+ from .views import ActorView
cache = LRUCache(1024)
@@ -16,128 +17,141 @@ cache = LRUCache(1024)
def person_check(actor: str, software: str) -> bool:
# pleroma and akkoma may use Person for the actor type for some reason
+ # akkoma changed this in a 3.6.0
if software in {'akkoma', 'pleroma'} and actor.id == f'https://{actor.domain}/relay':
return False
- ## make sure the actor is an application
+ # make sure the actor is an application
if actor.type != 'Application':
return True
return False
-async def handle_relay(view: View) -> None:
+async def handle_relay(view: ActorView) -> None:
if view.message.object_id in cache:
logging.verbose('already relayed %s', view.message.object_id)
return
- message = Message.new_announce(view.config.host, view.message.object_id)
+ message = Message.new_announce(view.config.domain, view.message.object_id)
cache[view.message.object_id] = message.id
logging.debug('>> relay: %s', message)
- inboxes = view.database.distill_inboxes(view.message)
-
- for inbox in inboxes:
- view.app.push_message(inbox, message)
+ with view.database.connection() as conn:
+ for inbox in conn.distill_inboxes(view.message):
+ view.app.push_message(inbox, message)
-async def handle_forward(view: View) -> None:
+async def handle_forward(view: ActorView) -> None:
if view.message.id in cache:
logging.verbose('already forwarded %s', view.message.id)
return
- message = Message.new_announce(view.config.host, view.message)
+ message = Message.new_announce(view.config.domain, view.message)
cache[view.message.id] = message.id
logging.debug('>> forward: %s', message)
- inboxes = view.database.distill_inboxes(view.message)
-
- for inbox in inboxes:
- view.app.push_message(inbox, message)
+ with view.database.connection() as conn:
+ for inbox in conn.distill_inboxes(view.message):
+ view.app.push_message(inbox, message)
-async def handle_follow(view: View) -> None:
+async def handle_follow(view: ActorView) -> None:
nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain)
software = nodeinfo.sw_name if nodeinfo else None
- ## reject if software used by actor is banned
- if view.config.is_banned_software(software):
+ with view.database.connection() as conn:
+ # reject if software used by actor is banned
+ if view.config.is_banned_software(software):
+ view.app.push_message(
+ view.actor.shared_inbox,
+ Message.new_response(
+ host = view.config.domain,
+ actor = view.actor.id,
+ followid = view.message.id,
+ accept = False
+ )
+ )
+
+ logging.verbose(
+ 'Rejected follow from actor for using specific software: actor=%s, software=%s',
+ view.actor.id,
+ software
+ )
+
+ return
+
+ ## reject if the actor is not an instance actor
+ if person_check(view.actor, software):
+ view.app.push_message(
+ view.actor.shared_inbox,
+ Message.new_response(
+ host = view.config.domain,
+ actor = view.actor.id,
+ followid = view.message.id,
+ accept = False
+ )
+ )
+
+ logging.verbose('Non-application actor tried to follow: %s', view.actor.id)
+ return
+
+ if conn.get_inbox(view.actor.shared_inbox):
+ data = {'followid': view.message.id}
+ statement = tinysql.Update('inboxes', data, inbox = view.actor.shared_inbox)
+
+ with conn.query(statement):
+ pass
+
+ else:
+ conn.put_inbox(
+ view.actor.domain,
+ view.actor.shared_inbox,
+ view.actor.id,
+ view.message.id,
+ software
+ )
+
view.app.push_message(
view.actor.shared_inbox,
Message.new_response(
- host = view.config.host,
+ host = view.config.domain,
actor = view.actor.id,
followid = view.message.id,
- accept = False
+ accept = True
)
)
- return logging.verbose(
- 'Rejected follow from actor for using specific software: actor=%s, software=%s',
- view.actor.id,
- software
- )
-
- ## reject if the actor is not an instance actor
- if person_check(view.actor, software):
- view.app.push_message(
- view.actor.shared_inbox,
- Message.new_response(
- host = view.config.host,
- actor = view.actor.id,
- followid = view.message.id,
- accept = False
+ # Are Akkoma and Pleroma the only two that expect a follow back?
+ # Ignoring only Mastodon for now
+ if software != 'mastodon':
+ view.app.push_message(
+ view.actor.shared_inbox,
+ Message.new_follow(
+ host = view.config.domain,
+ actor = view.actor.id
+ )
)
- )
-
- logging.verbose('Non-application actor tried to follow: %s', view.actor.id)
- return
-
- view.database.add_inbox(view.actor.shared_inbox, view.message.id, software)
- view.database.save()
-
- view.app.push_message(
- view.actor.shared_inbox,
- Message.new_response(
- host = view.config.host,
- actor = view.actor.id,
- followid = view.message.id,
- accept = True
- )
- )
-
- # Are Akkoma and Pleroma the only two that expect a follow back?
- # Ignoring only Mastodon for now
- if software != 'mastodon':
- view.app.push_message(
- view.actor.shared_inbox,
- Message.new_follow(
- host = view.config.host,
- actor = view.actor.id
- )
- )
-async def handle_undo(view: View) -> None:
+async def handle_undo(view: ActorView) -> None:
## If the object is not a Follow, forward it
if view.message.object['type'] != 'Follow':
- return await handle_forward(view)
-
- if not view.database.del_inbox(view.actor.domain, view.message.object['id']):
- logging.verbose(
- 'Failed to delete "%s" with follow ID "%s"',
- view.actor.id,
- view.message.object['id']
- )
-
+ await handle_forward(view)
return
- view.database.save()
+ with view.database.connection() as conn:
+ if not conn.del_inbox(view.actor.inbox):
+ logging.verbose(
+ 'Failed to delete "%s" with follow ID "%s"',
+ view.actor.id,
+ view.message.object['id']
+ )
view.app.push_message(
view.actor.shared_inbox,
Message.new_unfollow(
- host = view.config.host,
+ host = view.config.domain,
actor = view.actor.id,
follow = view.message
)
@@ -154,7 +168,7 @@ processors = {
}
-async def run_processor(view: View) -> None:
+async def run_processor(view: ActorView) -> None:
if view.message.type not in processors:
logging.verbose(
'Message type "%s" from actor cannot be handled: %s',
@@ -164,12 +178,21 @@ async def run_processor(view: View) -> None:
return
- if view.instance and not view.instance.get('software'):
- nodeinfo = await view.client.fetch_nodeinfo(view.instance['domain'])
+ if view.instance:
+ if not view.instance['software']:
+ if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])):
+ with view.database.connection() as conn:
+ view.instance = conn.update_inbox(
+ view.instance['inbox'],
+ software = nodeinfo.sw_name
+ )
- if nodeinfo:
- view.instance['software'] = nodeinfo.sw_name
- view.database.save()
+ if not view.instance['actor']:
+ with view.database.connection() as conn:
+ view.instance = conn.update_inbox(
+ view.instance['inbox'],
+ actor = view.actor.id
+ )
logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id)
await processors[view.message.type](view)
diff --git a/relay/views.py b/relay/views.py
index e1bed64..df06e81 100644
--- a/relay/views.py
+++ b/relay/views.py
@@ -2,8 +2,11 @@ from __future__ import annotations
import asyncio
import subprocess
+import traceback
import typing
+from aputils.errors import SignatureFailureError
+from aputils.misc import Digest, HttpDate, Signature
from aputils.objects import Nodeinfo, Webfinger, WellKnownNodeinfo
from pathlib import Path
@@ -14,6 +17,7 @@ from .processors import run_processor
if typing.TYPE_CHECKING:
from aiohttp.web import Request
+ from aputils.signer import Signer
from typing import Callable
@@ -71,12 +75,16 @@ def register_route(*paths: str) -> Callable:
@register_route('/')
class HomeView(View):
async def get(self, request: Request) -> Response:
- text = HOME_TEMPLATE.format(
- host = self.config.host,
- note = self.config.note,
- count = len(self.database.hostnames),
- targets = '
'.join(self.database.hostnames)
- )
+ with self.database.connection() as conn:
+ config = conn.get_config_all()
+ inboxes = conn.execute('SELECT * FROM inboxes').all()
+
+ text = HOME_TEMPLATE.format(
+ host = self.config.domain,
+ note = config['note'],
+ count = len(inboxes),
+ targets = '
'.join(inbox['domain'] for inbox in inboxes)
+ )
return Response.new(text, ctype='html')
@@ -84,44 +92,137 @@ class HomeView(View):
@register_route('/actor', '/inbox')
class ActorView(View):
+ def __init__(self, request: Request):
+ View.__init__(self, request)
+
+ self.signature: Signature = None
+ self.message: Message = None
+ self.actor: Message = None
+ self.instance: dict[str, str] = None
+ self.signer: Signer = None
+
+
async def get(self, request: Request) -> Response:
data = Message.new_actor(
- host = self.config.host,
- pubkey = self.database.signer.pubkey
+ host = self.config.domain,
+ pubkey = self.app.signer.pubkey
)
return Response.new(data, ctype='activity')
async def post(self, request: Request) -> Response:
- response = await self.get_post_data()
-
- if response is not None:
+ if (response := await self.get_post_data()):
return response
- ## reject if the actor isn't whitelisted while the whiltelist is enabled
- if self.config.whitelist_enabled and not self.config.is_whitelisted(self.actor.domain):
- logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id)
- return Response.new_error(403, 'access denied', 'json')
+ with self.database.connection() as conn:
+ self.instance = conn.get_inbox(self.actor.inbox)
+ config = conn.get_config_all()
- ## reject if actor is banned
- if self.config.is_banned(self.actor.domain):
- logging.verbose('Ignored request from banned actor: %s', self.actor.id)
- return Response.new_error(403, 'access denied', 'json')
+ ## reject if the actor isn't whitelisted while the whiltelist is enabled
+ if config['whitelist-enabled'] and not conn.get_domain_whitelist(self.actor.domain):
+ logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id)
+ return Response.new_error(403, 'access denied', 'json')
- ## reject if activity type isn't 'Follow' and the actor isn't following
- if self.message.type != 'Follow' and not self.database.get_inbox(self.actor.domain):
- logging.verbose(
- 'Rejected actor for trying to post while not following: %s',
- self.actor.id
- )
+ ## reject if actor is banned
+ if conn.get_domain_ban(self.actor.domain):
+ logging.verbose('Ignored request from banned actor: %s', self.actor.id)
+ return Response.new_error(403, 'access denied', 'json')
- return Response.new_error(401, 'access denied', 'json')
+ ## reject if activity type isn't 'Follow' and the actor isn't following
+ if self.message.type != 'Follow' and not self.instance:
+ logging.verbose(
+ 'Rejected actor for trying to post while not following: %s',
+ self.actor.id
+ )
- logging.debug('>> payload %s', self.message.to_json(4))
+ return Response.new_error(401, 'access denied', 'json')
- asyncio.ensure_future(run_processor(self))
- return Response.new(status = 202)
+ logging.debug('>> payload %s', self.message.to_json(4))
+
+ asyncio.ensure_future(run_processor(self))
+ return Response.new(status = 202)
+
+
+ async def get_post_data(self) -> Response | None:
+ try:
+ self.signature = Signature.new_from_signature(self.request.headers['signature'])
+
+ except KeyError:
+ logging.verbose('Missing signature header')
+ return Response.new_error(400, 'missing signature header', 'json')
+
+ try:
+ self.message = await self.request.json(loads = Message.parse)
+
+ except Exception:
+ traceback.print_exc()
+ logging.verbose('Failed to parse inbox message')
+ return Response.new_error(400, 'failed to parse message', 'json')
+
+ if self.message is None:
+ logging.verbose('empty message')
+ return Response.new_error(400, 'missing message', 'json')
+
+ if 'actor' not in self.message:
+ logging.verbose('actor not in message')
+ return Response.new_error(400, 'no actor in message', 'json')
+
+ self.actor = await self.client.get(self.signature.keyid, sign_headers = True)
+
+ if self.actor is None:
+ # ld signatures aren't handled atm, so just ignore it
+ if self.message.type == 'Delete':
+ logging.verbose('Instance sent a delete which cannot be handled')
+ return Response.new(status=202)
+
+ logging.verbose(f'Failed to fetch actor: {self.signature.keyid}')
+ return Response.new_error(400, 'failed to fetch actor', 'json')
+
+ try:
+ self.signer = self.actor.signer
+
+ except KeyError:
+ logging.verbose('Actor missing public key: %s', self.signature.keyid)
+ return Response.new_error(400, 'actor missing public key', 'json')
+
+ try:
+ self.validate_signature(await self.request.read())
+
+ except SignatureFailureError as e:
+ logging.verbose('signature validation failed for "%s": %s', self.actor.id, e)
+ return Response.new_error(401, str(e), 'json')
+
+
+ def validate_signature(self, body: bytes) -> None:
+ headers = {key.lower(): value for key, value in self.request.headers.items()}
+ headers["(request-target)"] = " ".join([self.request.method.lower(), self.request.path])
+
+ if (digest := Digest.new_from_digest(headers.get("digest"))):
+ if not body:
+ raise SignatureFailureError("Missing body for digest verification")
+
+ if not digest.validate(body):
+ raise SignatureFailureError("Body digest does not match")
+
+ if self.signature.algorithm_type == "hs2019":
+ if "(created)" not in self.signature.headers:
+ raise SignatureFailureError("'(created)' header not used")
+
+ current_timestamp = HttpDate.new_utc().timestamp()
+
+ if self.signature.created > current_timestamp:
+ raise SignatureFailureError("Creation date after current date")
+
+ if current_timestamp > self.signature.expires:
+ raise SignatureFailureError("Expiration date before current date")
+
+ headers["(created)"] = self.signature.created
+ headers["(expires)"] = self.signature.expires
+
+ # pylint: disable=protected-access
+ if not self.signer._validate_signature(headers, self.signature):
+ raise SignatureFailureError("Signature does not match")
@register_route('/.well-known/webfinger')
@@ -133,12 +234,12 @@ class WebfingerView(View):
except KeyError:
return Response.new_error(400, 'missing "resource" query key', 'json')
- if subject != f'acct:relay@{self.config.host}':
+ if subject != f'acct:relay@{self.config.domain}':
return Response.new_error(404, 'user not found', 'json')
data = Webfinger.new(
handle = 'relay',
- domain = self.config.host,
+ domain = self.config.domain,
actor = self.config.actor
)
@@ -148,14 +249,17 @@ class WebfingerView(View):
@register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}')
class NodeinfoView(View):
async def get(self, request: Request, niversion: str) -> Response:
- data = {
- 'name': 'activityrelay',
- 'version': VERSION,
- 'protocols': ['activitypub'],
- 'open_regs': not self.config.whitelist_enabled,
- 'users': 1,
- 'metadata': {'peers': self.database.hostnames}
- }
+ with self.database.connection() as conn:
+ inboxes = conn.execute('SELECT * FROM inboxes').all()
+
+ data = {
+ 'name': 'activityrelay',
+ 'version': VERSION,
+ 'protocols': ['activitypub'],
+ 'open_regs': not conn.get_config('whitelist-enabled'),
+ 'users': 1,
+ 'metadata': {'peers': [inbox['domain'] for inbox in inboxes]}
+ }
if niversion == '2.1':
data['repo'] = 'https://git.pleroma.social/pleroma/relay'
@@ -166,5 +270,5 @@ class NodeinfoView(View):
@register_route('/.well-known/nodeinfo')
class WellknownNodeinfoView(View):
async def get(self, request: Request) -> Response:
- data = WellKnownNodeinfo.new_template(self.config.host)
+ data = WellKnownNodeinfo.new_template(self.config.domain)
return Response.new(data, ctype = 'json')
diff --git a/requirements.txt b/requirements.txt
index 43bf45a..cc6fc4d 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,3 +3,4 @@ aputils@https://git.barkshark.xyz/barkshark/aputils/archive/0.1.6a.tar.gz
cachetools>=5.2.0
click>=8.1.2
pyyaml>=6.0
+tinysql@https://git.barkshark.xyz/barkshark/tinysql/archive/0.2.1.tar.gz
diff --git a/setup.cfg b/setup.cfg
index 8b807e3..65874ff 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -10,26 +10,33 @@ license_file = LICENSE
classifiers =
Environment :: Console
License :: OSI Approved :: AGPLv3 License
- Programming Language :: Python :: 3.6
- Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
+ Programming Language :: Python :: 3.11
+ Programming Language :: Python :: 3.12
project_urls =
Source = https://git.pleroma.social/pleroma/relay
Tracker = https://git.pleroma.social/pleroma/relay/-/issues
[options]
zip_safe = False
-packages = find:
+packages =
+ relay
+ relay.database
+include_package_data = true
install_requires = file: requirements.txt
python_requires = >=3.8
[options.extras_require]
dev =
- flake8 = 3.1.0
- pyinstaller = 6.3.0
- pylint = 3.0
+ flake8 == 3.1.0
+ pyinstaller == 6.3.0
+ pylint == 3.0
+
+[options.package_data]
+relay =
+ data/statements.sql
[options.entry_points]
console_scripts =