Compare commits

..

11 commits

Author SHA1 Message Date
Izalia Mae b8e0641733 Merge branch 'cache' into 'master'
caching

See merge request pleroma/relay!54
2024-02-05 18:34:43 +00:00
Izalia Mae bec5d5f207 use gunicorn to start the server 2024-02-05 13:15:08 -05:00
Izalia Mae 02ac1fa53b make sure db connection for request is open 2024-02-04 05:17:51 -05:00
Izalia Mae 2fcaea85ae create a new database connection for each request 2024-02-04 04:53:39 -05:00
Izalia Mae e6f30ddf64 update tinysql to 0.2.4 2024-02-04 04:41:04 -05:00
Izalia Mae 64690a5c05 create new Database object for SqlCache 2024-02-04 04:40:51 -05:00
Izalia Mae 46413be2af make sure Item.updated is a datetime object if it isn't one already 2024-02-03 05:40:57 -05:00
Izalia Mae 3d81e5ef68 pass instance row to HttpClient.post 2024-02-01 21:40:27 -05:00
Izalia Mae 1668d96485 add setup questions for redis 2024-02-01 21:37:46 -05:00
Izalia Mae 4c4dd3566b cache fixes
* make sure Item.updated is a datetime object
* remove id column when creating Item objects in SqlCache
2024-02-01 11:43:17 -05:00
Izalia Mae 2d641ea183 add database and redis caching 2024-01-31 21:23:45 -05:00
13 changed files with 851 additions and 292 deletions

View file

@ -35,6 +35,13 @@ SQL database backend to use. Valid values are `sqlite` or `postgres`.
database_type: sqlite database_type: sqlite
### Cache type
Cache backend to use. Valid values are `database` or `redis`
cache_type: database
### Sqlite File Path ### Sqlite File Path
Path to the sqlite database file. If the path is not absolute, it is relative to the config file. Path to the sqlite database file. If the path is not absolute, it is relative to the config file.
@ -47,7 +54,7 @@ directory.
In order to use the Postgresql backend, the user and database need to be created first. In order to use the Postgresql backend, the user and database need to be created first.
sudo -u postgres psql -c "CREATE USER activityrelay" sudo -u postgres psql -c "CREATE USER activityrelay WITH PASSWORD SomeSecurePassword"
sudo -u postgres psql -c "CREATE DATABASE activityrelay OWNER activityrelay" sudo -u postgres psql -c "CREATE DATABASE activityrelay OWNER activityrelay"
@ -84,3 +91,47 @@ User to use when logging into the server.
Password for the specified user. Password for the specified user.
pass: null pass: null
## Redis
### Host
Hostname, IP address, or unix socket the server is hosted on.
host: /var/run/postgresql
### Port
Port number the server is listening on.
port: 5432
### Username
User to use when logging into the server.
user: null
### Password
Password for the specified user.
pass: null
### Database Number
Number of the database to use.
database: 0
### Prefix
Text to prefix every key with. It cannot contain a `:` character.
prefix: activityrelay

View file

@ -7,12 +7,15 @@ listen: 0.0.0.0
# [integer] Port the relay will listen on # [integer] Port the relay will listen on
port: 8080 port: 8080
# [integer] Number of push workers to start (will get removed in a future update) # [integer] Number of push workers to start
workers: 8 workers: 8
# [string] Database backend to use. Valid values: sqlite, postgres # [string] Database backend to use. Valid values: sqlite, postgres
database_type: sqlite database_type: sqlite
# [string] Cache backend to use. Valid values: database, redis
cache_type: database
# [string] Path to the sqlite database file if the sqlite backend is in use # [string] Path to the sqlite database file if the sqlite backend is in use
sqlite_path: relay.sqlite3 sqlite_path: relay.sqlite3
@ -33,3 +36,24 @@ postgres:
# [string] name of the database to use # [string] name of the database to use
name: activityrelay name: activityrelay
# settings for the redis caching backend
redis:
# [string] hostname or unix socket to connect to
host: localhost
# [integer] port of the server
port: 6379
# [string] username to use when logging into the server
user: null
# [string] password for the server
pass: null
# [integer] database number to use
database: 0
# [string] prefix for keys
prefix: activityrelay

View file

