fix linter warnings

This commit is contained in:
Izalia Mae 2024-01-23 21:54:58 -05:00
parent 485d1cd23e
commit 7a9d346642
11 changed files with 86 additions and 90 deletions

View file

@ -13,7 +13,8 @@ from . import logger as logging
from .misc import Message, boolean from .misc import Message, boolean
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from typing import Any, Iterator, Optional from collections.abc import Iterator
from typing import Any
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@ -30,10 +31,10 @@ class RelayConfig(dict):
def __setitem__(self, key: str, value: Any) -> None: def __setitem__(self, key: str, value: Any) -> None:
if key in ['blocked_instances', 'blocked_software', 'whitelist']: if key in {'blocked_instances', 'blocked_software', 'whitelist'}:
assert isinstance(value, (list, set, tuple)) assert isinstance(value, (list, set, tuple))
elif key in ['port', 'workers', 'json_cache', 'timeout']: elif key in {'port', 'workers', 'json_cache', 'timeout'}:
if not isinstance(value, int): if not isinstance(value, int):
value = int(value) value = int(value)
@ -110,7 +111,7 @@ class RelayConfig(dict):
return return
for key, value in config.items(): for key, value in config.items():
if key in ['ap']: if key == 'ap':
for k, v in value.items(): for k, v in value.items():
if k not in self: if k not in self:
continue continue
@ -190,7 +191,7 @@ class RelayDatabase(dict):
json.dump(self, fd, indent=4) json.dump(self, fd, indent=4)
def get_inbox(self, domain: str, fail: Optional[bool] = False) -> dict[str, str] | None: def get_inbox(self, domain: str, fail: bool = False) -> dict[str, str] | None:
if domain.startswith('http'): if domain.startswith('http'):
domain = urlparse(domain).hostname domain = urlparse(domain).hostname
@ -205,14 +206,13 @@ class RelayDatabase(dict):
def add_inbox(self, def add_inbox(self,
inbox: str, inbox: str,
followid: Optional[str] = None, followid: str | None = None,
software: Optional[str] = None) -> dict[str, str]: software: str | None = None) -> dict[str, str]:
assert inbox.startswith('https'), 'Inbox must be a url' assert inbox.startswith('https'), 'Inbox must be a url'
domain = urlparse(inbox).hostname domain = urlparse(inbox).hostname
instance = self.get_inbox(domain)
if instance: if (instance := self.get_inbox(domain)):
if followid: if followid:
instance['followid'] = followid instance['followid'] = followid
@ -234,12 +234,10 @@ class RelayDatabase(dict):
def del_inbox(self, def del_inbox(self,
domain: str, domain: str,
followid: Optional[str] = None, followid: str = None,
fail: Optional[bool] = False) -> bool: fail: bool = False) -> bool:
data = self.get_inbox(domain, fail=False) if not (data := self.get_inbox(domain, fail=False)):
if not data:
if fail: if fail:
raise KeyError(domain) raise KeyError(domain)

View file

