diff --git a/dev.py b/dev.py index a6669fb..f72e346 100755 --- a/dev.py +++ b/dev.py @@ -11,11 +11,11 @@ from datetime import datetime, timedelta from pathlib import Path from relay import __version__, logger as logging from tempfile import TemporaryDirectory -from typing import Sequence +from typing import Any, Sequence try: from watchdog.observers import Observer - from watchdog.events import PatternMatchingEventHandler + from watchdog.events import FileSystemEvent, PatternMatchingEventHandler except ImportError: class PatternMatchingEventHandler: # type: ignore @@ -70,7 +70,7 @@ def cli_lint(path: Path, watch: bool) -> None: @cli.command('clean') -def cli_clean(): +def cli_clean() -> None: dirs = { 'dist', 'build', @@ -88,7 +88,7 @@ def cli_clean(): @cli.command('build') -def cli_build(): +def cli_build() -> None: with TemporaryDirectory() as tmp: arch = 'amd64' if sys.maxsize >= 2**32 else 'i386' cmd = [ @@ -118,7 +118,7 @@ def cli_build(): @cli.command('run') @click.option('--dev', '-d', is_flag = True) -def cli_run(dev: bool): +def cli_run(dev: bool) -> None: print('Starting process watcher') cmd = [sys.executable, '-m', 'relay', 'run'] @@ -129,13 +129,13 @@ def cli_run(dev: bool): handle_run_watcher(cmd) -def handle_run_watcher(*commands: Sequence[str], wait: bool = False): +def handle_run_watcher(*commands: Sequence[str], wait: bool = False) -> None: handler = WatchHandler(*commands, wait = wait) handler.run_procs() watcher = Observer() - watcher.schedule(handler, str(REPO), recursive=True) - watcher.start() + watcher.schedule(handler, str(REPO), recursive=True) # type: ignore + watcher.start() # type: ignore try: while True: @@ -145,7 +145,7 @@ def handle_run_watcher(*commands: Sequence[str], wait: bool = False): pass handler.kill_procs() - watcher.stop() + watcher.stop() # type: ignore watcher.join() @@ -153,16 +153,16 @@ class WatchHandler(PatternMatchingEventHandler): patterns = ['*.py'] - def __init__(self, *commands: Sequence[str], wait: bool = False): - PatternMatchingEventHandler.__init__(self) + def __init__(self, *commands: Sequence[str], wait: bool = False) -> None: + PatternMatchingEventHandler.__init__(self) # type: ignore self.commands: Sequence[Sequence[str]] = commands self.wait: bool = wait - self.procs: list[subprocess.Popen] = [] + self.procs: list[subprocess.Popen[Any]] = [] self.last_restart: datetime = datetime.now() - def kill_procs(self): + def kill_procs(self) -> None: for proc in self.procs: if proc.poll() is not None: continue @@ -183,7 +183,7 @@ class WatchHandler(PatternMatchingEventHandler): logging.info('Process terminated') - def run_procs(self, restart: bool = False): + def run_procs(self, restart: bool = False) -> None: if restart: if datetime.now() - timedelta(seconds = 3) < self.last_restart: return @@ -205,7 +205,7 @@ class WatchHandler(PatternMatchingEventHandler): logging.info('Started processes with PIDs: %s', ', '.join(pids)) - def on_any_event(self, event): + def on_any_event(self, event: FileSystemEvent) -> None: if event.event_type not in ['modified', 'created', 'deleted']: return diff --git a/relay/application.py b/relay/application.py index cc87c22..25373b5 100644 --- a/relay/application.py +++ b/relay/application.py @@ -5,35 +5,31 @@ import multiprocessing import signal import time import traceback -import typing from aiohttp import web from aiohttp.web import StaticResource from aiohttp_swagger import setup_swagger from aputils.signer import Signer +from bsql import Database, Row +from collections.abc import Awaitable, Callable from datetime import datetime, timedelta from mimetypes import guess_type from pathlib import Path from queue import Empty from threading import Event, Thread +from typing import Any from . import logger as logging -from .cache import get_cache +from .cache import Cache, get_cache from .config import Config -from .database import get_database +from .database import Connection, get_database from .http_client import HttpClient -from .misc import IS_WINDOWS, check_open_port, get_resource +from .misc import IS_WINDOWS, Message, Response, check_open_port, get_resource from .template import Template from .views import VIEWS from .views.api import handle_api_path from .views.frontend import handle_frontend_path -if typing.TYPE_CHECKING: - from collections.abc import Callable - from bsql import Database, Row - from .cache import Cache - from .misc import Message, Response - def get_csp(request: web.Request) -> str: data = [ @@ -58,9 +54,9 @@ class Application(web.Application): def __init__(self, cfgpath: Path | None, dev: bool = False): web.Application.__init__(self, middlewares = [ - handle_api_path, - handle_frontend_path, - handle_response_headers + handle_api_path, # type: ignore[list-item] + handle_frontend_path, # type: ignore[list-item] + handle_response_headers # type: ignore[list-item] ] ) @@ -96,27 +92,27 @@ class Application(web.Application): @property def cache(self) -> Cache: - return self['cache'] + return self['cache'] # type: ignore[no-any-return] @property def client(self) -> HttpClient: - return self['client'] + return self['client'] # type: ignore[no-any-return] @property def config(self) -> Config: - return self['config'] + return self['config'] # type: ignore[no-any-return] @property - def database(self) -> Database: - return self['database'] + def database(self) -> Database[Connection]: + return self['database'] # type: ignore[no-any-return] @property def signer(self) -> Signer: - return self['signer'] + return self['signer'] # type: ignore[no-any-return] @signer.setter @@ -130,7 +126,7 @@ class Application(web.Application): @property def template(self) -> Template: - return self['template'] + return self['template'] # type: ignore[no-any-return] @property @@ -185,11 +181,11 @@ class Application(web.Application): pass - def stop(self, *_): + def stop(self, *_: Any) -> None: self['running'] = False - async def handle_run(self): + async def handle_run(self) -> None: self['running'] = True self.set_signal_handler(True) @@ -295,7 +291,7 @@ class CacheCleanupThread(Thread): class PushWorker(multiprocessing.Process): - def __init__(self, queue: multiprocessing.Queue): + def __init__(self, queue: multiprocessing.Queue[tuple[str, Message | bytes, Row]]) -> None: if Application.DEFAULT is None: raise RuntimeError('Application not setup yet') @@ -347,7 +343,10 @@ class PushWorker(multiprocessing.Process): @web.middleware -async def handle_response_headers(request: web.Request, handler: Callable) -> Response: +async def handle_response_headers( + request: web.Request, + handler: Callable[[web.Request], Awaitable[Response]]) -> Response: + resp = await handler(request) resp.headers['Server'] = 'ActivityRelay' diff --git a/relay/cache.py b/relay/cache.py index 0f273db..e9f261b 100644 --- a/relay/cache.py +++ b/relay/cache.py @@ -2,28 +2,27 @@ from __future__ import annotations import json import os -import typing from abc import ABC, abstractmethod +from bsql import Database +from collections.abc import Callable, Iterator from dataclasses import asdict, dataclass from datetime import datetime, timedelta, timezone from redis import Redis +from typing import TYPE_CHECKING, Any -from .database import get_database +from .database import Connection, get_database from .misc import Message, boolean -if typing.TYPE_CHECKING: - from bsql import Database - from collections.abc import Callable, Iterator - from typing import Any +if TYPE_CHECKING: from .application import Application -# todo: implement more caching backends - +SerializerCallback = Callable[[Any], str] +DeserializerCallback = Callable[[str], Any] BACKENDS: dict[str, type[Cache]] = {} -CONVERTERS: dict[str, tuple[Callable, Callable]] = { +CONVERTERS: dict[str, tuple[SerializerCallback, DeserializerCallback]] = { 'str': (str, str), 'int': (str, int), 'bool': (str, boolean), @@ -61,13 +60,13 @@ class Item: updated: datetime - def __post_init__(self): - if isinstance(self.updated, str): - self.updated = datetime.fromisoformat(self.updated) + def __post_init__(self) -> None: + if isinstance(self.updated, str): # type: ignore[unreachable] + self.updated = datetime.fromisoformat(self.updated) # type: ignore[unreachable] @classmethod - def from_data(cls: type[Item], *args) -> Item: + def from_data(cls: type[Item], *args: Any) -> Item: data = cls(*args) data.value = deserialize_value(data.value, data.value_type) @@ -159,7 +158,7 @@ class SqlCache(Cache): def __init__(self, app: Application): Cache.__init__(self, app) - self._db: Database | None = None + self._db: Database[Connection] | None = None def get(self, namespace: str, key: str) -> Item: @@ -211,10 +210,10 @@ class SqlCache(Cache): } with self._db.session(True) as conn: - with conn.run('set-cache-item', params) as conn: - row = conn.one() - row.pop('id', None) - return Item.from_data(*tuple(row.values())) + with conn.run('set-cache-item', params) as cur: + row = cur.one() + row.pop('id', None) # type: ignore[union-attr] + return Item.from_data(*tuple(row.values())) # type: ignore[union-attr] def delete(self, namespace: str, key: str) -> None: @@ -381,5 +380,5 @@ class RedisCache(Cache): if not self._rd: return - self._rd.close() + self._rd.close() # type: ignore self._rd = None # type: ignore diff --git a/relay/compat.py b/relay/compat.py index 9884b25..54b6573 100644 --- a/relay/compat.py +++ b/relay/compat.py @@ -1,21 +1,16 @@ -from __future__ import annotations - import json import os -import typing import yaml from functools import cached_property from pathlib import Path +from typing import Any from urllib.parse import urlparse from .misc import boolean -if typing.TYPE_CHECKING: - from typing import Any - -class RelayConfig(dict): +class RelayConfig(dict[str, Any]): def __init__(self, path: str): dict.__init__(self, {}) @@ -122,7 +117,7 @@ class RelayConfig(dict): self[key] = value -class RelayDatabase(dict): +class RelayDatabase(dict[str, Any]): def __init__(self, config: RelayConfig): dict.__init__(self, { 'relay-list': {}, diff --git a/relay/config.py b/relay/config.py index dbfc0b4..ac2bbb6 100644 --- a/relay/config.py +++ b/relay/config.py @@ -1,25 +1,20 @@ -from __future__ import annotations - import getpass import os import platform -import typing import yaml from dataclasses import asdict, dataclass, fields from pathlib import Path from platformdirs import user_config_dir +from typing import Any from .misc import IS_DOCKER -if typing.TYPE_CHECKING: - from typing import Any +try: + from typing import Self - try: - from typing import Self - - except ImportError: - from typing_extensions import Self +except ImportError: + from typing_extensions import Self if platform.system() == 'Windows': diff --git a/relay/database/__init__.py b/relay/database/__init__.py index 08dbec6..becd456 100644 --- a/relay/database/__init__.py +++ b/relay/database/__init__.py @@ -1,20 +1,15 @@ -from __future__ import annotations - -import bsql -import typing +from bsql import Database from .config import THEMES, ConfigData from .connection import RELAY_SOFTWARE, Connection from .schema import TABLES, VERSIONS, migrate_0 from .. import logger as logging +from ..config import Config from ..misc import get_resource -if typing.TYPE_CHECKING: - from ..config import Config - -def get_database(config: Config, migrate: bool = True) -> bsql.Database: +def get_database(config: Config, migrate: bool = True) -> Database[Connection]: options = { 'connection_class': Connection, 'pool_size': 5, @@ -22,10 +17,10 @@ def get_database(config: Config, migrate: bool = True) -> bsql.Database: } if config.db_type == 'sqlite': - db = bsql.Database.sqlite(config.sqlite_path, **options) + db = Database.sqlite(config.sqlite_path, **options) elif config.db_type == 'postgres': - db = bsql.Database.postgresql( + db = Database.postgresql( config.pg_name, config.pg_host, config.pg_port, diff --git a/relay/database/config.py b/relay/database/config.py index 3922f62..6effbb9 100644 --- a/relay/database/config.py +++ b/relay/database/config.py @@ -1,22 +1,20 @@ from __future__ import annotations +# removing the above line turns annotations into types instead of str objects which messes with +# `Field.type` -import typing - +from bsql import Row +from collections.abc import Callable, Sequence from dataclasses import Field, asdict, dataclass, fields +from typing import Any from .. import logger as logging from ..misc import boolean -if typing.TYPE_CHECKING: - from bsql import Row - from collections.abc import Callable, Sequence - from typing import Any +try: + from typing import Self - try: - from typing import Self - - except ImportError: - from typing_extensions import Self +except ImportError: + from typing_extensions import Self THEMES = { @@ -120,7 +118,7 @@ class ConfigData: @classmethod - def FIELD(cls: type[Self], key: str) -> Field: + def FIELD(cls: type[Self], key: str) -> Field[Any]: for field in fields(cls): if field.name == key.replace('-', '_'): return field diff --git a/relay/database/connection.py b/relay/database/connection.py index f8de1c0..614f307 100644 --- a/relay/database/connection.py +++ b/relay/database/connection.py @@ -1,10 +1,10 @@ from __future__ import annotations -import typing - from argon2 import PasswordHasher -from bsql import Connection as SqlConnection, Update +from bsql import Connection as SqlConnection, Row, Update +from collections.abc import Iterator, Sequence from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any from urllib.parse import urlparse from uuid import uuid4 @@ -14,15 +14,10 @@ from .config import ( ) from .. import logger as logging -from ..misc import boolean, get_app +from ..misc import Message, boolean, get_app -if typing.TYPE_CHECKING: - from collections.abc import Iterator, Sequence - from bsql import Row - from typing import Any +if TYPE_CHECKING: from ..application import Application - from ..misc import Message - RELAY_SOFTWARE = [ 'activityrelay', # https://git.pleroma.social/pleroma/relay @@ -94,7 +89,7 @@ class Connection(SqlConnection): params = { 'key': key, 'value': data.get(key, serialize = True), - 'type': 'LogLevel' if field.type == 'logging.LogLevel' else field.type + 'type': 'LogLevel' if field.type == 'logging.LogLevel' else field.type # type: ignore } with self.run('put-config', params): diff --git a/relay/database/schema.py b/relay/database/schema.py index ba39ed2..409ee57 100644 --- a/relay/database/schema.py +++ b/relay/database/schema.py @@ -1,17 +1,11 @@ -from __future__ import annotations - -import typing - from bsql import Column, Table, Tables +from collections.abc import Callable from .config import ConfigData - -if typing.TYPE_CHECKING: - from collections.abc import Callable - from .connection import Connection +from .connection import Connection -VERSIONS: dict[int, Callable] = {} +VERSIONS: dict[int, Callable[[Connection], None]] = {} TABLES: Tables = Tables( Table( 'config', @@ -64,7 +58,7 @@ TABLES: Tables = Tables( ) -def migration(func: Callable) -> Callable: +def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]: ver = int(func.__name__.replace('migrate_', '')) VERSIONS[ver] = func return func diff --git a/relay/http_client.py b/relay/http_client.py index 27c2dec..54cea3c 100644 --- a/relay/http_client.py +++ b/relay/http_client.py @@ -2,26 +2,23 @@ from __future__ import annotations import json import traceback -import typing from aiohttp import ClientSession, ClientTimeout, TCPConnector from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError +from aputils import AlgorithmType, Nodeinfo, ObjectType, Signer, WellKnownNodeinfo from asyncio.exceptions import TimeoutError as AsyncTimeoutError -from aputils import AlgorithmType, Nodeinfo, ObjectType, WellKnownNodeinfo from blib import JsonBase +from bsql import Row from json.decoder import JSONDecodeError +from typing import TYPE_CHECKING, Any, TypeVar from urllib.parse import urlparse -from . import __version__ -from . import logger as logging +from . import __version__, logger as logging +from .cache import Cache from .misc import MIMETYPES, Message, get_app -if typing.TYPE_CHECKING: - from aputils import Signer - from bsql import Row - from typing import Any +if TYPE_CHECKING: from .application import Application - from .cache import Cache SUPPORTS_HS2019 = { @@ -39,7 +36,7 @@ SUPPORTS_HS2019 = { 'sharkey' } -T = typing.TypeVar('T', bound = JsonBase) +T = TypeVar('T', bound = JsonBase) HEADERS = { 'Accept': f'{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9', 'User-Agent': f'ActivityRelay/{__version__}' @@ -124,7 +121,7 @@ class HttpClient: if not force: try: if not (item := self.cache.get('request', url)).older_than(48): - return json.loads(item.value) + return json.loads(item.value) # type: ignore[no-any-return] except KeyError: logging.verbose('No cached data for url: %s', url) @@ -153,7 +150,7 @@ class HttpClient: self.cache.set('request', url, data, 'str') logging.debug('%s >> resp %s', url, json.dumps(json.loads(data), indent = 4)) - return json.loads(data) + return json.loads(data) # type: ignore [no-any-return] except JSONDecodeError: logging.verbose('Failed to parse JSON') diff --git a/relay/logger.py b/relay/logger.py index 916fa71..f1a1bd7 100644 --- a/relay/logger.py +++ b/relay/logger.py @@ -1,21 +1,19 @@ -from __future__ import annotations - import logging import os -import typing from enum import IntEnum from pathlib import Path +from typing import Any, Protocol -if typing.TYPE_CHECKING: - from collections.abc import Callable - from typing import Any +try: + from typing import Self - try: - from typing import Self +except ImportError: + from typing_extensions import Self - except ImportError: - from typing_extensions import Self + +class LoggingMethod(Protocol): + def __call__(self, msg: Any, *args: Any, **kwargs: Any) -> None: ... class LogLevel(IntEnum): @@ -75,11 +73,11 @@ def verbose(message: str, *args: Any, **kwargs: Any) -> None: logging.log(LogLevel.VERBOSE, message, *args, **kwargs) -debug: Callable = logging.debug -info: Callable = logging.info -warning: Callable = logging.warning -error: Callable = logging.error -critical: Callable = logging.critical +debug: LoggingMethod = logging.debug +info: LoggingMethod = logging.info +warning: LoggingMethod = logging.warning +error: LoggingMethod = logging.error +critical: LoggingMethod = logging.critical try: diff --git a/relay/manage.py b/relay/manage.py index c48e0f6..cb2b099 100644 --- a/relay/manage.py +++ b/relay/manage.py @@ -5,10 +5,11 @@ import asyncio import click import json import os -import typing +from bsql import Row from pathlib import Path from shutil import copyfile +from typing import Any from urllib.parse import urlparse from . import __version__ @@ -19,10 +20,6 @@ from .compat import RelayConfig, RelayDatabase from .database import RELAY_SOFTWARE, get_database from .misc import ACTOR_FORMATS, SOFTWARE, IS_DOCKER, Message -if typing.TYPE_CHECKING: - from bsql import Row - from typing import Any - def check_alphanumeric(text: str) -> str: if not text.isalnum(): diff --git a/relay/misc.py b/relay/misc.py index feee97c..9e8f035 100644 --- a/relay/misc.py +++ b/relay/misc.py @@ -5,11 +5,12 @@ import json import os import platform import socket -import typing from aiohttp.web import Response as AiohttpResponse +from collections.abc import Sequence from datetime import datetime from pathlib import Path +from typing import TYPE_CHECKING, Any, TypedDict, TypeVar from uuid import uuid4 try: @@ -18,21 +19,20 @@ try: except ImportError: from importlib_resources import files as pkgfiles # type: ignore -if typing.TYPE_CHECKING: - from typing import Any +try: + from typing import Self + +except ImportError: + from typing_extensions import Self + +if TYPE_CHECKING: from .application import Application - try: - from typing import Self - except ImportError: - from typing_extensions import Self - - -T = typing.TypeVar('T') -ResponseType = typing.TypedDict('ResponseType', { +T = TypeVar('T') +ResponseType = TypedDict('ResponseType', { 'status': int, - 'headers': dict[str, typing.Any] | None, + 'headers': dict[str, Any] | None, 'content_type': str, 'body': bytes | None, 'text': str | None @@ -128,7 +128,7 @@ class JsonEncoder(json.JSONEncoder): if isinstance(o, datetime): return o.isoformat() - return json.JSONEncoder.default(self, o) + return json.JSONEncoder.default(self, o) # type: ignore[no-any-return] class Message(aputils.Message): @@ -214,7 +214,7 @@ class Response(AiohttpResponse): @classmethod def new(cls: type[Self], - body: str | bytes | dict | tuple | list | set = '', + body: str | bytes | dict[str, Any] | Sequence[Any] = '', status: int = 200, headers: dict[str, str] | None = None, ctype: str = 'text') -> Self: @@ -227,22 +227,22 @@ class Response(AiohttpResponse): 'text': None } - if isinstance(body, bytes): + if isinstance(body, str): + kwargs['text'] = body + + elif isinstance(body, bytes): kwargs['body'] = body - elif isinstance(body, (dict, list, tuple, set)) or ctype in {'json', 'activity'}: + elif isinstance(body, (dict, Sequence)): kwargs['text'] = json.dumps(body, cls = JsonEncoder) - else: - kwargs['text'] = body - return cls(**kwargs) @classmethod def new_error(cls: type[Self], status: int, - body: str | bytes | dict, + body: str | bytes | dict[str, Any], ctype: str = 'text') -> Self: if ctype == 'json': diff --git a/relay/template.py b/relay/template.py index 1335fab..ef25f92 100644 --- a/relay/template.py +++ b/relay/template.py @@ -1,25 +1,22 @@ from __future__ import annotations import textwrap -import typing from collections.abc import Callable from hamlish_jinja import HamlishExtension from jinja2 import Environment, FileSystemLoader from jinja2.ext import Extension -from jinja2.nodes import CallBlock +from jinja2.nodes import CallBlock, Node +from jinja2.parser import Parser from markdown import Markdown - +from typing import TYPE_CHECKING, Any from . import __version__ from .misc import get_resource +from .views.base import View -if typing.TYPE_CHECKING: - from jinja2.nodes import Node - from jinja2.parser import Parser - from typing import Any +if TYPE_CHECKING: from .application import Application - from .views.base import View class Template(Environment): diff --git a/relay/views/api.py b/relay/views/api.py index d744761..86382f4 100644 --- a/relay/views/api.py +++ b/relay/views/api.py @@ -1,9 +1,7 @@ -from __future__ import annotations - -import typing - -from aiohttp import web +from aiohttp.web import Request, middleware from argon2.exceptions import VerifyMismatchError +from collections.abc import Awaitable, Callable, Sequence +from typing import Any from urllib.parse import urlparse from .base import View, register_route @@ -12,11 +10,6 @@ from .. import __version__ from ..database import ConfigData from ..misc import Message, Response, boolean, get_app -if typing.TYPE_CHECKING: - from aiohttp.web import Request - from collections.abc import Callable, Sequence - from typing import Any - ALLOWED_HEADERS = { 'accept', @@ -37,8 +30,10 @@ def check_api_path(method: str, path: str) -> bool: return path.startswith('/api') -@web.middleware -async def handle_api_path(request: Request, handler: Callable) -> Response: +@middleware +async def handle_api_path( + request: Request, + handler: Callable[[Request], Awaitable[Response]]) -> Response: try: if (token := request.cookies.get('user-token')): request['token'] = token @@ -222,7 +217,7 @@ class Inbox(View): if nodeinfo is not None: data['software'] = nodeinfo.sw_name - row = conn.put_inbox(**data) + row = conn.put_inbox(**data) # type: ignore[arg-type] return Response.new(row, ctype = 'json') @@ -237,7 +232,7 @@ class Inbox(View): if not (instance := conn.get_inbox(data['domain'])): return Response.new_error(404, 'Instance with domain not found', 'json') - instance = conn.put_inbox(instance['domain'], **data) + instance = conn.put_inbox(instance['domain'], **data) # type: ignore[arg-type] return Response.new(instance, ctype = 'json') diff --git a/relay/views/base.py b/relay/views/base.py index 93b3e3b..350016c 100644 --- a/relay/views/base.py +++ b/relay/views/base.py @@ -1,33 +1,33 @@ from __future__ import annotations -import typing - from Crypto.Random import get_random_bytes from aiohttp.abc import AbstractView from aiohttp.hdrs import METH_ALL as METHODS -from aiohttp.web import HTTPMethodNotAllowed +from aiohttp.web import HTTPMethodNotAllowed, Request from base64 import b64encode +from bsql import Database +from collections.abc import Awaitable, Callable, Generator, Sequence, Mapping from functools import cached_property from json.decoder import JSONDecodeError +from typing import TYPE_CHECKING, Any +from ..cache import Cache +from ..config import Config +from ..database import Connection +from ..http_client import HttpClient from ..misc import Response, get_app -if typing.TYPE_CHECKING: - from aiohttp.web import Request - from collections.abc import Callable, Generator, Sequence, Mapping - from bsql import Database - from typing import Any +if TYPE_CHECKING: from ..application import Application - from ..cache import Cache - from ..config import Config - from ..http_client import HttpClient from ..template import Template - try: - from typing import Self +try: + from typing import Self - except ImportError: - from typing_extensions import Self +except ImportError: + from typing_extensions import Self + +HandlerCallback = Callable[[Request], Awaitable[Response]] VIEWS: list[tuple[str, type[View]]] = [] @@ -37,7 +37,7 @@ def convert_data(data: Mapping[str, Any]) -> dict[str, str]: return {key: str(value) for key, value in data.items()} -def register_route(*paths: str) -> Callable: +def register_route(*paths: str) -> Callable[[type[View]], type[View]]: def wrapper(view: type[View]) -> type[View]: for path in paths: VIEWS.append((path, view)) @@ -63,7 +63,7 @@ class View(AbstractView): return await view.handlers[method](request, **kwargs) - async def _run_handler(self, handler: Callable[..., Any], **kwargs: Any) -> Response: + async def _run_handler(self, handler: HandlerCallback, **kwargs: Any) -> Response: self.request['hash'] = b64encode(get_random_bytes(16)).decode('ascii') return await handler(self.request, **self.request.match_info, **kwargs) @@ -78,7 +78,7 @@ class View(AbstractView): @cached_property - def handlers(self) -> dict[str, Callable[..., Any]]: + def handlers(self) -> dict[str, HandlerCallback]: data = {} for method in METHODS: @@ -112,13 +112,13 @@ class View(AbstractView): @property - def database(self) -> Database: + def database(self) -> Database[Connection]: return self.app.database @property def template(self) -> Template: - return self.app['template'] + return self.app['template'] # type: ignore[no-any-return] async def get_api_data(self, diff --git a/relay/views/frontend.py b/relay/views/frontend.py index 2b5bec0..5dfb43a 100644 --- a/relay/views/frontend.py +++ b/relay/views/frontend.py @@ -1,8 +1,6 @@ -from __future__ import annotations - -import typing - from aiohttp import web +from collections.abc import Awaitable, Callable +from typing import Any from .base import View, register_route @@ -10,11 +8,6 @@ from ..database import THEMES from ..logger import LogLevel from ..misc import Response, get_app -if typing.TYPE_CHECKING: - from aiohttp.web import Request - from collections.abc import Callable - from typing import Any - UNAUTH_ROUTES = { '/', @@ -23,7 +16,10 @@ UNAUTH_ROUTES = { @web.middleware -async def handle_frontend_path(request: web.Request, handler: Callable) -> Response: +async def handle_frontend_path( + request: web.Request, + handler: Callable[[web.Request], Awaitable[Response]]) -> Response: + app = get_app() if request.path in UNAUTH_ROUTES or request.path.startswith('/admin'): @@ -52,7 +48,7 @@ async def handle_frontend_path(request: web.Request, handler: Callable) -> Respo @register_route('/') class HomeView(View): - async def get(self, request: Request) -> Response: + async def get(self, request: web.Request) -> Response: with self.database.session() as conn: context: dict[str, Any] = { 'instances': tuple(conn.get_inboxes()) @@ -64,14 +60,14 @@ class HomeView(View): @register_route('/login') class Login(View): - async def get(self, request: Request) -> Response: + async def get(self, request: web.Request) -> Response: data = self.template.render('page/login.haml', self) return Response.new(data, ctype = 'html') @register_route('/logout') class Logout(View): - async def get(self, request: Request) -> Response: + async def get(self, request: web.Request) -> Response: with self.database.session(True) as conn: conn.del_token(request['token']) @@ -82,14 +78,14 @@ class Logout(View): @register_route('/admin') class Admin(View): - async def get(self, request: Request) -> Response: + async def get(self, request: web.Request) -> Response: return Response.new('', 302, {'Location': '/admin/instances'}) @register_route('/admin/instances') class AdminInstances(View): async def get(self, - request: Request, + request: web.Request, error: str | None = None, message: str | None = None) -> Response: @@ -112,7 +108,7 @@ class AdminInstances(View): @register_route('/admin/whitelist') class AdminWhitelist(View): async def get(self, - request: Request, + request: web.Request, error: str | None = None, message: str | None = None) -> Response: @@ -134,7 +130,7 @@ class AdminWhitelist(View): @register_route('/admin/domain_bans') class AdminDomainBans(View): async def get(self, - request: Request, + request: web.Request, error: str | None = None, message: str | None = None) -> Response: @@ -156,7 +152,7 @@ class AdminDomainBans(View): @register_route('/admin/software_bans') class AdminSoftwareBans(View): async def get(self, - request: Request, + request: web.Request, error: str | None = None, message: str | None = None) -> Response: @@ -178,7 +174,7 @@ class AdminSoftwareBans(View): @register_route('/admin/users') class AdminUsers(View): async def get(self, - request: Request, + request: web.Request, error: str | None = None, message: str | None = None) -> Response: @@ -199,7 +195,7 @@ class AdminUsers(View): @register_route('/admin/config') class AdminConfig(View): - async def get(self, request: Request, message: str | None = None) -> Response: + async def get(self, request: web.Request, message: str | None = None) -> Response: context: dict[str, Any] = { 'themes': tuple(THEMES.keys()), 'levels': tuple(level.name for level in LogLevel), @@ -212,7 +208,7 @@ class AdminConfig(View): @register_route('/manifest.json') class ManifestJson(View): - async def get(self, request: Request) -> Response: + async def get(self, request: web.Request) -> Response: with self.database.session(False) as conn: config = conn.get_config_all() theme = THEMES[config.theme] @@ -235,7 +231,7 @@ class ManifestJson(View): @register_route('/theme/{theme}.css') class ThemeCss(View): - async def get(self, request: Request, theme: str) -> Response: + async def get(self, request: web.Request, theme: str) -> Response: try: context: dict[str, Any] = { 'theme': THEMES[theme] diff --git a/relay/views/misc.py b/relay/views/misc.py index f10a877..5e2be52 100644 --- a/relay/views/misc.py +++ b/relay/views/misc.py @@ -1,9 +1,7 @@ -from __future__ import annotations - import aputils import subprocess -import typing +from aiohttp.web import Request from pathlib import Path from .base import View, register_route @@ -11,9 +9,6 @@ from .base import View, register_route from .. import __version__ from ..misc import Response -if typing.TYPE_CHECKING: - from aiohttp.web import Request - VERSION = __version__