mirror of
https://git.pleroma.social/pleroma/relay.git
synced 2024-11-22 22:48:00 +00:00
fix linter issues
This commit is contained in:
parent
a2b96d03dc
commit
15882f3e49
30
dev.py
30
dev.py
|
@ -11,11 +11,11 @@ from datetime import datetime, timedelta
|
|||
from pathlib import Path
|
||||
from relay import __version__, logger as logging
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Sequence
|
||||
from typing import Any, Sequence
|
||||
|
||||
try:
|
||||
from watchdog.observers import Observer
|
||||
from watchdog.events import PatternMatchingEventHandler
|
||||
from watchdog.events import FileSystemEvent, PatternMatchingEventHandler
|
||||
|
||||
except ImportError:
|
||||
class PatternMatchingEventHandler: # type: ignore
|
||||
|
@ -70,7 +70,7 @@ def cli_lint(path: Path, watch: bool) -> None:
|
|||
|
||||
|
||||
@cli.command('clean')
|
||||
def cli_clean():
|
||||
def cli_clean() -> None:
|
||||
dirs = {
|
||||
'dist',
|
||||
'build',
|
||||
|
@ -88,7 +88,7 @@ def cli_clean():
|
|||
|
||||
|
||||
@cli.command('build')
|
||||
def cli_build():
|
||||
def cli_build() -> None:
|
||||
with TemporaryDirectory() as tmp:
|
||||
arch = 'amd64' if sys.maxsize >= 2**32 else 'i386'
|
||||
cmd = [
|
||||
|
@ -118,7 +118,7 @@ def cli_build():
|
|||
|
||||
@cli.command('run')
|
||||
@click.option('--dev', '-d', is_flag = True)
|
||||
def cli_run(dev: bool):
|
||||
def cli_run(dev: bool) -> None:
|
||||
print('Starting process watcher')
|
||||
|
||||
cmd = [sys.executable, '-m', 'relay', 'run']
|
||||
|
@ -129,13 +129,13 @@ def cli_run(dev: bool):
|
|||
handle_run_watcher(cmd)
|
||||
|
||||
|
||||
def handle_run_watcher(*commands: Sequence[str], wait: bool = False):
|
||||
def handle_run_watcher(*commands: Sequence[str], wait: bool = False) -> None:
|
||||
handler = WatchHandler(*commands, wait = wait)
|
||||
handler.run_procs()
|
||||
|
||||
watcher = Observer()
|
||||
watcher.schedule(handler, str(REPO), recursive=True)
|
||||
watcher.start()
|
||||
watcher.schedule(handler, str(REPO), recursive=True) # type: ignore
|
||||
watcher.start() # type: ignore
|
||||
|
||||
try:
|
||||
while True:
|
||||
|
@ -145,7 +145,7 @@ def handle_run_watcher(*commands: Sequence[str], wait: bool = False):
|
|||
pass
|
||||
|
||||
handler.kill_procs()
|
||||
watcher.stop()
|
||||
watcher.stop() # type: ignore
|
||||
watcher.join()
|
||||
|
||||
|
||||
|
@ -153,16 +153,16 @@ class WatchHandler(PatternMatchingEventHandler):
|
|||
patterns = ['*.py']
|
||||
|
||||
|
||||
def __init__(self, *commands: Sequence[str], wait: bool = False):
|
||||
PatternMatchingEventHandler.__init__(self)
|
||||
def __init__(self, *commands: Sequence[str], wait: bool = False) -> None:
|
||||
PatternMatchingEventHandler.__init__(self) # type: ignore
|
||||
|
||||
self.commands: Sequence[Sequence[str]] = commands
|
||||
self.wait: bool = wait
|
||||
self.procs: list[subprocess.Popen] = []
|
||||
self.procs: list[subprocess.Popen[Any]] = []
|
||||
self.last_restart: datetime = datetime.now()
|
||||
|
||||
|
||||
def kill_procs(self):
|
||||
def kill_procs(self) -> None:
|
||||
for proc in self.procs:
|
||||
if proc.poll() is not None:
|
||||
continue
|
||||
|
@ -183,7 +183,7 @@ class WatchHandler(PatternMatchingEventHandler):
|
|||
logging.info('Process terminated')
|
||||
|
||||
|
||||
def run_procs(self, restart: bool = False):
|
||||
def run_procs(self, restart: bool = False) -> None:
|
||||
if restart:
|
||||
if datetime.now() - timedelta(seconds = 3) < self.last_restart:
|
||||
return
|
||||
|
@ -205,7 +205,7 @@ class WatchHandler(PatternMatchingEventHandler):
|
|||
logging.info('Started processes with PIDs: %s', ', '.join(pids))
|
||||
|
||||
|
||||
def on_any_event(self, event):
|
||||
def on_any_event(self, event: FileSystemEvent) -> None:
|
||||
if event.event_type not in ['modified', 'created', 'deleted']:
|
||||
return
|
||||
|
||||
|
|
|
@ -5,35 +5,31 @@ import multiprocessing
|
|||
import signal
|
||||
import time
|
||||
import traceback
|
||||
import typing
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp.web import StaticResource
|
||||
from aiohttp_swagger import setup_swagger
|
||||
from aputils.signer import Signer
|
||||
from bsql import Database, Row
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import datetime, timedelta
|
||||
from mimetypes import guess_type
|
||||
from pathlib import Path
|
||||
from queue import Empty
|
||||
from threading import Event, Thread
|
||||
from typing import Any
|
||||
|
||||
from . import logger as logging
|
||||
from .cache import get_cache
|
||||
from .cache import Cache, get_cache
|
||||
from .config import Config
|
||||
from .database import get_database
|
||||
from .database import Connection, get_database
|
||||
from .http_client import HttpClient
|
||||
from .misc import IS_WINDOWS, check_open_port, get_resource
|
||||
from .misc import IS_WINDOWS, Message, Response, check_open_port, get_resource
|
||||
from .template import Template
|
||||
from .views import VIEWS
|
||||
from .views.api import handle_api_path
|
||||
from .views.frontend import handle_frontend_path
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
from bsql import Database, Row
|
||||
from .cache import Cache
|
||||
from .misc import Message, Response
|
||||
|
||||
|
||||
def get_csp(request: web.Request) -> str:
|
||||
data = [
|
||||
|
@ -58,9 +54,9 @@ class Application(web.Application):
|
|||
def __init__(self, cfgpath: Path | None, dev: bool = False):
|
||||
web.Application.__init__(self,
|
||||
middlewares = [
|
||||
handle_api_path,
|
||||
handle_frontend_path,
|
||||
handle_response_headers
|
||||
handle_api_path, # type: ignore[list-item]
|
||||
handle_frontend_path, # type: ignore[list-item]
|
||||
handle_response_headers # type: ignore[list-item]
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -96,27 +92,27 @@ class Application(web.Application):
|
|||
|
||||
@property
|
||||
def cache(self) -> Cache:
|
||||
return self['cache']
|
||||
return self['cache'] # type: ignore[no-any-return]
|
||||
|
||||
|
||||
@property
|
||||
def client(self) -> HttpClient:
|
||||
return self['client']
|
||||
return self['client'] # type: ignore[no-any-return]
|
||||
|
||||
|
||||
@property
|
||||
def config(self) -> Config:
|
||||
return self['config']
|
||||
return self['config'] # type: ignore[no-any-return]
|
||||
|
||||
|
||||
@property
|
||||
def database(self) -> Database:
|
||||
return self['database']
|
||||
def database(self) -> Database[Connection]:
|
||||
return self['database'] # type: ignore[no-any-return]
|
||||
|
||||
|
||||
@property
|
||||
def signer(self) -> Signer:
|
||||
return self['signer']
|
||||
return self['signer'] # type: ignore[no-any-return]
|
||||
|
||||
|
||||
@signer.setter
|
||||
|
@ -130,7 +126,7 @@ class Application(web.Application):
|
|||
|
||||
@property
|
||||
def template(self) -> Template:
|
||||
return self['template']
|
||||
return self['template'] # type: ignore[no-any-return]
|
||||
|
||||
|
||||
@property
|
||||
|
@ -185,11 +181,11 @@ class Application(web.Application):
|
|||
pass
|
||||
|
||||
|
||||
def stop(self, *_):
|
||||
def stop(self, *_: Any) -> None:
|
||||
self['running'] = False
|
||||
|
||||
|
||||
async def handle_run(self):
|
||||
async def handle_run(self) -> None:
|
||||
self['running'] = True
|
||||
|
||||
self.set_signal_handler(True)
|
||||
|
@ -295,7 +291,7 @@ class CacheCleanupThread(Thread):
|
|||
|
||||
|
||||
class PushWorker(multiprocessing.Process):
|
||||
def __init__(self, queue: multiprocessing.Queue):
|
||||
def __init__(self, queue: multiprocessing.Queue[tuple[str, Message | bytes, Row]]) -> None:
|
||||
if Application.DEFAULT is None:
|
||||
raise RuntimeError('Application not setup yet')
|
||||
|
||||
|
@ -347,7 +343,10 @@ class PushWorker(multiprocessing.Process):
|
|||
|
||||
|
||||
@web.middleware
|
||||
async def handle_response_headers(request: web.Request, handler: Callable) -> Response:
|
||||
async def handle_response_headers(
|
||||
request: web.Request,
|
||||
handler: Callable[[web.Request], Awaitable[Response]]) -> Response:
|
||||
|
||||
resp = await handler(request)
|
||||
resp.headers['Server'] = 'ActivityRelay'
|
||||
|
||||
|
|
|
@ -2,28 +2,27 @@ from __future__ import annotations
|
|||
|
||||
import json
|
||||
import os
|
||||
import typing
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from bsql import Database
|
||||
from collections.abc import Callable, Iterator
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from redis import Redis
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from .database import get_database
|
||||
from .database import Connection, get_database
|
||||
from .misc import Message, boolean
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from bsql import Database
|
||||
from collections.abc import Callable, Iterator
|
||||
from typing import Any
|
||||
if TYPE_CHECKING:
|
||||
from .application import Application
|
||||
|
||||
|
||||
# todo: implement more caching backends
|
||||
|
||||
SerializerCallback = Callable[[Any], str]
|
||||
DeserializerCallback = Callable[[str], Any]
|
||||
|
||||
BACKENDS: dict[str, type[Cache]] = {}
|
||||
CONVERTERS: dict[str, tuple[Callable, Callable]] = {
|
||||
CONVERTERS: dict[str, tuple[SerializerCallback, DeserializerCallback]] = {
|
||||
'str': (str, str),
|
||||
'int': (str, int),
|
||||
'bool': (str, boolean),
|
||||
|
@ -61,13 +60,13 @@ class Item:
|
|||
updated: datetime
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.updated, str):
|
||||
self.updated = datetime.fromisoformat(self.updated)
|
||||
def __post_init__(self) -> None:
|
||||
if isinstance(self.updated, str): # type: ignore[unreachable]
|
||||
self.updated = datetime.fromisoformat(self.updated) # type: ignore[unreachable]
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_data(cls: type[Item], *args) -> Item:
|
||||
def from_data(cls: type[Item], *args: Any) -> Item:
|
||||
data = cls(*args)
|
||||
data.value = deserialize_value(data.value, data.value_type)
|
||||
|
||||
|
@ -159,7 +158,7 @@ class SqlCache(Cache):
|
|||
|
||||
def __init__(self, app: Application):
|
||||
Cache.__init__(self, app)
|
||||
self._db: Database | None = None
|
||||
self._db: Database[Connection] | None = None
|
||||
|
||||
|
||||
def get(self, namespace: str, key: str) -> Item:
|
||||
|
@ -211,10 +210,10 @@ class SqlCache(Cache):
|
|||
}
|
||||
|
||||
with self._db.session(True) as conn:
|
||||
with conn.run('set-cache-item', params) as conn:
|
||||
row = conn.one()
|
||||
row.pop('id', None)
|
||||
return Item.from_data(*tuple(row.values()))
|
||||
with conn.run('set-cache-item', params) as cur:
|
||||
row = cur.one()
|
||||
row.pop('id', None) # type: ignore[union-attr]
|
||||
return Item.from_data(*tuple(row.values())) # type: ignore[union-attr]
|
||||
|
||||
|
||||
def delete(self, namespace: str, key: str) -> None:
|
||||
|
@ -381,5 +380,5 @@ class RedisCache(Cache):
|
|||
if not self._rd:
|
||||
return
|
||||
|
||||
self._rd.close()
|
||||
self._rd.close() # type: ignore
|
||||
self._rd = None # type: ignore
|
||||
|
|
|
@ -1,21 +1,16 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import typing
|
||||
import yaml
|
||||
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from .misc import boolean
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
|
||||
class RelayConfig(dict):
|
||||
class RelayConfig(dict[str, Any]):
|
||||
def __init__(self, path: str):
|
||||
dict.__init__(self, {})
|
||||
|
||||
|
@ -122,7 +117,7 @@ class RelayConfig(dict):
|
|||
self[key] = value
|
||||
|
||||
|
||||
class RelayDatabase(dict):
|
||||
class RelayDatabase(dict[str, Any]):
|
||||
def __init__(self, config: RelayConfig):
|
||||
dict.__init__(self, {
|
||||
'relay-list': {},
|
||||
|
|
|
@ -1,25 +1,20 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import getpass
|
||||
import os
|
||||
import platform
|
||||
import typing
|
||||
import yaml
|
||||
|
||||
from dataclasses import asdict, dataclass, fields
|
||||
from pathlib import Path
|
||||
from platformdirs import user_config_dir
|
||||
from typing import Any
|
||||
|
||||
from .misc import IS_DOCKER
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from typing import Any
|
||||
try:
|
||||
from typing import Self
|
||||
|
||||
try:
|
||||
from typing import Self
|
||||
|
||||
except ImportError:
|
||||
from typing_extensions import Self
|
||||
except ImportError:
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
if platform.system() == 'Windows':
|
||||
|
|
|
@ -1,20 +1,15 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import bsql
|
||||
import typing
|
||||
from bsql import Database
|
||||
|
||||
from .config import THEMES, ConfigData
|
||||
from .connection import RELAY_SOFTWARE, Connection
|
||||
from .schema import TABLES, VERSIONS, migrate_0
|
||||
|
||||
from .. import logger as logging
|
||||
from ..config import Config
|
||||
from ..misc import get_resource
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ..config import Config
|
||||
|
||||
|
||||
def get_database(config: Config, migrate: bool = True) -> bsql.Database:
|
||||
def get_database(config: Config, migrate: bool = True) -> Database[Connection]:
|
||||
options = {
|
||||
'connection_class': Connection,
|
||||
'pool_size': 5,
|
||||
|
@ -22,10 +17,10 @@ def get_database(config: Config, migrate: bool = True) -> bsql.Database:
|
|||
}
|
||||
|
||||
if config.db_type == 'sqlite':
|
||||
db = bsql.Database.sqlite(config.sqlite_path, **options)
|
||||
db = Database.sqlite(config.sqlite_path, **options)
|
||||
|
||||
elif config.db_type == 'postgres':
|
||||
db = bsql.Database.postgresql(
|
||||
db = Database.postgresql(
|
||||
config.pg_name,
|
||||
config.pg_host,
|
||||
config.pg_port,
|
||||
|
|
|
@ -1,22 +1,20 @@
|
|||
from __future__ import annotations
|
||||
# removing the above line turns annotations into types instead of str objects which messes with
|
||||
# `Field.type`
|
||||
|
||||
import typing
|
||||
|
||||
from bsql import Row
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import Field, asdict, dataclass, fields
|
||||
from typing import Any
|
||||
|
||||
from .. import logger as logging
|
||||
from ..misc import boolean
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from bsql import Row
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
try:
|
||||
from typing import Self
|
||||
|
||||
try:
|
||||
from typing import Self
|
||||
|
||||
except ImportError:
|
||||
from typing_extensions import Self
|
||||
except ImportError:
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
THEMES = {
|
||||
|
@ -120,7 +118,7 @@ class ConfigData:
|
|||
|
||||
|
||||
@classmethod
|
||||
def FIELD(cls: type[Self], key: str) -> Field:
|
||||
def FIELD(cls: type[Self], key: str) -> Field[Any]:
|
||||
for field in fields(cls):
|
||||
if field.name == key.replace('-', '_'):
|
||||
return field
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from argon2 import PasswordHasher
|
||||
from bsql import Connection as SqlConnection, Update
|
||||
from bsql import Connection as SqlConnection, Row, Update
|
||||
from collections.abc import Iterator, Sequence
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
|
@ -14,15 +14,10 @@ from .config import (
|
|||
)
|
||||
|
||||
from .. import logger as logging
|
||||
from ..misc import boolean, get_app
|
||||
from ..misc import Message, boolean, get_app
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from collections.abc import Iterator, Sequence
|
||||
from bsql import Row
|
||||
from typing import Any
|
||||
if TYPE_CHECKING:
|
||||
from ..application import Application
|
||||
from ..misc import Message
|
||||
|
||||
|
||||
RELAY_SOFTWARE = [
|
||||
'activityrelay', # https://git.pleroma.social/pleroma/relay
|
||||
|
@ -94,7 +89,7 @@ class Connection(SqlConnection):
|
|||
params = {
|
||||
'key': key,
|
||||
'value': data.get(key, serialize = True),
|
||||
'type': 'LogLevel' if field.type == 'logging.LogLevel' else field.type
|
||||
'type': 'LogLevel' if field.type == 'logging.LogLevel' else field.type # type: ignore
|
||||
}
|
||||
|
||||
with self.run('put-config', params):
|
||||
|
|
|
@ -1,17 +1,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from bsql import Column, Table, Tables
|
||||
from collections.abc import Callable
|
||||
|
||||
from .config import ConfigData
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
from .connection import Connection
|
||||
from .connection import Connection
|
||||
|
||||
|
||||
VERSIONS: dict[int, Callable] = {}
|
||||
VERSIONS: dict[int, Callable[[Connection], None]] = {}
|
||||
TABLES: Tables = Tables(
|
||||
Table(
|
||||
'config',
|
||||
|
@ -64,7 +58,7 @@ TABLES: Tables = Tables(
|
|||
)
|
||||
|
||||
|
||||
def migration(func: Callable) -> Callable:
|
||||
def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]:
|
||||
ver = int(func.__name__.replace('migrate_', ''))
|
||||
VERSIONS[ver] = func
|
||||
return func
|
||||
|
|
|
@ -2,26 +2,23 @@ from __future__ import annotations
|
|||
|
||||
import json
|
||||
import traceback
|
||||
import typing
|
||||
|
||||
from aiohttp import ClientSession, ClientTimeout, TCPConnector
|
||||
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
|
||||
from aputils import AlgorithmType, Nodeinfo, ObjectType, Signer, WellKnownNodeinfo
|
||||
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
|
||||
from aputils import AlgorithmType, Nodeinfo, ObjectType, WellKnownNodeinfo
|
||||
from blib import JsonBase
|
||||
from bsql import Row
|
||||
from json.decoder import JSONDecodeError
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from . import __version__
|
||||
from . import logger as logging
|
||||
from . import __version__, logger as logging
|
||||
from .cache import Cache
|
||||
from .misc import MIMETYPES, Message, get_app
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from aputils import Signer
|
||||
from bsql import Row
|
||||
from typing import Any
|
||||
if TYPE_CHECKING:
|
||||
from .application import Application
|
||||
from .cache import Cache
|
||||
|
||||
|
||||
SUPPORTS_HS2019 = {
|
||||
|
@ -39,7 +36,7 @@ SUPPORTS_HS2019 = {
|
|||
'sharkey'
|
||||
}
|
||||
|
||||
T = typing.TypeVar('T', bound = JsonBase)
|
||||
T = TypeVar('T', bound = JsonBase)
|
||||
HEADERS = {
|
||||
'Accept': f'{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9',
|
||||
'User-Agent': f'ActivityRelay/{__version__}'
|
||||
|
@ -124,7 +121,7 @@ class HttpClient:
|
|||
if not force:
|
||||
try:
|
||||
if not (item := self.cache.get('request', url)).older_than(48):
|
||||
return json.loads(item.value)
|
||||
return json.loads(item.value) # type: ignore[no-any-return]
|
||||
|
||||
except KeyError:
|
||||
logging.verbose('No cached data for url: %s', url)
|
||||
|
@ -153,7 +150,7 @@ class HttpClient:
|
|||
self.cache.set('request', url, data, 'str')
|
||||
logging.debug('%s >> resp %s', url, json.dumps(json.loads(data), indent = 4))
|
||||
|
||||
return json.loads(data)
|
||||
return json.loads(data) # type: ignore [no-any-return]
|
||||
|
||||
except JSONDecodeError:
|
||||
logging.verbose('Failed to parse JSON')
|
||||
|
|
|
@ -1,21 +1,19 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import typing
|
||||
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
try:
|
||||
from typing import Self
|
||||
|
||||
try:
|
||||
from typing import Self
|
||||
except ImportError:
|
||||
from typing_extensions import Self
|
||||
|
||||
except ImportError:
|
||||
from typing_extensions import Self
|
||||
|
||||
class LoggingMethod(Protocol):
|
||||
def __call__(self, msg: Any, *args: Any, **kwargs: Any) -> None: ...
|
||||
|
||||
|
||||
class LogLevel(IntEnum):
|
||||
|
@ -75,11 +73,11 @@ def verbose(message: str, *args: Any, **kwargs: Any) -> None:
|
|||
logging.log(LogLevel.VERBOSE, message, *args, **kwargs)
|
||||
|
||||
|
||||
debug: Callable = logging.debug
|
||||
info: Callable = logging.info
|
||||
warning: Callable = logging.warning
|
||||
error: Callable = logging.error
|
||||
critical: Callable = logging.critical
|
||||
debug: LoggingMethod = logging.debug
|
||||
info: LoggingMethod = logging.info
|
||||
warning: LoggingMethod = logging.warning
|
||||
error: LoggingMethod = logging.error
|
||||
critical: LoggingMethod = logging.critical
|
||||
|
||||
|
||||
try:
|
||||
|
|
|
@ -5,10 +5,11 @@ import asyncio
|
|||
import click
|
||||
import json
|
||||
import os
|
||||
import typing
|
||||
|
||||
from bsql import Row
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from . import __version__
|
||||
|
@ -19,10 +20,6 @@ from .compat import RelayConfig, RelayDatabase
|
|||
from .database import RELAY_SOFTWARE, get_database
|
||||
from .misc import ACTOR_FORMATS, SOFTWARE, IS_DOCKER, Message
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from bsql import Row
|
||||
from typing import Any
|
||||
|
||||
|
||||
def check_alphanumeric(text: str) -> str:
|
||||
if not text.isalnum():
|
||||
|
|
|
@ -5,11 +5,12 @@ import json
|
|||
import os
|
||||
import platform
|
||||
import socket
|
||||
import typing
|
||||
|
||||
from aiohttp.web import Response as AiohttpResponse
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, TypedDict, TypeVar
|
||||
from uuid import uuid4
|
||||
|
||||
try:
|
||||
|
@ -18,21 +19,20 @@ try:
|
|||
except ImportError:
|
||||
from importlib_resources import files as pkgfiles # type: ignore
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from typing import Any
|
||||
try:
|
||||
from typing import Self
|
||||
|
||||
except ImportError:
|
||||
from typing_extensions import Self
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .application import Application
|
||||
|
||||
try:
|
||||
from typing import Self
|
||||
|
||||
except ImportError:
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
T = typing.TypeVar('T')
|
||||
ResponseType = typing.TypedDict('ResponseType', {
|
||||
T = TypeVar('T')
|
||||
ResponseType = TypedDict('ResponseType', {
|
||||
'status': int,
|
||||
'headers': dict[str, typing.Any] | None,
|
||||
'headers': dict[str, Any] | None,
|
||||
'content_type': str,
|
||||
'body': bytes | None,
|
||||
'text': str | None
|
||||
|
@ -128,7 +128,7 @@ class JsonEncoder(json.JSONEncoder):
|
|||
if isinstance(o, datetime):
|
||||
return o.isoformat()
|
||||
|
||||
return json.JSONEncoder.default(self, o)
|
||||
return json.JSONEncoder.default(self, o) # type: ignore[no-any-return]
|
||||
|
||||
|
||||
class Message(aputils.Message):
|
||||
|
@ -214,7 +214,7 @@ class Response(AiohttpResponse):
|
|||
|
||||
@classmethod
|
||||
def new(cls: type[Self],
|
||||
body: str | bytes | dict | tuple | list | set = '',
|
||||
body: str | bytes | dict[str, Any] | Sequence[Any] = '',
|
||||
status: int = 200,
|
||||
headers: dict[str, str] | None = None,
|
||||
ctype: str = 'text') -> Self:
|
||||
|
@ -227,22 +227,22 @@ class Response(AiohttpResponse):
|
|||
'text': None
|
||||
}
|
||||
|
||||
if isinstance(body, bytes):
|
||||
if isinstance(body, str):
|
||||
kwargs['text'] = body
|
||||
|
||||
elif isinstance(body, bytes):
|
||||
kwargs['body'] = body
|
||||
|
||||
elif isinstance(body, (dict, list, tuple, set)) or ctype in {'json', 'activity'}:
|
||||
elif isinstance(body, (dict, Sequence)):
|
||||
kwargs['text'] = json.dumps(body, cls = JsonEncoder)
|
||||
|
||||
else:
|
||||
kwargs['text'] = body
|
||||
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
@classmethod
|
||||
def new_error(cls: type[Self],
|
||||
status: int,
|
||||
body: str | bytes | dict,
|
||||
body: str | bytes | dict[str, Any],
|
||||
ctype: str = 'text') -> Self:
|
||||
|
||||
if ctype == 'json':
|
||||
|
|
|
@ -1,25 +1,22 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import textwrap
|
||||
import typing
|
||||
|
||||
from collections.abc import Callable
|
||||
from hamlish_jinja import HamlishExtension
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from jinja2.ext import Extension
|
||||
from jinja2.nodes import CallBlock
|
||||
from jinja2.nodes import CallBlock, Node
|
||||
from jinja2.parser import Parser
|
||||
from markdown import Markdown
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from . import __version__
|
||||
from .misc import get_resource
|
||||
from .views.base import View
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from jinja2.nodes import Node
|
||||
from jinja2.parser import Parser
|
||||
from typing import Any
|
||||
if TYPE_CHECKING:
|
||||
from .application import Application
|
||||
from .views.base import View
|
||||
|
||||
|
||||
class Template(Environment):
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp.web import Request, middleware
|
||||
from argon2.exceptions import VerifyMismatchError
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from .base import View, register_route
|
||||
|
@ -12,11 +10,6 @@ from .. import __version__
|
|||
from ..database import ConfigData
|
||||
from ..misc import Message, Response, boolean, get_app
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from aiohttp.web import Request
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
|
||||
|
||||
ALLOWED_HEADERS = {
|
||||
'accept',
|
||||
|
@ -37,8 +30,10 @@ def check_api_path(method: str, path: str) -> bool:
|
|||
return path.startswith('/api')
|
||||
|
||||
|
||||
@web.middleware
|
||||
async def handle_api_path(request: Request, handler: Callable) -> Response:
|
||||
@middleware
|
||||
async def handle_api_path(
|
||||
request: Request,
|
||||
handler: Callable[[Request], Awaitable[Response]]) -> Response:
|
||||
try:
|
||||
if (token := request.cookies.get('user-token')):
|
||||
request['token'] = token
|
||||
|
@ -222,7 +217,7 @@ class Inbox(View):
|
|||
if nodeinfo is not None:
|
||||
data['software'] = nodeinfo.sw_name
|
||||
|
||||
row = conn.put_inbox(**data)
|
||||
row = conn.put_inbox(**data) # type: ignore[arg-type]
|
||||
|
||||
return Response.new(row, ctype = 'json')
|
||||
|
||||
|
@ -237,7 +232,7 @@ class Inbox(View):
|
|||
if not (instance := conn.get_inbox(data['domain'])):
|
||||
return Response.new_error(404, 'Instance with domain not found', 'json')
|
||||
|
||||
instance = conn.put_inbox(instance['domain'], **data)
|
||||
instance = conn.put_inbox(instance['domain'], **data) # type: ignore[arg-type]
|
||||
|
||||
return Response.new(instance, ctype = 'json')
|
||||
|
||||
|
|
|
@ -1,33 +1,33 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from Crypto.Random import get_random_bytes
|
||||
from aiohttp.abc import AbstractView
|
||||
from aiohttp.hdrs import METH_ALL as METHODS
|
||||
from aiohttp.web import HTTPMethodNotAllowed
|
||||
from aiohttp.web import HTTPMethodNotAllowed, Request
|
||||
from base64 import b64encode
|
||||
from bsql import Database
|
||||
from collections.abc import Awaitable, Callable, Generator, Sequence, Mapping
|
||||
from functools import cached_property
|
||||
from json.decoder import JSONDecodeError
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ..cache import Cache
|
||||
from ..config import Config
|
||||
from ..database import Connection
|
||||
from ..http_client import HttpClient
|
||||
from ..misc import Response, get_app
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from aiohttp.web import Request
|
||||
from collections.abc import Callable, Generator, Sequence, Mapping
|
||||
from bsql import Database
|
||||
from typing import Any
|
||||
if TYPE_CHECKING:
|
||||
from ..application import Application
|
||||
from ..cache import Cache
|
||||
from ..config import Config
|
||||
from ..http_client import HttpClient
|
||||
from ..template import Template
|
||||
|
||||
try:
|
||||
from typing import Self
|
||||
try:
|
||||
from typing import Self
|
||||
|
||||
except ImportError:
|
||||
from typing_extensions import Self
|
||||
except ImportError:
|
||||
from typing_extensions import Self
|
||||
|
||||
HandlerCallback = Callable[[Request], Awaitable[Response]]
|
||||
|
||||
|
||||
VIEWS: list[tuple[str, type[View]]] = []
|
||||
|
@ -37,7 +37,7 @@ def convert_data(data: Mapping[str, Any]) -> dict[str, str]:
|
|||
return {key: str(value) for key, value in data.items()}
|
||||
|
||||
|
||||
def register_route(*paths: str) -> Callable:
|
||||
def register_route(*paths: str) -> Callable[[type[View]], type[View]]:
|
||||
def wrapper(view: type[View]) -> type[View]:
|
||||
for path in paths:
|
||||
VIEWS.append((path, view))
|
||||
|
@ -63,7 +63,7 @@ class View(AbstractView):
|
|||
return await view.handlers[method](request, **kwargs)
|
||||
|
||||
|
||||
async def _run_handler(self, handler: Callable[..., Any], **kwargs: Any) -> Response:
|
||||
async def _run_handler(self, handler: HandlerCallback, **kwargs: Any) -> Response:
|
||||
self.request['hash'] = b64encode(get_random_bytes(16)).decode('ascii')
|
||||
return await handler(self.request, **self.request.match_info, **kwargs)
|
||||
|
||||
|
@ -78,7 +78,7 @@ class View(AbstractView):
|
|||
|
||||
|
||||
@cached_property
|
||||
def handlers(self) -> dict[str, Callable[..., Any]]:
|
||||
def handlers(self) -> dict[str, HandlerCallback]:
|
||||
data = {}
|
||||
|
||||
for method in METHODS:
|
||||
|
@ -112,13 +112,13 @@ class View(AbstractView):
|
|||
|
||||
|
||||
@property
|
||||
def database(self) -> Database:
|
||||
def database(self) -> Database[Connection]:
|
||||
return self.app.database
|
||||
|
||||
|
||||
@property
|
||||
def template(self) -> Template:
|
||||
return self.app['template']
|
||||
return self.app['template'] # type: ignore[no-any-return]
|
||||
|
||||
|
||||
async def get_api_data(self,
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from aiohttp import web
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from .base import View, register_route
|
||||
|
||||
|
@ -10,11 +8,6 @@ from ..database import THEMES
|
|||
from ..logger import LogLevel
|
||||
from ..misc import Response, get_app
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from aiohttp.web import Request
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
|
||||
UNAUTH_ROUTES = {
|
||||
'/',
|
||||
|
@ -23,7 +16,10 @@ UNAUTH_ROUTES = {
|
|||
|
||||
|
||||
@web.middleware
|
||||
async def handle_frontend_path(request: web.Request, handler: Callable) -> Response:
|
||||
async def handle_frontend_path(
|
||||
request: web.Request,
|
||||
handler: Callable[[web.Request], Awaitable[Response]]) -> Response:
|
||||
|
||||
app = get_app()
|
||||
|
||||
if request.path in UNAUTH_ROUTES or request.path.startswith('/admin'):
|
||||
|
@ -52,7 +48,7 @@ async def handle_frontend_path(request: web.Request, handler: Callable) -> Respo
|
|||
|
||||
@register_route('/')
|
||||
class HomeView(View):
|
||||
async def get(self, request: Request) -> Response:
|
||||
async def get(self, request: web.Request) -> Response:
|
||||
with self.database.session() as conn:
|
||||
context: dict[str, Any] = {
|
||||
'instances': tuple(conn.get_inboxes())
|
||||
|
@ -64,14 +60,14 @@ class HomeView(View):
|
|||
|
||||
@register_route('/login')
|
||||
class Login(View):
|
||||
async def get(self, request: Request) -> Response:
|
||||
async def get(self, request: web.Request) -> Response:
|
||||
data = self.template.render('page/login.haml', self)
|
||||
return Response.new(data, ctype = 'html')
|
||||
|
||||
|
||||
@register_route('/logout')
|
||||
class Logout(View):
|
||||
async def get(self, request: Request) -> Response:
|
||||
async def get(self, request: web.Request) -> Response:
|
||||
with self.database.session(True) as conn:
|
||||
conn.del_token(request['token'])
|
||||
|
||||
|
@ -82,14 +78,14 @@ class Logout(View):
|
|||
|
||||
@register_route('/admin')
|
||||
class Admin(View):
|
||||
async def get(self, request: Request) -> Response:
|
||||
async def get(self, request: web.Request) -> Response:
|
||||
return Response.new('', 302, {'Location': '/admin/instances'})
|
||||
|
||||
|
||||
@register_route('/admin/instances')
|
||||
class AdminInstances(View):
|
||||
async def get(self,
|
||||
request: Request,
|
||||
request: web.Request,
|
||||
error: str | None = None,
|
||||
message: str | None = None) -> Response:
|
||||
|
||||
|
@ -112,7 +108,7 @@ class AdminInstances(View):
|
|||
@register_route('/admin/whitelist')
|
||||
class AdminWhitelist(View):
|
||||
async def get(self,
|
||||
request: Request,
|
||||
request: web.Request,
|
||||
error: str | None = None,
|
||||
message: str | None = None) -> Response:
|
||||
|
||||
|
@ -134,7 +130,7 @@ class AdminWhitelist(View):
|
|||
@register_route('/admin/domain_bans')
|
||||
class AdminDomainBans(View):
|
||||
async def get(self,
|
||||
request: Request,
|
||||
request: web.Request,
|
||||
error: str | None = None,
|
||||
message: str | None = None) -> Response:
|
||||
|
||||
|
@ -156,7 +152,7 @@ class AdminDomainBans(View):
|
|||
@register_route('/admin/software_bans')
|
||||
class AdminSoftwareBans(View):
|
||||
async def get(self,
|
||||
request: Request,
|
||||
request: web.Request,
|
||||
error: str | None = None,
|
||||
message: str | None = None) -> Response:
|
||||
|
||||
|
@ -178,7 +174,7 @@ class AdminSoftwareBans(View):
|
|||
@register_route('/admin/users')
|
||||
class AdminUsers(View):
|
||||
async def get(self,
|
||||
request: Request,
|
||||
request: web.Request,
|
||||
error: str | None = None,
|
||||
message: str | None = None) -> Response:
|
||||
|
||||
|
@ -199,7 +195,7 @@ class AdminUsers(View):
|
|||
|
||||
@register_route('/admin/config')
|
||||
class AdminConfig(View):
|
||||
async def get(self, request: Request, message: str | None = None) -> Response:
|
||||
async def get(self, request: web.Request, message: str | None = None) -> Response:
|
||||
context: dict[str, Any] = {
|
||||
'themes': tuple(THEMES.keys()),
|
||||
'levels': tuple(level.name for level in LogLevel),
|
||||
|
@ -212,7 +208,7 @@ class AdminConfig(View):
|
|||
|
||||
@register_route('/manifest.json')
|
||||
class ManifestJson(View):
|
||||
async def get(self, request: Request) -> Response:
|
||||
async def get(self, request: web.Request) -> Response:
|
||||
with self.database.session(False) as conn:
|
||||
config = conn.get_config_all()
|
||||
theme = THEMES[config.theme]
|
||||
|
@ -235,7 +231,7 @@ class ManifestJson(View):
|
|||
|
||||
@register_route('/theme/{theme}.css')
|
||||
class ThemeCss(View):
|
||||
async def get(self, request: Request, theme: str) -> Response:
|
||||
async def get(self, request: web.Request, theme: str) -> Response:
|
||||
try:
|
||||
context: dict[str, Any] = {
|
||||
'theme': THEMES[theme]
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import aputils
|
||||
import subprocess
|
||||
import typing
|
||||
|
||||
from aiohttp.web import Request
|
||||
from pathlib import Path
|
||||
|
||||
from .base import View, register_route
|
||||
|
@ -11,9 +9,6 @@ from .base import View, register_route
|
|||
from .. import __version__
|
||||
from ..misc import Response
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from aiohttp.web import Request
|
||||
|
||||
|
||||
VERSION = __version__
|
||||
|
||||
|
|
Loading…
Reference in a new issue