Compare commits

..

1 commit

Author SHA1 Message Date
Izalia Mae f4b30e3c6c Merge branch 'sql' into 'master'
Draft: switch database backend to sql

See merge request pleroma/relay!53
2024-01-22 11:50:58 +00:00
18 changed files with 176 additions and 199 deletions

View file

@ -1,3 +0,0 @@
flake8 == 7.0.0
pyinstaller == 6.3.0
pylint == 3.0

View file

@ -3,8 +3,11 @@
There are a number of commands to manage your relay's database and config. You can add `--help` to There are a number of commands to manage your relay's database and config. You can add `--help` to
any category or command to get help on that specific option (ex. `activityrelay inbox --help`). any category or command to get help on that specific option (ex. `activityrelay inbox --help`).
Note: `activityrelay` is only available via pip or pipx if `~/.local/bin` is in `$PATH`. If not, Note: Unless specified, it is recommended to run any commands while the relay is shutdown.
use `python3 -m relay` if installed via pip or `~/.local/bin/activityrelay` if installed via pipx.
Note 2: `activityrelay` is only available via pip or pipx if `~/.local/bin` is in `$PATH`. If it
isn't, use `python3 -m relay` if installed via pip or `~/.local/bin/activityrelay` if installed
via pipx
## Run ## Run
@ -21,22 +24,6 @@ Run the setup wizard to configure your relay.
activityrelay setup activityrelay setup
## Convert
Convert the old config and jsonld to the new config and SQL backend. If the old config filename is
not specified, the config will get backed up as `relay.backup.yaml` before converting.
activityrelay convert --old-config relaycfg.yaml
## Edit Config
Open the config file in a text editor. If an editor is not specified with `--editor`, the default
editor will be used.
activityrelay edit-config --editor micro
## Config ## Config
Manage the relay config Manage the relay config
@ -133,7 +120,7 @@ Remove a domain from the whitelist.
### Import ### Import
Add all current inboxes to the whitelist. Add all current inboxes to the whitelist
activityrelay whitelist import activityrelay whitelist import
@ -145,15 +132,15 @@ Manage the instance ban list.
### List ### List
List the currently banned instances. List the currently banned instances
activityrelay instance list activityrelay instance list
### Ban ### Ban
Add an instance to the ban list. If the instance is currently subscribed, it will be removed from Add an instance to the ban list. If the instance is currently subscribed, remove it from the
the inbox list. database.
activityrelay instance ban <domain> activityrelay instance ban <domain>
@ -165,17 +152,10 @@ Remove an instance from the ban list.
activityrelay instance unban <domain> activityrelay instance unban <domain>
### Update
Update the ban reason or note for an instance ban.
activityrelay instance update bad.example.com --reason "the baddest reason"
## Software ## Software
Manage the software ban list. To get the correct name, check the software's nodeinfo endpoint. Manage the software ban list. To get the correct name, check the software's nodeinfo endpoint.
You can find it at `nodeinfo['software']['name']`. You can find it at nodeinfo\['software']\['name'].
### List ### List
@ -206,12 +186,4 @@ name via nodeinfo.
If the name is `RELAYS` (case-sensitive), remove all known relay software names from the list. If the name is `RELAYS` (case-sensitive), remove all known relay software names from the list.
activityrelay software unban [-f/--fetch-nodeinfo] <name, domain, or RELAYS> activityrelay unban [-f/--fetch-nodeinfo] <name, domain, or RELAYS>
### Update
Update the ban reason or note for a software ban. Either `--reason` and/or `--note` must be
specified.
activityrelay software update relay.example.com --reason "begone relay"

View file