@ -1,17 +1,20 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import queue import os
import signal import signal
import threading import subprocess
import traceback import sys
import time
import typing import typing
from aiohttp import web from aiohttp import web
from aputils.signer import Signer from aputils.signer import Signer
from datetime import datetime, timedelta from datetime import datetime, timedelta
from gunicorn.app.wsgiapp import WSGIApplication
from . import logger as logging from . import logger as logging
from .cache import get_cache
from .config import Config from .config import Config
from .database import get_database from .database import get_database
from .http_client import HttpClient from .http_client import HttpClient
@ -19,8 +22,10 @@ from .misc import check_open_port
from .views import VIEWS from .views import VIEWS
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from tinysql import Database from collections.abc import Awaitable
from tinysql import Database, Row
from typing import Any from typing import Any
from .cache import Cache
from .misc import Message from .misc import Message
@ -29,25 +34,34 @@ if typing.TYPE_CHECKING:
class Application(web.Application): class Application(web.Application):
DEFAULT: Application = None DEFAULT: Application = None
def __init__(self, cfgpath: str): def __init__(self, cfgpath: str, gunicorn: bool = False):
web.Application.__init__(self) web.Application.__init__(self)
Application.DEFAULT = self Application.DEFAULT = self
self['proc'] = None
self['signer'] = None self['signer'] = None
self['start_time'] = None
self['config'] = Config(cfgpath, load = True) self['config'] = Config(cfgpath, load = True)
self['database'] = get_database(self.config) self['database'] = get_database(self.config)
self['client'] = HttpClient() self['client'] = HttpClient()
self['cache'] = get_cache(self)
self['workers'] = [] if not gunicorn:
self['last_worker'] = 0 return
self['start_time'] = None
self['running'] = False self.on_response_prepare.append(handle_access_log)
for path, view in VIEWS: for path, view in VIEWS:
self.router.add_view(path, view) self.router.add_view(path, view)
@property
def cache(self) -> Cache:
return self['cache']
@property @property
def client(self) -> HttpClient: def client(self) -> HttpClient:
return self['client'] return self['client']
@ -87,18 +101,17 @@ class Application(web.Application):
return timedelta(seconds=uptime.seconds) return timedelta(seconds=uptime.seconds)
def push_message(self, inbox: str, message: Message) -> None: def push_message(self, inbox: str, message: Message, instance: Row) -> None:
if self.config.workers <= 0: asyncio.ensure_future(self.client.post(inbox, message, instance))
asyncio.ensure_future(self.client.post(inbox, message))
return
worker = self['workers'][self['last_worker']]
worker.queue.put((inbox, message))
self['last_worker'] += 1 def run(self, dev: bool = False) -> None:
self.start(dev)
if self['last_worker'] >= len(self['workers']): while self['proc'] and self['proc'].poll() is None:
self['last_worker'] = 0 time.sleep(0.1)
self.stop()
def set_signal_handler(self, startup: bool) -> None: def set_signal_handler(self, startup: bool) -> None:
@ -111,91 +124,101 @@ class Application(web.Application):
pass pass
def run(self) -> None:
if not check_open_port(self.config.listen, self.config.port): def start(self, dev: bool = False) -> None:
logging.error('A server is already running on port %i', self.config.port) if self['proc']:
return return
for view in VIEWS: if not check_open_port(self.config.listen, self.config.port):
self.router.add_view(*view) logging.error('Server already running on %s:%s', self.config.listen, self.config.port)
return
logging.info( cmd = [
'Starting webserver at %s (%s:%i)', sys.executable, '-m', 'gunicorn',
self.config.domain, 'relay.application:main_gunicorn',
self.config.listen, '--bind', f'{self.config.listen}:{self.config.port}',
self.config.port '--worker-class', 'aiohttp.GunicornWebWorker',
) '--workers', str(self.config.workers),
'--env', f'CONFIG_FILE={self.config.path}'
]
asyncio.run(self.handle_run()) if dev:
cmd.append('--reload')
def stop(self, *_: Any) -> None:
self['running'] = False
async def handle_run(self) -> None:
self['running'] = True
self.set_signal_handler(True) self.set_signal_handler(True)
self['proc'] = subprocess.Popen(cmd) # pylint: disable=consider-using-with
if self.config.workers > 0:
for _ in range(self.config.workers):
worker = PushWorker(self)
worker.start()
self['workers'].append(worker) def stop(self, *_) -> None:
if not self['proc']:
return
runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"') self['proc'].terminate()
await runner.setup() time_wait = 0.0
site = web.TCPSite( while self['proc'].poll() is None:
runner, time.sleep(0.1)
host = self.config.listen, time_wait += 0.1
port = self.config.port,
reuse_address = True if time_wait >= 5.0:
self['proc'].kill()
break
self.set_signal_handler(False)
self['proc'] = None
# not used, but keeping just in case
class GunicornRunner(WSGIApplication):
def __init__(self, app: Application):
self.app = app
self.app_uri = 'relay.application:main_gunicorn'
self.options = {
'bind': f'{app.config.listen}:{app.config.port}',
'worker_class': 'aiohttp.GunicornWebWorker',
'workers': app.config.workers,
'raw_env': f'CONFIG_FILE={app.config.path}'
}
WSGIApplication.__init__(self)
def load_config(self):
for key, value in self.options.items():
self.cfg.set(key, value)
def run(self):
logging.info('Starting webserver for %s', self.app.config.domain)
WSGIApplication.run(self)
async def handle_access_log(request: web.Request, response: web.Response) -> None:
address = request.headers.get(
'X-Forwarded-For',
request.headers.get(
'X-Real-Ip',
request.remote
)
) )
await site.start() logging.info(
self['start_time'] = datetime.now() '%s "%s %s" %i %i "%s"',
address,
while self['running']: request.method,
await asyncio.sleep(0.25) request.path,
response.status,
await site.stop() len(response.body),
await self.client.close() request.headers.get('User-Agent', 'n/a')
)
self['start_time'] = None
self['running'] = False
self['workers'].clear()
class PushWorker(threading.Thread): async def main_gunicorn():
def __init__(self, app: Application):
threading.Thread.__init__(self)
self.app = app
self.queue = queue.Queue()
self.client = None
def run(self) -> None:
asyncio.run(self.handle_queue())
async def handle_queue(self) -> None:
self.client = HttpClient()
while self.app['running']:
try: try:
inbox, message = self.queue.get(block=True, timeout=0.25) app = Application(os.environ['CONFIG_FILE'], gunicorn = True)
self.queue.task_done()
logging.verbose('New push from Thread-%i', threading.get_ident())
await self.client.post(inbox, message)
except queue.Empty: except KeyError:
pass logging.error('Failed to set "CONFIG_FILE" environment. Trying to run without gunicorn?')
raise
## make sure an exception doesn't bring down the worker return app
except Exception:
traceback.print_exc()
await self.client.close()

288
relay/cache.py Normal file
View file

