diff --git a/docs/configuration.md b/docs/configuration.md
index bef0799..36f0463 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -35,6 +35,13 @@ SQL database backend to use. Valid values are `sqlite` or `postgres`.
database_type: sqlite
+### Cache type
+
+Cache backend to use. Valid values are `database` or `redis`
+
+ cache_type: database
+
+
### Sqlite File Path
Path to the sqlite database file. If the path is not absolute, it is relative to the config file.
@@ -47,7 +54,7 @@ directory.
In order to use the Postgresql backend, the user and database need to be created first.
- sudo -u postgres psql -c "CREATE USER activityrelay"
+ sudo -u postgres psql -c "CREATE USER activityrelay WITH PASSWORD SomeSecurePassword"
sudo -u postgres psql -c "CREATE DATABASE activityrelay OWNER activityrelay"
@@ -84,3 +91,47 @@ User to use when logging into the server.
Password for the specified user.
pass: null
+
+
+## Redis
+
+### Host
+
+Hostname, IP address, or unix socket the server is hosted on.
+
+ host: /var/run/postgresql
+
+
+### Port
+
+Port number the server is listening on.
+
+ port: 5432
+
+
+### Username
+
+User to use when logging into the server.
+
+ user: null
+
+
+### Password
+
+Password for the specified user.
+
+ pass: null
+
+
+### Database Number
+
+Number of the database to use.
+
+ database: 0
+
+
+### Prefix
+
+Text to prefix every key with. It cannot contain a `:` character.
+
+ prefix: activityrelay
diff --git a/relay.yaml.example b/relay.yaml.example
index 90b9e8f..823feaa 100644
--- a/relay.yaml.example
+++ b/relay.yaml.example
@@ -7,12 +7,15 @@ listen: 0.0.0.0
# [integer] Port the relay will listen on
port: 8080
-# [integer] Number of push workers to start (will get removed in a future update)
+# [integer] Number of push workers to start
workers: 8
# [string] Database backend to use. Valid values: sqlite, postgres
database_type: sqlite
+# [string] Cache backend to use. Valid values: database, redis
+cache_type: database
+
# [string] Path to the sqlite database file if the sqlite backend is in use
sqlite_path: relay.sqlite3
@@ -33,3 +36,24 @@ postgres:
# [string] name of the database to use
name: activityrelay
+
+# settings for the redis caching backend
+redis:
+
+ # [string] hostname or unix socket to connect to
+ host: localhost
+
+ # [integer] port of the server
+ port: 6379
+
+ # [string] username to use when logging into the server
+ user: null
+
+ # [string] password for the server
+ pass: null
+
+ # [integer] database number to use
+ database: 0
+
+ # [string] prefix for keys
+ prefix: activityrelay
diff --git a/relay/application.py b/relay/application.py
index 9440098..92bccb5 100644
--- a/relay/application.py
+++ b/relay/application.py
@@ -12,6 +12,7 @@ from aputils.signer import Signer
from datetime import datetime, timedelta
from . import logger as logging
+from .cache import get_cache
from .config import Config
from .database import get_database
from .http_client import HttpClient
@@ -19,8 +20,9 @@ from .misc import check_open_port
from .views import VIEWS
if typing.TYPE_CHECKING:
- from tinysql import Database
+ from tinysql import Database, Row
from typing import Any
+ from .cache import Cache
from .misc import Message
@@ -38,6 +40,7 @@ class Application(web.Application):
self['config'] = Config(cfgpath, load = True)
self['database'] = get_database(self.config)
self['client'] = HttpClient()
+ self['cache'] = get_cache(self)
self['workers'] = []
self['last_worker'] = 0
@@ -48,6 +51,11 @@ class Application(web.Application):
self.router.add_view(path, view)
+ @property
+ def cache(self) -> Cache:
+ return self['cache']
+
+
@property
def client(self) -> HttpClient:
return self['client']
@@ -87,13 +95,13 @@ class Application(web.Application):
return timedelta(seconds=uptime.seconds)
- def push_message(self, inbox: str, message: Message) -> None:
+ def push_message(self, inbox: str, message: Message, instance: Row) -> None:
if self.config.workers <= 0:
- asyncio.ensure_future(self.client.post(inbox, message))
+ asyncio.ensure_future(self.client.post(inbox, message, instance))
return
worker = self['workers'][self['last_worker']]
- worker.queue.put((inbox, message))
+ worker.queue.put((inbox, message, instance))
self['last_worker'] += 1
@@ -186,10 +194,10 @@ class PushWorker(threading.Thread):
while self.app['running']:
try:
- inbox, message = self.queue.get(block=True, timeout=0.25)
+ inbox, message, instance = self.queue.get(block=True, timeout=0.25)
self.queue.task_done()
logging.verbose('New push from Thread-%i', threading.get_ident())
- await self.client.post(inbox, message)
+ await self.client.post(inbox, message, instance)
except queue.Empty:
pass
diff --git a/relay/cache.py b/relay/cache.py
new file mode 100644
index 0000000..9c28108
--- /dev/null
+++ b/relay/cache.py
@@ -0,0 +1,288 @@
+from __future__ import annotations
+
+import json
+import os
+import typing
+
+from abc import ABC, abstractmethod
+from dataclasses import asdict, dataclass
+from datetime import datetime, timezone
+from redis import Redis
+
+from .database import get_database
+from .misc import Message, boolean
+
+if typing.TYPE_CHECKING:
+ from typing import Any
+ from collections.abc import Callable, Iterator
+ from tinysql import Database
+ from .application import Application
+
+
+# todo: implement more caching backends
+
+
+BACKENDS: dict[str, Cache] = {}
+CONVERTERS: dict[str, tuple[Callable, Callable]] = {
+ 'str': (str, str),
+ 'int': (str, int),
+ 'bool': (str, boolean),
+ 'json': (json.dumps, json.loads),
+ 'message': (lambda x: x.to_json(), Message.parse)
+}
+
+
+def get_cache(app: Application) -> Cache:
+ return BACKENDS[app.config.ca_type](app)
+
+
+def register_cache(backend: type[Cache]) -> type[Cache]:
+ BACKENDS[backend.name] = backend
+ return backend
+
+
+def serialize_value(value: Any, value_type: str = 'str') -> str:
+ if isinstance(value, str):
+ return value
+
+ return CONVERTERS[value_type][0](value)
+
+
+def deserialize_value(value: str, value_type: str = 'str') -> Any:
+ return CONVERTERS[value_type][1](value)
+
+
+@dataclass
+class Item:
+ namespace: str
+ key: str
+ value: Any
+ value_type: str
+ updated: datetime
+
+
+ def __post_init__(self):
+ if isinstance(self.updated, str):
+ self.updated = datetime.fromisoformat(self.updated)
+
+
+ @classmethod
+ def from_data(cls: type[Item], *args) -> Item:
+ data = cls(*args)
+ data.value = deserialize_value(data.value, data.value_type)
+
+ if not isinstance(data.updated, datetime):
+ data.updated = datetime.fromtimestamp(data.updated, tz = timezone.utc)
+
+ return data
+
+
+ def older_than(self, hours: int) -> bool:
+ delta = datetime.now(tz = timezone.utc) - self.updated
+ return (delta.total_seconds()) > hours * 3600
+
+
+ def to_dict(self) -> dict[str, Any]:
+ return asdict(self)
+
+
+class Cache(ABC):
+ name: str = 'null'
+
+
+ def __init__(self, app: Application):
+ self.app = app
+ self.setup()
+
+ @abstractmethod
+ def get(self, namespace: str, key: str) -> Item:
+ ...
+
+
+ @abstractmethod
+ def get_keys(self, namespace: str) -> Iterator[str]:
+ ...
+
+
+ @abstractmethod
+ def get_namespaces(self) -> Iterator[str]:
+ ...
+
+
+ @abstractmethod
+ def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> Item:
+ ...
+
+
+ @abstractmethod
+ def delete(self, namespace: str, key: str) -> None:
+ ...
+
+
+ @abstractmethod
+ def setup(self) -> None:
+ ...
+
+
+ def set_item(self, item: Item) -> Item:
+ return self.set(
+ item.namespace,
+ item.key,
+ item.value,
+ item.type
+ )
+
+
+ def delete_item(self, item: Item) -> None:
+ self.delete(item.namespace, item.key)
+
+
+@register_cache
+class SqlCache(Cache):
+ name: str = 'database'
+
+
+ def __init__(self, app: Application):
+ self._db = get_database(app.config)
+ Cache.__init__(self, app)
+
+
+ def get(self, namespace: str, key: str) -> Item:
+ params = {
+ 'namespace': namespace,
+ 'key': key
+ }
+
+ with self._db.connection() as conn:
+ with conn.exec_statement('get-cache-item', params) as cur:
+ if not (row := cur.one()):
+ raise KeyError(f'{namespace}:{key}')
+
+ row.pop('id', None)
+ return Item.from_data(*tuple(row.values()))
+
+
+ def get_keys(self, namespace: str) -> Iterator[str]:
+ with self._db.connection() as conn:
+ for row in conn.exec_statement('get-cache-keys', {'namespace': namespace}):
+ yield row['key']
+
+
+ def get_namespaces(self) -> Iterator[str]:
+ with self._db.connection() as conn:
+ for row in conn.exec_statement('get-cache-namespaces', None):
+ yield row['namespace']
+
+
+ def set(self, namespace: str, key: str, value: Any, value_type: str = 'str') -> Item:
+ params = {
+ 'namespace': namespace,
+ 'key': key,
+ 'value': serialize_value(value, value_type),
+ 'type': value_type,
+ 'date': datetime.now(tz = timezone.utc)
+ }
+
+ with self._db.connection() as conn:
+ with conn.exec_statement('set-cache-item', params) as conn:
+ row = conn.one()
+ row.pop('id', None)
+ return Item.from_data(*tuple(row.values()))
+
+
+ def delete(self, namespace: str, key: str) -> None:
+ params = {
+ 'namespace': namespace,
+ 'key': key
+ }
+
+ with self._db.connection() as conn:
+ with conn.exec_statement('del-cache-item', params):
+ pass
+
+
+ def setup(self) -> None:
+ with self._db.connection() as conn:
+ with conn.exec_statement(f'create-cache-table-{self._db.type.name.lower()}', None):
+ pass
+
+
+@register_cache
+class RedisCache(Cache):
+ name: str = 'redis'
+ _rd: Redis
+
+
+ @property
+ def prefix(self) -> str:
+ return self.app.config.rd_prefix
+
+
+ def get_key_name(self, namespace: str, key: str) -> str:
+ return f'{self.prefix}:{namespace}:{key}'
+
+
+ def get(self, namespace: str, key: str) -> Item:
+ key_name = self.get_key_name(namespace, key)
+
+ if not (raw_value := self._rd.get(key_name)):
+ raise KeyError(f'{namespace}:{key}')
+
+ value_type, updated, value = raw_value.split(':', 2)
+ return Item.from_data(
+ namespace,
+ key,
+ value,
+ value_type,
+ datetime.fromtimestamp(float(updated), tz = timezone.utc)
+ )
+
+
+ def get_keys(self, namespace: str) -> Iterator[str]:
+ for key in self._rd.keys(self.get_key_name(namespace, '*')):
+ *_, key_name = key.split(':', 2)
+ yield key_name
+
+
+ def get_namespaces(self) -> Iterator[str]:
+ namespaces = []
+
+ for key in self._rd.keys(f'{self.prefix}:*'):
+ _, namespace, _ = key.split(':', 2)
+
+ if namespace not in namespaces:
+ namespaces.append(namespace)
+ yield namespace
+
+
+ def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> None:
+ date = datetime.now(tz = timezone.utc).timestamp()
+ value = serialize_value(value, value_type)
+
+ self._rd.set(
+ self.get_key_name(namespace, key),
+ f'{value_type}:{date}:{value}'
+ )
+
+
+ def delete(self, namespace: str, key: str) -> None:
+ self._rd.delete(self.get_key_name(namespace, key))
+
+
+ def setup(self) -> None:
+ options = {
+ 'client_name': f'ActivityRelay_{self.app.config.domain}',
+ 'decode_responses': True,
+ 'username': self.app.config.rd_user,
+ 'password': self.app.config.rd_pass,
+ 'db': self.app.config.rd_database
+ }
+
+ if os.path.exists(self.app.config.rd_host):
+ options['unix_socket_path'] = self.app.config.rd_host
+
+ else:
+ options['host'] = self.app.config.rd_host
+ options['port'] = self.app.config.rd_port
+
+ self._rd = Redis(**options)
diff --git a/relay/config.py b/relay/config.py
index 58c2eb8..eff6215 100644
--- a/relay/config.py
+++ b/relay/config.py
@@ -19,12 +19,21 @@ DEFAULTS: dict[str, Any] = {
'domain': 'relay.example.com',
'workers': len(os.sched_getaffinity(0)),
'db_type': 'sqlite',
+ 'ca_type': 'database',
'sq_path': 'relay.sqlite3',
+
'pg_host': '/var/run/postgresql',
'pg_port': 5432,
'pg_user': getpass.getuser(),
'pg_pass': None,
- 'pg_name': 'activityrelay'
+ 'pg_name': 'activityrelay',
+
+ 'rd_host': 'localhost',
+ 'rd_port': 6379,
+ 'rd_user': None,
+ 'rd_pass': None,
+ 'rd_database': 0,
+ 'rd_prefix': 'activityrelay'
}
if IS_DOCKER:
@@ -40,13 +49,22 @@ class Config:
self.domain = None
self.workers = None
self.db_type = None
+ self.ca_type = None
self.sq_path = None
+
self.pg_host = None
self.pg_port = None
self.pg_user = None
self.pg_pass = None
self.pg_name = None
+ self.rd_host = None
+ self.rd_port = None
+ self.rd_user = None
+ self.rd_pass = None
+ self.rd_database = None
+ self.rd_prefix = None
+
if load:
try:
self.load()
@@ -92,6 +110,7 @@ class Config:
with self.path.open('r', encoding = 'UTF-8') as fd:
config = yaml.load(fd, **options)
pgcfg = config.get('postgresql', {})
+ rdcfg = config.get('redis', {})
if not config:
raise ValueError('Config is empty')
@@ -106,18 +125,25 @@ class Config:
self.set('port', config.get('port', DEFAULTS['port']))
self.set('sq_path', config.get('sqlite_path', DEFAULTS['sq_path']))
+ self.set('workers', config.get('workers', DEFAULTS['workers']))
self.set('domain', config.get('domain', DEFAULTS['domain']))
self.set('db_type', config.get('database_type', DEFAULTS['db_type']))
+ self.set('ca_type', config.get('cache_type', DEFAULTS['ca_type']))
for key in DEFAULTS:
- if not key.startswith('pg'):
- continue
+ if key.startswith('pg'):
+ try:
+ self.set(key, pgcfg[key[3:]])
- try:
- self.set(key, pgcfg[key[3:]])
+ except KeyError:
+ continue
- except KeyError:
- continue
+ elif key.startswith('rd'):
+ try:
+ self.set(key, rdcfg[key[3:]])
+
+ except KeyError:
+ continue
def reset(self) -> None:
@@ -132,7 +158,9 @@ class Config:
'listen': self.listen,
'port': self.port,
'domain': self.domain,
+ 'workers': self.workers,
'database_type': self.db_type,
+ 'cache_type': self.ca_type,
'sqlite_path': self.sq_path,
'postgres': {
'host': self.pg_host,
@@ -140,6 +168,14 @@ class Config:
'user': self.pg_user,
'pass': self.pg_pass,
'name': self.pg_name
+ },
+ 'redis': {
+ 'host': self.rd_host,
+ 'port': self.rd_port,
+ 'user': self.rd_user,
+ 'pass': self.rd_pass,
+ 'database': self.rd_database,
+ 'refix': self.rd_prefix
}
}
diff --git a/relay/data/statements.sql b/relay/data/statements.sql
index a262feb..9bcea41 100644
--- a/relay/data/statements.sql
+++ b/relay/data/statements.sql
@@ -77,3 +77,64 @@ RETURNING *
-- name: del-domain-whitelist
DELETE FROM whitelist
WHERE domain = :domain
+
+
+-- cache functions --
+
+-- name: create-cache-table-sqlite
+CREATE TABLE IF NOT EXISTS cache (
+ id INTEGER PRIMARY KEY UNIQUE,
+ namespace TEXT NOT NULL,
+ key TEXT NOT NULL,
+ "value" TEXT,
+ type TEXT DEFAULT 'str',
+ updated TIMESTAMP NOT NULL,
+ UNIQUE(namespace, key)
+)
+
+-- name: create-cache-table-postgres
+CREATE TABLE IF NOT EXISTS cache (
+ id SERIAL PRIMARY KEY,
+ namespace TEXT NOT NULL,
+ key TEXT NOT NULL,
+ "value" TEXT,
+ type TEXT DEFAULT 'str',
+ updated TIMESTAMP NOT NULL,
+ UNIQUE(namespace, key)
+)
+
+
+-- name: get-cache-item
+SELECT * FROM cache
+WHERE namespace = :namespace and key = :key
+
+
+-- name: get-cache-keys
+SELECT key FROM cache
+WHERE namespace = :namespace
+
+
+-- name: get-cache-namespaces
+SELECT DISTINCT namespace FROM cache
+
+
+-- name: set-cache-item
+INSERT INTO cache (namespace, key, value, type, updated)
+VALUES (:namespace, :key, :value, :type, :date)
+ON CONFLICT (namespace, key) DO
+UPDATE SET value = :value, type = :type, updated = :date
+RETURNING *
+
+
+-- name: del-cache-item
+DELETE FROM cache
+WHERE namespace = :namespace and key = :key
+
+
+-- name: del-cache-namespace
+DELETE FROM cache
+WHERE namespace = :namespace
+
+
+-- name: del-cache-all
+DELETE FROM cache
diff --git a/relay/http_client.py b/relay/http_client.py
index 5040059..6bed4ae 100644
--- a/relay/http_client.py
+++ b/relay/http_client.py
@@ -1,5 +1,6 @@
from __future__ import annotations
+import json
import traceback
import typing
@@ -7,7 +8,6 @@ from aiohttp import ClientSession, ClientTimeout, TCPConnector
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
from aputils.objects import Nodeinfo, WellKnownNodeinfo
-from cachetools import LRUCache
from json.decoder import JSONDecodeError
from urllib.parse import urlparse
@@ -16,7 +16,11 @@ from . import logger as logging
from .misc import MIMETYPES, Message, get_app
if typing.TYPE_CHECKING:
+ from aputils import Signer
+ from tinysql import Row
from typing import Any
+ from .application import Application
+ from .cache import Cache
HEADERS = {
@@ -26,8 +30,7 @@ HEADERS = {
class HttpClient:
- def __init__(self, limit: int = 100, timeout: int = 10, cache_size: int = 1024):
- self.cache = LRUCache(cache_size)
+ def __init__(self, limit: int = 100, timeout: int = 10):
self.limit = limit
self.timeout = timeout
self._conn = None
@@ -43,6 +46,21 @@ class HttpClient:
await self.close()
+ @property
+ def app(self) -> Application:
+ return get_app()
+
+
+ @property
+ def cache(self) -> Cache:
+ return self.app.cache
+
+
+ @property
+ def signer(self) -> Signer:
+ return self.app.signer
+
+
async def open(self) -> None:
if self._session:
return
@@ -74,8 +92,8 @@ class HttpClient:
async def get(self, # pylint: disable=too-many-branches
url: str,
sign_headers: bool = False,
- loads: callable | None = None,
- force: bool = False) -> Message | dict | None:
+ loads: callable = json.loads,
+ force: bool = False) -> dict | None:
await self.open()
@@ -85,13 +103,20 @@ class HttpClient:
except ValueError:
pass
- if not force and url in self.cache:
- return self.cache[url]
+ if not force:
+ try:
+ item = self.cache.get('request', url)
+
+ if not item.older_than(48):
+ return loads(item.value)
+
+ except KeyError:
+ logging.verbose('Failed to fetch cached data for url: %s', url)
headers = {}
if sign_headers:
- get_app().signer.sign_headers('GET', url, algorithm = 'original')
+ self.signer.sign_headers('GET', url, algorithm = 'original')
try:
logging.debug('Fetching resource: %s', url)
@@ -101,32 +126,22 @@ class HttpClient:
if resp.status == 202:
return None
- if resp.status != 200:
- logging.verbose('Received error when requesting %s: %i', url, resp.status)
- logging.debug(await resp.read())
- return None
+ data = await resp.read()
- if loads:
- message = await resp.json(loads=loads)
+ if resp.status != 200:
+ logging.verbose('Received error when requesting %s: %i', url, resp.status)
+ logging.debug(await resp.read())
+ return None
- elif resp.content_type == MIMETYPES['activity']:
- message = await resp.json(loads = Message.parse)
+ message = loads(data)
+ self.cache.set('request', url, data.decode('utf-8'), 'str')
+ logging.debug('%s >> resp %s', url, json.dumps(message, indent = 4))
- elif resp.content_type == MIMETYPES['json']:
- message = await resp.json()
-
- else:
- logging.verbose('Invalid Content-Type for "%s": %s', url, resp.content_type)
- logging.debug('Response: %s', await resp.read())
- return None
-
- logging.debug('%s >> resp %s', url, message.to_json(4))
-
- self.cache[url] = message
- return message
+ return message
except JSONDecodeError:
logging.verbose('Failed to parse JSON')
+ return None
except ClientSSLError:
logging.verbose('SSL error when connecting to %s', urlparse(url).netloc)
@@ -140,13 +155,9 @@ class HttpClient:
return None
- async def post(self, url: str, message: Message) -> None:
+ async def post(self, url: str, message: Message, instance: Row | None = None) -> None:
await self.open()
- # todo: cache inboxes to avoid opening a db connection
- with get_app().database.connection() as conn:
- instance = conn.get_inbox(url)
-
## Using the old algo by default is probably a better idea right now
# pylint: disable=consider-ternary-expression
if instance and instance['software'] in {'mastodon'}:
diff --git a/relay/manage.py b/relay/manage.py
index 69930dc..df5b4cb 100644
--- a/relay/manage.py
+++ b/relay/manage.py
@@ -51,6 +51,13 @@ SOFTWARE = (
)
+def check_alphanumeric(text: str) -> str:
+ if not text.isalnum():
+ raise click.BadParameter('String not alphanumeric')
+
+ return text
+
+
@click.group('cli', context_settings={'show_default': True}, invoke_without_command=True)
@click.option('--config', '-c', default='relay.yaml', help='path to the relay\'s config')
@click.version_option(version=__version__, prog_name='ActivityRelay')
@@ -113,8 +120,9 @@ def cli_setup(ctx: click.Context) -> None:
)
ctx.obj.config.pg_host = click.prompt(
- 'What IP address or hostname does the server listen on?',
- default = ctx.obj.config.pg_host
+ 'What IP address, hostname, or unix socket does the server listen on?',
+ default = ctx.obj.config.pg_host,
+ type = int
)
ctx.obj.config.pg_port = click.prompt(
@@ -135,6 +143,48 @@ def cli_setup(ctx: click.Context) -> None:
default = ctx.obj.config.pg_pass or ""
) or None
+ ctx.obj.config.ca_type = click.prompt(
+ 'Which caching backend?',
+ default = ctx.obj.config.ca_type,
+ type = click.Choice(['database', 'redis'], case_sensitive = False)
+ )
+
+ if ctx.obj.config.ca_type == 'redis':
+ ctx.obj.config.rd_host = click.prompt(
+ 'What IP address, hostname, or unix socket does the server listen on?',
+ default = ctx.obj.config.rd_host
+ )
+
+ ctx.obj.config.rd_port = click.prompt(
+ 'What port does the server listen on?',
+ default = ctx.obj.config.rd_port,
+ type = int
+ )
+
+ ctx.obj.config.rd_user = click.prompt(
+ 'Which user will authenticate with the server',
+ default = ctx.obj.config.rd_user
+ )
+
+ ctx.obj.config.rd_pass = click.prompt(
+ 'User password',
+ hide_input = True,
+ show_default = False,
+ default = ctx.obj.config.rd_pass or ""
+ ) or None
+
+ ctx.obj.config.rd_database = click.prompt(
+ 'Which database number to use?',
+ default = ctx.obj.config.rd_database,
+ type = int
+ )
+
+ ctx.obj.config.rd_prefix = click.prompt(
+ 'What text should each cache key be prefixed with?',
+ default = ctx.obj.config.rd_database,
+ type = check_alphanumeric
+ )
+
ctx.obj.config.save()
config = {
@@ -364,7 +414,7 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
actor = actor
)
- asyncio.run(http.post(inbox, message))
+ asyncio.run(http.post(inbox, message, None, inbox_data))
click.echo(f'Sent follow message to actor: {actor}')
@@ -405,7 +455,7 @@ def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
}
)
- asyncio.run(http.post(inbox, message))
+ asyncio.run(http.post(inbox, message, inbox_data))
click.echo(f'Sent unfollow message to: {actor}')
diff --git a/relay/misc.py b/relay/misc.py
index d7e96d8..2a72f22 100644
--- a/relay/misc.py
+++ b/relay/misc.py
@@ -15,8 +15,10 @@ from uuid import uuid4
if typing.TYPE_CHECKING:
from collections.abc import Coroutine, Generator
- from typing import Any
+ from tinysql import Connection
+ from typing import Any, Awaitable
from .application import Application
+ from .cache import Cache
from .config import Config
from .database import Database
from .http_client import HttpClient
@@ -240,7 +242,15 @@ class View(AbstractView):
if not (handler := self.handlers.get(self.request.method)):
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) from None
- return handler(self.request, **self.request.match_info).__await__()
+ return self._run_handler(handler).__await__()
+
+
+ async def _run_handler(self, handler: Awaitable) -> Response:
+ with self.database.config.connection_class(self.database) as conn:
+ # todo: remove on next tinysql release
+ conn.open()
+
+ return await handler(self.request, conn, **self.request.match_info)
@cached_property
@@ -268,6 +278,11 @@ class View(AbstractView):
return self.request.app
+ @property
+ def cache(self) -> Cache:
+ return self.app.cache
+
+
@property
def client(self) -> HttpClient:
return self.app.client
diff --git a/relay/processors.py b/relay/processors.py
index 8cf7722..d9780d1 100644
--- a/relay/processors.py
+++ b/relay/processors.py
@@ -1,20 +1,15 @@
from __future__ import annotations
-import tinysql
import typing
-from cachetools import LRUCache
-
from . import logger as logging
+from .database.connection import Connection
from .misc import Message
if typing.TYPE_CHECKING:
from .views import ActorView
-cache = LRUCache(1024)
-
-
def person_check(actor: str, software: str) -> bool:
# pleroma and akkoma may use Person for the actor type for some reason
# akkoma changed this in 3.6.0
@@ -28,83 +23,87 @@ def person_check(actor: str, software: str) -> bool:
return False
-async def handle_relay(view: ActorView) -> None:
- if view.message.object_id in cache:
+async def handle_relay(view: ActorView, conn: Connection) -> None:
+ try:
+ view.cache.get('handle-relay', view.message.object_id)
logging.verbose('already relayed %s', view.message.object_id)
return
+ except KeyError:
+ pass
+
message = Message.new_announce(view.config.domain, view.message.object_id)
- cache[view.message.object_id] = message.id
logging.debug('>> relay: %s', message)
- with view.database.connection() as conn:
- for inbox in conn.distill_inboxes(view.message):
- view.app.push_message(inbox, message)
+ for inbox in conn.distill_inboxes(view.message):
+ view.app.push_message(inbox, message, view.instance)
+
+ view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
-async def handle_forward(view: ActorView) -> None:
- if view.message.id in cache:
- logging.verbose('already forwarded %s', view.message.id)
+async def handle_forward(view: ActorView, conn: Connection) -> None:
+ try:
+ view.cache.get('handle-relay', view.message.object_id)
+ logging.verbose('already forwarded %s', view.message.object_id)
return
+ except KeyError:
+ pass
+
message = Message.new_announce(view.config.domain, view.message)
- cache[view.message.id] = message.id
logging.debug('>> forward: %s', message)
- with view.database.connection() as conn:
- for inbox in conn.distill_inboxes(view.message):
- view.app.push_message(inbox, message)
+ for inbox in conn.distill_inboxes(view.message):
+ view.app.push_message(inbox, message, view.instance)
+
+ view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
-async def handle_follow(view: ActorView) -> None:
+async def handle_follow(view: ActorView, conn: Connection) -> None:
nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain)
software = nodeinfo.sw_name if nodeinfo else None
- with view.database.connection() as conn:
- # reject if software used by actor is banned
- if conn.get_software_ban(software):
- view.app.push_message(
- view.actor.shared_inbox,
- Message.new_response(
- host = view.config.domain,
- actor = view.actor.id,
- followid = view.message.id,
- accept = False
- )
+ # reject if software used by actor is banned
+ if conn.get_software_ban(software):
+ view.app.push_message(
+ view.actor.shared_inbox,
+ Message.new_response(
+ host = view.config.domain,
+ actor = view.actor.id,
+ followid = view.message.id,
+ accept = False
)
+ )
- logging.verbose(
- 'Rejected follow from actor for using specific software: actor=%s, software=%s',
- view.actor.id,
- software
+ logging.verbose(
+ 'Rejected follow from actor for using specific software: actor=%s, software=%s',
+ view.actor.id,
+ software
+ )
+
+ return
+
+ ## reject if the actor is not an instance actor
+ if person_check(view.actor, software):
+ view.app.push_message(
+ view.actor.shared_inbox,
+ Message.new_response(
+ host = view.config.domain,
+ actor = view.actor.id,
+ followid = view.message.id,
+ accept = False
)
+ )
- return
+ logging.verbose('Non-application actor tried to follow: %s', view.actor.id)
+ return
- ## reject if the actor is not an instance actor
- if person_check(view.actor, software):
- view.app.push_message(
- view.actor.shared_inbox,
- Message.new_response(
- host = view.config.domain,
- actor = view.actor.id,
- followid = view.message.id,
- accept = False
- )
- )
+ if conn.get_inbox(view.actor.shared_inbox):
+ view.instance = conn.update_inbox(view.actor.shared_inbox, followid = view.message.id)
- logging.verbose('Non-application actor tried to follow: %s', view.actor.id)
- return
-
- if conn.get_inbox(view.actor.shared_inbox):
- data = {'followid': view.message.id}
- statement = tinysql.Update('inboxes', data, inbox = view.actor.shared_inbox)
-
- with conn.query(statement):
- pass
-
- else:
- conn.put_inbox(
+ else:
+ with conn.transaction():
+ view.instance = conn.put_inbox(
view.actor.domain,
view.actor.shared_inbox,
view.actor.id,
@@ -112,35 +111,37 @@ async def handle_follow(view: ActorView) -> None:
software
)
+ view.app.push_message(
+ view.actor.shared_inbox,
+ Message.new_response(
+ host = view.config.domain,
+ actor = view.actor.id,
+ followid = view.message.id,
+ accept = True
+ ),
+ view.instance
+ )
+
+ # Are Akkoma and Pleroma the only two that expect a follow back?
+ # Ignoring only Mastodon for now
+ if software != 'mastodon':
view.app.push_message(
view.actor.shared_inbox,
- Message.new_response(
+ Message.new_follow(
host = view.config.domain,
- actor = view.actor.id,
- followid = view.message.id,
- accept = True
- )
+ actor = view.actor.id
+ ),
+ view.instance
)
- # Are Akkoma and Pleroma the only two that expect a follow back?
- # Ignoring only Mastodon for now
- if software != 'mastodon':
- view.app.push_message(
- view.actor.shared_inbox,
- Message.new_follow(
- host = view.config.domain,
- actor = view.actor.id
- )
- )
-
-async def handle_undo(view: ActorView) -> None:
+async def handle_undo(view: ActorView, conn: Connection) -> None:
## If the object is not a Follow, forward it
if view.message.object['type'] != 'Follow':
- await handle_forward(view)
+ await handle_forward(view, conn)
return
- with view.database.connection() as conn:
+ with conn.transaction():
if not conn.del_inbox(view.actor.id):
logging.verbose(
'Failed to delete "%s" with follow ID "%s"',
@@ -154,7 +155,8 @@ async def handle_undo(view: ActorView) -> None:
host = view.config.domain,
actor = view.actor.id,
follow = view.message
- )
+ ),
+ view.instance
)
@@ -168,7 +170,7 @@ processors = {
}
-async def run_processor(view: ActorView) -> None:
+async def run_processor(view: ActorView, conn: Connection) -> None:
if view.message.type not in processors:
logging.verbose(
'Message type "%s" from actor cannot be handled: %s',
@@ -179,20 +181,19 @@ async def run_processor(view: ActorView) -> None:
return
if view.instance:
- if not view.instance['software']:
- if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])):
- with view.database.connection() as conn:
+ with conn.transaction():
+ if not view.instance['software']:
+ if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])):
view.instance = conn.update_inbox(
view.instance['inbox'],
software = nodeinfo.sw_name
)
- if not view.instance['actor']:
- with view.database.connection() as conn:
+ if not view.instance['actor']:
view.instance = conn.update_inbox(
view.instance['inbox'],
actor = view.actor.id
)
logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id)
- await processors[view.message.type](view)
+ await processors[view.message.type](view, conn)
diff --git a/relay/views.py b/relay/views.py
index 8b84c02..cb648a2 100644
--- a/relay/views.py
+++ b/relay/views.py
@@ -1,6 +1,5 @@
from __future__ import annotations
-import asyncio
import subprocess
import traceback
import typing
@@ -12,6 +11,7 @@ from pathlib import Path
from . import __version__
from . import logger as logging
+from .database.connection import Connection
from .misc import Message, Response, View
from .processors import run_processor
@@ -19,6 +19,7 @@ if typing.TYPE_CHECKING:
from aiohttp.web import Request
from aputils.signer import Signer
from collections.abc import Callable
+ from tinysql import Row
VIEWS = []
@@ -74,17 +75,16 @@ def register_route(*paths: str) -> Callable:
@register_route('/')
class HomeView(View):
- async def get(self, request: Request) -> Response:
- with self.database.connection() as conn:
- config = conn.get_config_all()
- inboxes = conn.execute('SELECT * FROM inboxes').all()
+ async def get(self, request: Request, conn: Connection) -> Response:
+ config = conn.get_config_all()
+ inboxes = conn.execute('SELECT * FROM inboxes').all()
- text = HOME_TEMPLATE.format(
- host = self.config.domain,
- note = config['note'],
- count = len(inboxes),
- targets = '
'.join(inbox['domain'] for inbox in inboxes)
- )
+ text = HOME_TEMPLATE.format(
+ host = self.config.domain,
+ note = config['note'],
+ count = len(inboxes),
+ targets = '
'.join(inbox['domain'] for inbox in inboxes)
+ )
return Response.new(text, ctype='html')
@@ -98,11 +98,11 @@ class ActorView(View):
self.signature: Signature = None
self.message: Message = None
self.actor: Message = None
- self.instance: dict[str, str] = None
+ self.instance: Row = None
self.signer: Signer = None
- async def get(self, request: Request) -> Response:
+ async def get(self, request: Request, conn: Connection) -> Response:
data = Message.new_actor(
host = self.config.domain,
pubkey = self.app.signer.pubkey
@@ -111,37 +111,36 @@ class ActorView(View):
return Response.new(data, ctype='activity')
- async def post(self, request: Request) -> Response:
+ async def post(self, request: Request, conn: Connection) -> Response:
if response := await self.get_post_data():
return response
- with self.database.connection() as conn:
- self.instance = conn.get_inbox(self.actor.shared_inbox)
- config = conn.get_config_all()
+ self.instance = conn.get_inbox(self.actor.shared_inbox)
+ config = conn.get_config_all()
- ## reject if the actor isn't whitelisted while the whiltelist is enabled
- if config['whitelist-enabled'] and not conn.get_domain_whitelist(self.actor.domain):
- logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id)
- return Response.new_error(403, 'access denied', 'json')
+ ## reject if the actor isn't whitelisted while the whiltelist is enabled
+ if config['whitelist-enabled'] and not conn.get_domain_whitelist(self.actor.domain):
+ logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id)
+ return Response.new_error(403, 'access denied', 'json')
- ## reject if actor is banned
- if conn.get_domain_ban(self.actor.domain):
- logging.verbose('Ignored request from banned actor: %s', self.actor.id)
- return Response.new_error(403, 'access denied', 'json')
+ ## reject if actor is banned
+ if conn.get_domain_ban(self.actor.domain):
+ logging.verbose('Ignored request from banned actor: %s', self.actor.id)
+ return Response.new_error(403, 'access denied', 'json')
- ## reject if activity type isn't 'Follow' and the actor isn't following
- if self.message.type != 'Follow' and not self.instance:
- logging.verbose(
- 'Rejected actor for trying to post while not following: %s',
- self.actor.id
- )
+ ## reject if activity type isn't 'Follow' and the actor isn't following
+ if self.message.type != 'Follow' and not self.instance:
+ logging.verbose(
+ 'Rejected actor for trying to post while not following: %s',
+ self.actor.id
+ )
- return Response.new_error(401, 'access denied', 'json')
+ return Response.new_error(401, 'access denied', 'json')
- logging.debug('>> payload %s', self.message.to_json(4))
+ logging.debug('>> payload %s', self.message.to_json(4))
- asyncio.ensure_future(run_processor(self))
- return Response.new(status = 202)
+ await run_processor(self, conn)
+ return Response.new(status = 202)
async def get_post_data(self) -> Response | None:
@@ -168,7 +167,11 @@ class ActorView(View):
logging.verbose('actor not in message')
return Response.new_error(400, 'no actor in message', 'json')
- self.actor = await self.client.get(self.signature.keyid, sign_headers = True)
+ self.actor = await self.client.get(
+ self.signature.keyid,
+ sign_headers = True,
+ loads = Message.parse
+ )
if not self.actor:
# ld signatures aren't handled atm, so just ignore it
@@ -227,7 +230,7 @@ class ActorView(View):
@register_route('/.well-known/webfinger')
class WebfingerView(View):
- async def get(self, request: Request) -> Response:
+ async def get(self, request: Request, conn: Connection) -> Response:
try:
subject = request.query['resource']
@@ -248,18 +251,18 @@ class WebfingerView(View):
@register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}')
class NodeinfoView(View):
- async def get(self, request: Request, niversion: str) -> Response:
- with self.database.connection() as conn:
- inboxes = conn.execute('SELECT * FROM inboxes').all()
+ # pylint: disable=no-self-use
+ async def get(self, request: Request, conn: Connection, niversion: str) -> Response:
+ inboxes = conn.execute('SELECT * FROM inboxes').all()
- data = {
- 'name': 'activityrelay',
- 'version': VERSION,
- 'protocols': ['activitypub'],
- 'open_regs': not conn.get_config('whitelist-enabled'),
- 'users': 1,
- 'metadata': {'peers': [inbox['domain'] for inbox in inboxes]}
- }
+ data = {
+ 'name': 'activityrelay',
+ 'version': VERSION,
+ 'protocols': ['activitypub'],
+ 'open_regs': not conn.get_config('whitelist-enabled'),
+ 'users': 1,
+ 'metadata': {'peers': [inbox['domain'] for inbox in inboxes]}
+ }
if niversion == '2.1':
data['repo'] = 'https://git.pleroma.social/pleroma/relay'
@@ -269,6 +272,6 @@ class NodeinfoView(View):
@register_route('/.well-known/nodeinfo')
class WellknownNodeinfoView(View):
- async def get(self, request: Request) -> Response:
+ async def get(self, request: Request, conn: Connection) -> Response:
data = WellKnownNodeinfo.new_template(self.config.domain)
return Response.new(data, ctype = 'json')
diff --git a/requirements.txt b/requirements.txt
index 3fb5e40..979a894 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,8 +1,9 @@
aiohttp>=3.9.1
aputils@https://git.barkshark.xyz/barkshark/aputils/archive/0.1.6a.tar.gz
-cachetools>=5.2.0
click>=8.1.2
+hiredis==2.3.2
pyyaml>=6.0
-tinysql[postgres]@https://git.barkshark.xyz/barkshark/tinysql/archive/0.2.3.tar.gz
+redis==5.0.1
+tinysql[postgres]@https://git.barkshark.xyz/barkshark/tinysql/archive/0.2.4.tar.gz
importlib_resources==6.1.1;python_version<'3.9'