ensure the relay can run on python >= 3.10

This commit is contained in:
Izalia Mae 2024-07-03 00:59:59 -04:00
parent e8b3a210a9
commit b22b5bbefa
10 changed files with 119 additions and 99 deletions

47
dev.py
View file

@ -1,25 +1,38 @@
#!/usr/bin/env python3
import click
import platform
import shutil
import subprocess
import sys
import time
import tomllib
from datetime import datetime, timedelta
from importlib.util import find_spec
from pathlib import Path
from relay import __version__, logger as logging
from tempfile import TemporaryDirectory
from typing import Any, Sequence
try:
from watchdog.observers import Observer
from watchdog.events import FileSystemEvent, PatternMatchingEventHandler
import tomllib
except ImportError:
class PatternMatchingEventHandler: # type: ignore
pass
if find_spec("toml") is None:
subprocess.run([sys.executable, "-m", "pip", "install", "toml"])
import toml as tomllib # type: ignore[no-redef]
if None in [find_spec("click"), find_spec("watchdog")]:
CMD = [sys.executable, "-m", "pip", "install", "click >= 8.1.0", "watchdog >= 4.0.0"]
PROC = subprocess.run(CMD, check = False)
if PROC.returncode != 0:
sys.exit()
print("Successfully installed dependencies")
import click
from watchdog.observers import Observer
from watchdog.events import FileSystemEvent, PatternMatchingEventHandler
REPO = Path(__file__).parent
@ -37,12 +50,10 @@ def cli() -> None:
@cli.command('install')
@click.option('--no-dev', '-d', is_flag = True, help = 'Do not install development dependencies')
def cli_install(no_dev: bool) -> None:
with open('pyproject.toml', 'rb') as fd:
data = tomllib.load(fd)
with open('pyproject.toml', 'r', encoding = 'utf-8') as fd:
data = tomllib.loads(fd.read())
deps = data['project']['dependencies']
if not no_dev:
deps.extend(data['project']['optional-dependencies']['dev'])
subprocess.run([sys.executable, '-m', 'pip', 'install', '-U', *deps], check = False)
@ -60,7 +71,7 @@ def cli_lint(path: Path, watch: bool) -> None:
return
flake8 = [sys.executable, '-m', 'flake8', "dev.py", str(path)]
mypy = [sys.executable, '-m', 'mypy', "dev.py", str(path)]
mypy = [sys.executable, '-m', 'mypy', '--python-version', '3.12', 'dev.py', str(path)]
click.echo('----- flake8 -----')
subprocess.run(flake8)
@ -89,6 +100,8 @@ def cli_clean() -> None:
@cli.command('build')
def cli_build() -> None:
from relay import __version__
with TemporaryDirectory() as tmp:
arch = 'amd64' if sys.maxsize >= 2**32 else 'i386'
cmd = [
@ -171,7 +184,7 @@ class WatchHandler(PatternMatchingEventHandler):
if proc.poll() is not None:
continue
logging.info(f'Terminating process {proc.pid}')
print(f'Terminating process {proc.pid}')
proc.terminate()
sec = 0.0
@ -180,11 +193,11 @@ class WatchHandler(PatternMatchingEventHandler):
sec += 0.1
if sec >= 5:
logging.error('Failed to terminate. Killing process...')
print('Failed to terminate. Killing process...')
proc.kill()
break
logging.info('Process terminated')
print('Process terminated')
def run_procs(self, restart: bool = False) -> None:
@ -200,13 +213,13 @@ class WatchHandler(PatternMatchingEventHandler):
self.procs = []
for cmd in self.commands:
logging.info('Running command: %s', ' '.join(cmd))
print('Running command:', ' '.join(cmd))
subprocess.run(cmd)
else:
self.procs = list(subprocess.Popen(cmd) for cmd in self.commands)
pids = (str(proc.pid) for proc in self.procs)
logging.info('Started processes with PIDs: %s', ', '.join(pids))
print('Started processes with PIDs:', ', '.join(pids))
def on_any_event(self, event: FileSystemEvent) -> None:

View file

@ -9,30 +9,27 @@ license = {text = "AGPLv3"}
classifiers = [
"Environment :: Console",
"License :: OSI Approved :: GNU Affero General Public License v3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.12"
]
dependencies = [
"activitypub-utils >= 0.3.1, < 0.4.0",
"activitypub-utils >= 0.3.1.post1, < 0.4.0",
"aiohttp >= 3.9.5",
"aiohttp-swagger[performance] == 1.0.16",
"argon2-cffi == 23.1.0",
"barkshark-lib >= 0.1.4, < 0.2.0",
"barkshark-sql >= 0.2.0-rc1, < 0.3.0",
"barkshark-lib >= 0.1.5rc1, < 0.2.0",
"barkshark-sql >= 0.2.0rc2, < 0.3.0",
"click == 8.1.2",
"hiredis == 2.3.2",
"idna == 3.4",
"jinja2-haml == 0.3.5",
"markdown == 3.6",
"platformdirs == 4.2.2",
"pyyaml == 6.0",
"redis == 5.0.5",
"importlib-resources == 6.4.0; python_version < '3.9'"
"pyyaml == 6.0.1",
"redis == 5.0.7"
]
requires-python = ">=3.8"
requires-python = ">=3.10"
dynamic = ["version"]
[project.readme]
@ -49,11 +46,10 @@ activityrelay = "relay.manage:main"
[project.optional-dependencies]
dev = [
"flake8 == 7.0.0",
"mypy == 1.10.0",
"flake8 == 7.1.0",
"mypy == 1.10.1",
"pyinstaller == 6.8.0",
"watchdog == 4.0.1",
"typing-extensions == 4.12.2; python_version < '3.11.0'"
"watchdog == 4.0.1"
]
[tool.setuptools]

View file

@ -4,12 +4,13 @@ import json
import os
from abc import ABC, abstractmethod
from blib import Date
from bsql import Database, Row
from collections.abc import Callable, Iterator
from dataclasses import asdict, dataclass
from datetime import datetime, timedelta, timezone
from datetime import timedelta
from redis import Redis
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, TypedDict
from .database import Connection, get_database
from .misc import Message, boolean
@ -31,6 +32,14 @@ CONVERTERS: dict[str, tuple[SerializerCallback, DeserializerCallback]] = {
}
class RedisConnectType(TypedDict):
client_name: str
decode_responses: bool
username: str | None
password: str | None
db: int
def get_cache(app: Application) -> Cache:
return BACKENDS[app.config.ca_type](app)
@ -57,12 +66,11 @@ class Item:
key: str
value: Any
value_type: str
updated: datetime
updated: Date
def __post_init__(self) -> None:
if isinstance(self.updated, str): # type: ignore[unreachable]
self.updated = datetime.fromisoformat(self.updated) # type: ignore[unreachable]
self.updated = Date.parse(self.updated)
@classmethod
@ -70,14 +78,11 @@ class Item:
data = cls(*args)
data.value = deserialize_value(data.value, data.value_type)
if not isinstance(data.updated, datetime):
data.updated = datetime.fromtimestamp(data.updated, tz = timezone.utc) # type: ignore
return data
def older_than(self, hours: int) -> bool:
delta = datetime.now(tz = timezone.utc) - self.updated
delta = Date.new_utc() - self.updated
return (delta.total_seconds()) > hours * 3600
@ -206,7 +211,7 @@ class SqlCache(Cache):
'key': key,
'value': serialize_value(value, value_type),
'type': value_type,
'date': datetime.now(tz = timezone.utc)
'date': Date.new_utc()
}
with self._db.session(True) as conn:
@ -236,7 +241,7 @@ class SqlCache(Cache):
if self._db is None:
raise RuntimeError("Database has not been setup")
limit = datetime.now(tz = timezone.utc) - timedelta(days = days)
limit = Date.new_utc() - timedelta(days = days)
params = {"limit": limit.timestamp()}
with self._db.session(True) as conn:
@ -280,7 +285,7 @@ class RedisCache(Cache):
def __init__(self, app: Application):
Cache.__init__(self, app)
self._rd: Redis = None # type: ignore
self._rd: Redis | None = None
@property
@ -293,28 +298,38 @@ class RedisCache(Cache):
def get(self, namespace: str, key: str) -> Item:
if self._rd is None:
raise ConnectionError("Not connected")
key_name = self.get_key_name(namespace, key)
if not (raw_value := self._rd.get(key_name)):
raise KeyError(f'{namespace}:{key}')
value_type, updated, value = raw_value.split(':', 2) # type: ignore
value_type, updated, value = raw_value.split(':', 2) # type: ignore[union-attr]
return Item.from_data(
namespace,
key,
value,
value_type,
datetime.fromtimestamp(float(updated), tz = timezone.utc)
Date.parse(float(updated))
)
def get_keys(self, namespace: str) -> Iterator[str]:
if self._rd is None:
raise ConnectionError("Not connected")
for key in self._rd.scan_iter(self.get_key_name(namespace, '*')):
*_, key_name = key.split(':', 2)
yield key_name
def get_namespaces(self) -> Iterator[str]:
if self._rd is None:
raise ConnectionError("Not connected")
namespaces = []
for key in self._rd.scan_iter(f'{self.prefix}:*'):
@ -326,7 +341,10 @@ class RedisCache(Cache):
def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> Item:
date = datetime.now(tz = timezone.utc).timestamp()
if self._rd is None:
raise ConnectionError("Not connected")
date = Date.new_utc().timestamp()
value = serialize_value(value, value_type)
self._rd.set(
@ -338,11 +356,17 @@ class RedisCache(Cache):
def delete(self, namespace: str, key: str) -> None:
if self._rd is None:
raise ConnectionError("Not connected")
self._rd.delete(self.get_key_name(namespace, key))
def delete_old(self, days: int = 14) -> None:
limit = datetime.now(tz = timezone.utc) - timedelta(days = days)
if self._rd is None:
raise ConnectionError("Not connected")
limit = Date.new_utc() - timedelta(days = days)
for full_key in self._rd.scan_iter(f'{self.prefix}:*'):
_, namespace, key = full_key.split(':', 2)
@ -353,14 +377,17 @@ class RedisCache(Cache):
def clear(self) -> None:
if self._rd is None:
raise ConnectionError("Not connected")
self._rd.delete(f"{self.prefix}:*")
def setup(self) -> None:
if self._rd:
if self._rd is not None:
return
options = {
options: RedisConnectType = {
'client_name': f'ActivityRelay_{self.app.config.domain}',
'decode_responses': True,
'username': self.app.config.rd_user,
@ -369,18 +396,22 @@ class RedisCache(Cache):
}
if os.path.exists(self.app.config.rd_host):
options['unix_socket_path'] = self.app.config.rd_host
self._rd = Redis(
unix_socket_path = self.app.config.rd_host,
**options
)
return
else:
options['host'] = self.app.config.rd_host
options['port'] = self.app.config.rd_port
self._rd = Redis(**options) # type: ignore
self._rd = Redis(
host = self.app.config.rd_host,
port = self.app.config.rd_port,
**options
)
def close(self) -> None:
if not self._rd:
return
self._rd.close() # type: ignore
self._rd = None # type: ignore
self._rd.close() # type: ignore[no-untyped-call]
self._rd = None

View file

@ -13,12 +13,8 @@ from typing import TYPE_CHECKING, Any
from .misc import IS_DOCKER
if TYPE_CHECKING:
try:
from typing import Self
except ImportError:
from typing_extensions import Self
if platform.system() == 'Windows':
import multiprocessing
@ -84,7 +80,7 @@ class Config:
def DEFAULT(cls: type[Self], key: str) -> str | int | None:
for field in fields(cls):
if field.name == key:
return field.default # type: ignore
return field.default # type: ignore[return-value]
raise KeyError(key)

View file

@ -8,12 +8,8 @@ from pathlib import Path
from typing import TYPE_CHECKING, Any, Protocol
if TYPE_CHECKING:
try:
from typing import Self
except ImportError:
from typing_extensions import Self
class LoggingMethod(Protocol):
def __call__(self, msg: Any, *args: Any, **kwargs: Any) -> None: ...

View file

@ -9,24 +9,14 @@ import socket
from aiohttp.web import Response as AiohttpResponse
from collections.abc import Sequence
from datetime import datetime
from importlib.resources import files as pkgfiles
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypedDict, TypeVar
from uuid import uuid4
try:
from importlib.resources import files as pkgfiles
except ImportError:
from importlib_resources import files as pkgfiles # type: ignore
if TYPE_CHECKING:
from .application import Application
try:
from typing import Self
except ImportError:
from typing_extensions import Self
from .application import Application
T = TypeVar('T')

View file

@ -20,6 +20,9 @@ if TYPE_CHECKING:
class Template(Environment):
_render_markdown: Callable[[str], str]
def __init__(self, app: Application):
Environment.__init__(self,
autoescape = True,
@ -56,7 +59,7 @@ class Template(Environment):
def render_markdown(self, text: str) -> str:
return self._render_markdown(text) # type: ignore
return self._render_markdown(text)
class MarkdownExtension(Extension):

View file

@ -13,7 +13,8 @@ from ..database import ConfigData
from ..misc import Message, Response, boolean, get_app
ALLOWED_HEADERS = {
DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob'
ALLOWED_HEADERS: set[str] = {
'accept',
'authorization',
'content-type'

View file

@ -18,18 +18,12 @@ from ..http_client import HttpClient
from ..misc import Response, get_app
if TYPE_CHECKING:
from typing import Self
from ..application import Application
from ..template import Template
try:
from typing import Self
except ImportError:
from typing_extensions import Self
HandlerCallback = Callable[[Request], Awaitable[Response]]
VIEWS: list[tuple[str, type[View]]] = []

View file

@ -1,14 +1,17 @@
from __future__ import annotations
import asyncio
import traceback
import typing
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
from dataclasses import dataclass
from multiprocessing import Event, Process, Queue, Value
from multiprocessing.queues import Queue as QueueType
from multiprocessing.sharedctypes import Synchronized
from multiprocessing.synchronize import Event as EventType
from pathlib import Path
from queue import Empty, Queue as QueueType
from queue import Empty
from urllib.parse import urlparse
from . import application, logger as logging
@ -16,9 +19,6 @@ from .database.schema import Instance
from .http_client import HttpClient
from .misc import IS_WINDOWS, Message, get_app
if typing.TYPE_CHECKING:
from .multiprocessing.synchronize import Syncronized
@dataclass
class QueueItem:
@ -40,13 +40,13 @@ class PushWorker(Process):
client: HttpClient
def __init__(self, queue: QueueType[QueueItem], log_level: "Syncronized[str]") -> None:
def __init__(self, queue: QueueType[QueueItem], log_level: Synchronized[int]) -> None:
Process.__init__(self)
self.queue: QueueType[QueueItem] = queue
self.shutdown: EventType = Event()
self.path: Path = get_app().config.path
self.log_level: "Syncronized[str]" = log_level
self.log_level: Synchronized[int] = log_level
self._log_level_changed: EventType = Event()
@ -113,8 +113,8 @@ class PushWorker(Process):
class PushWorkers(list[PushWorker]):
def __init__(self, count: int) -> None:
self.queue: QueueType[QueueItem] = Queue() # type: ignore[assignment]
self._log_level: "Syncronized[str]" = Value("i", logging.get_level())
self.queue: QueueType[QueueItem] = Queue()
self._log_level: Synchronized[int] = Value("i", logging.get_level())
self._count: int = count