@ -0,0 +1,288 @@
from __future__ import annotations
import json
import os
import typing
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass
from datetime import datetime, timezone
from redis import Redis
from .database import get_database
from .misc import Message, boolean
if typing.TYPE_CHECKING:
from typing import Any
from collections.abc import Callable, Iterator
from tinysql import Database
from .application import Application
# todo: implement more caching backends
BACKENDS: dict[str, Cache] = {}
CONVERTERS: dict[str, tuple[Callable, Callable]] = {
'str': (str, str),
'int': (str, int),
'bool': (str, boolean),
'json': (json.dumps, json.loads),
'message': (lambda x: x.to_json(), Message.parse)
}
def get_cache(app: Application) -> Cache:
return BACKENDS[app.config.ca_type](app)
def register_cache(backend: type[Cache]) -> type[Cache]:
BACKENDS[backend.name] = backend
return backend
def serialize_value(value: Any, value_type: str = 'str') -> str:
if isinstance(value, str):
return value
return CONVERTERS[value_type][0](value)
def deserialize_value(value: str, value_type: str = 'str') -> Any:
return CONVERTERS[value_type][1](value)
@dataclass
class Item:
namespace: str
key: str
value: Any
value_type: str
updated: datetime
def __post_init__(self):
if isinstance(self.updated, str):
self.updated = datetime.fromisoformat(self.updated)
@classmethod
def from_data(cls: type[Item], *args) -> Item:
data = cls(*args)
data.value = deserialize_value(data.value, data.value_type)
if not isinstance(data.updated, datetime):
data.updated = datetime.fromtimestamp(data.updated, tz = timezone.utc)
return data
def older_than(self, hours: int) -> bool:
delta = datetime.now(tz = timezone.utc) - self.updated
return (delta.total_seconds()) > hours * 3600
def to_dict(self) -> dict[str, Any]:
return asdict(self)
class Cache(ABC):
name: str = 'null'
def __init__(self, app: Application):
self.app = app
self.setup()
@abstractmethod
def get(self, namespace: str, key: str) -> Item:
...
@abstractmethod
def get_keys(self, namespace: str) -> Iterator[str]:
...
@abstractmethod
def get_namespaces(self) -> Iterator[str]:
...
@abstractmethod
def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> Item:
...
@abstractmethod
def delete(self, namespace: str, key: str) -> None:
...
@abstractmethod
def setup(self) -> None:
...
def set_item(self, item: Item) -> Item:
return self.set(
item.namespace,
item.key,
item.value,
item.type
)
def delete_item(self, item: Item) -> None:
self.delete(item.namespace, item.key)
@register_cache
class SqlCache(Cache):
name: str = 'database'
def __init__(self, app: Application):
self._db = get_database(app.config)
Cache.__init__(self, app)
def get(self, namespace: str, key: str) -> Item:
params = {
'namespace': namespace,
'key': key
}
with self._db.connection() as conn:
with conn.exec_statement('get-cache-item', params) as cur:
if not (row := cur.one()):
raise KeyError(f'{namespace}:{key}')
row.pop('id', None)
return Item.from_data(*tuple(row.values()))
def get_keys(self, namespace: str) -> Iterator[str]:
with self._db.connection() as conn:
for row in conn.exec_statement('get-cache-keys', {'namespace': namespace}):
yield row['key']
def get_namespaces(self) -> Iterator[str]:
with self._db.connection() as conn:
for row in conn.exec_statement('get-cache-namespaces', None):
yield row['namespace']
def set(self, namespace: str, key: str, value: Any, value_type: str = 'str') -> Item:
params = {
'namespace': namespace,
'key': key,
'value': serialize_value(value, value_type),
'type': value_type,
'date': datetime.now(tz = timezone.utc)
}
with self._db.connection() as conn:
with conn.exec_statement('set-cache-item', params) as conn:
row = conn.one()
row.pop('id', None)
return Item.from_data(*tuple(row.values()))
def delete(self, namespace: str, key: str) -> None:
params = {
'namespace': namespace,
'key': key
}
with self._db.connection() as conn:
with conn.exec_statement('del-cache-item', params):
pass
def setup(self) -> None:
with self._db.connection() as conn:
with conn.exec_statement(f'create-cache-table-{self._db.type.name.lower()}', None):
pass
@register_cache
class RedisCache(Cache):
name: str = 'redis'
_rd: Redis
@property
def prefix(self) -> str:
return self.app.config.rd_prefix
def get_key_name(self, namespace: str, key: str) -> str:
return f'{self.prefix}:{namespace}:{key}'
def get(self, namespace: str, key: str) -> Item:
key_name = self.get_key_name(namespace, key)
if not (raw_value := self._rd.get(key_name)):
raise KeyError(f'{namespace}:{key}')
value_type, updated, value = raw_value.split(':', 2)
return Item.from_data(
namespace,
key,
value,
value_type,
datetime.fromtimestamp(float(updated), tz = timezone.utc)
)
def get_keys(self, namespace: str) -> Iterator[str]:
for key in self._rd.keys(self.get_key_name(namespace, '*')):
*_, key_name = key.split(':', 2)
yield key_name
def get_namespaces(self) -> Iterator[str]:
namespaces = []
for key in self._rd.keys(f'{self.prefix}:*'):
_, namespace, _ = key.split(':', 2)
if namespace not in namespaces:
namespaces.append(namespace)
yield namespace
def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> None:
date = datetime.now(tz = timezone.utc).timestamp()
value = serialize_value(value, value_type)
self._rd.set(
self.get_key_name(namespace, key),
f'{value_type}:{date}:{value}'
)
def delete(self, namespace: str, key: str) -> None:
self._rd.delete(self.get_key_name(namespace, key))
def setup(self) -> None:
options = {
'client_name': f'ActivityRelay_{self.app.config.domain}',
'decode_responses': True,
'username': self.app.config.rd_user,
'password': self.app.config.rd_pass,
'db': self.app.config.rd_database
}
if os.path.exists(self.app.config.rd_host):
options['unix_socket_path'] = self.app.config.rd_host
else:
options['host'] = self.app.config.rd_host
options['port'] = self.app.config.rd_port
self._rd = Redis(**options)

View file