@ -2,23 +2,41 @@
## General ## General
### Domain ### DB
Hostname the relay will be hosted on. The path to the database. It contains the relay actor private key and all subscribed
instances. If the path is not absolute, it is relative to the working directory.
domain: relay.example.com db: relay.jsonld
### Listener ### Listener
The address and port the relay will listen on. If the reverse proxy (nginx, apache, caddy, etc) The address and port the relay will listen on. If the reverse proxy (nginx, apache, caddy, etc)
is running on the same host, it is recommended to change `listen` to `localhost` if the reverse is running on the same host, it is recommended to change `listen` to `localhost`
proxy is on the same host.
listen: 0.0.0.0 listen: 0.0.0.0
port: 8080 port: 8080
### Note
A small blurb to describe your relay instance. This will show up on the relay's home page.
note: "Make a note about your instance here."
### Post Limit
The maximum number of messages to send out at once. For each incoming message, a message will be
sent out to every subscribed instance minus the instance which sent the message. This limit
is to prevent too many outgoing connections from being made, so adjust if necessary.
Note: If the `workers` option is set to anything above 0, this limit will be per worker.
push_limit: 512
### Push Workers ### Push Workers
The relay can be configured to use threads to push messages out. For smaller relays, this isn't The relay can be configured to use threads to push messages out. For smaller relays, this isn't
@ -28,59 +46,60 @@ threads.
workers: 0 workers: 0
### Database type ### JSON GET cache limit
SQL database backend to use. Valid values are `sqlite` or `postgres`. JSON objects (actors, nodeinfo, etc) will get cached when fetched. This will set the max number of
objects to keep in the cache.
database_type: sqlite json_cache: 1024
### Sqlite File Path ## AP
Path to the sqlite database file. If the path is not absolute, it is relative to the config file. Various ActivityPub-related settings
directory.
sqlite_path: relay.jsonld
## Postgresql
In order to use the Postgresql backend, the user and database need to be created first.
sudo -u postgres psql -c "CREATE USER activityrelay"
sudo -u postgres psql -c "CREATE DATABASE activityrelay OWNER activityrelay"
### Database Name
Name of the database to use.
name: activityrelay
### Host ### Host
Hostname, IP address, or unix socket the server is hosted on. The domain your relay will use to identify itself.
host: /var/run/postgresql host: relay.example.com
### Port ### Whitelist Enabled
Port number the server is listening on. If set to `true`, only instances in the whitelist can follow the relay. Any subscribed instances
not in the whitelist will be removed from the inbox list on startup.
port: 5432 whitelist_enabled: false
### Username ### Whitelist
User to use when logging into the server. A list of domains of instances which are allowed to subscribe to your relay.
user: null whitelist:
- bad-instance.example.com
- another-bad-instance.example.com
### Password ### Blocked Instances
Password for the specified user. A list of instances which are unable to follow the instance. If a subscribed instance is added to
the block list, it will be removed from the inbox list on startup.
pass: null blocked_instances:
- bad-instance.example.com
- another-bad-instance.example.com
### Blocked Software
A list of ActivityPub software which cannot follow your relay. This list is empty by default, but
setting this to the below list will block all other relays and prevent relay chains
blocked_software:
- activityrelay
- aoderelay
- social.seattle.wa.us-relay
- unciarelay

View file

