Compare commits

..

No commits in common. "7a9d346642263623748bc3a37490df040af657f6" and "57d7d257438224297e69e8383d7eb547fb2ab4b7" have entirely different histories.

14 changed files with 100 additions and 112 deletions

View file

@ -1,3 +0,0 @@
flake8 == 7.0.0
pyinstaller == 6.3.0
pylint == 3.0

View file

@ -6,23 +6,6 @@ build-backend = 'setuptools.build_meta'
[tool.pylint.main]
jobs = 0
persistent = true
load-plugins = [
"pylint.extensions.code_style",
"pylint.extensions.comparison_placement",
"pylint.extensions.confusing_elif",
"pylint.extensions.for_any_all",
"pylint.extensions.consider_ternary_expression",
"pylint.extensions.bad_builtin",
"pylint.extensions.dict_init_mutate",
"pylint.extensions.check_elif",
"pylint.extensions.empty_comment",
"pylint.extensions.private_import",
"pylint.extensions.redefined_variable_type",
"pylint.extensions.no_self_use",
"pylint.extensions.overlapping_exceptions",
"pylint.extensions.set_membership",
"pylint.extensions.typing"
]
[tool.pylint.design]
@ -39,7 +22,6 @@ single-line-if-stmt = true
[tool.pylint.messages_control]
disable = [
"fixme",
"broad-exception-caught",
"cyclic-import",
"global-statement",
@ -49,8 +31,7 @@ disable = [
"too-many-public-methods",
"too-many-return-statements",
"wrong-import-order",
"wrong-import-position",
"missing-function-docstring",
"missing-class-docstring",
"consider-using-namedtuple-or-dataclass",
"confusing-consecutive-elif"
"missing-class-docstring"
]

View file

@ -13,8 +13,7 @@ from . import logger as logging
from .misc import Message, boolean
if typing.TYPE_CHECKING:
from collections.abc import Iterator
from typing import Any
from typing import Any, Iterator, Optional
# pylint: disable=duplicate-code
@ -31,10 +30,10 @@ class RelayConfig(dict):
def __setitem__(self, key: str, value: Any) -> None:
if key in {'blocked_instances', 'blocked_software', 'whitelist'}:
if key in ['blocked_instances', 'blocked_software', 'whitelist']:
assert isinstance(value, (list, set, tuple))
elif key in {'port', 'workers', 'json_cache', 'timeout'}:
elif key in ['port', 'workers', 'json_cache', 'timeout']:
if not isinstance(value, int):
value = int(value)
@ -111,7 +110,7 @@ class RelayConfig(dict):
return
for key, value in config.items():
if key == 'ap':
if key in ['ap']:
for k, v in value.items():
if k not in self:
continue
@ -191,7 +190,7 @@ class RelayDatabase(dict):
json.dump(self, fd, indent=4)
def get_inbox(self, domain: str, fail: bool = False) -> dict[str, str] | None:
def get_inbox(self, domain: str, fail: Optional[bool] = False) -> dict[str, str] | None:
if domain.startswith('http'):
domain = urlparse(domain).hostname
@ -206,13 +205,14 @@ class RelayDatabase(dict):
def add_inbox(self,
inbox: str,
followid: str | None = None,
software: str | None = None) -> dict[str, str]:
followid: Optional[str] = None,
software: Optional[str] = None) -> dict[str, str]:
assert inbox.startswith('https'), 'Inbox must be a url'
domain = urlparse(inbox).hostname
instance = self.get_inbox(domain)
if (instance := self.get_inbox(domain)):
if instance:
if followid:
instance['followid'] = followid
@ -234,10 +234,12 @@ class RelayDatabase(dict):
def del_inbox(self,
domain: str,
followid: str = None,
fail: bool = False) -> bool:
followid: Optional[str] = None,
fail: Optional[bool] = False) -> bool:
if not (data := self.get_inbox(domain, fail=False)):
data = self.get_inbox(domain, fail=False)
if not data:
if fail:
raise KeyError(domain)