@ -19,12 +19,21 @@ DEFAULTS: dict[str, Any] = {
'domain': 'relay.example.com', 'domain': 'relay.example.com',
'workers': len(os.sched_getaffinity(0)), 'workers': len(os.sched_getaffinity(0)),
'db_type': 'sqlite', 'db_type': 'sqlite',
'ca_type': 'database',
'sq_path': 'relay.sqlite3', 'sq_path': 'relay.sqlite3',
'pg_host': '/var/run/postgresql', 'pg_host': '/var/run/postgresql',
'pg_port': 5432, 'pg_port': 5432,
'pg_user': getpass.getuser(), 'pg_user': getpass.getuser(),
'pg_pass': None, 'pg_pass': None,
'pg_name': 'activityrelay' 'pg_name': 'activityrelay',
'rd_host': 'localhost',
'rd_port': 6379,
'rd_user': None,
'rd_pass': None,
'rd_database': 0,
'rd_prefix': 'activityrelay'
} }
if IS_DOCKER: if IS_DOCKER:
@ -40,13 +49,22 @@ class Config:
self.domain = None self.domain = None
self.workers = None self.workers = None
self.db_type = None self.db_type = None
self.ca_type = None
self.sq_path = None self.sq_path = None
self.pg_host = None self.pg_host = None
self.pg_port = None self.pg_port = None
self.pg_user = None self.pg_user = None
self.pg_pass = None self.pg_pass = None
self.pg_name = None self.pg_name = None
self.rd_host = None
self.rd_port = None
self.rd_user = None
self.rd_pass = None
self.rd_database = None
self.rd_prefix = None
if load: if load:
try: try:
self.load() self.load()
@ -92,6 +110,7 @@ class Config:
with self.path.open('r', encoding = 'UTF-8') as fd: with self.path.open('r', encoding = 'UTF-8') as fd:
config = yaml.load(fd, **options) config = yaml.load(fd, **options)
pgcfg = config.get('postgresql', {}) pgcfg = config.get('postgresql', {})
rdcfg = config.get('redis', {})
if not config: if not config:
raise ValueError('Config is empty') raise ValueError('Config is empty')
@ -106,19 +125,26 @@ class Config:
self.set('port', config.get('port', DEFAULTS['port'])) self.set('port', config.get('port', DEFAULTS['port']))
self.set('sq_path', config.get('sqlite_path', DEFAULTS['sq_path'])) self.set('sq_path', config.get('sqlite_path', DEFAULTS['sq_path']))
self.set('workers', config.get('workers', DEFAULTS['workers']))
self.set('domain', config.get('domain', DEFAULTS['domain'])) self.set('domain', config.get('domain', DEFAULTS['domain']))
self.set('db_type', config.get('database_type', DEFAULTS['db_type'])) self.set('db_type', config.get('database_type', DEFAULTS['db_type']))
self.set('ca_type', config.get('cache_type', DEFAULTS['ca_type']))
for key in DEFAULTS: for key in DEFAULTS:
if not key.startswith('pg'): if key.startswith('pg'):
continue
try: try:
self.set(key, pgcfg[key[3:]]) self.set(key, pgcfg[key[3:]])
except KeyError: except KeyError:
continue continue
elif key.startswith('rd'):
try:
self.set(key, rdcfg[key[3:]])
except KeyError:
continue
def reset(self) -> None: def reset(self) -> None:
for key, value in DEFAULTS.items(): for key, value in DEFAULTS.items():
@ -132,7 +158,9 @@ class Config:
'listen': self.listen, 'listen': self.listen,
'port': self.port, 'port': self.port,
'domain': self.domain, 'domain': self.domain,
'workers': self.workers,
'database_type': self.db_type, 'database_type': self.db_type,
'cache_type': self.ca_type,
'sqlite_path': self.sq_path, 'sqlite_path': self.sq_path,
'postgres': { 'postgres': {
'host': self.pg_host, 'host': self.pg_host,
@ -140,6 +168,14 @@ class Config:
'user': self.pg_user, 'user': self.pg_user,
'pass': self.pg_pass, 'pass': self.pg_pass,
'name': self.pg_name 'name': self.pg_name
},
'redis': {
'host': self.rd_host,
'port': self.rd_port,
'user': self.rd_user,
'pass': self.rd_pass,
'database': self.rd_database,
'refix': self.rd_prefix
} }
} }

View file

@ -77,3 +77,64 @@ RETURNING *
-- name: del-domain-whitelist -- name: del-domain-whitelist
DELETE FROM whitelist DELETE FROM whitelist
WHERE domain = :domain WHERE domain = :domain
-- cache functions --
-- name: create-cache-table-sqlite
CREATE TABLE IF NOT EXISTS cache (
id INTEGER PRIMARY KEY UNIQUE,
namespace TEXT NOT NULL,
key TEXT NOT NULL,
"value" TEXT,
type TEXT DEFAULT 'str',
updated TIMESTAMP NOT NULL,
UNIQUE(namespace, key)
)
-- name: create-cache-table-postgres
CREATE TABLE IF NOT EXISTS cache (
id SERIAL PRIMARY KEY,
namespace TEXT NOT NULL,
key TEXT NOT NULL,
"value" TEXT,
type TEXT DEFAULT 'str',
updated TIMESTAMP NOT NULL,
UNIQUE(namespace, key)
)
-- name: get-cache-item
SELECT * FROM cache
WHERE namespace = :namespace and key = :key
-- name: get-cache-keys
SELECT key FROM cache
WHERE namespace = :namespace
-- name: get-cache-namespaces
SELECT DISTINCT namespace FROM cache
-- name: set-cache-item
INSERT INTO cache (namespace, key, value, type, updated)
VALUES (:namespace, :key, :value, :type, :date)
ON CONFLICT (namespace, key) DO
UPDATE SET value = :value, type = :type, updated = :date
RETURNING *
-- name: del-cache-item
DELETE FROM cache
WHERE namespace = :namespace and key = :key
-- name: del-cache-namespace
DELETE FROM cache
WHERE namespace = :namespace
-- name: del-cache-all
DELETE FROM cache

