fix linter issues

This commit is contained in:
Izalia Mae 2024-06-14 15:05:55 -04:00
parent a2b96d03dc
commit 15882f3e49
18 changed files with 186 additions and 241 deletions

30
dev.py
View file

@ -11,11 +11,11 @@ from datetime import datetime, timedelta
from pathlib import Path
from relay import __version__, logger as logging
from tempfile import TemporaryDirectory
from typing import Sequence
from typing import Any, Sequence
try:
from watchdog.observers import Observer
from watchdog.events import PatternMatchingEventHandler
from watchdog.events import FileSystemEvent, PatternMatchingEventHandler
except ImportError:
class PatternMatchingEventHandler: # type: ignore
@ -70,7 +70,7 @@ def cli_lint(path: Path, watch: bool) -> None:
@cli.command('clean')
def cli_clean():
def cli_clean() -> None:
dirs = {
'dist',
'build',
@ -88,7 +88,7 @@ def cli_clean():
@cli.command('build')
def cli_build():
def cli_build() -> None:
with TemporaryDirectory() as tmp:
arch = 'amd64' if sys.maxsize >= 2**32 else 'i386'
cmd = [
@ -118,7 +118,7 @@ def cli_build():
@cli.command('run')
@click.option('--dev', '-d', is_flag = True)
def cli_run(dev: bool):
def cli_run(dev: bool) -> None:
print('Starting process watcher')
cmd = [sys.executable, '-m', 'relay', 'run']
@ -129,13 +129,13 @@ def cli_run(dev: bool):
handle_run_watcher(cmd)
def handle_run_watcher(*commands: Sequence[str], wait: bool = False):
def handle_run_watcher(*commands: Sequence[str], wait: bool = False) -> None:
handler = WatchHandler(*commands, wait = wait)
handler.run_procs()
watcher = Observer()
watcher.schedule(handler, str(REPO), recursive=True)
watcher.start()
watcher.schedule(handler, str(REPO), recursive=True) # type: ignore
watcher.start() # type: ignore
try:
while True:
@ -145,7 +145,7 @@ def handle_run_watcher(*commands: Sequence[str], wait: bool = False):
pass
handler.kill_procs()
watcher.stop()
watcher.stop() # type: ignore
watcher.join()
@ -153,16 +153,16 @@ class WatchHandler(PatternMatchingEventHandler):
patterns = ['*.py']
def __init__(self, *commands: Sequence[str], wait: bool = False):
PatternMatchingEventHandler.__init__(self)
def __init__(self, *commands: Sequence[str], wait: bool = False) -> None:
PatternMatchingEventHandler.__init__(self) # type: ignore
self.commands: Sequence[Sequence[str]] = commands
self.wait: bool = wait
self.procs: list[subprocess.Popen] = []
self.procs: list[subprocess.Popen[Any]] = []
self.last_restart: datetime = datetime.now()
def kill_procs(self):
def kill_procs(self) -> None:
for proc in self.procs:
if proc.poll() is not None:
continue
@ -183,7 +183,7 @@ class WatchHandler(PatternMatchingEventHandler):
logging.info('Process terminated')
def run_procs(self, restart: bool = False):
def run_procs(self, restart: bool = False) -> None:
if restart:
if datetime.now() - timedelta(seconds = 3) < self.last_restart:
return
@ -205,7 +205,7 @@ class WatchHandler(PatternMatchingEventHandler):
logging.info('Started processes with PIDs: %s', ', '.join(pids))
def on_any_event(self, event):
def on_any_event(self, event: FileSystemEvent) -> None:
if event.event_type not in ['modified', 'created', 'deleted']:
return

View file

@ -5,35 +5,31 @@ import multiprocessing
import signal
import time
import traceback
import typing
from aiohttp import web
from aiohttp.web import StaticResource
from aiohttp_swagger import setup_swagger
from aputils.signer import Signer
from bsql import Database, Row
from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta
from mimetypes import guess_type
from pathlib import Path
from queue import Empty
from threading import Event, Thread
from typing import Any
from . import logger as logging
from .cache import get_cache
from .cache import Cache, get_cache
from .config import Config
from .database import get_database
from .database import Connection, get_database
from .http_client import HttpClient
from .misc import IS_WINDOWS, check_open_port, get_resource
from .misc import IS_WINDOWS, Message, Response, check_open_port, get_resource
from .template import Template
from .views import VIEWS
from .views.api import handle_api_path
from .views.frontend import handle_frontend_path
if typing.TYPE_CHECKING:
from collections.abc import Callable
from bsql import Database, Row
from .cache import Cache
from .misc import Message, Response
def get_csp(request: web.Request) -> str:
data = [
@ -58,9 +54,9 @@ class Application(web.Application):
def __init__(self, cfgpath: Path | None, dev: bool = False):
web.Application.__init__(self,
middlewares = [
handle_api_path,
handle_frontend_path,
handle_response_headers
handle_api_path, # type: ignore[list-item]
handle_frontend_path, # type: ignore[list-item]
handle_response_headers # type: ignore[list-item]
]
)
@ -96,27 +92,27 @@ class Application(web.Application):
@property
def cache(self) -> Cache:
return self['cache']
return self['cache'] # type: ignore[no-any-return]
@property
def client(self) -> HttpClient:
return self['client']
return self['client'] # type: ignore[no-any-return]
@property
def config(self) -> Config:
return self['config']
return self['config'] # type: ignore[no-any-return]
@property
def database(self) -> Database:
return self['database']
def database(self) -> Database[Connection]:
return self['database'] # type: ignore[no-any-return]
@property
def signer(self) -> Signer:
return self['signer']
return self['signer'] # type: ignore[no-any-return]
@signer.setter
@ -130,7 +126,7 @@ class Application(web.Application):
@property
def template(self) -> Template:
return self['template']
return self['template'] # type: ignore[no-any-return]
@property
@ -185,11 +181,11 @@ class Application(web.Application):
pass
def stop(self, *_):
def stop(self, *_: Any) -> None:
self['running'] = False
async def handle_run(self):
async def handle_run(self) -> None:
self['running'] = True
self.set_signal_handler(True)
@ -295,7 +291,7 @@ class CacheCleanupThread(Thread):
class PushWorker(multiprocessing.Process):
def __init__(self, queue: multiprocessing.Queue):
def __init__(self, queue: multiprocessing.Queue[tuple[str, Message | bytes, Row]]) -> None:
if Application.DEFAULT is None:
raise RuntimeError('Application not setup yet')
@ -347,7 +343,10 @@ class PushWorker(multiprocessing.Process):
@web.middleware
async def handle_response_headers(request: web.Request, handler: Callable) -> Response:
async def handle_response_headers(
request: web.Request,
handler: Callable[[web.Request], Awaitable[Response]]) -> Response:
resp = await handler(request)
resp.headers['Server'] = 'ActivityRelay'

View file

@ -2,28 +2,27 @@ from __future__ import annotations
import json
import os
import typing
from abc import ABC, abstractmethod
from bsql import Database
from collections.abc import Callable, Iterator
from dataclasses import asdict, dataclass
from datetime import datetime, timedelta, timezone
from redis import Redis
from typing import TYPE_CHECKING, Any
from .database import get_database
from .database import Connection, get_database
from .misc import Message, boolean
if typing.TYPE_CHECKING:
from bsql import Database
from collections.abc import Callable, Iterator
from typing import Any
if TYPE_CHECKING:
from .application import Application
# todo: implement more caching backends
SerializerCallback = Callable[[Any], str]
DeserializerCallback = Callable[[str], Any]
BACKENDS: dict[str, type[Cache]] = {}
CONVERTERS: dict[str, tuple[Callable, Callable]] = {
CONVERTERS: dict[str, tuple[SerializerCallback, DeserializerCallback]] = {
'str': (str, str),
'int': (str, int),
'bool': (str, boolean),
@ -61,13 +60,13 @@ class Item:
updated: datetime
def __post_init__(self):
if isinstance(self.updated, str):
self.updated = datetime.fromisoformat(self.updated)
def __post_init__(self) -> None:
if isinstance(self.updated, str): # type: ignore[unreachable]
self.updated = datetime.fromisoformat(self.updated) # type: ignore[unreachable]
@classmethod
def from_data(cls: type[Item], *args) -> Item:
def from_data(cls: type[Item], *args: Any) -> Item:
data = cls(*args)
data.value = deserialize_value(data.value, data.value_type)
@ -159,7 +158,7 @@ class SqlCache(Cache):
def __init__(self, app: Application):
Cache.__init__(self, app)
self._db: Database | None = None
self._db: Database[Connection] | None = None
def get(self, namespace: str, key: str) -> Item:
@ -211,10 +210,10 @@ class SqlCache(Cache):
}
with self._db.session(True) as conn:
with conn.run('set-cache-item', params) as conn:
row = conn.one()
row.pop('id', None)
return Item.from_data(*tuple(row.values()))
with conn.run('set-cache-item', params) as cur:
row = cur.one()
row.pop('id', None) # type: ignore[union-attr]
return Item.from_data(*tuple(row.values())) # type: ignore[union-attr]
def delete(self, namespace: str, key: str) -> None:
@ -381,5 +380,5 @@ class RedisCache(Cache):
if not self._rd:
return
self._rd.close()
self._rd.close() # type: ignore
self._rd = None # type: ignore

View file

@ -1,21 +1,16 @@
from __future__ import annotations
import json
import os
import typing
import yaml
from functools import cached_property
from pathlib import Path
from typing import Any
from urllib.parse import urlparse
from .misc import boolean
if typing.TYPE_CHECKING:
from typing import Any
class RelayConfig(dict):
class RelayConfig(dict[str, Any]):
def __init__(self, path: str):
dict.__init__(self, {})
@ -122,7 +117,7 @@ class RelayConfig(dict):
self[key] = value
class RelayDatabase(dict):
class RelayDatabase(dict[str, Any]):
def __init__(self, config: RelayConfig):
dict.__init__(self, {
'relay-list': {},

View file

@ -1,20 +1,15 @@
from __future__ import annotations
import getpass
import os
import platform
import typing
import yaml
from dataclasses import asdict, dataclass, fields
from pathlib import Path
from platformdirs import user_config_dir
from typing import Any
from .misc import IS_DOCKER
if typing.TYPE_CHECKING:
from typing import Any
try:
from typing import Self

View file

@ -1,20 +1,15 @@
from __future__ import annotations
import bsql
import typing
from bsql import Database
from .config import THEMES, ConfigData
from .connection import RELAY_SOFTWARE, Connection
from .schema import TABLES, VERSIONS, migrate_0
from .. import logger as logging
from ..config import Config
from ..misc import get_resource
if typing.TYPE_CHECKING:
from ..config import Config
def get_database(config: Config, migrate: bool = True) -> bsql.Database:
def get_database(config: Config, migrate: bool = True) -> Database[Connection]:
options = {
'connection_class': Connection,
'pool_size': 5,
@ -22,10 +17,10 @@ def get_database(config: Config, migrate: bool = True) -> bsql.Database:
}
if config.db_type == 'sqlite':
db = bsql.Database.sqlite(config.sqlite_path, **options)
db = Database.sqlite(config.sqlite_path, **options)
elif config.db_type == 'postgres':
db = bsql.Database.postgresql(
db = Database.postgresql(
config.pg_name,
config.pg_host,
config.pg_port,

View file

@ -1,17 +1,15 @@
from __future__ import annotations
# removing the above line turns annotations into types instead of str objects which messes with
# `Field.type`
import typing
from bsql import Row
from collections.abc import Callable, Sequence
from dataclasses import Field, asdict, dataclass, fields
from typing import Any
from .. import logger as logging
from ..misc import boolean
if typing.TYPE_CHECKING:
from bsql import Row
from collections.abc import Callable, Sequence
from typing import Any
try:
from typing import Self
@ -120,7 +118,7 @@ class ConfigData:
@classmethod
def FIELD(cls: type[Self], key: str) -> Field:
def FIELD(cls: type[Self], key: str) -> Field[Any]:
for field in fields(cls):
if field.name == key.replace('-', '_'):
return field

View file

@ -1,10 +1,10 @@
from __future__ import annotations
import typing
from argon2 import PasswordHasher
from bsql import Connection as SqlConnection, Update
from bsql import Connection as SqlConnection, Row, Update
from collections.abc import Iterator, Sequence
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse
from uuid import uuid4
@ -14,15 +14,10 @@ from .config import (
)
from .. import logger as logging
from ..misc import boolean, get_app
from ..misc import Message, boolean, get_app
if typing.TYPE_CHECKING:
from collections.abc import Iterator, Sequence
from bsql import Row
from typing import Any
if TYPE_CHECKING:
from ..application import Application
from ..misc import Message
RELAY_SOFTWARE = [
'activityrelay', # https://git.pleroma.social/pleroma/relay
@ -94,7 +89,7 @@ class Connection(SqlConnection):
params = {
'key': key,
'value': data.get(key, serialize = True),
'type': 'LogLevel' if field.type == 'logging.LogLevel' else field.type
'type': 'LogLevel' if field.type == 'logging.LogLevel' else field.type # type: ignore
}
with self.run('put-config', params):

View file

@ -1,17 +1,11 @@
from __future__ import annotations
import typing
from bsql import Column, Table, Tables
from collections.abc import Callable
from .config import ConfigData
if typing.TYPE_CHECKING:
from collections.abc import Callable
from .connection import Connection
VERSIONS: dict[int, Callable] = {}
VERSIONS: dict[int, Callable[[Connection], None]] = {}
TABLES: Tables = Tables(
Table(
'config',
@ -64,7 +58,7 @@ TABLES: Tables = Tables(
)
def migration(func: Callable) -> Callable:
def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]:
ver = int(func.__name__.replace('migrate_', ''))
VERSIONS[ver] = func
return func

View file

@ -2,26 +2,23 @@ from __future__ import annotations
import json
import traceback
import typing
from aiohttp import ClientSession, ClientTimeout, TCPConnector
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
from aputils import AlgorithmType, Nodeinfo, ObjectType, Signer, WellKnownNodeinfo
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
from aputils import AlgorithmType, Nodeinfo, ObjectType, WellKnownNodeinfo
from blib import JsonBase
from bsql import Row
from json.decoder import JSONDecodeError
from typing import TYPE_CHECKING, Any, TypeVar
from urllib.parse import urlparse
from . import __version__
from . import logger as logging
from . import __version__, logger as logging
from .cache import Cache
from .misc import MIMETYPES, Message, get_app
if typing.TYPE_CHECKING:
from aputils import Signer
from bsql import Row
from typing import Any
if TYPE_CHECKING:
from .application import Application
from .cache import Cache
SUPPORTS_HS2019 = {
@ -39,7 +36,7 @@ SUPPORTS_HS2019 = {
'sharkey'
}
T = typing.TypeVar('T', bound = JsonBase)
T = TypeVar('T', bound = JsonBase)
HEADERS = {
'Accept': f'{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9',
'User-Agent': f'ActivityRelay/{__version__}'
@ -124,7 +121,7 @@ class HttpClient:
if not force:
try:
if not (item := self.cache.get('request', url)).older_than(48):
return json.loads(item.value)
return json.loads(item.value) # type: ignore[no-any-return]
except KeyError:
logging.verbose('No cached data for url: %s', url)
@ -153,7 +150,7 @@ class HttpClient:
self.cache.set('request', url, data, 'str')
logging.debug('%s >> resp %s', url, json.dumps(json.loads(data), indent = 4))
return json.loads(data)
return json.loads(data) # type: ignore [no-any-return]
except JSONDecodeError:
logging.verbose('Failed to parse JSON')

View file

@ -1,15 +1,9 @@
from __future__ import annotations
import logging
import os
import typing
from enum import IntEnum
from pathlib import Path
if typing.TYPE_CHECKING:
from collections.abc import Callable
from typing import Any
from typing import Any, Protocol
try:
from typing import Self
@ -18,6 +12,10 @@ if typing.TYPE_CHECKING:
from typing_extensions import Self
class LoggingMethod(Protocol):
def __call__(self, msg: Any, *args: Any, **kwargs: Any) -> None: ...
class LogLevel(IntEnum):
DEBUG = logging.DEBUG
VERBOSE = 15
@ -75,11 +73,11 @@ def verbose(message: str, *args: Any, **kwargs: Any) -> None:
logging.log(LogLevel.VERBOSE, message, *args, **kwargs)
debug: Callable = logging.debug
info: Callable = logging.info
warning: Callable = logging.warning
error: Callable = logging.error
critical: Callable = logging.critical
debug: LoggingMethod = logging.debug
info: LoggingMethod = logging.info
warning: LoggingMethod = logging.warning
error: LoggingMethod = logging.error
critical: LoggingMethod = logging.critical
try:

View file

@ -5,10 +5,11 @@ import asyncio
import click
import json
import os
import typing
from bsql import Row
from pathlib import Path
from shutil import copyfile
from typing import Any
from urllib.parse import urlparse
from . import __version__
@ -19,10 +20,6 @@ from .compat import RelayConfig, RelayDatabase
from .database import RELAY_SOFTWARE, get_database
from .misc import ACTOR_FORMATS, SOFTWARE, IS_DOCKER, Message
if typing.TYPE_CHECKING:
from bsql import Row
from typing import Any
def check_alphanumeric(text: str) -> str:
if not text.isalnum():

View file

@ -5,11 +5,12 @@ import json
import os
import platform
import socket
import typing
from aiohttp.web import Response as AiohttpResponse
from collections.abc import Sequence
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypedDict, TypeVar
from uuid import uuid4
try:
@ -18,21 +19,20 @@ try:
except ImportError:
from importlib_resources import files as pkgfiles # type: ignore
if typing.TYPE_CHECKING:
from typing import Any
from .application import Application
try:
from typing import Self
except ImportError:
from typing_extensions import Self
if TYPE_CHECKING:
from .application import Application
T = typing.TypeVar('T')
ResponseType = typing.TypedDict('ResponseType', {
T = TypeVar('T')
ResponseType = TypedDict('ResponseType', {
'status': int,
'headers': dict[str, typing.Any] | None,
'headers': dict[str, Any] | None,
'content_type': str,
'body': bytes | None,
'text': str | None
@ -128,7 +128,7 @@ class JsonEncoder(json.JSONEncoder):
if isinstance(o, datetime):
return o.isoformat()
return json.JSONEncoder.default(self, o)
return json.JSONEncoder.default(self, o) # type: ignore[no-any-return]
class Message(aputils.Message):
@ -214,7 +214,7 @@ class Response(AiohttpResponse):
@classmethod
def new(cls: type[Self],
body: str | bytes | dict | tuple | list | set = '',
body: str | bytes | dict[str, Any] | Sequence[Any] = '',
status: int = 200,
headers: dict[str, str] | None = None,
ctype: str = 'text') -> Self:
@ -227,22 +227,22 @@ class Response(AiohttpResponse):
'text': None
}
if isinstance(body, bytes):
if isinstance(body, str):
kwargs['text'] = body
elif isinstance(body, bytes):
kwargs['body'] = body
elif isinstance(body, (dict, list, tuple, set)) or ctype in {'json', 'activity'}:
elif isinstance(body, (dict, Sequence)):
kwargs['text'] = json.dumps(body, cls = JsonEncoder)
else:
kwargs['text'] = body
return cls(**kwargs)
@classmethod
def new_error(cls: type[Self],
status: int,
body: str | bytes | dict,
body: str | bytes | dict[str, Any],
ctype: str = 'text') -> Self:
if ctype == 'json':

View file

@ -1,26 +1,23 @@
from __future__ import annotations
import textwrap
import typing
from collections.abc import Callable
from hamlish_jinja import HamlishExtension
from jinja2 import Environment, FileSystemLoader
from jinja2.ext import Extension
from jinja2.nodes import CallBlock
from jinja2.nodes import CallBlock, Node
from jinja2.parser import Parser
from markdown import Markdown
from typing import TYPE_CHECKING, Any
from . import __version__
from .misc import get_resource
if typing.TYPE_CHECKING:
from jinja2.nodes import Node
from jinja2.parser import Parser
from typing import Any
from .application import Application
from .views.base import View
if TYPE_CHECKING:
from .application import Application
class Template(Environment):
def __init__(self, app: Application):

View file

@ -1,9 +1,7 @@
from __future__ import annotations
import typing
from aiohttp import web
from aiohttp.web import Request, middleware
from argon2.exceptions import VerifyMismatchError
from collections.abc import Awaitable, Callable, Sequence
from typing import Any
from urllib.parse import urlparse
from .base import View, register_route
@ -12,11 +10,6 @@ from .. import __version__
from ..database import ConfigData
from ..misc import Message, Response, boolean, get_app
if typing.TYPE_CHECKING:
from aiohttp.web import Request
from collections.abc import Callable, Sequence
from typing import Any
ALLOWED_HEADERS = {
'accept',
@ -37,8 +30,10 @@ def check_api_path(method: str, path: str) -> bool:
return path.startswith('/api')
@web.middleware
async def handle_api_path(request: Request, handler: Callable) -> Response:
@middleware
async def handle_api_path(
request: Request,
handler: Callable[[Request], Awaitable[Response]]) -> Response:
try:
if (token := request.cookies.get('user-token')):
request['token'] = token
@ -222,7 +217,7 @@ class Inbox(View):
if nodeinfo is not None:
data['software'] = nodeinfo.sw_name
row = conn.put_inbox(**data)
row = conn.put_inbox(**data) # type: ignore[arg-type]
return Response.new(row, ctype = 'json')
@ -237,7 +232,7 @@ class Inbox(View):
if not (instance := conn.get_inbox(data['domain'])):
return Response.new_error(404, 'Instance with domain not found', 'json')
instance = conn.put_inbox(instance['domain'], **data)
instance = conn.put_inbox(instance['domain'], **data) # type: ignore[arg-type]
return Response.new(instance, ctype = 'json')

View file

@ -1,26 +1,24 @@
from __future__ import annotations
import typing
from Crypto.Random import get_random_bytes
from aiohttp.abc import AbstractView
from aiohttp.hdrs import METH_ALL as METHODS
from aiohttp.web import HTTPMethodNotAllowed
from aiohttp.web import HTTPMethodNotAllowed, Request
from base64 import b64encode
from bsql import Database
from collections.abc import Awaitable, Callable, Generator, Sequence, Mapping
from functools import cached_property
from json.decoder import JSONDecodeError
from typing import TYPE_CHECKING, Any
from ..misc import Response, get_app
if typing.TYPE_CHECKING:
from aiohttp.web import Request
from collections.abc import Callable, Generator, Sequence, Mapping
from bsql import Database
from typing import Any
from ..application import Application
from ..cache import Cache
from ..config import Config
from ..database import Connection
from ..http_client import HttpClient
from ..misc import Response, get_app
if TYPE_CHECKING:
from ..application import Application
from ..template import Template
try:
@ -29,6 +27,8 @@ if typing.TYPE_CHECKING:
except ImportError:
from typing_extensions import Self
HandlerCallback = Callable[[Request], Awaitable[Response]]
VIEWS: list[tuple[str, type[View]]] = []
@ -37,7 +37,7 @@ def convert_data(data: Mapping[str, Any]) -> dict[str, str]:
return {key: str(value) for key, value in data.items()}
def register_route(*paths: str) -> Callable:
def register_route(*paths: str) -> Callable[[type[View]], type[View]]:
def wrapper(view: type[View]) -> type[View]:
for path in paths:
VIEWS.append((path, view))
@ -63,7 +63,7 @@ class View(AbstractView):
return await view.handlers[method](request, **kwargs)
async def _run_handler(self, handler: Callable[..., Any], **kwargs: Any) -> Response:
async def _run_handler(self, handler: HandlerCallback, **kwargs: Any) -> Response:
self.request['hash'] = b64encode(get_random_bytes(16)).decode('ascii')
return await handler(self.request, **self.request.match_info, **kwargs)
@ -78,7 +78,7 @@ class View(AbstractView):
@cached_property
def handlers(self) -> dict[str, Callable[..., Any]]:
def handlers(self) -> dict[str, HandlerCallback]:
data = {}
for method in METHODS:
@ -112,13 +112,13 @@ class View(AbstractView):
@property
def database(self) -> Database:
def database(self) -> Database[Connection]:
return self.app.database
@property
def template(self) -> Template:
return self.app['template']
return self.app['template'] # type: ignore[no-any-return]
async def get_api_data(self,

View file

@ -1,8 +1,6 @@
from __future__ import annotations
import typing
from aiohttp import web
from collections.abc import Awaitable, Callable
from typing import Any
from .base import View, register_route
@ -10,11 +8,6 @@ from ..database import THEMES
from ..logger import LogLevel
from ..misc import Response, get_app
if typing.TYPE_CHECKING:
from aiohttp.web import Request
from collections.abc import Callable
from typing import Any
UNAUTH_ROUTES = {
'/',
@ -23,7 +16,10 @@ UNAUTH_ROUTES = {
@web.middleware
async def handle_frontend_path(request: web.Request, handler: Callable) -> Response:
async def handle_frontend_path(
request: web.Request,
handler: Callable[[web.Request], Awaitable[Response]]) -> Response:
app = get_app()
if request.path in UNAUTH_ROUTES or request.path.startswith('/admin'):
@ -52,7 +48,7 @@ async def handle_frontend_path(request: web.Request, handler: Callable) -> Respo
@register_route('/')
class HomeView(View):
async def get(self, request: Request) -> Response:
async def get(self, request: web.Request) -> Response:
with self.database.session() as conn:
context: dict[str, Any] = {
'instances': tuple(conn.get_inboxes())
@ -64,14 +60,14 @@ class HomeView(View):
@register_route('/login')
class Login(View):
async def get(self, request: Request) -> Response:
async def get(self, request: web.Request) -> Response:
data = self.template.render('page/login.haml', self)
return Response.new(data, ctype = 'html')
@register_route('/logout')
class Logout(View):
async def get(self, request: Request) -> Response:
async def get(self, request: web.Request) -> Response:
with self.database.session(True) as conn:
conn.del_token(request['token'])
@ -82,14 +78,14 @@ class Logout(View):
@register_route('/admin')
class Admin(View):
async def get(self, request: Request) -> Response:
async def get(self, request: web.Request) -> Response:
return Response.new('', 302, {'Location': '/admin/instances'})
@register_route('/admin/instances')
class AdminInstances(View):
async def get(self,
request: Request,
request: web.Request,
error: str | None = None,
message: str | None = None) -> Response:
@ -112,7 +108,7 @@ class AdminInstances(View):
@register_route('/admin/whitelist')
class AdminWhitelist(View):
async def get(self,
request: Request,
request: web.Request,
error: str | None = None,
message: str | None = None) -> Response:
@ -134,7 +130,7 @@ class AdminWhitelist(View):
@register_route('/admin/domain_bans')
class AdminDomainBans(View):
async def get(self,
request: Request,
request: web.Request,
error: str | None = None,
message: str | None = None) -> Response:
@ -156,7 +152,7 @@ class AdminDomainBans(View):
@register_route('/admin/software_bans')
class AdminSoftwareBans(View):
async def get(self,
request: Request,
request: web.Request,
error: str | None = None,
message: str | None = None) -> Response:
@ -178,7 +174,7 @@ class AdminSoftwareBans(View):
@register_route('/admin/users')
class AdminUsers(View):
async def get(self,
request: Request,
request: web.Request,
error: str | None = None,
message: str | None = None) -> Response:
@ -199,7 +195,7 @@ class AdminUsers(View):
@register_route('/admin/config')
class AdminConfig(View):
async def get(self, request: Request, message: str | None = None) -> Response:
async def get(self, request: web.Request, message: str | None = None) -> Response:
context: dict[str, Any] = {
'themes': tuple(THEMES.keys()),
'levels': tuple(level.name for level in LogLevel),
@ -212,7 +208,7 @@ class AdminConfig(View):
@register_route('/manifest.json')
class ManifestJson(View):
async def get(self, request: Request) -> Response:
async def get(self, request: web.Request) -> Response:
with self.database.session(False) as conn:
config = conn.get_config_all()
theme = THEMES[config.theme]
@ -235,7 +231,7 @@ class ManifestJson(View):
@register_route('/theme/{theme}.css')
class ThemeCss(View):
async def get(self, request: Request, theme: str) -> Response:
async def get(self, request: web.Request, theme: str) -> Response:
try:
context: dict[str, Any] = {
'theme': THEMES[theme]

View file

@ -1,9 +1,7 @@
from __future__ import annotations
import aputils
import subprocess
import typing
from aiohttp.web import Request
from pathlib import Path
from .base import View, register_route
@ -11,9 +9,6 @@ from .base import View, register_route
from .. import __version__
from ..misc import Response
if typing.TYPE_CHECKING:
from aiohttp.web import Request
VERSION = __version__