add database and redis caching

This commit is contained in:
Izalia Mae 2024-01-31 21:23:45 -05:00
parent f2baf7f9f9
commit 2d641ea183
11 changed files with 549 additions and 63 deletions

View file

@ -35,6 +35,13 @@ SQL database backend to use. Valid values are `sqlite` or `postgres`.
database_type: sqlite
### Cache type
Cache backend to use. Valid values are `database` or `redis`
cache_type: database
### Sqlite File Path
Path to the sqlite database file. If the path is not absolute, it is relative to the config file.
@ -47,7 +54,7 @@ directory.
In order to use the Postgresql backend, the user and database need to be created first.
sudo -u postgres psql -c "CREATE USER activityrelay"
sudo -u postgres psql -c "CREATE USER activityrelay WITH PASSWORD SomeSecurePassword"
sudo -u postgres psql -c "CREATE DATABASE activityrelay OWNER activityrelay"
@ -84,3 +91,47 @@ User to use when logging into the server.
Password for the specified user.
pass: null
## Redis
### Host
Hostname, IP address, or unix socket the server is hosted on.
host: /var/run/postgresql
### Port
Port number the server is listening on.
port: 5432
### Username
User to use when logging into the server.
user: null
### Password
Password for the specified user.
pass: null
### Database Number
Number of the database to use.
database: 0
### Prefix
Text to prefix every key with. It cannot contain a `:` character.
prefix: activityrelay

View file

@ -7,12 +7,15 @@ listen: 0.0.0.0
# [integer] Port the relay will listen on
port: 8080
# [integer] Number of push workers to start (will get removed in a future update)
# [integer] Number of push workers to start
workers: 8
# [string] Database backend to use. Valid values: sqlite, postgres
database_type: sqlite
# [string] Cache backend to use. Valid values: database, redis
cache_type: database
# [string] Path to the sqlite database file if the sqlite backend is in use
sqlite_path: relay.sqlite3
@ -33,3 +36,24 @@ postgres:
# [string] name of the database to use
name: activityrelay
# settings for the redis caching backend
redis:
# [string] hostname or unix socket to connect to
host: localhost
# [integer] port of the server
port: 6379
# [string] username to use when logging into the server
user: null
# [string] password for the server
pass: null
# [integer] database number to use
database: 0
# [string] prefix for keys
prefix: activityrelay

View file

@ -12,6 +12,7 @@ from aputils.signer import Signer
from datetime import datetime, timedelta
from . import logger as logging
from .cache import get_cache
from .config import Config
from .database import get_database
from .http_client import HttpClient
@ -19,8 +20,9 @@ from .misc import check_open_port
from .views import VIEWS
if typing.TYPE_CHECKING:
from tinysql import Database
from tinysql import Database, Row
from typing import Any
from .cache import Cache
from .misc import Message
@ -38,6 +40,7 @@ class Application(web.Application):
self['config'] = Config(cfgpath, load = True)
self['database'] = get_database(self.config)
self['client'] = HttpClient()
self['cache'] = get_cache(self)
self['workers'] = []
self['last_worker'] = 0
@ -48,6 +51,11 @@ class Application(web.Application):
self.router.add_view(path, view)
@property
def cache(self) -> Cache:
return self['cache']
@property
def client(self) -> HttpClient:
return self['client']
@ -87,13 +95,13 @@ class Application(web.Application):
return timedelta(seconds=uptime.seconds)
def push_message(self, inbox: str, message: Message) -> None:
def push_message(self, inbox: str, message: Message, instance: Row) -> None:
if self.config.workers <= 0:
asyncio.ensure_future(self.client.post(inbox, message))
asyncio.ensure_future(self.client.post(inbox, message, instance))
return
worker = self['workers'][self['last_worker']]
worker.queue.put((inbox, message))
worker.queue.put((inbox, message, instance))
self['last_worker'] += 1
@ -186,10 +194,10 @@ class PushWorker(threading.Thread):
while self.app['running']:
try:
inbox, message = self.queue.get(block=True, timeout=0.25)
inbox, message, instance = self.queue.get(block=True, timeout=0.25)
self.queue.task_done()
logging.verbose('New push from Thread-%i', threading.get_ident())
await self.client.post(inbox, message)
await self.client.post(inbox, message, instance)
except queue.Empty:
pass

280
relay/cache.py Normal file
View file

