ensure the relay can run on python >= 3.10

This commit is contained in:
Izalia Mae 2024-07-03 00:59:59 -04:00
parent e8b3a210a9
commit b22b5bbefa
10 changed files with 119 additions and 99 deletions

47
dev.py
View file

@ -1,25 +1,38 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import click
import platform import platform
import shutil import shutil
import subprocess import subprocess
import sys import sys
import time import time
import tomllib
from datetime import datetime, timedelta from datetime import datetime, timedelta
from importlib.util import find_spec
from pathlib import Path from pathlib import Path
from relay import __version__, logger as logging
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Any, Sequence from typing import Any, Sequence
try: try:
from watchdog.observers import Observer import tomllib
from watchdog.events import FileSystemEvent, PatternMatchingEventHandler
except ImportError: except ImportError:
class PatternMatchingEventHandler: # type: ignore if find_spec("toml") is None:
pass subprocess.run([sys.executable, "-m", "pip", "install", "toml"])
import toml as tomllib # type: ignore[no-redef]
if None in [find_spec("click"), find_spec("watchdog")]:
CMD = [sys.executable, "-m", "pip", "install", "click >= 8.1.0", "watchdog >= 4.0.0"]
PROC = subprocess.run(CMD, check = False)
if PROC.returncode != 0:
sys.exit()
print("Successfully installed dependencies")
import click
from watchdog.observers import Observer
from watchdog.events import FileSystemEvent, PatternMatchingEventHandler
REPO = Path(__file__).parent REPO = Path(__file__).parent
@ -37,12 +50,10 @@ def cli() -> None:
@cli.command('install') @cli.command('install')
@click.option('--no-dev', '-d', is_flag = True, help = 'Do not install development dependencies') @click.option('--no-dev', '-d', is_flag = True, help = 'Do not install development dependencies')
def cli_install(no_dev: bool) -> None: def cli_install(no_dev: bool) -> None:
with open('pyproject.toml', 'rb') as fd: with open('pyproject.toml', 'r', encoding = 'utf-8') as fd:
data = tomllib.load(fd) data = tomllib.loads(fd.read())
deps = data['project']['dependencies'] deps = data['project']['dependencies']
if not no_dev:
deps.extend(data['project']['optional-dependencies']['dev']) deps.extend(data['project']['optional-dependencies']['dev'])
subprocess.run([sys.executable, '-m', 'pip', 'install', '-U', *deps], check = False) subprocess.run([sys.executable, '-m', 'pip', 'install', '-U', *deps], check = False)
@ -60,7 +71,7 @@ def cli_lint(path: Path, watch: bool) -> None:
return return
flake8 = [sys.executable, '-m', 'flake8', "dev.py", str(path)] flake8 = [sys.executable, '-m', 'flake8', "dev.py", str(path)]
mypy = [sys.executable, '-m', 'mypy', "dev.py", str(path)] mypy = [sys.executable, '-m', 'mypy', '--python-version', '3.12', 'dev.py', str(path)]
click.echo('----- flake8 -----') click.echo('----- flake8 -----')
subprocess.run(flake8) subprocess.run(flake8)
@ -89,6 +100,8 @@ def cli_clean() -> None:
@cli.command('build') @cli.command('build')
def cli_build() -> None: def cli_build() -> None:
from relay import __version__
with TemporaryDirectory() as tmp: with TemporaryDirectory() as tmp:
arch = 'amd64' if sys.maxsize >= 2**32 else 'i386' arch = 'amd64' if sys.maxsize >= 2**32 else 'i386'
cmd = [ cmd = [
@ -171,7 +184,7 @@ class WatchHandler(PatternMatchingEventHandler):
if proc.poll() is not None: if proc.poll() is not None:
continue continue
logging.info(f'Terminating process {proc.pid}') print(f'Terminating process {proc.pid}')
proc.terminate() proc.terminate()
sec = 0.0 sec = 0.0
@ -180,11 +193,11 @@ class WatchHandler(PatternMatchingEventHandler):
sec += 0.1 sec += 0.1
if sec >= 5: if sec >= 5:
logging.error('Failed to terminate. Killing process...') print('Failed to terminate. Killing process...')
proc.kill() proc.kill()
break break
logging.info('Process terminated') print('Process terminated')
def run_procs(self, restart: bool = False) -> None: def run_procs(self, restart: bool = False) -> None:
@ -200,13 +213,13 @@ class WatchHandler(PatternMatchingEventHandler):
self.procs = [] self.procs = []
for cmd in self.commands: for cmd in self.commands:
logging.info('Running command: %s', ' '.join(cmd)) print('Running command:', ' '.join(cmd))
subprocess.run(cmd) subprocess.run(cmd)
else: else:
self.procs = list(subprocess.Popen(cmd) for cmd in self.commands) self.procs = list(subprocess.Popen(cmd) for cmd in self.commands)
pids = (str(proc.pid) for proc in self.procs) pids = (str(proc.pid) for proc in self.procs)
logging.info('Started processes with PIDs: %s', ', '.join(pids)) print('Started processes with PIDs:', ', '.join(pids))
def on_any_event(self, event: FileSystemEvent) -> None: def on_any_event(self, event: FileSystemEvent) -> None:

View file

@ -9,30 +9,27 @@ license = {text = "AGPLv3"}
classifiers = [ classifiers = [
"Environment :: Console", "Environment :: Console",
"License :: OSI Approved :: GNU Affero General Public License v3", "License :: OSI Approved :: GNU Affero General Public License v3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.12"
] ]
dependencies = [ dependencies = [
"activitypub-utils >= 0.3.1, < 0.4.0", "activitypub-utils >= 0.3.1.post1, < 0.4.0",
"aiohttp >= 3.9.5", "aiohttp >= 3.9.5",
"aiohttp-swagger[performance] == 1.0.16", "aiohttp-swagger[performance] == 1.0.16",
"argon2-cffi == 23.1.0", "argon2-cffi == 23.1.0",
"barkshark-lib >= 0.1.4, < 0.2.0", "barkshark-lib >= 0.1.5rc1, < 0.2.0",
"barkshark-sql >= 0.2.0-rc1, < 0.3.0", "barkshark-sql >= 0.2.0rc2, < 0.3.0",
"click == 8.1.2", "click == 8.1.2",
"hiredis == 2.3.2", "hiredis == 2.3.2",
"idna == 3.4", "idna == 3.4",
"jinja2-haml == 0.3.5", "jinja2-haml == 0.3.5",
"markdown == 3.6", "markdown == 3.6",
"platformdirs == 4.2.2", "platformdirs == 4.2.2",
"pyyaml == 6.0", "pyyaml == 6.0.1",
"redis == 5.0.5", "redis == 5.0.7"
"importlib-resources == 6.4.0; python_version < '3.9'"
] ]
requires-python = ">=3.8" requires-python = ">=3.10"
dynamic = ["version"] dynamic = ["version"]
[project.readme] [project.readme]
@ -49,11 +46,10 @@ activityrelay = "relay.manage:main"
[project.optional-dependencies] [project.optional-dependencies]
dev = [ dev = [
"flake8 == 7.0.0", "flake8 == 7.1.0",
"mypy == 1.10.0", "mypy == 1.10.1",
"pyinstaller == 6.8.0", "pyinstaller == 6.8.0",
"watchdog == 4.0.1", "watchdog == 4.0.1"
"typing-extensions == 4.12.2; python_version < '3.11.0'"
] ]
[tool.setuptools] [tool.setuptools]

View file

@ -4,12 +4,13 @@ import json
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from blib import Date
from bsql import Database, Row from bsql import Database, Row
from collections.abc import Callable, Iterator from collections.abc import Callable, Iterator
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from datetime import datetime, timedelta, timezone from datetime import timedelta
from redis import Redis from redis import Redis
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, TypedDict
from .database import Connection, get_database from .database import Connection, get_database
from .misc import Message, boolean from .misc import Message, boolean
@ -31,6 +32,14 @@ CONVERTERS: dict[str, tuple[SerializerCallback, DeserializerCallback]] = {
} }
class RedisConnectType(TypedDict):
client_name: str
decode_responses: bool
username: str | None
password: str | None
db: int
def get_cache(app: Application) -> Cache: def get_cache(app: Application) -> Cache:
return BACKENDS[app.config.ca_type](app) return BACKENDS[app.config.ca_type](app)
@ -57,12 +66,11 @@ class Item:
key: str key: str
value: Any value: Any
value_type: str value_type: str
updated: datetime updated: Date
def __post_init__(self) -> None: def __post_init__(self) -> None:
if isinstance(self.updated, str): # type: ignore[unreachable] self.updated = Date.parse(self.updated)
self.updated = datetime.fromisoformat(self.updated) # type: ignore[unreachable]
@classmethod @classmethod
@ -70,14 +78,11 @@ class Item:
data = cls(*args) data = cls(*args)
data.value = deserialize_value(data.value, data.value_type) data.value = deserialize_value(data.value, data.value_type)
if not isinstance(data.updated, datetime):
data.updated = datetime.fromtimestamp(data.updated, tz = timezone.utc) # type: ignore
return data return data
def older_than(self, hours: int) -> bool: def older_than(self, hours: int) -> bool:
delta = datetime.now(tz = timezone.utc) - self.updated delta = Date.new_utc() - self.updated
return (delta.total_seconds()) > hours * 3600 return (delta.total_seconds()) > hours * 3600
@ -206,7 +211,7 @@ class SqlCache(Cache):
'key': key, 'key': key,
'value': serialize_value(value, value_type), 'value': serialize_value(value, value_type),
'type': value_type, 'type': value_type,
'date': datetime.now(tz = timezone.utc) 'date': Date.new_utc()
} }
with self._db.session(True) as conn: with self._db.session(True) as conn:
@ -236,7 +241,7 @@ class SqlCache(Cache):
if self._db is None: if self._db is None:
raise RuntimeError("Database has not been setup") raise RuntimeError("Database has not been setup")
limit = datetime.now(tz = timezone.utc) - timedelta(days = days) limit = Date.new_utc() - timedelta(days = days)
params = {"limit": limit.timestamp()} params = {"limit": limit.timestamp()}
with self._db.session(True) as conn: with self._db.session(True) as conn:
@ -280,7 +285,7 @@ class RedisCache(Cache):
def __init__(self, app: Application): def __init__(self, app: Application):
Cache.__init__(self, app) Cache.__init__(self, app)
self._rd: Redis = None # type: ignore self._rd: Redis | None = None
@property @property
@ -293,28 +298,38 @@ class RedisCache(Cache):
def get(self, namespace: str, key: str) -> Item: def get(self, namespace: str, key: str) -> Item:
if self._rd is None:
raise ConnectionError("Not connected")
key_name = self.get_key_name(namespace, key) key_name = self.get_key_name(namespace, key)
if not (raw_value := self._rd.get(key_name)): if not (raw_value := self._rd.get(key_name)):
raise KeyError(f'{namespace}:{key}') raise KeyError(f'{namespace}:{key}')
value_type, updated, value = raw_value.split(':', 2) # type: ignore value_type, updated, value = raw_value.split(':', 2) # type: ignore[union-attr]
return Item.from_data( return Item.from_data(
namespace, namespace,
key, key,
value, value,
value_type, value_type,
datetime.fromtimestamp(float(updated), tz = timezone.utc) Date.parse(float(updated))
) )
def get_keys(self, namespace: str) -> Iterator[str]: def get_keys(self, namespace: str) -> Iterator[str]:
if self._rd is None:
raise ConnectionError("Not connected")
for key in self._rd.scan_iter(self.get_key_name(namespace, '*')): for key in self._rd.scan_iter(self.get_key_name(namespace, '*')):
*_, key_name = key.split(':', 2) *_, key_name = key.split(':', 2)
yield key_name yield key_name
def get_namespaces(self) -> Iterator[str]: def get_namespaces(self) -> Iterator[str]:
if self._rd is None:
raise ConnectionError("Not connected")
namespaces = [] namespaces = []
for key in self._rd.scan_iter(f'{self.prefix}:*'): for key in self._rd.scan_iter(f'{self.prefix}:*'):
@ -326,7 +341,10 @@ class RedisCache(Cache):
def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> Item: def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> Item:
date = datetime.now(tz = timezone.utc).timestamp() if self._rd is None:
raise ConnectionError("Not connected")
date = Date.new_utc().timestamp()
value = serialize_value(value, value_type) value = serialize_value(value, value_type)
self._rd.set( self._rd.set(
@ -338,11 +356,17 @@ class RedisCache(Cache):
def delete(self, namespace: str, key: str) -> None: def delete(self, namespace: str, key: str) -> None:
if self._rd is None:
raise ConnectionError("Not connected")
self._rd.delete(self.get_key_name(namespace, key)) self._rd.delete(self.get_key_name(namespace, key))
def delete_old(self, days: int = 14) -> None: def delete_old(self, days: int = 14) -> None:
limit = datetime.now(tz = timezone.utc) - timedelta(days = days) if self._rd is None:
raise ConnectionError("Not connected")
limit = Date.new_utc() - timedelta(days = days)
for full_key in self._rd.scan_iter(f'{self.prefix}:*'): for full_key in self._rd.scan_iter(f'{self.prefix}:*'):
_, namespace, key = full_key.split(':', 2) _, namespace, key = full_key.split(':', 2)
@ -353,14 +377,17 @@ class RedisCache(Cache):
def clear(self) -> None: def clear(self) -> None:
if self._rd is None:
raise ConnectionError("Not connected")
self._rd.delete(f"{self.prefix}:*") self._rd.delete(f"{self.prefix}:*")
def setup(self) -> None: def setup(self) -> None:
if self._rd: if self._rd is not None:
return return
options = { options: RedisConnectType = {
'client_name': f'ActivityRelay_{self.app.config.domain}', 'client_name': f'ActivityRelay_{self.app.config.domain}',
'decode_responses': True, 'decode_responses': True,
'username': self.app.config.rd_user, 'username': self.app.config.rd_user,
@ -369,18 +396,22 @@ class RedisCache(Cache):
} }
if os.path.exists(self.app.config.rd_host): if os.path.exists(self.app.config.rd_host):
options['unix_socket_path'] = self.app.config.rd_host self._rd = Redis(
unix_socket_path = self.app.config.rd_host,
**options
)
return
else: self._rd = Redis(
options['host'] = self.app.config.rd_host host = self.app.config.rd_host,
options['port'] = self.app.config.rd_port port = self.app.config.rd_port,
**options
self._rd = Redis(**options) # type: ignore )
def close(self) -> None: def close(self) -> None:
if not self._rd: if not self._rd:
return return
self._rd.close() # type: ignore self._rd.close() # type: ignore[no-untyped-call]
self._rd = None # type: ignore self._rd = None

View file

@ -13,12 +13,8 @@ from typing import TYPE_CHECKING, Any
from .misc import IS_DOCKER from .misc import IS_DOCKER
if TYPE_CHECKING: if TYPE_CHECKING:
try:
from typing import Self from typing import Self
except ImportError:
from typing_extensions import Self
if platform.system() == 'Windows': if platform.system() == 'Windows':
import multiprocessing import multiprocessing
@ -84,7 +80,7 @@ class Config:
def DEFAULT(cls: type[Self], key: str) -> str | int | None: def DEFAULT(cls: type[Self], key: str) -> str | int | None:
for field in fields(cls): for field in fields(cls):
if field.name == key: if field.name == key:
return field.default # type: ignore return field.default # type: ignore[return-value]
raise KeyError(key) raise KeyError(key)

View file

@ -8,12 +8,8 @@ from pathlib import Path
from typing import TYPE_CHECKING, Any, Protocol from typing import TYPE_CHECKING, Any, Protocol
if TYPE_CHECKING: if TYPE_CHECKING:
try:
from typing import Self from typing import Self
except ImportError:
from typing_extensions import Self
class LoggingMethod(Protocol): class LoggingMethod(Protocol):
def __call__(self, msg: Any, *args: Any, **kwargs: Any) -> None: ... def __call__(self, msg: Any, *args: Any, **kwargs: Any) -> None: ...

View file

@ -9,24 +9,14 @@ import socket
from aiohttp.web import Response as AiohttpResponse from aiohttp.web import Response as AiohttpResponse
from collections.abc import Sequence from collections.abc import Sequence
from datetime import datetime from datetime import datetime
from importlib.resources import files as pkgfiles
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, TypedDict, TypeVar from typing import TYPE_CHECKING, Any, TypedDict, TypeVar
from uuid import uuid4 from uuid import uuid4
try:
from importlib.resources import files as pkgfiles
except ImportError:
from importlib_resources import files as pkgfiles # type: ignore
if TYPE_CHECKING: if TYPE_CHECKING:
from .application import Application
try:
from typing import Self from typing import Self
from .application import Application
except ImportError:
from typing_extensions import Self
T = TypeVar('T') T = TypeVar('T')

View file

@ -20,6 +20,9 @@ if TYPE_CHECKING:
class Template(Environment): class Template(Environment):
_render_markdown: Callable[[str], str]
def __init__(self, app: Application): def __init__(self, app: Application):
Environment.__init__(self, Environment.__init__(self,
autoescape = True, autoescape = True,
@ -56,7 +59,7 @@ class Template(Environment):
def render_markdown(self, text: str) -> str: def render_markdown(self, text: str) -> str:
return self._render_markdown(text) # type: ignore return self._render_markdown(text)
class MarkdownExtension(Extension): class MarkdownExtension(Extension):

View file

@ -13,7 +13,8 @@ from ..database import ConfigData
from ..misc import Message, Response, boolean, get_app from ..misc import Message, Response, boolean, get_app
ALLOWED_HEADERS = { DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob'
ALLOWED_HEADERS: set[str] = {
'accept', 'accept',
'authorization', 'authorization',
'content-type' 'content-type'

View file

@ -18,18 +18,12 @@ from ..http_client import HttpClient
from ..misc import Response, get_app from ..misc import Response, get_app
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Self
from ..application import Application from ..application import Application
from ..template import Template from ..template import Template
try:
from typing import Self
except ImportError:
from typing_extensions import Self
HandlerCallback = Callable[[Request], Awaitable[Response]] HandlerCallback = Callable[[Request], Awaitable[Response]]
VIEWS: list[tuple[str, type[View]]] = [] VIEWS: list[tuple[str, type[View]]] = []

View file

@ -1,14 +1,17 @@
from __future__ import annotations
import asyncio import asyncio
import traceback import traceback
import typing
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
from asyncio.exceptions import TimeoutError as AsyncTimeoutError from asyncio.exceptions import TimeoutError as AsyncTimeoutError
from dataclasses import dataclass from dataclasses import dataclass
from multiprocessing import Event, Process, Queue, Value from multiprocessing import Event, Process, Queue, Value
from multiprocessing.queues import Queue as QueueType
from multiprocessing.sharedctypes import Synchronized
from multiprocessing.synchronize import Event as EventType from multiprocessing.synchronize import Event as EventType
from pathlib import Path from pathlib import Path
from queue import Empty, Queue as QueueType from queue import Empty
from urllib.parse import urlparse from urllib.parse import urlparse
from . import application, logger as logging from . import application, logger as logging
@ -16,9 +19,6 @@ from .database.schema import Instance
from .http_client import HttpClient from .http_client import HttpClient
from .misc import IS_WINDOWS, Message, get_app from .misc import IS_WINDOWS, Message, get_app
if typing.TYPE_CHECKING:
from .multiprocessing.synchronize import Syncronized
@dataclass @dataclass
class QueueItem: class QueueItem:
@ -40,13 +40,13 @@ class PushWorker(Process):
client: HttpClient client: HttpClient
def __init__(self, queue: QueueType[QueueItem], log_level: "Syncronized[str]") -> None: def __init__(self, queue: QueueType[QueueItem], log_level: Synchronized[int]) -> None:
Process.__init__(self) Process.__init__(self)
self.queue: QueueType[QueueItem] = queue self.queue: QueueType[QueueItem] = queue
self.shutdown: EventType = Event() self.shutdown: EventType = Event()
self.path: Path = get_app().config.path self.path: Path = get_app().config.path
self.log_level: "Syncronized[str]" = log_level self.log_level: Synchronized[int] = log_level
self._log_level_changed: EventType = Event() self._log_level_changed: EventType = Event()
@ -113,8 +113,8 @@ class PushWorker(Process):
class PushWorkers(list[PushWorker]): class PushWorkers(list[PushWorker]):
def __init__(self, count: int) -> None: def __init__(self, count: int) -> None:
self.queue: QueueType[QueueItem] = Queue() # type: ignore[assignment] self.queue: QueueType[QueueItem] = Queue()
self._log_level: "Syncronized[str]" = Value("i", logging.get_level()) self._log_level: Synchronized[int] = Value("i", logging.get_level())
self._count: int = count self._count: int = count