@ -28,14 +28,14 @@ server {
# logging, mostly for debug purposes. Disable if you wish. # logging, mostly for debug purposes. Disable if you wish.
access_log /srv/www/relay.<yourdomain>/logs/access.log; access_log /srv/www/relay.<yourdomain>/logs/access.log;
error_log /srv/www/relay.<yourdomain>/logs/error.log; error_log /srv/www/relay.<yourdomain>/logs/error.log;
ssl_protocols TLSv1.2; ssl_protocols TLSv1.2;
ssl_ciphers EECDH+AESGCM:EECDH+AES; ssl_ciphers EECDH+AESGCM:EECDH+AES;
ssl_ecdh_curve secp384r1; ssl_ecdh_curve secp384r1;
ssl_prefer_server_ciphers on; ssl_prefer_server_ciphers on;
ssl_session_cache shared:SSL:10m; ssl_session_cache shared:SSL:10m;
# ssl certs. # ssl certs.
ssl_certificate /usr/local/etc/letsencrypt/live/relay.<yourdomain>/fullchain.pem; ssl_certificate /usr/local/etc/letsencrypt/live/relay.<yourdomain>/fullchain.pem;
ssl_certificate_key /usr/local/etc/letsencrypt/live/relay.<yourdomain>/privkey.pem; ssl_certificate_key /usr/local/etc/letsencrypt/live/relay.<yourdomain>/privkey.pem;
@ -48,7 +48,7 @@ server {
# sts, change if you care. # sts, change if you care.
# add_header Strict-Transport-Security "max-age=31536000; includeSubDomains"; # add_header Strict-Transport-Security "max-age=31536000; includeSubDomains";
# uncomment this to use a static page in your webroot for your root page. # uncomment this to use a static page in your webroot for your root page.
#location = / { #location = / {
# index index.html; # index index.html;

View file

@ -3,7 +3,7 @@ Description=ActivityPub Relay
[Service] [Service]
WorkingDirectory=/home/relay/relay WorkingDirectory=/home/relay/relay
ExecStart=/usr/bin/python3 -m relay run ExecStart=/usr/bin/python3 -m relay
[Install] [Install]
WantedBy=multi-user.target WantedBy=multi-user.target

View file

@ -6,23 +6,6 @@ build-backend = 'setuptools.build_meta'
[tool.pylint.main] [tool.pylint.main]
jobs = 0 jobs = 0
persistent = true persistent = true
load-plugins = [
"pylint.extensions.code_style",
"pylint.extensions.comparison_placement",
"pylint.extensions.confusing_elif",
"pylint.extensions.for_any_all",
"pylint.extensions.consider_ternary_expression",
"pylint.extensions.bad_builtin",
"pylint.extensions.dict_init_mutate",
"pylint.extensions.check_elif",
"pylint.extensions.empty_comment",
"pylint.extensions.private_import",
"pylint.extensions.redefined_variable_type",
"pylint.extensions.no_self_use",
"pylint.extensions.overlapping_exceptions",
"pylint.extensions.set_membership",
"pylint.extensions.typing"
]
[tool.pylint.design] [tool.pylint.design]
@ -39,7 +22,6 @@ single-line-if-stmt = true
[tool.pylint.messages_control] [tool.pylint.messages_control]
disable = [ disable = [
"fixme",
"broad-exception-caught", "broad-exception-caught",
"cyclic-import", "cyclic-import",
"global-statement", "global-statement",
@ -49,8 +31,7 @@ disable = [
"too-many-public-methods", "too-many-public-methods",
"too-many-return-statements", "too-many-return-statements",
"wrong-import-order", "wrong-import-order",
"wrong-import-position",
"missing-function-docstring", "missing-function-docstring",
"missing-class-docstring", "missing-class-docstring"
"consider-using-namedtuple-or-dataclass",
"confusing-consecutive-elif"
] ]

View file

@ -13,8 +13,7 @@ from . import logger as logging
from .misc import Message, boolean from .misc import Message, boolean
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from collections.abc import Iterator from typing import Any, Iterator, Optional
from typing import Any
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@ -31,10 +30,10 @@ class RelayConfig(dict):
def __setitem__(self, key: str, value: Any) -> None: def __setitem__(self, key: str, value: Any) -> None:
if key in {'blocked_instances', 'blocked_software', 'whitelist'}: if key in ['blocked_instances', 'blocked_software', 'whitelist']:
assert isinstance(value, (list, set, tuple)) assert isinstance(value, (list, set, tuple))
elif key in {'port', 'workers', 'json_cache', 'timeout'}: elif key in ['port', 'workers', 'json_cache', 'timeout']:
if not isinstance(value, int): if not isinstance(value, int):
value = int(value) value = int(value)
@ -111,7 +110,7 @@ class RelayConfig(dict):
return return
for key, value in config.items(): for key, value in config.items():
if key == 'ap': if key in ['ap']:
for k, v in value.items(): for k, v in value.items():
if k not in self: if k not in self:
continue continue
@ -191,7 +190,7 @@ class RelayDatabase(dict):
json.dump(self, fd, indent=4) json.dump(self, fd, indent=4)
def get_inbox(self, domain: str, fail: bool = False) -> dict[str, str] | None: def get_inbox(self, domain: str, fail: Optional[bool] = False) -> dict[str, str] | None:
if domain.startswith('http'): if domain.startswith('http'):
domain = urlparse(domain).hostname domain = urlparse(domain).hostname
@ -206,13 +205,14 @@ class RelayDatabase(dict):
def add_inbox(self, def add_inbox(self,
inbox: str, inbox: str,
followid: str | None = None, followid: Optional[str] = None,
software: str | None = None) -> dict[str, str]: software: Optional[str] = None) -> dict[str, str]:
assert inbox.startswith('https'), 'Inbox must be a url' assert inbox.startswith('https'), 'Inbox must be a url'
domain = urlparse(inbox).hostname domain = urlparse(inbox).hostname
instance = self.get_inbox(domain)
if (instance := self.get_inbox(domain)): if instance:
if followid: if followid:
instance['followid'] = followid instance['followid'] = followid
@ -234,10 +234,12 @@ class RelayDatabase(dict):
def del_inbox(self, def del_inbox(self,
domain: str, domain: str,
followid: str = None, followid: Optional[str] = None,
fail: bool = False) -> bool: fail: Optional[bool] = False) -> bool:
if not (data := self.get_inbox(domain, fail=False)): data = self.get_inbox(domain, fail=False)
if not data:
if fail: if fail:
raise KeyError(domain) raise KeyError(domain)

View file

@ -10,7 +10,7 @@ from pathlib import Path
from .misc import IS_DOCKER from .misc import IS_DOCKER
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from typing import Any from typing import Any, Optional
DEFAULTS: dict[str, Any] = { DEFAULTS: dict[str, Any] = {
@ -32,7 +32,7 @@ if IS_DOCKER:
class Config: class Config:
def __init__(self, path: str, load: bool = False): def __init__(self, path: str, load: Optional[bool] = False):
self.path = Path(path).expanduser().resolve() self.path = Path(path).expanduser().resolve()
self.listen = None self.listen = None
@ -151,7 +151,7 @@ class Config:
if key not in DEFAULTS: if key not in DEFAULTS:
raise KeyError(key) raise KeyError(key)
if key in {'port', 'pg_port', 'workers'} and not isinstance(value, int): if key in ('port', 'pg_port', 'workers') and not isinstance(value, int):
value = int(value) value = int(value)
setattr(self, key, value) setattr(self, key, value)

View file

@ -12,10 +12,11 @@ from .schema import VERSIONS, migrate_0
from .. import logger as logging from .. import logger as logging
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from typing import Optional
from .config import Config from .config import Config
def get_database(config: Config, migrate: bool = True) -> tinysql.Database: def get_database(config: Config, migrate: Optional[bool] = True) -> tinysql.Database:
if config.db_type == "sqlite": if config.db_type == "sqlite":
db = tinysql.Database.sqlite(config.sqlite_path, connection_class = Connection) db = tinysql.Database.sqlite(config.sqlite_path, connection_class = Connection)
@ -40,7 +41,9 @@ def get_database(config: Config, migrate: bool = True) -> tinysql.Database:
migrate_0(conn) migrate_0(conn)
return db return db
if (schema_ver := conn.get_config('schema-version')) < get_default_value('schema-version'): schema_ver = conn.get_config('schema-version')
if schema_ver < get_default_value('schema-version'):
logging.info("Migrating database from version '%i'", schema_ver) logging.info("Migrating database from version '%i'", schema_ver)
for ver, func in VERSIONS: for ver, func in VERSIONS:

View file

@ -6,8 +6,7 @@ from .. import logger as logging
from ..misc import boolean from ..misc import boolean
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from collections.abc import Callable from typing import Any, Callable
from typing import Any
CONFIG_DEFAULTS: dict[str, tuple[str, Any]] = { CONFIG_DEFAULTS: dict[str, tuple[str, Any]] = {

View file

@ -12,9 +12,8 @@ from .. import logger as logging
from ..misc import get_app from ..misc import get_app
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from collections.abc import Iterator
from tinysql import Cursor, Row from tinysql import Cursor, Row
from typing import Any from typing import Any, Iterator, Optional
from .application import Application from .application import Application
from ..misc import Message from ..misc import Message
@ -44,7 +43,7 @@ class Connection(tinysql.Connection):
yield inbox['inbox'] yield inbox['inbox']
def exec_statement(self, name: str, params: dict[str, Any] | None = None) -> Cursor: def exec_statement(self, name: str, params: Optional[dict[str, Any]] = None) -> Cursor:
return self.execute(self.database.prepared_statements[name], params) return self.execute(self.database.prepared_statements[name], params)
@ -111,9 +110,9 @@ class Connection(tinysql.Connection):
def put_inbox(self, def put_inbox(self,
domain: str, domain: str,
inbox: str, inbox: str,
actor: str | None = None, actor: Optional[str] = None,
followid: str | None = None, followid: Optional[str] = None,
software: str | None = None) -> Row: software: Optional[str] = None) -> Row:
params = { params = {
'domain': domain, 'domain': domain,
@ -130,9 +129,9 @@ class Connection(tinysql.Connection):
def update_inbox(self, def update_inbox(self,
inbox: str, inbox: str,
actor: str | None = None, actor: Optional[str] = None,
followid: str | None = None, followid: Optional[str] = None,
software: str | None = None) -> Row: software: Optional[str] = None) -> Row:
if not (actor or followid or software): if not (actor or followid or software):
raise ValueError('Missing "actor", "followid", and/or "software"') raise ValueError('Missing "actor", "followid", and/or "software"')
@ -172,8 +171,8 @@ class Connection(tinysql.Connection):
def put_domain_ban(self, def put_domain_ban(self,
domain: str, domain: str,
reason: str | None = None, reason: Optional[str] = None,
note: str | None = None) -> Row: note: Optional[str] = None) -> Row:
params = { params = {
'domain': domain, 'domain': domain,
@ -188,8 +187,8 @@ class Connection(tinysql.Connection):
def update_domain_ban(self, def update_domain_ban(self,
domain: str, domain: str,
reason: str | None = None, reason: Optional[str] = None,
note: str | None = None) -> tinysql.Row: note: Optional[str] = None) -> tinysql.Row:
if not (reason or note): if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified') raise ValueError('"reason" and/or "note" must be specified')
@ -226,8 +225,8 @@ class Connection(tinysql.Connection):
def put_software_ban(self, def put_software_ban(self,
name: str, name: str,
reason: str | None = None, reason: Optional[str] = None,
note: str | None = None) -> Row: note: Optional[str] = None) -> Row:
params = { params = {
'name': name, 'name': name,
@ -242,8 +241,8 @@ class Connection(tinysql.Connection):
def update_software_ban(self, def update_software_ban(self,
name: str, name: str,
reason: str | None = None, reason: Optional[str] = None,
note: str | None = None) -> tinysql.Row: note: Optional[str] = None) -> tinysql.Row:
if not (reason or note): if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified') raise ValueError('"reason" and/or "note" must be specified')

View file

@ -7,7 +7,7 @@ from tinysql import Column, Connection, Table
from .config import get_default_value from .config import get_default_value
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from collections.abc import Callable from typing import Callable
VERSIONS: list[Callable] = [] VERSIONS: list[Callable] = []
@ -33,7 +33,7 @@ TABLES: list[Table] = [
Column('created', 'timestamp') Column('created', 'timestamp')
), ),
Table( Table(
'domain_bans', 'instance_bans',
Column('domain', 'text', primary_key = True, unique = True, nullable = True), Column('domain', 'text', primary_key = True, unique = True, nullable = True),
Column('reason', 'text'), Column('reason', 'text'),
Column('note', 'text'), Column('note', 'text'),

View file

@ -16,7 +16,7 @@ from . import logger as logging
from .misc import MIMETYPES, Message, get_app from .misc import MIMETYPES, Message, get_app
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from typing import Any from typing import Any, Callable, Optional
HEADERS = { HEADERS = {
@ -26,7 +26,11 @@ HEADERS = {
class HttpClient: class HttpClient:
def __init__(self, limit: int = 100, timeout: int = 10, cache_size: int = 1024): def __init__(self,
limit: Optional[int] = 100,
timeout: Optional[int] = 10,
cache_size: Optional[int] = 1024):
self.cache = LRUCache(cache_size) self.cache = LRUCache(cache_size)
self.limit = limit self.limit = limit
self.timeout = timeout self.timeout = timeout
@ -73,9 +77,9 @@ class HttpClient:
async def get(self, # pylint: disable=too-many-branches async def get(self, # pylint: disable=too-many-branches
url: str, url: str,
sign_headers: bool = False, sign_headers: Optional[bool] = False,
loads: callable | None = None, loads: Optional[Callable] = None,
force: bool = False) -> Message | dict | None: force: Optional[bool] = False) -> Message | dict | None:
await self.open() await self.open()
@ -147,13 +151,11 @@ class HttpClient:
instance = conn.get_inbox(url) instance = conn.get_inbox(url)
## Using the old algo by default is probably a better idea right now ## Using the old algo by default is probably a better idea right now
# pylint: disable=consider-ternary-expression
if instance and instance['software'] in {'mastodon'}: if instance and instance['software'] in {'mastodon'}:
algorithm = 'hs2019' algorithm = 'hs2019'
else: else:
algorithm = 'original' algorithm = 'original'
# pylint: enable=consider-ternary-expression
headers = {'Content-Type': 'application/activity+json'} headers = {'Content-Type': 'application/activity+json'}
headers.update(get_app().signer.sign_headers('POST', url, message, algorithm=algorithm)) headers.update(get_app().signer.sign_headers('POST', url, message, algorithm=algorithm))
@ -193,7 +195,7 @@ class HttpClient:
logging.verbose('Failed to fetch well-known nodeinfo url for %s', domain) logging.verbose('Failed to fetch well-known nodeinfo url for %s', domain)
return None return None
for version in ('20', '21'): for version in ['20', '21']:
try: try:
nodeinfo_url = wk_nodeinfo.get_url(version) nodeinfo_url = wk_nodeinfo.get_url(version)

View file

@ -8,8 +8,7 @@ from enum import IntEnum
from pathlib import Path from pathlib import Path
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from collections.abc import Callable from typing import Any, Callable, Type
from typing import Any
class LogLevel(IntEnum): class LogLevel(IntEnum):
@ -26,7 +25,7 @@ class LogLevel(IntEnum):
@classmethod @classmethod
def parse(cls: type[IntEnum], data: object) -> IntEnum: def parse(cls: Type[IntEnum], data: object) -> IntEnum:
if isinstance(data, cls): if isinstance(data, cls):
return data return data

View file

@ -22,7 +22,7 @@ from .misc import IS_DOCKER, Message, check_open_port
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from tinysql import Row from tinysql import Row
from typing import Any from typing import Any, Optional
# pylint: disable=unsubscriptable-object,unsupported-assignment-operation # pylint: disable=unsubscriptable-object,unsupported-assignment-operation
@ -69,7 +69,7 @@ def cli(ctx: click.Context, config: str) -> None:
@cli.command('setup') @cli.command('setup')
@click.pass_context @click.pass_context
def cli_setup(ctx: click.Context) -> None: def cli_setup(ctx: click.Context) -> None:
'Generate a new config and create the database' 'Generate a new config'
while True: while True:
ctx.obj.config.domain = click.prompt( ctx.obj.config.domain = click.prompt(
@ -184,12 +184,12 @@ def cli_run(ctx: click.Context) -> None:
@cli.command('convert') @cli.command('convert')
@click.option('--old-config', '-o', help = 'Path to the config file to convert from') @click.option('--old-config', '-o', help = 'Path to the new config file')
@click.pass_context @click.pass_context
def cli_convert(ctx: click.Context, old_config: str) -> None: def cli_convert(ctx: click.Context, old_config: str) -> None:
'Convert an old config and jsonld database to the new format.' 'Convert an old config and jsonld database to the new format.'
old_config = Path(old_config).expanduser().resolve() if old_config else ctx.obj.config.path old_config = Path(old_config).expanduser().resolve()
backup = ctx.obj.config.path.parent.joinpath(f'{ctx.obj.config.path.stem}.backup.yaml') backup = ctx.obj.config.path.parent.joinpath(f'{ctx.obj.config.path.stem}.backup.yaml')
if str(old_config) == str(ctx.obj.config.path) and not backup.exists(): if str(old_config) == str(ctx.obj.config.path) and not backup.exists():
@ -206,8 +206,6 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
ctx.obj.config.set('port', config['port']) ctx.obj.config.set('port', config['port'])
ctx.obj.config.set('workers', config['workers']) ctx.obj.config.set('workers', config['workers'])
ctx.obj.config.set('sq_path', config['db'].replace('jsonld', 'sqlite3')) ctx.obj.config.set('sq_path', config['db'].replace('jsonld', 'sqlite3'))
ctx.obj.config.set('domain', config['host'])
ctx.obj.config.save()
with get_database(ctx.obj.config) as db: with get_database(ctx.obj.config) as db:
with db.connection() as conn: with db.connection() as conn:
@ -222,7 +220,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
) as inboxes: ) as inboxes:
for inbox in inboxes: for inbox in inboxes:
if inbox['software'] in {'akkoma', 'pleroma'}: if inbox['software'] in ('akkoma', 'pleroma'):
actor = f'https://{inbox["domain"]}/relay' actor = f'https://{inbox["domain"]}/relay'
elif inbox['software'] == 'mastodon': elif inbox['software'] == 'mastodon':
@ -351,7 +349,9 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
if not actor.startswith('http'): if not actor.startswith('http'):
actor = f'https://{actor}/actor' actor = f'https://{actor}/actor'
if not (actor_data := asyncio.run(http.get(actor, sign_headers = True))): actor_data = asyncio.run(http.get(actor, sign_headers = True))
if not actor_data:
click.echo(f'Failed to fetch actor: {actor}') click.echo(f'Failed to fetch actor: {actor}')
return return
@ -411,17 +411,14 @@ def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
@click.argument('inbox') @click.argument('inbox')
@click.option('--actor', '-a', help = 'Actor url for the inbox') @click.option('--actor', '-a', help = 'Actor url for the inbox')
@click.option('--followid', '-f', help = 'Url for the follow activity') @click.option('--followid', '-f', help = 'Url for the follow activity')
@click.option('--software', '-s', @click.option('--software', '-s', type = click.Choice(SOFTWARE))
type = click.Choice(SOFTWARE),
help = 'Nodeinfo software name of the instance'
) # noqa: E124
@click.pass_context @click.pass_context
def cli_inbox_add( def cli_inbox_add(
ctx: click.Context, ctx: click.Context,
inbox: str, inbox: str,
actor: str | None = None, actor: Optional[str] = None,
followid: str | None = None, followid: Optional[str] = None,
software: str | None = None) -> None: software: Optional[str] = None) -> None:
'Add an inbox to the database' 'Add an inbox to the database'
if not inbox.startswith('http'): if not inbox.startswith('http'):
@ -431,10 +428,6 @@ def cli_inbox_add(
else: else:
domain = urlparse(inbox).netloc domain = urlparse(inbox).netloc
if not software:
if (nodeinfo := asyncio.run(http.fetch_nodeinfo(domain))):
software = nodeinfo.sw_name
if not actor and software: if not actor and software:
try: try:
actor = ACTOR_FORMATS[software].format(domain = domain) actor = ACTOR_FORMATS[software].format(domain = domain)
@ -599,7 +592,9 @@ def cli_software_ban(ctx: click.Context,
return return
if fetch_nodeinfo: if fetch_nodeinfo:
if not (nodeinfo := asyncio.run(http.fetch_nodeinfo(name))): nodeinfo = asyncio.run(http.fetch_nodeinfo(name))
if not nodeinfo:
click.echo(f'Failed to fetch software name from domain: {name}') click.echo(f'Failed to fetch software name from domain: {name}')
return return
@ -639,7 +634,9 @@ def cli_software_unban(ctx: click.Context, name: str, fetch_nodeinfo: bool) -> N
return return
if fetch_nodeinfo: if fetch_nodeinfo:
if not (nodeinfo := asyncio.run(http.fetch_nodeinfo(name))): nodeinfo = asyncio.run(http.fetch_nodeinfo(name))
if not nodeinfo:
click.echo(f'Failed to fetch software name from domain: {name}') click.echo(f'Failed to fetch software name from domain: {name}')
return return

View file

@ -14,8 +14,7 @@ from functools import cached_property
from uuid import uuid4 from uuid import uuid4
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from collections.abc import Coroutine, Generator from typing import Any, Coroutine, Generator, Optional, Type
from typing import Any
from .application import Application from .application import Application
from .config import Config from .config import Config
from .database import Database from .database import Database
@ -38,10 +37,10 @@ NODEINFO_NS = {
def boolean(value: Any) -> bool: def boolean(value: Any) -> bool:
if isinstance(value, str): if isinstance(value, str):
if value.lower() in {'on', 'y', 'yes', 'true', 'enable', 'enabled', '1'}: if value.lower() in ['on', 'y', 'yes', 'true', 'enable', 'enabled', '1']:
return True return True
if value.lower() in {'off', 'n', 'no', 'false', 'disable', 'disabled', '0'}: if value.lower() in ['off', 'n', 'no', 'false', 'disable', 'disable', '0']:
return False return False
raise TypeError(f'Cannot parse string "{value}" as a boolean') raise TypeError(f'Cannot parse string "{value}" as a boolean')
@ -84,10 +83,10 @@ def get_app() -> Application:
class Message(ApMessage): class Message(ApMessage):
@classmethod @classmethod
def new_actor(cls: type[Message], # pylint: disable=arguments-differ def new_actor(cls: Type[Message], # pylint: disable=arguments-differ
host: str, host: str,
pubkey: str, pubkey: str,
description: str | None = None) -> Message: description: Optional[str] = None) -> Message:
return cls({ return cls({
'@context': 'https://www.w3.org/ns/activitystreams', '@context': 'https://www.w3.org/ns/activitystreams',
@ -112,7 +111,7 @@ class Message(ApMessage):
@classmethod @classmethod
def new_announce(cls: type[Message], host: str, obj: str) -> Message: def new_announce(cls: Type[Message], host: str, obj: str) -> Message:
return cls({ return cls({
'@context': 'https://www.w3.org/ns/activitystreams', '@context': 'https://www.w3.org/ns/activitystreams',
'id': f'https://{host}/activities/{uuid4()}', 'id': f'https://{host}/activities/{uuid4()}',
@ -124,7 +123,7 @@ class Message(ApMessage):
@classmethod @classmethod
def new_follow(cls: type[Message], host: str, actor: str) -> Message: def new_follow(cls: Type[Message], host: str, actor: str) -> Message:
return cls({ return cls({
'@context': 'https://www.w3.org/ns/activitystreams', '@context': 'https://www.w3.org/ns/activitystreams',
'type': 'Follow', 'type': 'Follow',
@ -136,7 +135,7 @@ class Message(ApMessage):
@classmethod @classmethod
def new_unfollow(cls: type[Message], host: str, actor: str, follow: str) -> Message: def new_unfollow(cls: Type[Message], host: str, actor: str, follow: str) -> Message:
return cls({ return cls({
'@context': 'https://www.w3.org/ns/activitystreams', '@context': 'https://www.w3.org/ns/activitystreams',
'id': f'https://{host}/activities/{uuid4()}', 'id': f'https://{host}/activities/{uuid4()}',
@ -148,7 +147,7 @@ class Message(ApMessage):
@classmethod @classmethod
def new_response(cls: type[Message], def new_response(cls: Type[Message],
host: str, host: str,
actor: str, actor: str,
followid: str, followid: str,
@ -181,11 +180,11 @@ class Message(ApMessage):
class Response(AiohttpResponse): class Response(AiohttpResponse):
@classmethod @classmethod
def new(cls: type[Response], def new(cls: Type[Response],
body: str | bytes | dict = '', body: Optional[str | bytes | dict] = '',
status: int = 200, status: Optional[int] = 200,
headers: dict[str, str] | None = None, headers: Optional[dict[str, str]] = None,
ctype: str = 'text') -> Response: ctype: Optional[str] = 'text') -> Response:
kwargs = { kwargs = {
'status': status, 'status': status,
@ -206,7 +205,7 @@ class Response(AiohttpResponse):
@classmethod @classmethod
def new_error(cls: type[Response], def new_error(cls: Type[Response],
status: int, status: int,
body: str | bytes | dict, body: str | bytes | dict,
ctype: str = 'text') -> Response: ctype: str = 'text') -> Response:
@ -229,10 +228,12 @@ class Response(AiohttpResponse):
class View(AbstractView): class View(AbstractView):
def __await__(self) -> Generator[Response]: def __await__(self) -> Generator[Response]:
if (self.request.method) not in METHODS: method = self.request.method.upper()
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
if not (handler := self.handlers.get(self.request.method)): if method not in METHODS:
raise HTTPMethodNotAllowed(method, self.allowed_methods)
if not (handler := self.handlers.get(method)):
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) from None raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) from None
return handler(self.request, **self.request.match_info).__await__() return handler(self.request, **self.request.match_info).__await__()

View file

@ -18,7 +18,7 @@ from .processors import run_processor
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from aiohttp.web import Request from aiohttp.web import Request
from aputils.signer import Signer from aputils.signer import Signer
from collections.abc import Callable from typing import Callable
VIEWS = [] VIEWS = []

View file

@ -29,7 +29,10 @@ install_requires = file: requirements.txt
python_requires = >=3.8 python_requires = >=3.8
[options.extras_require] [options.extras_require]
dev = file: dev-requirements.txt dev =
flake8 == 3.1.0
pyinstaller == 6.3.0
pylint == 3.0
[options.package_data] [options.package_data]
relay = relay =
@ -41,4 +44,7 @@ console_scripts =
[flake8] [flake8]
select = F401 extend-ignore = ANN101,ANN204,E128,E251,E261,E266,E301,E303,W191
extend-exclude = docs, test*.py
max-line-length = 100
indent-size = 4