View file

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import json
import traceback import traceback
import typing import typing
@ -7,7 +8,6 @@ from aiohttp import ClientSession, ClientTimeout, TCPConnector
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
from asyncio.exceptions import TimeoutError as AsyncTimeoutError from asyncio.exceptions import TimeoutError as AsyncTimeoutError
from aputils.objects import Nodeinfo, WellKnownNodeinfo from aputils.objects import Nodeinfo, WellKnownNodeinfo
from cachetools import LRUCache
from json.decoder import JSONDecodeError from json.decoder import JSONDecodeError
from urllib.parse import urlparse from urllib.parse import urlparse
@ -16,7 +16,11 @@ from . import logger as logging
from .misc import MIMETYPES, Message, get_app from .misc import MIMETYPES, Message, get_app
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from aputils import Signer
from tinysql import Row
from typing import Any from typing import Any
from .application import Application
from .cache import Cache
HEADERS = { HEADERS = {
@ -26,8 +30,7 @@ HEADERS = {
class HttpClient: class HttpClient:
def __init__(self, limit: int = 100, timeout: int = 10, cache_size: int = 1024): def __init__(self, limit: int = 100, timeout: int = 10):
self.cache = LRUCache(cache_size)
self.limit = limit self.limit = limit
self.timeout = timeout self.timeout = timeout
self._conn = None self._conn = None
@ -43,6 +46,21 @@ class HttpClient:
await self.close() await self.close()
@property
def app(self) -> Application:
return get_app()
@property
def cache(self) -> Cache:
return self.app.cache
@property
def signer(self) -> Signer:
return self.app.signer
async def open(self) -> None: async def open(self) -> None:
if self._session: if self._session:
return return
@ -74,8 +92,8 @@ class HttpClient:
async def get(self, # pylint: disable=too-many-branches async def get(self, # pylint: disable=too-many-branches
url: str, url: str,
sign_headers: bool = False, sign_headers: bool = False,
loads: callable | None = None, loads: callable = json.loads,
force: bool = False) -> Message | dict | None: force: bool = False) -> dict | None:
await self.open() await self.open()
@ -85,13 +103,20 @@ class HttpClient:
except ValueError: except ValueError:
pass pass
if not force and url in self.cache: if not force:
return self.cache[url] try:
item = self.cache.get('request', url)
if not item.older_than(48):
return loads(item.value)
except KeyError:
logging.verbose('Failed to fetch cached data for url: %s', url)
headers = {} headers = {}
if sign_headers: if sign_headers:
get_app().signer.sign_headers('GET', url, algorithm = 'original') self.signer.sign_headers('GET', url, algorithm = 'original')
try: try:
logging.debug('Fetching resource: %s', url) logging.debug('Fetching resource: %s', url)
@ -101,32 +126,22 @@ class HttpClient:
if resp.status == 202: if resp.status == 202:
return None return None
data = await resp.read()
if resp.status != 200: if resp.status != 200:
logging.verbose('Received error when requesting %s: %i', url, resp.status) logging.verbose('Received error when requesting %s: %i', url, resp.status)
logging.debug(await resp.read()) logging.debug(await resp.read())
return None return None
if loads: message = loads(data)
message = await resp.json(loads=loads) self.cache.set('request', url, data.decode('utf-8'), 'str')
logging.debug('%s >> resp %s', url, json.dumps(message, indent = 4))
elif resp.content_type == MIMETYPES['activity']:
message = await resp.json(loads = Message.parse)
elif resp.content_type == MIMETYPES['json']:
message = await resp.json()
else:
logging.verbose('Invalid Content-Type for "%s": %s', url, resp.content_type)
logging.debug('Response: %s', await resp.read())
return None
logging.debug('%s >> resp %s', url, message.to_json(4))
self.cache[url] = message
return message return message
except JSONDecodeError: except JSONDecodeError:
logging.verbose('Failed to parse JSON') logging.verbose('Failed to parse JSON')
return None
except ClientSSLError: except ClientSSLError:
logging.verbose('SSL error when connecting to %s', urlparse(url).netloc) logging.verbose('SSL error when connecting to %s', urlparse(url).netloc)
@ -140,13 +155,9 @@ class HttpClient:
return None return None
async def post(self, url: str, message: Message) -> None: async def post(self, url: str, message: Message, instance: Row | None = None) -> None:
await self.open() await self.open()
# todo: cache inboxes to avoid opening a db connection
with get_app().database.connection() as conn:
instance = conn.get_inbox(url)
## Using the old algo by default is probably a better idea right now ## Using the old algo by default is probably a better idea right now
# pylint: disable=consider-ternary-expression # pylint: disable=consider-ternary-expression
if instance and instance['software'] in {'mastodon'}: if instance and instance['software'] in {'mastodon'}:

View file

@ -70,7 +70,6 @@ error: Callable = logging.error
critical: Callable = logging.critical critical: Callable = logging.critical
logging.addLevelName(LogLevel['VERBOSE'], 'VERBOSE')
env_log_level = os.environ.get('LOG_LEVEL', 'INFO').upper() env_log_level = os.environ.get('LOG_LEVEL', 'INFO').upper()
try: try:
@ -79,22 +78,15 @@ try:
except KeyError: except KeyError:
env_log_file = None env_log_file = None
try:
log_level = LogLevel[env_log_level]
except KeyError:
print('Invalid log level:', env_log_level)
log_level = LogLevel['INFO']
handlers = [logging.StreamHandler()] handlers = [logging.StreamHandler()]
if env_log_file: if env_log_file:
handlers.append(logging.FileHandler(env_log_file)) handlers.append(logging.FileHandler(env_log_file))
logging.addLevelName(LogLevel['VERBOSE'], 'VERBOSE')
logging.basicConfig( logging.basicConfig(
level = log_level, level = LogLevel.INFO,
format = '[%(asctime)s] %(levelname)s: %(message)s', format = '[%(asctime)s] %(levelname)s: %(message)s',
datefmt = '%Y-%m-%d %H:%M:%S',
handlers = handlers handlers = handlers
) )

View file

@ -18,7 +18,7 @@ from .application import Application
from .compat import RelayConfig, RelayDatabase from .compat import RelayConfig, RelayDatabase
from .database import get_database from .database import get_database
from .database.connection import RELAY_SOFTWARE from .database.connection import RELAY_SOFTWARE
from .misc import IS_DOCKER, Message, check_open_port from .misc import IS_DOCKER, Message
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from tinysql import Row from tinysql import Row
@ -51,6 +51,13 @@ SOFTWARE = (
) )
def check_alphanumeric(text: str) -> str:
if not text.isalnum():
raise click.BadParameter('String not alphanumeric')
return text
@click.group('cli', context_settings={'show_default': True}, invoke_without_command=True) @click.group('cli', context_settings={'show_default': True}, invoke_without_command=True)
@click.option('--config', '-c', default='relay.yaml', help='path to the relay\'s config') @click.option('--config', '-c', default='relay.yaml', help='path to the relay\'s config')
@click.version_option(version=__version__, prog_name='ActivityRelay') @click.version_option(version=__version__, prog_name='ActivityRelay')
@ -63,6 +70,11 @@ def cli(ctx: click.Context, config: str) -> None:
cli_setup.callback() cli_setup.callback()
else: else:
click.echo(
'[DEPRECATED] Running the relay without the "run" command will be removed in the ' +
'future.'
)
cli_run.callback() cli_run.callback()
@ -113,8 +125,9 @@ def cli_setup(ctx: click.Context) -> None:
) )
ctx.obj.config.pg_host = click.prompt( ctx.obj.config.pg_host = click.prompt(
'What IP address or hostname does the server listen on?', 'What IP address, hostname, or unix socket does the server listen on?',
default = ctx.obj.config.pg_host default = ctx.obj.config.pg_host,
type = int
) )
ctx.obj.config.pg_port = click.prompt( ctx.obj.config.pg_port = click.prompt(
@ -135,6 +148,48 @@ def cli_setup(ctx: click.Context) -> None:
default = ctx.obj.config.pg_pass or "" default = ctx.obj.config.pg_pass or ""
) or None ) or None
ctx.obj.config.ca_type = click.prompt(
'Which caching backend?',
default = ctx.obj.config.ca_type,
type = click.Choice(['database', 'redis'], case_sensitive = False)
)
if ctx.obj.config.ca_type == 'redis':
ctx.obj.config.rd_host = click.prompt(
'What IP address, hostname, or unix socket does the server listen on?',
default = ctx.obj.config.rd_host
)
ctx.obj.config.rd_port = click.prompt(
'What port does the server listen on?',
default = ctx.obj.config.rd_port,
type = int
)
ctx.obj.config.rd_user = click.prompt(
'Which user will authenticate with the server',
default = ctx.obj.config.rd_user
)
ctx.obj.config.rd_pass = click.prompt(
'User password',
hide_input = True,
show_default = False,
default = ctx.obj.config.rd_pass or ""
) or None
ctx.obj.config.rd_database = click.prompt(
'Which database number to use?',
default = ctx.obj.config.rd_database,
type = int
)
ctx.obj.config.rd_prefix = click.prompt(
'What text should each cache key be prefixed with?',
default = ctx.obj.config.rd_database,
type = check_alphanumeric
)
ctx.obj.config.save() ctx.obj.config.save()
config = { config = {
@ -150,8 +205,9 @@ def cli_setup(ctx: click.Context) -> None:
@cli.command('run') @cli.command('run')
@click.option('--dev', '-d', is_flag = True, help = 'Enable worker reloading on code change')
@click.pass_context @click.pass_context
def cli_run(ctx: click.Context) -> None: def cli_run(ctx: click.Context, dev: bool = False) -> None:
'Run the relay' 'Run the relay'
if ctx.obj.config.domain.endswith('example.com') or not ctx.obj.signer: if ctx.obj.config.domain.endswith('example.com') or not ctx.obj.signer:
@ -178,11 +234,7 @@ def cli_run(ctx: click.Context) -> None:
click.echo(pip_command) click.echo(pip_command)
return return
if not check_open_port(ctx.obj.config.listen, ctx.obj.config.port): ctx.obj.run(dev)
click.echo(f'Error: A server is already running on port {ctx.obj.config.port}')
return
ctx.obj.run()
@cli.command('convert') @cli.command('convert')
@ -364,7 +416,7 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
actor = actor actor = actor
) )
asyncio.run(http.post(inbox, message)) asyncio.run(http.post(inbox, message, None, inbox_data))
click.echo(f'Sent follow message to actor: {actor}') click.echo(f'Sent follow message to actor: {actor}')
@ -405,7 +457,7 @@ def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
} }
) )
asyncio.run(http.post(inbox, message)) asyncio.run(http.post(inbox, message, inbox_data))
click.echo(f'Sent unfollow message to: {actor}') click.echo(f'Sent unfollow message to: {actor}')