@ -0,0 +1,280 @@
from __future__ import annotations
import json
import os
import typing
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass
from datetime import datetime, timezone
from redis import Redis
from .misc import Message, boolean
if typing.TYPE_CHECKING:
from typing import Any
from collections.abc import Callable, Iterator
from tinysql import Database
from .application import Application
# todo: implement more caching backends
BACKENDS: dict[str, Cache] = {}
CONVERTERS: dict[str, tuple[Callable, Callable]] = {
'str': (str, str),
'int': (str, int),
'bool': (str, boolean),
'json': (json.dumps, json.loads),
'message': (lambda x: x.to_json(), Message.parse)
}
def get_cache(app: Application) -> Cache:
return BACKENDS[app.config.ca_type](app)
def register_cache(backend: type[Cache]) -> type[Cache]:
BACKENDS[backend.name] = backend
return backend
def serialize_value(value: Any, value_type: str = 'str') -> str:
if isinstance(value, str):
return value
return CONVERTERS[value_type][0](value)
def deserialize_value(value: str, value_type: str = 'str') -> Any:
return CONVERTERS[value_type][1](value)
@dataclass
class Item:
namespace: str
key: str
value: Any
value_type: str
updated: datetime
def __post_init__(self):
if isinstance(self.updated, str):
self.updated = datetime.fromisoformat(self.updated)
@classmethod
def from_data(cls: type[Item], *args) -> Item:
data = cls(*args)
data.value = deserialize_value(data.value, data.value_type)
return data
def older_than(self, hours: int) -> bool:
delta = datetime.now(tz = timezone.utc) - self.updated
return (delta.total_seconds()) > hours * 3600
def to_dict(self) -> dict[str, Any]:
return asdict(self)
class Cache(ABC):
name: str = 'null'
def __init__(self, app: Application):
self.app = app
self.setup()
@abstractmethod
def get(self, namespace: str, key: str) -> Item:
...
@abstractmethod
def get_keys(self, namespace: str) -> Iterator[str]:
...
@abstractmethod
def get_namespaces(self) -> Iterator[str]:
...
@abstractmethod
def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> Item:
...
@abstractmethod
def delete(self, namespace: str, key: str) -> None:
...
@abstractmethod
def setup(self) -> None:
...
def set_item(self, item: Item) -> Item:
return self.set(
item.namespace,
item.key,
item.value,
item.type
)
def delete_item(self, item: Item) -> None:
self.delete(item.namespace, item.key)
@register_cache
class SqlCache(Cache):
name: str = 'database'
@property
def _db(self) -> Database:
return self.app.database
def get(self, namespace: str, key: str) -> Item:
params = {
'namespace': namespace,
'key': key
}
with self._db.connection() as conn:
with conn.exec_statement('get-cache-item', params) as cur:
if not (row := cur.one()):
raise KeyError(f'{namespace}:{key}')
return Item.from_data(*tuple(row.values()))
def get_keys(self, namespace: str) -> Iterator[str]:
with self._db.connection() as conn:
for row in conn.exec_statement('get-cache-keys', {'namespace': namespace}):
yield row['key']
def get_namespaces(self) -> Iterator[str]:
with self._db.connection() as conn:
for row in conn.exec_statement('get-cache-namespaces', None):
yield row['namespace']
def set(self, namespace: str, key: str, value: Any, value_type: str = 'str') -> Item:
params = {
'namespace': namespace,
'key': key,
'value': serialize_value(value, value_type),
'type': value_type,
'date': datetime.now(tz = timezone.utc)
}
with self._db.connection() as conn:
for row in conn.exec_statement('set-cache-item', params):
return Item.from_data(*tuple(row.values()))
def delete(self, namespace: str, key: str) -> None:
params = {
'namespace': namespace,
'key': key
}
with self._db.connection() as conn:
with conn.exec_statement('del-cache-item', params):
pass
def setup(self) -> None:
with self._db.connection() as conn:
with conn.exec_statement(f'create-cache-table-{self._db.type.name.lower()}', None):
pass
@register_cache
class RedisCache(Cache):
name: str = 'redis'
_rd: Redis
@property
def prefix(self) -> str:
return self.app.config.rd_prefix
def get_key_name(self, namespace: str, key: str) -> str:
return f'{self.prefix}:{namespace}:{key}'
def get(self, namespace: str, key: str) -> Item:
key_name = self.get_key_name(namespace, key)
if not (raw_value := self._rd.get(key_name)):
raise KeyError(f'{namespace}:{key}')
value_type, updated, value = raw_value.split(':', 2)
return Item.from_data(
namespace,
key,
value,
value_type,
datetime.fromtimestamp(float(updated), tz = timezone.utc)
)
def get_keys(self, namespace: str) -> Iterator[str]:
for key in self._rd.keys(self.get_key_name(namespace, '*')):
*_, key_name = key.split(':', 2)
yield key_name
def get_namespaces(self) -> Iterator[str]:
namespaces = []
for key in self._rd.keys(f'{self.prefix}:*'):
_, namespace, _ = key.split(':', 2)
if namespace not in namespaces:
namespaces.append(namespace)
yield namespace
def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> None:
date = datetime.now(tz = timezone.utc).timestamp()
value = serialize_value(value, value_type)
self._rd.set(
self.get_key_name(namespace, key),
f'{value_type}:{date}:{value}'
)
def delete(self, namespace: str, key: str) -> None:
self._rd.delete(self.get_key_name(namespace, key))
def setup(self) -> None:
options = {
'client_name': f'ActivityRelay_{self.app.config.domain}',
'decode_responses': True,
'username': self.app.config.rd_user,
'password': self.app.config.rd_pass,
'db': self.app.config.rd_database
}
if os.path.exists(self.app.config.rd_host):
options['unix_socket_path'] = self.app.config.rd_host
else:
options['host'] = self.app.config.rd_host
options['port'] = self.app.config.rd_port
self._rd = Redis(**options)