@ -10,7 +10,7 @@ from pathlib import Path
from .misc import IS_DOCKER from .misc import IS_DOCKER
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from typing import Any, Optional from typing import Any
DEFAULTS: dict[str, Any] = { DEFAULTS: dict[str, Any] = {
@ -32,7 +32,7 @@ if IS_DOCKER:
class Config: class Config:
def __init__(self, path: str, load: Optional[bool] = False): def __init__(self, path: str, load: bool = False):
self.path = Path(path).expanduser().resolve() self.path = Path(path).expanduser().resolve()
self.listen = None self.listen = None
@ -151,7 +151,7 @@ class Config:
if key not in DEFAULTS: if key not in DEFAULTS:
raise KeyError(key) raise KeyError(key)
if key in ('port', 'pg_port', 'workers') and not isinstance(value, int): if key in {'port', 'pg_port', 'workers'} and not isinstance(value, int):
value = int(value) value = int(value)
setattr(self, key, value) setattr(self, key, value)

View file

@ -12,11 +12,10 @@ from .schema import VERSIONS, migrate_0
from .. import logger as logging from .. import logger as logging
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from typing import Optional
from .config import Config from .config import Config
def get_database(config: Config, migrate: Optional[bool] = True) -> tinysql.Database: def get_database(config: Config, migrate: bool = True) -> tinysql.Database:
if config.db_type == "sqlite": if config.db_type == "sqlite":
db = tinysql.Database.sqlite(config.sqlite_path, connection_class = Connection) db = tinysql.Database.sqlite(config.sqlite_path, connection_class = Connection)
@ -41,9 +40,7 @@ def get_database(config: Config, migrate: Optional[bool] = True) -> tinysql.Data
migrate_0(conn) migrate_0(conn)
return db return db
schema_ver = conn.get_config('schema-version') if (schema_ver := conn.get_config('schema-version')) < get_default_value('schema-version'):
if schema_ver < get_default_value('schema-version'):
logging.info("Migrating database from version '%i'", schema_ver) logging.info("Migrating database from version '%i'", schema_ver)
for ver, func in VERSIONS: for ver, func in VERSIONS:

View file

@ -6,7 +6,8 @@ from .. import logger as logging
from ..misc import boolean from ..misc import boolean
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from typing import Any, Callable from collections.abc import Callable
from typing import Any
CONFIG_DEFAULTS: dict[str, tuple[str, Any]] = { CONFIG_DEFAULTS: dict[str, tuple[str, Any]] = {

View file

@ -12,8 +12,9 @@ from .. import logger as logging
from ..misc import get_app from ..misc import get_app
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from collections.abc import Iterator
from tinysql import Cursor, Row from tinysql import Cursor, Row
from typing import Any, Iterator, Optional from typing import Any
from .application import Application from .application import Application
from ..misc import Message from ..misc import Message
@ -43,7 +44,7 @@ class Connection(tinysql.Connection):
yield inbox['inbox'] yield inbox['inbox']
def exec_statement(self, name: str, params: Optional[dict[str, Any]] = None) -> Cursor: def exec_statement(self, name: str, params: dict[str, Any] | None = None) -> Cursor:
return self.execute(self.database.prepared_statements[name], params) return self.execute(self.database.prepared_statements[name], params)
@ -110,9 +111,9 @@ class Connection(tinysql.Connection):
def put_inbox(self, def put_inbox(self,
domain: str, domain: str,
inbox: str, inbox: str,
actor: Optional[str] = None, actor: str | None = None,
followid: Optional[str] = None, followid: str | None = None,
software: Optional[str] = None) -> Row: software: str | None = None) -> Row:
params = { params = {
'domain': domain, 'domain': domain,
@ -129,9 +130,9 @@ class Connection(tinysql.Connection):
def update_inbox(self, def update_inbox(self,
inbox: str, inbox: str,
actor: Optional[str] = None, actor: str | None = None,
followid: Optional[str] = None, followid: str | None = None,
software: Optional[str] = None) -> Row: software: str | None = None) -> Row:
if not (actor or followid or software): if not (actor or followid or software):
raise ValueError('Missing "actor", "followid", and/or "software"') raise ValueError('Missing "actor", "followid", and/or "software"')
@ -171,8 +172,8 @@ class Connection(tinysql.Connection):
def put_domain_ban(self, def put_domain_ban(self,
domain: str, domain: str,
reason: Optional[str] = None, reason: str | None = None,
note: Optional[str] = None) -> Row: note: str | None = None) -> Row:
params = { params = {
'domain': domain, 'domain': domain,
@ -187,8 +188,8 @@ class Connection(tinysql.Connection):
def update_domain_ban(self, def update_domain_ban(self,
domain: str, domain: str,
reason: Optional[str] = None, reason: str | None = None,
note: Optional[str] = None) -> tinysql.Row: note: str | None = None) -> tinysql.Row:
if not (reason or note): if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified') raise ValueError('"reason" and/or "note" must be specified')
@ -225,8 +226,8 @@ class Connection(tinysql.Connection):
def put_software_ban(self, def put_software_ban(self,
name: str, name: str,
reason: Optional[str] = None, reason: str | None = None,
note: Optional[str] = None) -> Row: note: str | None = None) -> Row:
params = { params = {
'name': name, 'name': name,
@ -241,8 +242,8 @@ class Connection(tinysql.Connection):
def update_software_ban(self, def update_software_ban(self,
name: str, name: str,
reason: Optional[str] = None, reason: str | None = None,
note: Optional[str] = None) -> tinysql.Row: note: str | None = None) -> tinysql.Row:
if not (reason or note): if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified') raise ValueError('"reason" and/or "note" must be specified')

View file

@ -7,7 +7,7 @@ from tinysql import Column, Connection, Table
from .config import get_default_value from .config import get_default_value
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from typing import Callable from collections.abc import Callable
VERSIONS: list[Callable] = [] VERSIONS: list[Callable] = []

View file

@ -16,7 +16,7 @@ from . import logger as logging
from .misc import MIMETYPES, Message, get_app from .misc import MIMETYPES, Message, get_app
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from typing import Any, Callable, Optional from typing import Any
HEADERS = { HEADERS = {
@ -26,11 +26,7 @@ HEADERS = {
class HttpClient: class HttpClient:
def __init__(self, def __init__(self, limit: int = 100, timeout: int = 10, cache_size: int = 1024):
limit: Optional[int] = 100,
timeout: Optional[int] = 10,
cache_size: Optional[int] = 1024):
self.cache = LRUCache(cache_size) self.cache = LRUCache(cache_size)
self.limit = limit self.limit = limit
self.timeout = timeout self.timeout = timeout
@ -77,9 +73,9 @@ class HttpClient:
async def get(self, # pylint: disable=too-many-branches async def get(self, # pylint: disable=too-many-branches
url: str, url: str,
sign_headers: Optional[bool] = False, sign_headers: bool = False,
loads: Optional[Callable] = None, loads: callable | None = None,
force: Optional[bool] = False) -> Message | dict | None: force: bool = False) -> Message | dict | None:
await self.open() await self.open()
@ -151,11 +147,13 @@ class HttpClient:
instance = conn.get_inbox(url) instance = conn.get_inbox(url)
## Using the old algo by default is probably a better idea right now ## Using the old algo by default is probably a better idea right now
# pylint: disable=consider-ternary-expression
if instance and instance['software'] in {'mastodon'}: if instance and instance['software'] in {'mastodon'}:
algorithm = 'hs2019' algorithm = 'hs2019'
else: else:
algorithm = 'original' algorithm = 'original'
# pylint: enable=consider-ternary-expression
headers = {'Content-Type': 'application/activity+json'} headers = {'Content-Type': 'application/activity+json'}
headers.update(get_app().signer.sign_headers('POST', url, message, algorithm=algorithm)) headers.update(get_app().signer.sign_headers('POST', url, message, algorithm=algorithm))
@ -195,7 +193,7 @@ class HttpClient:
logging.verbose('Failed to fetch well-known nodeinfo url for %s', domain) logging.verbose('Failed to fetch well-known nodeinfo url for %s', domain)
return None return None
for version in ['20', '21']: for version in ('20', '21'):
try: try:
nodeinfo_url = wk_nodeinfo.get_url(version) nodeinfo_url = wk_nodeinfo.get_url(version)

View file

@ -8,7 +8,8 @@ from enum import IntEnum
from pathlib import Path from pathlib import Path
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from typing import Any, Callable, Type from collections.abc import Callable
from typing import Any
class LogLevel(IntEnum): class LogLevel(IntEnum):
@ -25,7 +26,7 @@ class LogLevel(IntEnum):
@classmethod @classmethod
def parse(cls: Type[IntEnum], data: object) -> IntEnum: def parse(cls: type[IntEnum], data: object) -> IntEnum:
if isinstance(data, cls): if isinstance(data, cls):
return data return data

View file

@ -22,7 +22,7 @@ from .misc import IS_DOCKER, Message, check_open_port
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from tinysql import Row from tinysql import Row
from typing import Any, Optional from typing import Any
# pylint: disable=unsubscriptable-object,unsupported-assignment-operation # pylint: disable=unsubscriptable-object,unsupported-assignment-operation
@ -69,7 +69,7 @@ def cli(ctx: click.Context, config: str) -> None:
@cli.command('setup') @cli.command('setup')
@click.pass_context @click.pass_context
def cli_setup(ctx: click.Context) -> None: def cli_setup(ctx: click.Context) -> None:
'Generate a new config' 'Generate a new config and create the database'
while True: while True:
ctx.obj.config.domain = click.prompt( ctx.obj.config.domain = click.prompt(
@ -184,7 +184,7 @@ def cli_run(ctx: click.Context) -> None:
@cli.command('convert') @cli.command('convert')
@click.option('--old-config', '-o', help = 'Path to the new config file') @click.option('--old-config', '-o', help = 'Path to the config file to convert from')
@click.pass_context @click.pass_context
def cli_convert(ctx: click.Context, old_config: str) -> None: def cli_convert(ctx: click.Context, old_config: str) -> None:
'Convert an old config and jsonld database to the new format.' 'Convert an old config and jsonld database to the new format.'
@ -220,7 +220,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
) as inboxes: ) as inboxes:
for inbox in inboxes: for inbox in inboxes:
if inbox['software'] in ('akkoma', 'pleroma'): if inbox['software'] in {'akkoma', 'pleroma'}:
actor = f'https://{inbox["domain"]}/relay' actor = f'https://{inbox["domain"]}/relay'
elif inbox['software'] == 'mastodon': elif inbox['software'] == 'mastodon':
@ -349,9 +349,7 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
if not actor.startswith('http'): if not actor.startswith('http'):
actor = f'https://{actor}/actor' actor = f'https://{actor}/actor'
actor_data = asyncio.run(http.get(actor, sign_headers = True)) if not (actor_data := asyncio.run(http.get(actor, sign_headers = True))):
if not actor_data:
click.echo(f'Failed to fetch actor: {actor}') click.echo(f'Failed to fetch actor: {actor}')
return return
@ -411,14 +409,17 @@ def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
@click.argument('inbox') @click.argument('inbox')
@click.option('--actor', '-a', help = 'Actor url for the inbox') @click.option('--actor', '-a', help = 'Actor url for the inbox')
@click.option('--followid', '-f', help = 'Url for the follow activity') @click.option('--followid', '-f', help = 'Url for the follow activity')
@click.option('--software', '-s', type = click.Choice(SOFTWARE)) @click.option('--software', '-s',
type = click.Choice(SOFTWARE),
help = 'Nodeinfo software name of the instance'
) # noqa: E124
@click.pass_context @click.pass_context
def cli_inbox_add( def cli_inbox_add(
ctx: click.Context, ctx: click.Context,
inbox: str, inbox: str,
actor: Optional[str] = None, actor: str | None = None,
followid: Optional[str] = None, followid: str | None = None,
software: Optional[str] = None) -> None: software: str | None = None) -> None:
'Add an inbox to the database' 'Add an inbox to the database'
if not inbox.startswith('http'): if not inbox.startswith('http'):
@ -428,6 +429,10 @@ def cli_inbox_add(
else: else:
domain = urlparse(inbox).netloc domain = urlparse(inbox).netloc
if not software:
if (nodeinfo := asyncio.run(http.fetch_nodeinfo(domain))):
software = nodeinfo.sw_name
if not actor and software: if not actor and software:
try: try:
actor = ACTOR_FORMATS[software].format(domain = domain) actor = ACTOR_FORMATS[software].format(domain = domain)
@ -592,9 +597,7 @@ def cli_software_ban(ctx: click.Context,
return return
if fetch_nodeinfo: if fetch_nodeinfo:
nodeinfo = asyncio.run(http.fetch_nodeinfo(name)) if not (nodeinfo := asyncio.run(http.fetch_nodeinfo(name))):
if not nodeinfo:
click.echo(f'Failed to fetch software name from domain: {name}') click.echo(f'Failed to fetch software name from domain: {name}')
return return
@ -634,9 +637,7 @@ def cli_software_unban(ctx: click.Context, name: str, fetch_nodeinfo: bool) -> N
return return
if fetch_nodeinfo: if fetch_nodeinfo:
nodeinfo = asyncio.run(http.fetch_nodeinfo(name)) if not (nodeinfo := asyncio.run(http.fetch_nodeinfo(name))):
if not nodeinfo:
click.echo(f'Failed to fetch software name from domain: {name}') click.echo(f'Failed to fetch software name from domain: {name}')
return return

View file

@ -14,7 +14,8 @@ from functools import cached_property
from uuid import uuid4 from uuid import uuid4
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from typing import Any, Coroutine, Generator, Optional, Type from collections.abc import Coroutine, Generator
from typing import Any
from .application import Application from .application import Application
from .config import Config from .config import Config
from .database import Database from .database import Database
@ -37,10 +38,10 @@ NODEINFO_NS = {
def boolean(value: Any) -> bool: def boolean(value: Any) -> bool:
if isinstance(value, str): if isinstance(value, str):
if value.lower() in ['on', 'y', 'yes', 'true', 'enable', 'enabled', '1']: if value.lower() in {'on', 'y', 'yes', 'true', 'enable', 'enabled', '1'}:
return True return True
if value.lower() in ['off', 'n', 'no', 'false', 'disable', 'disable', '0']: if value.lower() in {'off', 'n', 'no', 'false', 'disable', 'disabled', '0'}:
return False return False
raise TypeError(f'Cannot parse string "{value}" as a boolean') raise TypeError(f'Cannot parse string "{value}" as a boolean')
@ -83,10 +84,10 @@ def get_app() -> Application:
class Message(ApMessage): class Message(ApMessage):
@classmethod @classmethod
def new_actor(cls: Type[Message], # pylint: disable=arguments-differ def new_actor(cls: type[Message], # pylint: disable=arguments-differ
host: str, host: str,
pubkey: str, pubkey: str,
description: Optional[str] = None) -> Message: description: str | None = None) -> Message:
return cls({ return cls({
'@context': 'https://www.w3.org/ns/activitystreams', '@context': 'https://www.w3.org/ns/activitystreams',
@ -111,7 +112,7 @@ class Message(ApMessage):
@classmethod @classmethod
def new_announce(cls: Type[Message], host: str, obj: str) -> Message: def new_announce(cls: type[Message], host: str, obj: str) -> Message:
return cls({ return cls({
'@context': 'https://www.w3.org/ns/activitystreams', '@context': 'https://www.w3.org/ns/activitystreams',
'id': f'https://{host}/activities/{uuid4()}', 'id': f'https://{host}/activities/{uuid4()}',
@ -123,7 +124,7 @@ class Message(ApMessage):
@classmethod @classmethod
def new_follow(cls: Type[Message], host: str, actor: str) -> Message: def new_follow(cls: type[Message], host: str, actor: str) -> Message:
return cls({ return cls({
'@context': 'https://www.w3.org/ns/activitystreams', '@context': 'https://www.w3.org/ns/activitystreams',
'type': 'Follow', 'type': 'Follow',
@ -135,7 +136,7 @@ class Message(ApMessage):
@classmethod @classmethod
def new_unfollow(cls: Type[Message], host: str, actor: str, follow: str) -> Message: def new_unfollow(cls: type[Message], host: str, actor: str, follow: str) -> Message:
return cls({ return cls({
'@context': 'https://www.w3.org/ns/activitystreams', '@context': 'https://www.w3.org/ns/activitystreams',
'id': f'https://{host}/activities/{uuid4()}', 'id': f'https://{host}/activities/{uuid4()}',
@ -147,7 +148,7 @@ class Message(ApMessage):
@classmethod @classmethod
def new_response(cls: Type[Message], def new_response(cls: type[Message],
host: str, host: str,
actor: str, actor: str,
followid: str, followid: str,
@ -180,11 +181,11 @@ class Message(ApMessage):
class Response(AiohttpResponse): class Response(AiohttpResponse):
@classmethod @classmethod
def new(cls: Type[Response], def new(cls: type[Response],
body: Optional[str | bytes | dict] = '', body: str | bytes | dict = '',
status: Optional[int] = 200, status: int = 200,
headers: Optional[dict[str, str]] = None, headers: dict[str, str] | None = None,
ctype: Optional[str] = 'text') -> Response: ctype: str = 'text') -> Response:
kwargs = { kwargs = {
'status': status, 'status': status,
@ -205,7 +206,7 @@ class Response(AiohttpResponse):
@classmethod @classmethod
def new_error(cls: Type[Response], def new_error(cls: type[Response],
status: int, status: int,
body: str | bytes | dict, body: str | bytes | dict,
ctype: str = 'text') -> Response: ctype: str = 'text') -> Response:
@ -228,12 +229,10 @@ class Response(AiohttpResponse):
class View(AbstractView): class View(AbstractView):
def __await__(self) -> Generator[Response]: def __await__(self) -> Generator[Response]:
method = self.request.method.upper() if (self.request.method) not in METHODS:
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
if method not in METHODS: if not (handler := self.handlers.get(self.request.method)):
raise HTTPMethodNotAllowed(method, self.allowed_methods)
if not (handler := self.handlers.get(method)):
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) from None raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) from None
return handler(self.request, **self.request.match_info).__await__() return handler(self.request, **self.request.match_info).__await__()

View file

@ -18,7 +18,7 @@ from .processors import run_processor
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from aiohttp.web import Request from aiohttp.web import Request
from aputils.signer import Signer from aputils.signer import Signer
from typing import Callable from collections.abc import Callable
VIEWS = [] VIEWS = []