View file

@ -10,7 +10,7 @@ from pathlib import Path
from .misc import IS_DOCKER
if typing.TYPE_CHECKING:
from typing import Any
from typing import Any, Optional
DEFAULTS: dict[str, Any] = {
@ -32,7 +32,7 @@ if IS_DOCKER:
class Config:
def __init__(self, path: str, load: bool = False):
def __init__(self, path: str, load: Optional[bool] = False):
self.path = Path(path).expanduser().resolve()
self.listen = None
@ -151,7 +151,7 @@ class Config:
if key not in DEFAULTS:
raise KeyError(key)
if key in {'port', 'pg_port', 'workers'} and not isinstance(value, int):
if key in ('port', 'pg_port', 'workers') and not isinstance(value, int):
value = int(value)
setattr(self, key, value)

View file

@ -12,10 +12,11 @@ from .schema import VERSIONS, migrate_0
from .. import logger as logging
if typing.TYPE_CHECKING:
from typing import Optional
from .config import Config
def get_database(config: Config, migrate: bool = True) -> tinysql.Database:
def get_database(config: Config, migrate: Optional[bool] = True) -> tinysql.Database:
if config.db_type == "sqlite":
db = tinysql.Database.sqlite(config.sqlite_path, connection_class = Connection)
@ -40,7 +41,9 @@ def get_database(config: Config, migrate: bool = True) -> tinysql.Database:
migrate_0(conn)
return db
if (schema_ver := conn.get_config('schema-version')) < get_default_value('schema-version'):
schema_ver = conn.get_config('schema-version')
if schema_ver < get_default_value('schema-version'):
logging.info("Migrating database from version '%i'", schema_ver)
for ver, func in VERSIONS:

View file

@ -6,8 +6,7 @@ from .. import logger as logging
from ..misc import boolean
if typing.TYPE_CHECKING:
from collections.abc import Callable
from typing import Any
from typing import Any, Callable
CONFIG_DEFAULTS: dict[str, tuple[str, Any]] = {

View file

@ -12,9 +12,8 @@ from .. import logger as logging
from ..misc import get_app
if typing.TYPE_CHECKING:
from collections.abc import Iterator
from tinysql import Cursor, Row
from typing import Any
from typing import Any, Iterator, Optional
from .application import Application
from ..misc import Message
@ -44,7 +43,7 @@ class Connection(tinysql.Connection):
yield inbox['inbox']
def exec_statement(self, name: str, params: dict[str, Any] | None = None) -> Cursor:
def exec_statement(self, name: str, params: Optional[dict[str, Any]] = None) -> Cursor:
return self.execute(self.database.prepared_statements[name], params)
@ -111,9 +110,9 @@ class Connection(tinysql.Connection):
def put_inbox(self,
domain: str,
inbox: str,
actor: str | None = None,
followid: str | None = None,
software: str | None = None) -> Row:
actor: Optional[str] = None,
followid: Optional[str] = None,
software: Optional[str] = None) -> Row:
params = {
'domain': domain,
@ -130,9 +129,9 @@ class Connection(tinysql.Connection):
def update_inbox(self,
inbox: str,
actor: str | None = None,
followid: str | None = None,
software: str | None = None) -> Row:
actor: Optional[str] = None,
followid: Optional[str] = None,
software: Optional[str] = None) -> Row:
if not (actor or followid or software):
raise ValueError('Missing "actor", "followid", and/or "software"')
@ -172,8 +171,8 @@ class Connection(tinysql.Connection):
def put_domain_ban(self,
domain: str,
reason: str | None = None,
note: str | None = None) -> Row:
reason: Optional[str] = None,
note: Optional[str] = None) -> Row:
params = {
'domain': domain,
@ -188,8 +187,8 @@ class Connection(tinysql.Connection):
def update_domain_ban(self,
domain: str,
reason: str | None = None,
note: str | None = None) -> tinysql.Row:
reason: Optional[str] = None,
note: Optional[str] = None) -> tinysql.Row:
if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified')
@ -226,8 +225,8 @@ class Connection(tinysql.Connection):
def put_software_ban(self,
name: str,
reason: str | None = None,
note: str | None = None) -> Row:
reason: Optional[str] = None,
note: Optional[str] = None) -> Row:
params = {
'name': name,
@ -242,8 +241,8 @@ class Connection(tinysql.Connection):
def update_software_ban(self,
name: str,
reason: str | None = None,
note: str | None = None) -> tinysql.Row:
reason: Optional[str] = None,
note: Optional[str] = None) -> tinysql.Row:
if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified')