View file

@ -19,12 +19,21 @@ DEFAULTS: dict[str, Any] = {
'domain': 'relay.example.com',
'workers': len(os.sched_getaffinity(0)),
'db_type': 'sqlite',
'ca_type': 'database',
'sq_path': 'relay.sqlite3',
'pg_host': '/var/run/postgresql',
'pg_port': 5432,
'pg_user': getpass.getuser(),
'pg_pass': None,
'pg_name': 'activityrelay'
'pg_name': 'activityrelay',
'rd_host': 'localhost',
'rd_port': 6379,
'rd_user': None,
'rd_pass': None,
'rd_database': 0,
'rd_prefix': 'activityrelay'
}
if IS_DOCKER:
@ -40,13 +49,22 @@ class Config:
self.domain = None
self.workers = None
self.db_type = None
self.ca_type = None
self.sq_path = None
self.pg_host = None
self.pg_port = None
self.pg_user = None
self.pg_pass = None
self.pg_name = None
self.rd_host = None
self.rd_port = None
self.rd_user = None
self.rd_pass = None
self.rd_database = None
self.rd_prefix = None
if load:
try:
self.load()
@ -92,6 +110,7 @@ class Config:
with self.path.open('r', encoding = 'UTF-8') as fd:
config = yaml.load(fd, **options)
pgcfg = config.get('postgresql', {})
rdcfg = config.get('redis', {})
if not config:
raise ValueError('Config is empty')
@ -106,18 +125,25 @@ class Config:
self.set('port', config.get('port', DEFAULTS['port']))
self.set('sq_path', config.get('sqlite_path', DEFAULTS['sq_path']))
self.set('workers', config.get('workers', DEFAULTS['workers']))
self.set('domain', config.get('domain', DEFAULTS['domain']))
self.set('db_type', config.get('database_type', DEFAULTS['db_type']))
self.set('ca_type', config.get('cache_type', DEFAULTS['ca_type']))
for key in DEFAULTS:
if not key.startswith('pg'):
continue
if key.startswith('pg'):
try:
self.set(key, pgcfg[key[3:]])
try:
self.set(key, pgcfg[key[3:]])
except KeyError:
continue
except KeyError:
continue
elif key.startswith('rd'):
try:
self.set(key, rdcfg[key[3:]])
except KeyError:
continue
def reset(self) -> None:
@ -132,7 +158,9 @@ class Config:
'listen': self.listen,
'port': self.port,
'domain': self.domain,
'workers': self.workers,
'database_type': self.db_type,
'cache_type': self.ca_type,
'sqlite_path': self.sq_path,
'postgres': {
'host': self.pg_host,
@ -140,6 +168,14 @@ class Config:
'user': self.pg_user,
'pass': self.pg_pass,
'name': self.pg_name
},
'redis': {
'host': self.rd_host,
'port': self.rd_port,
'user': self.rd_user,
'pass': self.rd_pass,
'database': self.rd_database,
'refix': self.rd_prefix
}
}

View file