View file

@ -14,9 +14,11 @@ from functools import cached_property
from uuid import uuid4 from uuid import uuid4
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from collections.abc import Coroutine, Generator from collections.abc import Awaitable, Coroutine, Generator
from tinysql import Connection
from typing import Any from typing import Any
from .application import Application from .application import Application
from .cache import Cache
from .config import Config from .config import Config
from .database import Database from .database import Database
from .http_client import HttpClient from .http_client import HttpClient
@ -234,13 +236,21 @@ class Response(AiohttpResponse):
class View(AbstractView): class View(AbstractView):
def __await__(self) -> Generator[Response]: def __await__(self) -> Generator[Response]:
if (self.request.method) not in METHODS: if self.request.method not in METHODS:
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
if not (handler := self.handlers.get(self.request.method)): if not (handler := self.handlers.get(self.request.method)):
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) from None raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) from None
return handler(self.request, **self.request.match_info).__await__() return self._run_handler(handler).__await__()
async def _run_handler(self, handler: Awaitable) -> Response:
with self.database.config.connection_class(self.database) as conn:
# todo: remove on next tinysql release
conn.open()
return await handler(self.request, conn, **self.request.match_info)
@cached_property @cached_property
@ -268,6 +278,11 @@ class View(AbstractView):
return self.request.app return self.request.app
@property
def cache(self) -> Cache:
return self.app.cache
@property @property
def client(self) -> HttpClient: def client(self) -> HttpClient:
return self.app.client return self.app.client