View file

@ -7,7 +7,7 @@ from tinysql import Column, Connection, Table
from .config import get_default_value
if typing.TYPE_CHECKING:
from collections.abc import Callable
from typing import Callable
VERSIONS: list[Callable] = []

View file

@ -16,7 +16,7 @@ from . import logger as logging
from .misc import MIMETYPES, Message, get_app
if typing.TYPE_CHECKING:
from typing import Any
from typing import Any, Callable, Optional
HEADERS = {
@ -26,7 +26,11 @@ HEADERS = {
class HttpClient:
def __init__(self, limit: int = 100, timeout: int = 10, cache_size: int = 1024):
def __init__(self,
limit: Optional[int] = 100,
timeout: Optional[int] = 10,
cache_size: Optional[int] = 1024):
self.cache = LRUCache(cache_size)
self.limit = limit
self.timeout = timeout
@ -73,9 +77,9 @@ class HttpClient:
async def get(self, # pylint: disable=too-many-branches
url: str,
sign_headers: bool = False,
loads: callable | None = None,
force: bool = False) -> Message | dict | None:
sign_headers: Optional[bool] = False,
loads: Optional[Callable] = None,
force: Optional[bool] = False) -> Message | dict | None:
await self.open()
@ -147,13 +151,11 @@ class HttpClient:
instance = conn.get_inbox(url)
## Using the old algo by default is probably a better idea right now
# pylint: disable=consider-ternary-expression
if instance and instance['software'] in {'mastodon'}:
algorithm = 'hs2019'
else:
algorithm = 'original'
# pylint: enable=consider-ternary-expression
headers = {'Content-Type': 'application/activity+json'}
headers.update(get_app().signer.sign_headers('POST', url, message, algorithm=algorithm))
@ -193,7 +195,7 @@ class HttpClient:
logging.verbose('Failed to fetch well-known nodeinfo url for %s', domain)
return None
for version in ('20', '21'):
for version in ['20', '21']:
try:
nodeinfo_url = wk_nodeinfo.get_url(version)

View file

@ -8,8 +8,7 @@ from enum import IntEnum
from pathlib import Path
if typing.TYPE_CHECKING:
from collections.abc import Callable
from typing import Any
from typing import Any, Callable, Type
class LogLevel(IntEnum):
@ -26,7 +25,7 @@ class LogLevel(IntEnum):
@classmethod
def parse(cls: type[IntEnum], data: object) -> IntEnum:
def parse(cls: Type[IntEnum], data: object) -> IntEnum:
if isinstance(data, cls):
return data

View file

@ -22,7 +22,7 @@ from .misc import IS_DOCKER, Message, check_open_port
if typing.TYPE_CHECKING:
from tinysql import Row
from typing import Any
from typing import Any, Optional
# pylint: disable=unsubscriptable-object,unsupported-assignment-operation
@ -69,7 +69,7 @@ def cli(ctx: click.Context, config: str) -> None:
@cli.command('setup')
@click.pass_context
def cli_setup(ctx: click.Context) -> None:
'Generate a new config and create the database'
'Generate a new config'
while True:
ctx.obj.config.domain = click.prompt(
@ -184,7 +184,7 @@ def cli_run(ctx: click.Context) -> None:
@cli.command('convert')
@click.option('--old-config', '-o', help = 'Path to the config file to convert from')
@click.option('--old-config', '-o', help = 'Path to the new config file')
@click.pass_context
def cli_convert(ctx: click.Context, old_config: str) -> None:
'Convert an old config and jsonld database to the new format.'
@ -220,7 +220,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
) as inboxes:
for inbox in inboxes:
if inbox['software'] in {'akkoma', 'pleroma'}:
if inbox['software'] in ('akkoma', 'pleroma'):
actor = f'https://{inbox["domain"]}/relay'
elif inbox['software'] == 'mastodon':
@ -349,7 +349,9 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
if not actor.startswith('http'):
actor = f'https://{actor}/actor'
if not (actor_data := asyncio.run(http.get(actor, sign_headers = True))):
actor_data = asyncio.run(http.get(actor, sign_headers = True))
if not actor_data:
click.echo(f'Failed to fetch actor: {actor}')
return
@ -409,17 +411,14 @@ def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
@click.argument('inbox')
@click.option('--actor', '-a', help = 'Actor url for the inbox')
@click.option('--followid', '-f', help = 'Url for the follow activity')
@click.option('--software', '-s',
type = click.Choice(SOFTWARE),
help = 'Nodeinfo software name of the instance'
) # noqa: E124
@click.option('--software', '-s', type = click.Choice(SOFTWARE))
@click.pass_context
def cli_inbox_add(
ctx: click.Context,
inbox: str,
actor: str | None = None,
followid: str | None = None,
software: str | None = None) -> None:
actor: Optional[str] = None,
followid: Optional[str] = None,
software: Optional[str] = None) -> None:
'Add an inbox to the database'
if not inbox.startswith('http'):
@ -429,10 +428,6 @@ def cli_inbox_add(
else:
domain = urlparse(inbox).netloc
if not software:
if (nodeinfo := asyncio.run(http.fetch_nodeinfo(domain))):
software = nodeinfo.sw_name
if not actor and software:
try:
actor = ACTOR_FORMATS[software].format(domain = domain)
@ -597,7 +592,9 @@ def cli_software_ban(ctx: click.Context,
return
if fetch_nodeinfo:
if not (nodeinfo := asyncio.run(http.fetch_nodeinfo(name))):
nodeinfo = asyncio.run(http.fetch_nodeinfo(name))
if not nodeinfo:
click.echo(f'Failed to fetch software name from domain: {name}')
return
@ -637,7 +634,9 @@ def cli_software_unban(ctx: click.Context, name: str, fetch_nodeinfo: bool) -> N
return
if fetch_nodeinfo:
if not (nodeinfo := asyncio.run(http.fetch_nodeinfo(name))):
nodeinfo = asyncio.run(http.fetch_nodeinfo(name))
if not nodeinfo:
click.echo(f'Failed to fetch software name from domain: {name}')
return

View file

@ -14,8 +14,7 @@ from functools import cached_property
from uuid import uuid4
if typing.TYPE_CHECKING:
from collections.abc import Coroutine, Generator
from typing import Any
from typing import Any, Coroutine, Generator, Optional, Type
from .application import Application
from .config import Config
from .database import Database
@ -38,10 +37,10 @@ NODEINFO_NS = {
def boolean(value: Any) -> bool:
if isinstance(value, str):
if value.lower() in {'on', 'y', 'yes', 'true', 'enable', 'enabled', '1'}:
if value.lower() in ['on', 'y', 'yes', 'true', 'enable', 'enabled', '1']:
return True
if value.lower() in {'off', 'n', 'no', 'false', 'disable', 'disabled', '0'}:
if value.lower() in ['off', 'n', 'no', 'false', 'disable', 'disable', '0']:
return False
raise TypeError(f'Cannot parse string "{value}" as a boolean')
@ -84,10 +83,10 @@ def get_app() -> Application:
class Message(ApMessage):
@classmethod
def new_actor(cls: type[Message], # pylint: disable=arguments-differ
def new_actor(cls: Type[Message], # pylint: disable=arguments-differ
host: str,
pubkey: str,
description: str | None = None) -> Message:
description: Optional[str] = None) -> Message:
return cls({
'@context': 'https://www.w3.org/ns/activitystreams',
@ -112,7 +111,7 @@ class Message(ApMessage):
@classmethod
def new_announce(cls: type[Message], host: str, obj: str) -> Message:
def new_announce(cls: Type[Message], host: str, obj: str) -> Message:
return cls({
'@context': 'https://www.w3.org/ns/activitystreams',
'id': f'https://{host}/activities/{uuid4()}',
@ -124,7 +123,7 @@ class Message(ApMessage):
@classmethod
def new_follow(cls: type[Message], host: str, actor: str) -> Message:
def new_follow(cls: Type[Message], host: str, actor: str) -> Message:
return cls({
'@context': 'https://www.w3.org/ns/activitystreams',
'type': 'Follow',
@ -136,7 +135,7 @@ class Message(ApMessage):
@classmethod
def new_unfollow(cls: type[Message], host: str, actor: str, follow: str) -> Message:
def new_unfollow(cls: Type[Message], host: str, actor: str, follow: str) -> Message:
return cls({
'@context': 'https://www.w3.org/ns/activitystreams',
'id': f'https://{host}/activities/{uuid4()}',
@ -148,7 +147,7 @@ class Message(ApMessage):
@classmethod
def new_response(cls: type[Message],
def new_response(cls: Type[Message],
host: str,
actor: str,
followid: str,
@ -181,11 +180,11 @@ class Message(ApMessage):
class Response(AiohttpResponse):
@classmethod
def new(cls: type[Response],
body: str | bytes | dict = '',
status: int = 200,
headers: dict[str, str] | None = None,
ctype: str = 'text') -> Response:
def new(cls: Type[Response],
body: Optional[str | bytes | dict] = '',
status: Optional[int] = 200,
headers: Optional[dict[str, str]] = None,
ctype: Optional[str] = 'text') -> Response:
kwargs = {
'status': status,
@ -206,7 +205,7 @@ class Response(AiohttpResponse):
@classmethod
def new_error(cls: type[Response],
def new_error(cls: Type[Response],
status: int,
body: str | bytes | dict,
ctype: str = 'text') -> Response:
@ -229,10 +228,12 @@ class Response(AiohttpResponse):
class View(AbstractView):
def __await__(self) -> Generator[Response]:
if (self.request.method) not in METHODS:
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
method = self.request.method.upper()
if not (handler := self.handlers.get(self.request.method)):
if method not in METHODS:
raise HTTPMethodNotAllowed(method, self.allowed_methods)
if not (handler := self.handlers.get(method)):
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) from None
return handler(self.request, **self.request.match_info).__await__()

View file

@ -18,7 +18,7 @@ from .processors import run_processor
if typing.TYPE_CHECKING:
from aiohttp.web import Request
from aputils.signer import Signer
from collections.abc import Callable
from typing import Callable
VIEWS = []

View file

@ -29,7 +29,10 @@ install_requires = file: requirements.txt
python_requires = >=3.8
[options.extras_require]
dev = file: dev-requirements.txt
dev =
flake8 == 3.1.0
pyinstaller == 6.3.0
pylint == 3.0
[options.package_data]
relay =
@ -41,4 +44,7 @@ console_scripts =
[flake8]
select = F401
extend-ignore = ANN101,ANN204,E128,E251,E261,E266,E301,E303,W191
extend-exclude = docs, test*.py
max-line-length = 100
indent-size = 4