@ -77,3 +77,64 @@ RETURNING *
-- name: del-domain-whitelist
DELETE FROM whitelist
WHERE domain = :domain
-- cache functions --
-- name: create-cache-table-sqlite
CREATE TABLE IF NOT EXISTS cache (
id INTEGER PRIMARY KEY UNIQUE,
namespace TEXT NOT NULL,
key TEXT NOT NULL,
"value" TEXT,
type TEXT DEFAULT 'str',
updated TIMESTAMP NOT NULL,
UNIQUE(namespace, key)
)
-- name: create-cache-table-postgres
CREATE TABLE IF NOT EXISTS cache (
id SERIAL PRIMARY KEY,
namespace TEXT NOT NULL,
key TEXT NOT NULL,
"value" TEXT,
type TEXT DEFAULT 'str',
updated TIMESTAMP NOT NULL,
UNIQUE(namespace, key)
)
-- name: get-cache-item
SELECT * FROM cache
WHERE namespace = :namespace and key = :key
-- name: get-cache-keys
SELECT key FROM cache
WHERE namespace = :namespace
-- name: get-cache-namespaces
SELECT DISTINCT namespace FROM cache
-- name: set-cache-item
INSERT INTO cache (namespace, key, value, type, updated)
VALUES (:namespace, :key, :value, :type, :date)
ON CONFLICT (namespace, key) DO
UPDATE SET value = :value, type = :type, updated = :date
RETURNING *
-- name: del-cache-item
DELETE FROM cache
WHERE namespace = :namespace and key = :key
-- name: del-cache-namespace
DELETE FROM cache
WHERE namespace = :namespace
-- name: del-cache-all
DELETE FROM cache

View file