View file

@ -1,20 +1,15 @@
from __future__ import annotations from __future__ import annotations
import tinysql
import typing import typing
from cachetools import LRUCache
from . import logger as logging from . import logger as logging
from .database.connection import Connection
from .misc import Message from .misc import Message
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from .views import ActorView from .views import ActorView
cache = LRUCache(1024)
def person_check(actor: str, software: str) -> bool: def person_check(actor: str, software: str) -> bool:
# pleroma and akkoma may use Person for the actor type for some reason # pleroma and akkoma may use Person for the actor type for some reason
# akkoma changed this in 3.6.0 # akkoma changed this in 3.6.0
@ -28,39 +23,46 @@ def person_check(actor: str, software: str) -> bool:
return False return False
async def handle_relay(view: ActorView) -> None: async def handle_relay(view: ActorView, conn: Connection) -> None:
if view.message.object_id in cache: try:
view.cache.get('handle-relay', view.message.object_id)
logging.verbose('already relayed %s', view.message.object_id) logging.verbose('already relayed %s', view.message.object_id)
return return
except KeyError:
pass
message = Message.new_announce(view.config.domain, view.message.object_id) message = Message.new_announce(view.config.domain, view.message.object_id)
cache[view.message.object_id] = message.id
logging.debug('>> relay: %s', message) logging.debug('>> relay: %s', message)
with view.database.connection() as conn:
for inbox in conn.distill_inboxes(view.message): for inbox in conn.distill_inboxes(view.message):
view.app.push_message(inbox, message) view.app.push_message(inbox, message, view.instance)
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
async def handle_forward(view: ActorView) -> None: async def handle_forward(view: ActorView, conn: Connection) -> None:
if view.message.id in cache: try:
logging.verbose('already forwarded %s', view.message.id) view.cache.get('handle-relay', view.message.object_id)
logging.verbose('already forwarded %s', view.message.object_id)
return return
except KeyError:
pass
message = Message.new_announce(view.config.domain, view.message) message = Message.new_announce(view.config.domain, view.message)
cache[view.message.id] = message.id
logging.debug('>> forward: %s', message) logging.debug('>> forward: %s', message)
with view.database.connection() as conn:
for inbox in conn.distill_inboxes(view.message): for inbox in conn.distill_inboxes(view.message):
view.app.push_message(inbox, message) view.app.push_message(inbox, message, view.instance)
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
async def handle_follow(view: ActorView) -> None: async def handle_follow(view: ActorView, conn: Connection) -> None:
nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain) nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain)
software = nodeinfo.sw_name if nodeinfo else None software = nodeinfo.sw_name if nodeinfo else None
with view.database.connection() as conn:
# reject if software used by actor is banned # reject if software used by actor is banned
if conn.get_software_ban(software): if conn.get_software_ban(software):
view.app.push_message( view.app.push_message(
@ -97,14 +99,11 @@ async def handle_follow(view: ActorView) -> None:
return return
if conn.get_inbox(view.actor.shared_inbox): if conn.get_inbox(view.actor.shared_inbox):
data = {'followid': view.message.id} view.instance = conn.update_inbox(view.actor.shared_inbox, followid = view.message.id)
statement = tinysql.Update('inboxes', data, inbox = view.actor.shared_inbox)
with conn.query(statement):
pass
else: else:
conn.put_inbox( with conn.transaction():
view.instance = conn.put_inbox(
view.actor.domain, view.actor.domain,
view.actor.shared_inbox, view.actor.shared_inbox,
view.actor.id, view.actor.id,
@ -119,7 +118,8 @@ async def handle_follow(view: ActorView) -> None:
actor = view.actor.id, actor = view.actor.id,
followid = view.message.id, followid = view.message.id,
accept = True accept = True
) ),
view.instance
) )
# Are Akkoma and Pleroma the only two that expect a follow back? # Are Akkoma and Pleroma the only two that expect a follow back?
@ -130,17 +130,18 @@ async def handle_follow(view: ActorView) -> None:
Message.new_follow( Message.new_follow(
host = view.config.domain, host = view.config.domain,
actor = view.actor.id actor = view.actor.id
) ),
view.instance
) )
async def handle_undo(view: ActorView) -> None: async def handle_undo(view: ActorView, conn: Connection) -> None:
## If the object is not a Follow, forward it ## If the object is not a Follow, forward it
if view.message.object['type'] != 'Follow': if view.message.object['type'] != 'Follow':
await handle_forward(view) await handle_forward(view, conn)
return return
with view.database.connection() as conn: with conn.transaction():
if not conn.del_inbox(view.actor.id): if not conn.del_inbox(view.actor.id):
logging.verbose( logging.verbose(
'Failed to delete "%s" with follow ID "%s"', 'Failed to delete "%s" with follow ID "%s"',
@ -154,7 +155,8 @@ async def handle_undo(view: ActorView) -> None:
host = view.config.domain, host = view.config.domain,
actor = view.actor.id, actor = view.actor.id,
follow = view.message follow = view.message
) ),
view.instance
) )
@ -168,7 +170,7 @@ processors = {
} }
async def run_processor(view: ActorView) -> None: async def run_processor(view: ActorView, conn: Connection) -> None:
if view.message.type not in processors: if view.message.type not in processors:
logging.verbose( logging.verbose(
'Message type "%s" from actor cannot be handled: %s', 'Message type "%s" from actor cannot be handled: %s',
@ -179,20 +181,19 @@ async def run_processor(view: ActorView) -> None:
return return
if view.instance: if view.instance:
with conn.transaction():
if not view.instance['software']: if not view.instance['software']:
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])): if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])):
with view.database.connection() as conn:
view.instance = conn.update_inbox( view.instance = conn.update_inbox(
view.instance['inbox'], view.instance['inbox'],
software = nodeinfo.sw_name software = nodeinfo.sw_name
) )
if not view.instance['actor']: if not view.instance['actor']:
with view.database.connection() as conn:
view.instance = conn.update_inbox( view.instance = conn.update_inbox(
view.instance['inbox'], view.instance['inbox'],
actor = view.actor.id actor = view.actor.id
) )
logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id) logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id)
await processors[view.message.type](view) await processors[view.message.type](view, conn)

