mirror of
https://git.pleroma.social/pleroma/relay.git
synced 2024-11-14 03:27:59 +00:00
362 lines
7.4 KiB
Python
362 lines
7.4 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import typing
|
|
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import asdict, dataclass
|
|
from datetime import datetime, timedelta, timezone
|
|
from redis import Redis
|
|
|
|
from .database import get_database
|
|
from .misc import Message, boolean
|
|
|
|
if typing.TYPE_CHECKING:
|
|
from typing import Any
|
|
from collections.abc import Callable, Iterator
|
|
from .application import Application
|
|
|
|
|
|
# todo: implement more caching backends
|
|
|
|
|
|
BACKENDS: dict[str, Cache] = {}
|
|
CONVERTERS: dict[str, tuple[Callable, Callable]] = {
|
|
'str': (str, str),
|
|
'int': (str, int),
|
|
'bool': (str, boolean),
|
|
'json': (json.dumps, json.loads),
|
|
'message': (lambda x: x.to_json(), Message.parse)
|
|
}
|
|
|
|
|
|
def get_cache(app: Application) -> Cache:
|
|
return BACKENDS[app.config.ca_type](app)
|
|
|
|
|
|
def register_cache(backend: type[Cache]) -> type[Cache]:
|
|
BACKENDS[backend.name] = backend
|
|
return backend
|
|
|
|
|
|
def serialize_value(value: Any, value_type: str = 'str') -> str:
|
|
if isinstance(value, str):
|
|
return value
|
|
|
|
return CONVERTERS[value_type][0](value)
|
|
|
|
|
|
def deserialize_value(value: str, value_type: str = 'str') -> Any:
|
|
return CONVERTERS[value_type][1](value)
|
|
|
|
|
|
@dataclass
|
|
class Item:
|
|
namespace: str
|
|
key: str
|
|
value: Any
|
|
value_type: str
|
|
updated: datetime
|
|
|
|
|
|
def __post_init__(self):
|
|
if isinstance(self.updated, str):
|
|
self.updated = datetime.fromisoformat(self.updated)
|
|
|
|
|
|
@classmethod
|
|
def from_data(cls: type[Item], *args) -> Item:
|
|
data = cls(*args)
|
|
data.value = deserialize_value(data.value, data.value_type)
|
|
|
|
if not isinstance(data.updated, datetime):
|
|
data.updated = datetime.fromtimestamp(data.updated, tz = timezone.utc)
|
|
|
|
return data
|
|
|
|
|
|
def older_than(self, hours: int) -> bool:
|
|
delta = datetime.now(tz = timezone.utc) - self.updated
|
|
return (delta.total_seconds()) > hours * 3600
|
|
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
return asdict(self)
|
|
|
|
|
|
class Cache(ABC):
|
|
name: str = 'null'
|
|
|
|
|
|
def __init__(self, app: Application):
|
|
self.app = app
|
|
|
|
|
|
@abstractmethod
|
|
def get(self, namespace: str, key: str) -> Item:
|
|
...
|
|
|
|
|
|
@abstractmethod
|
|
def get_keys(self, namespace: str) -> Iterator[str]:
|
|
...
|
|
|
|
|
|
@abstractmethod
|
|
def get_namespaces(self) -> Iterator[str]:
|
|
...
|
|
|
|
|
|
@abstractmethod
|
|
def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> Item:
|
|
...
|
|
|
|
|
|
@abstractmethod
|
|
def delete(self, namespace: str, key: str) -> None:
|
|
...
|
|
|
|
|
|
@abstractmethod
|
|
def delete_old(self, days: int = 14) -> None:
|
|
...
|
|
|
|
|
|
@abstractmethod
|
|
def clear(self) -> None:
|
|
...
|
|
|
|
|
|
@abstractmethod
|
|
def setup(self) -> None:
|
|
...
|
|
|
|
|
|
@abstractmethod
|
|
def close(self) -> None:
|
|
...
|
|
|
|
|
|
def set_item(self, item: Item) -> Item:
|
|
return self.set(
|
|
item.namespace,
|
|
item.key,
|
|
item.value,
|
|
item.type
|
|
)
|
|
|
|
|
|
def delete_item(self, item: Item) -> None:
|
|
self.delete(item.namespace, item.key)
|
|
|
|
|
|
@register_cache
|
|
class SqlCache(Cache):
|
|
name: str = 'database'
|
|
|
|
|
|
def __init__(self, app: Application):
|
|
Cache.__init__(self, app)
|
|
self._db = None
|
|
|
|
|
|
def get(self, namespace: str, key: str) -> Item:
|
|
params = {
|
|
'namespace': namespace,
|
|
'key': key
|
|
}
|
|
|
|
with self._db.session(False) as conn:
|
|
with conn.run('get-cache-item', params) as cur:
|
|
if not (row := cur.one()):
|
|
raise KeyError(f'{namespace}:{key}')
|
|
|
|
row.pop('id', None)
|
|
return Item.from_data(*tuple(row.values()))
|
|
|
|
|
|
def get_keys(self, namespace: str) -> Iterator[str]:
|
|
with self._db.session(False) as conn:
|
|
for row in conn.run('get-cache-keys', {'namespace': namespace}):
|
|
yield row['key']
|
|
|
|
|
|
def get_namespaces(self) -> Iterator[str]:
|
|
with self._db.session(False) as conn:
|
|
for row in conn.run('get-cache-namespaces', None):
|
|
yield row['namespace']
|
|
|
|
|
|
def set(self, namespace: str, key: str, value: Any, value_type: str = 'str') -> Item:
|
|
params = {
|
|
'namespace': namespace,
|
|
'key': key,
|
|
'value': serialize_value(value, value_type),
|
|
'type': value_type,
|
|
'date': datetime.now(tz = timezone.utc)
|
|
}
|
|
|
|
with self._db.session(True) as conn:
|
|
with conn.run('set-cache-item', params) as conn:
|
|
row = conn.one()
|
|
row.pop('id', None)
|
|
return Item.from_data(*tuple(row.values()))
|
|
|
|
|
|
def delete(self, namespace: str, key: str) -> None:
|
|
params = {
|
|
'namespace': namespace,
|
|
'key': key
|
|
}
|
|
|
|
with self._db.session(True) as conn:
|
|
with conn.run('del-cache-item', params):
|
|
pass
|
|
|
|
|
|
def delete_old(self, days: int = 14) -> None:
|
|
limit = datetime.now(tz = timezone.utc) - timedelta(days = days)
|
|
params = {"limit": limit.timestamp()}
|
|
|
|
with self._db.session(True) as conn:
|
|
with conn.execute("DELETE FROM cache WHERE updated < :limit", params):
|
|
pass
|
|
|
|
|
|
def clear(self) -> None:
|
|
with self._db.session(True) as conn:
|
|
with conn.execute("DELETE FROM cache"):
|
|
pass
|
|
|
|
|
|
def setup(self) -> None:
|
|
if self._db and self._db.connected:
|
|
return
|
|
|
|
self._db = get_database(self.app.config)
|
|
self._db.connect()
|
|
|
|
with self._db.session(True) as conn:
|
|
with conn.run(f'create-cache-table-{self._db.backend_type.value}', None):
|
|
pass
|
|
|
|
|
|
def close(self) -> None:
|
|
if not self._db:
|
|
return
|
|
|
|
self._db.disconnect()
|
|
self._db = None
|
|
|
|
|
|
@register_cache
|
|
class RedisCache(Cache):
|
|
name: str = 'redis'
|
|
|
|
|
|
def __init__(self, app: Application):
|
|
Cache.__init__(self, app)
|
|
self._rd = None
|
|
|
|
|
|
@property
|
|
def prefix(self) -> str:
|
|
return self.app.config.rd_prefix
|
|
|
|
|
|
def get_key_name(self, namespace: str, key: str) -> str:
|
|
return f'{self.prefix}:{namespace}:{key}'
|
|
|
|
|
|
def get(self, namespace: str, key: str) -> Item:
|
|
key_name = self.get_key_name(namespace, key)
|
|
|
|
if not (raw_value := self._rd.get(key_name)):
|
|
raise KeyError(f'{namespace}:{key}')
|
|
|
|
value_type, updated, value = raw_value.split(':', 2)
|
|
return Item.from_data(
|
|
namespace,
|
|
key,
|
|
value,
|
|
value_type,
|
|
datetime.fromtimestamp(float(updated), tz = timezone.utc)
|
|
)
|
|
|
|
|
|
def get_keys(self, namespace: str) -> Iterator[str]:
|
|
for key in self._rd.scan_iter(self.get_key_name(namespace, '*')):
|
|
*_, key_name = key.split(':', 2)
|
|
yield key_name
|
|
|
|
|
|
def get_namespaces(self) -> Iterator[str]:
|
|
namespaces = []
|
|
|
|
for key in self._rd.scan_iter(f'{self.prefix}:*'):
|
|
_, namespace, _ = key.split(':', 2)
|
|
|
|
if namespace not in namespaces:
|
|
namespaces.append(namespace)
|
|
yield namespace
|
|
|
|
|
|
def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> None:
|
|
date = datetime.now(tz = timezone.utc).timestamp()
|
|
value = serialize_value(value, value_type)
|
|
|
|
self._rd.set(
|
|
self.get_key_name(namespace, key),
|
|
f'{value_type}:{date}:{value}'
|
|
)
|
|
|
|
|
|
def delete(self, namespace: str, key: str) -> None:
|
|
self._rd.delete(self.get_key_name(namespace, key))
|
|
|
|
|
|
def delete_old(self, days: int = 14) -> None:
|
|
limit = datetime.now(tz = timezone.utc) - timedelta(days = days)
|
|
|
|
for full_key in self._rd.scan_iter(f'{self.prefix}:*'):
|
|
_, namespace, key = full_key.split(':', 2)
|
|
item = self.get(namespace, key)
|
|
|
|
if item.updated < limit:
|
|
self.delete_item(item)
|
|
|
|
|
|
def clear(self) -> None:
|
|
self._rd.delete(f"{self.prefix}:*")
|
|
|
|
|
|
def setup(self) -> None:
|
|
if self._rd:
|
|
return
|
|
|
|
options = {
|
|
'client_name': f'ActivityRelay_{self.app.config.domain}',
|
|
'decode_responses': True,
|
|
'username': self.app.config.rd_user,
|
|
'password': self.app.config.rd_pass,
|
|
'db': self.app.config.rd_database
|
|
}
|
|
|
|
if os.path.exists(self.app.config.rd_host):
|
|
options['unix_socket_path'] = self.app.config.rd_host
|
|
|
|
else:
|
|
options['host'] = self.app.config.rd_host
|
|
options['port'] = self.app.config.rd_port
|
|
|
|
self._rd = Redis(**options)
|
|
|
|
|
|
def close(self) -> None:
|
|
if not self._rd:
|
|
return
|
|
|
|
self._rd.close()
|
|
self._rd = None
|