@ -1,5 +1,6 @@
from __future__ import annotations
import json
import traceback
import typing
@ -7,7 +8,6 @@ from aiohttp import ClientSession, ClientTimeout, TCPConnector
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
from aputils.objects import Nodeinfo, WellKnownNodeinfo
from cachetools import LRUCache
from json.decoder import JSONDecodeError
from urllib.parse import urlparse
@ -16,7 +16,11 @@ from . import logger as logging
from .misc import MIMETYPES, Message, get_app
if typing.TYPE_CHECKING:
from aputils import Signer
from tinysql import Row
from typing import Any
from .application import Application
from .cache import Cache
HEADERS = {
@ -26,8 +30,7 @@ HEADERS = {
class HttpClient:
def __init__(self, limit: int = 100, timeout: int = 10, cache_size: int = 1024):
self.cache = LRUCache(cache_size)
def __init__(self, limit: int = 100, timeout: int = 10):
self.limit = limit
self.timeout = timeout
self._conn = None
@ -43,6 +46,21 @@ class HttpClient:
await self.close()
@property
def app(self) -> Application:
return get_app()
@property
def cache(self) -> Cache:
return self.app.cache
@property
def signer(self) -> Signer:
return self.app.signer
async def open(self) -> None:
if self._session:
return
@ -74,8 +92,8 @@ class HttpClient:
async def get(self, # pylint: disable=too-many-branches
url: str,
sign_headers: bool = False,
loads: callable | None = None,
force: bool = False) -> Message | dict | None:
loads: callable = json.loads,
force: bool = False) -> dict | None:
await self.open()
@ -85,13 +103,20 @@ class HttpClient:
except ValueError:
pass
if not force and url in self.cache:
return self.cache[url]
if not force:
try:
item = self.cache.get('request', url)
if not item.older_than(48):
return loads(item.value)
except KeyError:
logging.verbose('Failed to fetch cached data for url: %s', url)
headers = {}
if sign_headers:
get_app().signer.sign_headers('GET', url, algorithm = 'original')
self.signer.sign_headers('GET', url, algorithm = 'original')
try:
logging.debug('Fetching resource: %s', url)
@ -101,32 +126,22 @@ class HttpClient:
if resp.status == 202:
return None
if resp.status != 200:
logging.verbose('Received error when requesting %s: %i', url, resp.status)
logging.debug(await resp.read())
return None
data = await resp.read()
if loads:
message = await resp.json(loads=loads)
if resp.status != 200:
logging.verbose('Received error when requesting %s: %i', url, resp.status)
logging.debug(await resp.read())
return None
elif resp.content_type == MIMETYPES['activity']:
message = await resp.json(loads = Message.parse)
message = loads(data)
self.cache.set('request', url, data.decode('utf-8'), 'str')
logging.debug('%s >> resp %s', url, json.dumps(message, indent = 4))
elif resp.content_type == MIMETYPES['json']:
message = await resp.json()
else:
logging.verbose('Invalid Content-Type for "%s": %s', url, resp.content_type)
logging.debug('Response: %s', await resp.read())
return None
logging.debug('%s >> resp %s', url, message.to_json(4))
self.cache[url] = message
return message
return message
except JSONDecodeError:
logging.verbose('Failed to parse JSON')
return None
except ClientSSLError:
logging.verbose('SSL error when connecting to %s', urlparse(url).netloc)
@ -140,13 +155,9 @@ class HttpClient:
return None
async def post(self, url: str, message: Message) -> None:
async def post(self, url: str, message: Message, instance: Row | None = None) -> None:
await self.open()
# todo: cache inboxes to avoid opening a db connection
with get_app().database.connection() as conn:
instance = conn.get_inbox(url)
## Using the old algo by default is probably a better idea right now
# pylint: disable=consider-ternary-expression
if instance and instance['software'] in {'mastodon'}:

View file

@ -17,6 +17,7 @@ if typing.TYPE_CHECKING:
from collections.abc import Coroutine, Generator
from typing import Any
from .application import Application
from .cache import Cache
from .config import Config
from .database import Database
from .http_client import HttpClient
@ -268,6 +269,11 @@ class View(AbstractView):
return self.request.app
@property
def cache(self) -> Cache:
return self.app.cache
@property
def client(self) -> HttpClient:
return self.app.client

View file

@ -3,8 +3,6 @@ from __future__ import annotations
import tinysql
import typing
from cachetools import LRUCache
from . import logger as logging
from .misc import Message
@ -12,9 +10,6 @@ if typing.TYPE_CHECKING:
from .views import ActorView
cache = LRUCache(1024)
def person_check(actor: str, software: str) -> bool:
# pleroma and akkoma may use Person for the actor type for some reason
# akkoma changed this in 3.6.0
@ -29,31 +24,39 @@ def person_check(actor: str, software: str) -> bool:
async def handle_relay(view: ActorView) -> None:
if view.message.object_id in cache:
try:
view.cache.get('handle-relay', view.message.object_id)
logging.verbose('already relayed %s', view.message.object_id)
return
except KeyError:
pass
message = Message.new_announce(view.config.domain, view.message.object_id)
cache[view.message.object_id] = message.id
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
logging.debug('>> relay: %s', message)
with view.database.connection() as conn:
for inbox in conn.distill_inboxes(view.message):
view.app.push_message(inbox, message)
view.app.push_message(inbox, message, view.instance)
async def handle_forward(view: ActorView) -> None:
if view.message.id in cache:
logging.verbose('already forwarded %s', view.message.id)
try:
view.cache.get('handle-relay', view.message.object_id)
logging.verbose('already forwarded %s', view.message.object_id)
return
except KeyError:
pass
message = Message.new_announce(view.config.domain, view.message)
cache[view.message.id] = message.id
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
logging.debug('>> forward: %s', message)
with view.database.connection() as conn:
for inbox in conn.distill_inboxes(view.message):
view.app.push_message(inbox, message)
view.app.push_message(inbox, message, view.instance)
async def handle_follow(view: ActorView) -> None:

View file

@ -19,6 +19,7 @@ if typing.TYPE_CHECKING:
from aiohttp.web import Request
from aputils.signer import Signer
from collections.abc import Callable
from tinysql import Row
VIEWS = []
@ -98,7 +99,7 @@ class ActorView(View):
self.signature: Signature = None
self.message: Message = None
self.actor: Message = None
self.instance: dict[str, str] = None
self.instance: Row = None
self.signer: Signer = None
@ -168,7 +169,11 @@ class ActorView(View):
logging.verbose('actor not in message')
return Response.new_error(400, 'no actor in message', 'json')
self.actor = await self.client.get(self.signature.keyid, sign_headers = True)
self.actor = await self.client.get(
self.signature.keyid,
sign_headers = True,
loads = Message.parse
)
if not self.actor:
# ld signatures aren't handled atm, so just ignore it

View file

@ -1,8 +1,9 @@
aiohttp>=3.9.1
aputils@https://git.barkshark.xyz/barkshark/aputils/archive/0.1.6a.tar.gz
cachetools>=5.2.0
click>=8.1.2
hiredis==2.3.2
pyyaml>=6.0
redis==5.0.1
tinysql[postgres]@https://git.barkshark.xyz/barkshark/tinysql/archive/0.2.3.tar.gz
importlib_resources==6.1.1;python_version<'3.9'