View file

@ -1,6 +1,5 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import subprocess import subprocess
import traceback import traceback
import typing import typing
@ -12,6 +11,7 @@ from pathlib import Path
from . import __version__ from . import __version__
from . import logger as logging from . import logger as logging
from .database.connection import Connection
from .misc import Message, Response, View from .misc import Message, Response, View
from .processors import run_processor from .processors import run_processor
@ -19,6 +19,7 @@ if typing.TYPE_CHECKING:
from aiohttp.web import Request from aiohttp.web import Request
from aputils.signer import Signer from aputils.signer import Signer
from collections.abc import Callable from collections.abc import Callable
from tinysql import Row
VIEWS = [] VIEWS = []
@ -74,8 +75,7 @@ def register_route(*paths: str) -> Callable:
@register_route('/') @register_route('/')
class HomeView(View): class HomeView(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request, conn: Connection) -> Response:
with self.database.connection() as conn:
config = conn.get_config_all() config = conn.get_config_all()
inboxes = conn.execute('SELECT * FROM inboxes').all() inboxes = conn.execute('SELECT * FROM inboxes').all()
@ -98,11 +98,11 @@ class ActorView(View):
self.signature: Signature = None self.signature: Signature = None
self.message: Message = None self.message: Message = None
self.actor: Message = None self.actor: Message = None
self.instance: dict[str, str] = None self.instance: Row = None
self.signer: Signer = None self.signer: Signer = None
async def get(self, request: Request) -> Response: async def get(self, request: Request, conn: Connection) -> Response:
data = Message.new_actor( data = Message.new_actor(
host = self.config.domain, host = self.config.domain,
pubkey = self.app.signer.pubkey pubkey = self.app.signer.pubkey
@ -111,11 +111,10 @@ class ActorView(View):
return Response.new(data, ctype='activity') return Response.new(data, ctype='activity')
async def post(self, request: Request) -> Response: async def post(self, request: Request, conn: Connection) -> Response:
if response := await self.get_post_data(): if response := await self.get_post_data():
return response return response
with self.database.connection() as conn:
self.instance = conn.get_inbox(self.actor.shared_inbox) self.instance = conn.get_inbox(self.actor.shared_inbox)
config = conn.get_config_all() config = conn.get_config_all()
@ -140,7 +139,7 @@ class ActorView(View):
logging.debug('>> payload %s', self.message.to_json(4)) logging.debug('>> payload %s', self.message.to_json(4))
asyncio.ensure_future(run_processor(self)) await run_processor(self, conn)
return Response.new(status = 202) return Response.new(status = 202)
@ -168,7 +167,11 @@ class ActorView(View):
logging.verbose('actor not in message') logging.verbose('actor not in message')
return Response.new_error(400, 'no actor in message', 'json') return Response.new_error(400, 'no actor in message', 'json')
self.actor = await self.client.get(self.signature.keyid, sign_headers = True) self.actor = await self.client.get(
self.signature.keyid,
sign_headers = True,
loads = Message.parse
)
if not self.actor: if not self.actor:
# ld signatures aren't handled atm, so just ignore it # ld signatures aren't handled atm, so just ignore it
@ -227,7 +230,7 @@ class ActorView(View):
@register_route('/.well-known/webfinger') @register_route('/.well-known/webfinger')
class WebfingerView(View): class WebfingerView(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request, conn: Connection) -> Response:
try: try:
subject = request.query['resource'] subject = request.query['resource']
@ -248,8 +251,8 @@ class WebfingerView(View):
@register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}') @register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}')
class NodeinfoView(View): class NodeinfoView(View):
async def get(self, request: Request, niversion: str) -> Response: # pylint: disable=no-self-use
with self.database.connection() as conn: async def get(self, request: Request, conn: Connection, niversion: str) -> Response:
inboxes = conn.execute('SELECT * FROM inboxes').all() inboxes = conn.execute('SELECT * FROM inboxes').all()
data = { data = {
@ -269,6 +272,6 @@ class NodeinfoView(View):
@register_route('/.well-known/nodeinfo') @register_route('/.well-known/nodeinfo')
class WellknownNodeinfoView(View): class WellknownNodeinfoView(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request, conn: Connection) -> Response:
data = WellKnownNodeinfo.new_template(self.config.domain) data = WellKnownNodeinfo.new_template(self.config.domain)
return Response.new(data, ctype = 'json') return Response.new(data, ctype = 'json')

View file

@ -1,8 +1,10 @@
aiohttp>=3.9.1 aiohttp>=3.9.1
aputils@https://git.barkshark.xyz/barkshark/aputils/archive/0.1.6a.tar.gz aputils@https://git.barkshark.xyz/barkshark/aputils/archive/0.1.6a.tar.gz
cachetools>=5.2.0
click>=8.1.2 click>=8.1.2
gunicorn==21.1.0
hiredis==2.3.2
pyyaml>=6.0 pyyaml>=6.0
tinysql[postgres]@https://git.barkshark.xyz/barkshark/tinysql/archive/0.2.3.tar.gz redis==5.0.1
tinysql[postgres]@https://git.barkshark.xyz/barkshark/tinysql/archive/0.2.4.tar.gz
importlib_resources==6.1.1;python_version<'3.9' importlib_resources==6.1.1;python_version<'3.9'