mirror of
https://git.pleroma.social/pleroma/relay.git
synced 2024-12-23 11:31:07 +00:00
add database and redis caching
This commit is contained in:
parent
f2baf7f9f9
commit
2d641ea183
|
@ -35,6 +35,13 @@ SQL database backend to use. Valid values are `sqlite` or `postgres`.
|
|||
database_type: sqlite
|
||||
|
||||
|
||||
### Cache type
|
||||
|
||||
Cache backend to use. Valid values are `database` or `redis`
|
||||
|
||||
cache_type: database
|
||||
|
||||
|
||||
### Sqlite File Path
|
||||
|
||||
Path to the sqlite database file. If the path is not absolute, it is relative to the config file.
|
||||
|
@ -47,7 +54,7 @@ directory.
|
|||
|
||||
In order to use the Postgresql backend, the user and database need to be created first.
|
||||
|
||||
sudo -u postgres psql -c "CREATE USER activityrelay"
|
||||
sudo -u postgres psql -c "CREATE USER activityrelay WITH PASSWORD SomeSecurePassword"
|
||||
sudo -u postgres psql -c "CREATE DATABASE activityrelay OWNER activityrelay"
|
||||
|
||||
|
||||
|
@ -84,3 +91,47 @@ User to use when logging into the server.
|
|||
Password for the specified user.
|
||||
|
||||
pass: null
|
||||
|
||||
|
||||
## Redis
|
||||
|
||||
### Host
|
||||
|
||||
Hostname, IP address, or unix socket the server is hosted on.
|
||||
|
||||
host: /var/run/postgresql
|
||||
|
||||
|
||||
### Port
|
||||
|
||||
Port number the server is listening on.
|
||||
|
||||
port: 5432
|
||||
|
||||
|
||||
### Username
|
||||
|
||||
User to use when logging into the server.
|
||||
|
||||
user: null
|
||||
|
||||
|
||||
### Password
|
||||
|
||||
Password for the specified user.
|
||||
|
||||
pass: null
|
||||
|
||||
|
||||
### Database Number
|
||||
|
||||
Number of the database to use.
|
||||
|
||||
database: 0
|
||||
|
||||
|
||||
### Prefix
|
||||
|
||||
Text to prefix every key with. It cannot contain a `:` character.
|
||||
|
||||
prefix: activityrelay
|
||||
|
|
|
@ -7,12 +7,15 @@ listen: 0.0.0.0
|
|||
# [integer] Port the relay will listen on
|
||||
port: 8080
|
||||
|
||||
# [integer] Number of push workers to start (will get removed in a future update)
|
||||
# [integer] Number of push workers to start
|
||||
workers: 8
|
||||
|
||||
# [string] Database backend to use. Valid values: sqlite, postgres
|
||||
database_type: sqlite
|
||||
|
||||
# [string] Cache backend to use. Valid values: database, redis
|
||||
cache_type: database
|
||||
|
||||
# [string] Path to the sqlite database file if the sqlite backend is in use
|
||||
sqlite_path: relay.sqlite3
|
||||
|
||||
|
@ -33,3 +36,24 @@ postgres:
|
|||
|
||||
# [string] name of the database to use
|
||||
name: activityrelay
|
||||
|
||||
# settings for the redis caching backend
|
||||
redis:
|
||||
|
||||
# [string] hostname or unix socket to connect to
|
||||
host: localhost
|
||||
|
||||
# [integer] port of the server
|
||||
port: 6379
|
||||
|
||||
# [string] username to use when logging into the server
|
||||
user: null
|
||||
|
||||
# [string] password for the server
|
||||
pass: null
|
||||
|
||||
# [integer] database number to use
|
||||
database: 0
|
||||
|
||||
# [string] prefix for keys
|
||||
prefix: activityrelay
|
||||
|
|
|
@ -12,6 +12,7 @@ from aputils.signer import Signer
|
|||
from datetime import datetime, timedelta
|
||||
|
||||
from . import logger as logging
|
||||
from .cache import get_cache
|
||||
from .config import Config
|
||||
from .database import get_database
|
||||
from .http_client import HttpClient
|
||||
|
@ -19,8 +20,9 @@ from .misc import check_open_port
|
|||
from .views import VIEWS
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from tinysql import Database
|
||||
from tinysql import Database, Row
|
||||
from typing import Any
|
||||
from .cache import Cache
|
||||
from .misc import Message
|
||||
|
||||
|
||||
|
@ -38,6 +40,7 @@ class Application(web.Application):
|
|||
self['config'] = Config(cfgpath, load = True)
|
||||
self['database'] = get_database(self.config)
|
||||
self['client'] = HttpClient()
|
||||
self['cache'] = get_cache(self)
|
||||
|
||||
self['workers'] = []
|
||||
self['last_worker'] = 0
|
||||
|
@ -48,6 +51,11 @@ class Application(web.Application):
|
|||
self.router.add_view(path, view)
|
||||
|
||||
|
||||
@property
|
||||
def cache(self) -> Cache:
|
||||
return self['cache']
|
||||
|
||||
|
||||
@property
|
||||
def client(self) -> HttpClient:
|
||||
return self['client']
|
||||
|
@ -87,13 +95,13 @@ class Application(web.Application):
|
|||
return timedelta(seconds=uptime.seconds)
|
||||
|
||||
|
||||
def push_message(self, inbox: str, message: Message) -> None:
|
||||
def push_message(self, inbox: str, message: Message, instance: Row) -> None:
|
||||
if self.config.workers <= 0:
|
||||
asyncio.ensure_future(self.client.post(inbox, message))
|
||||
asyncio.ensure_future(self.client.post(inbox, message, instance))
|
||||
return
|
||||
|
||||
worker = self['workers'][self['last_worker']]
|
||||
worker.queue.put((inbox, message))
|
||||
worker.queue.put((inbox, message, instance))
|
||||
|
||||
self['last_worker'] += 1
|
||||
|
||||
|
@ -186,10 +194,10 @@ class PushWorker(threading.Thread):
|
|||
|
||||
while self.app['running']:
|
||||
try:
|
||||
inbox, message = self.queue.get(block=True, timeout=0.25)
|
||||
inbox, message, instance = self.queue.get(block=True, timeout=0.25)
|
||||
self.queue.task_done()
|
||||
logging.verbose('New push from Thread-%i', threading.get_ident())
|
||||
await self.client.post(inbox, message)
|
||||
await self.client.post(inbox, message, instance)
|
||||
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
|
280
relay/cache.py
Normal file
280
relay/cache.py
Normal file
|
@ -0,0 +1,280 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import typing
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime, timezone
|
||||
from redis import Redis
|
||||
|
||||
from .misc import Message, boolean
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from typing import Any
|
||||
from collections.abc import Callable, Iterator
|
||||
from tinysql import Database
|
||||
from .application import Application
|
||||
|
||||
|
||||
# todo: implement more caching backends
|
||||
|
||||
|
||||
BACKENDS: dict[str, Cache] = {}
|
||||
CONVERTERS: dict[str, tuple[Callable, Callable]] = {
|
||||
'str': (str, str),
|
||||
'int': (str, int),
|
||||
'bool': (str, boolean),
|
||||
'json': (json.dumps, json.loads),
|
||||
'message': (lambda x: x.to_json(), Message.parse)
|
||||
}
|
||||
|
||||
|
||||
def get_cache(app: Application) -> Cache:
|
||||
return BACKENDS[app.config.ca_type](app)
|
||||
|
||||
|
||||
def register_cache(backend: type[Cache]) -> type[Cache]:
|
||||
BACKENDS[backend.name] = backend
|
||||
return backend
|
||||
|
||||
|
||||
def serialize_value(value: Any, value_type: str = 'str') -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
|
||||
return CONVERTERS[value_type][0](value)
|
||||
|
||||
|
||||
def deserialize_value(value: str, value_type: str = 'str') -> Any:
|
||||
return CONVERTERS[value_type][1](value)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Item:
|
||||
namespace: str
|
||||
key: str
|
||||
value: Any
|
||||
value_type: str
|
||||
updated: datetime
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.updated, str):
|
||||
self.updated = datetime.fromisoformat(self.updated)
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_data(cls: type[Item], *args) -> Item:
|
||||
data = cls(*args)
|
||||
data.value = deserialize_value(data.value, data.value_type)
|
||||
return data
|
||||
|
||||
|
||||
def older_than(self, hours: int) -> bool:
|
||||
delta = datetime.now(tz = timezone.utc) - self.updated
|
||||
return (delta.total_seconds()) > hours * 3600
|
||||
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
class Cache(ABC):
|
||||
name: str = 'null'
|
||||
|
||||
|
||||
def __init__(self, app: Application):
|
||||
self.app = app
|
||||
self.setup()
|
||||
|
||||
@abstractmethod
|
||||
def get(self, namespace: str, key: str) -> Item:
|
||||
...
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_keys(self, namespace: str) -> Iterator[str]:
|
||||
...
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_namespaces(self) -> Iterator[str]:
|
||||
...
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> Item:
|
||||
...
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, namespace: str, key: str) -> None:
|
||||
...
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def setup(self) -> None:
|
||||
...
|
||||
|
||||
|
||||
def set_item(self, item: Item) -> Item:
|
||||
return self.set(
|
||||
item.namespace,
|
||||
item.key,
|
||||
item.value,
|
||||
item.type
|
||||
)
|
||||
|
||||
|
||||
def delete_item(self, item: Item) -> None:
|
||||
self.delete(item.namespace, item.key)
|
||||
|
||||
|
||||
@register_cache
|
||||
class SqlCache(Cache):
|
||||
name: str = 'database'
|
||||
|
||||
|
||||
@property
|
||||
def _db(self) -> Database:
|
||||
return self.app.database
|
||||
|
||||
|
||||
def get(self, namespace: str, key: str) -> Item:
|
||||
params = {
|
||||
'namespace': namespace,
|
||||
'key': key
|
||||
}
|
||||
|
||||
with self._db.connection() as conn:
|
||||
with conn.exec_statement('get-cache-item', params) as cur:
|
||||
if not (row := cur.one()):
|
||||
raise KeyError(f'{namespace}:{key}')
|
||||
|
||||
return Item.from_data(*tuple(row.values()))
|
||||
|
||||
|
||||
def get_keys(self, namespace: str) -> Iterator[str]:
|
||||
with self._db.connection() as conn:
|
||||
for row in conn.exec_statement('get-cache-keys', {'namespace': namespace}):
|
||||
yield row['key']
|
||||
|
||||
|
||||
def get_namespaces(self) -> Iterator[str]:
|
||||
with self._db.connection() as conn:
|
||||
for row in conn.exec_statement('get-cache-namespaces', None):
|
||||
yield row['namespace']
|
||||
|
||||
|
||||
def set(self, namespace: str, key: str, value: Any, value_type: str = 'str') -> Item:
|
||||
params = {
|
||||
'namespace': namespace,
|
||||
'key': key,
|
||||
'value': serialize_value(value, value_type),
|
||||
'type': value_type,
|
||||
'date': datetime.now(tz = timezone.utc)
|
||||
}
|
||||
|
||||
with self._db.connection() as conn:
|
||||
for row in conn.exec_statement('set-cache-item', params):
|
||||
return Item.from_data(*tuple(row.values()))
|
||||
|
||||
|
||||
def delete(self, namespace: str, key: str) -> None:
|
||||
params = {
|
||||
'namespace': namespace,
|
||||
'key': key
|
||||
}
|
||||
|
||||
with self._db.connection() as conn:
|
||||
with conn.exec_statement('del-cache-item', params):
|
||||
pass
|
||||
|
||||
|
||||
def setup(self) -> None:
|
||||
with self._db.connection() as conn:
|
||||
with conn.exec_statement(f'create-cache-table-{self._db.type.name.lower()}', None):
|
||||
pass
|
||||
|
||||
|
||||
@register_cache
|
||||
class RedisCache(Cache):
|
||||
name: str = 'redis'
|
||||
_rd: Redis
|
||||
|
||||
|
||||
@property
|
||||
def prefix(self) -> str:
|
||||
return self.app.config.rd_prefix
|
||||
|
||||
|
||||
def get_key_name(self, namespace: str, key: str) -> str:
|
||||
return f'{self.prefix}:{namespace}:{key}'
|
||||
|
||||
|
||||
def get(self, namespace: str, key: str) -> Item:
|
||||
key_name = self.get_key_name(namespace, key)
|
||||
|
||||
if not (raw_value := self._rd.get(key_name)):
|
||||
raise KeyError(f'{namespace}:{key}')
|
||||
|
||||
value_type, updated, value = raw_value.split(':', 2)
|
||||
return Item.from_data(
|
||||
namespace,
|
||||
key,
|
||||
value,
|
||||
value_type,
|
||||
datetime.fromtimestamp(float(updated), tz = timezone.utc)
|
||||
)
|
||||
|
||||
|
||||
def get_keys(self, namespace: str) -> Iterator[str]:
|
||||
for key in self._rd.keys(self.get_key_name(namespace, '*')):
|
||||
*_, key_name = key.split(':', 2)
|
||||
yield key_name
|
||||
|
||||
|
||||
def get_namespaces(self) -> Iterator[str]:
|
||||
namespaces = []
|
||||
|
||||
for key in self._rd.keys(f'{self.prefix}:*'):
|
||||
_, namespace, _ = key.split(':', 2)
|
||||
|
||||
if namespace not in namespaces:
|
||||
namespaces.append(namespace)
|
||||
yield namespace
|
||||
|
||||
|
||||
def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> None:
|
||||
date = datetime.now(tz = timezone.utc).timestamp()
|
||||
value = serialize_value(value, value_type)
|
||||
|
||||
self._rd.set(
|
||||
self.get_key_name(namespace, key),
|
||||
f'{value_type}:{date}:{value}'
|
||||
)
|
||||
|
||||
|
||||
def delete(self, namespace: str, key: str) -> None:
|
||||
self._rd.delete(self.get_key_name(namespace, key))
|
||||
|
||||
|
||||
def setup(self) -> None:
|
||||
options = {
|
||||
'client_name': f'ActivityRelay_{self.app.config.domain}',
|
||||
'decode_responses': True,
|
||||
'username': self.app.config.rd_user,
|
||||
'password': self.app.config.rd_pass,
|
||||
'db': self.app.config.rd_database
|
||||
}
|
||||
|
||||
if os.path.exists(self.app.config.rd_host):
|
||||
options['unix_socket_path'] = self.app.config.rd_host
|
||||
|
||||
else:
|
||||
options['host'] = self.app.config.rd_host
|
||||
options['port'] = self.app.config.rd_port
|
||||
|
||||
self._rd = Redis(**options)
|
|
@ -19,12 +19,21 @@ DEFAULTS: dict[str, Any] = {
|
|||
'domain': 'relay.example.com',
|
||||
'workers': len(os.sched_getaffinity(0)),
|
||||
'db_type': 'sqlite',
|
||||
'ca_type': 'database',
|
||||
'sq_path': 'relay.sqlite3',
|
||||
|
||||
'pg_host': '/var/run/postgresql',
|
||||
'pg_port': 5432,
|
||||
'pg_user': getpass.getuser(),
|
||||
'pg_pass': None,
|
||||
'pg_name': 'activityrelay'
|
||||
'pg_name': 'activityrelay',
|
||||
|
||||
'rd_host': 'localhost',
|
||||
'rd_port': 6379,
|
||||
'rd_user': None,
|
||||
'rd_pass': None,
|
||||
'rd_database': 0,
|
||||
'rd_prefix': 'activityrelay'
|
||||
}
|
||||
|
||||
if IS_DOCKER:
|
||||
|
@ -40,13 +49,22 @@ class Config:
|
|||
self.domain = None
|
||||
self.workers = None
|
||||
self.db_type = None
|
||||
self.ca_type = None
|
||||
self.sq_path = None
|
||||
|
||||
self.pg_host = None
|
||||
self.pg_port = None
|
||||
self.pg_user = None
|
||||
self.pg_pass = None
|
||||
self.pg_name = None
|
||||
|
||||
self.rd_host = None
|
||||
self.rd_port = None
|
||||
self.rd_user = None
|
||||
self.rd_pass = None
|
||||
self.rd_database = None
|
||||
self.rd_prefix = None
|
||||
|
||||
if load:
|
||||
try:
|
||||
self.load()
|
||||
|
@ -92,6 +110,7 @@ class Config:
|
|||
with self.path.open('r', encoding = 'UTF-8') as fd:
|
||||
config = yaml.load(fd, **options)
|
||||
pgcfg = config.get('postgresql', {})
|
||||
rdcfg = config.get('redis', {})
|
||||
|
||||
if not config:
|
||||
raise ValueError('Config is empty')
|
||||
|
@ -106,18 +125,25 @@ class Config:
|
|||
self.set('port', config.get('port', DEFAULTS['port']))
|
||||
self.set('sq_path', config.get('sqlite_path', DEFAULTS['sq_path']))
|
||||
|
||||
self.set('workers', config.get('workers', DEFAULTS['workers']))
|
||||
self.set('domain', config.get('domain', DEFAULTS['domain']))
|
||||
self.set('db_type', config.get('database_type', DEFAULTS['db_type']))
|
||||
self.set('ca_type', config.get('cache_type', DEFAULTS['ca_type']))
|
||||
|
||||
for key in DEFAULTS:
|
||||
if not key.startswith('pg'):
|
||||
continue
|
||||
if key.startswith('pg'):
|
||||
try:
|
||||
self.set(key, pgcfg[key[3:]])
|
||||
|
||||
try:
|
||||
self.set(key, pgcfg[key[3:]])
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
except KeyError:
|
||||
continue
|
||||
elif key.startswith('rd'):
|
||||
try:
|
||||
self.set(key, rdcfg[key[3:]])
|
||||
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
|
||||
def reset(self) -> None:
|
||||
|
@ -132,7 +158,9 @@ class Config:
|
|||
'listen': self.listen,
|
||||
'port': self.port,
|
||||
'domain': self.domain,
|
||||
'workers': self.workers,
|
||||
'database_type': self.db_type,
|
||||
'cache_type': self.ca_type,
|
||||
'sqlite_path': self.sq_path,
|
||||
'postgres': {
|
||||
'host': self.pg_host,
|
||||
|
@ -140,6 +168,14 @@ class Config:
|
|||
'user': self.pg_user,
|
||||
'pass': self.pg_pass,
|
||||
'name': self.pg_name
|
||||
},
|
||||
'redis': {
|
||||
'host': self.rd_host,
|
||||
'port': self.rd_port,
|
||||
'user': self.rd_user,
|
||||
'pass': self.rd_pass,
|
||||
'database': self.rd_database,
|
||||
'refix': self.rd_prefix
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -77,3 +77,64 @@ RETURNING *
|
|||
-- name: del-domain-whitelist
|
||||
DELETE FROM whitelist
|
||||
WHERE domain = :domain
|
||||
|
||||
|
||||
-- cache functions --
|
||||
|
||||
-- name: create-cache-table-sqlite
|
||||
CREATE TABLE IF NOT EXISTS cache (
|
||||
id INTEGER PRIMARY KEY UNIQUE,
|
||||
namespace TEXT NOT NULL,
|
||||
key TEXT NOT NULL,
|
||||
"value" TEXT,
|
||||
type TEXT DEFAULT 'str',
|
||||
updated TIMESTAMP NOT NULL,
|
||||
UNIQUE(namespace, key)
|
||||
)
|
||||
|
||||
-- name: create-cache-table-postgres
|
||||
CREATE TABLE IF NOT EXISTS cache (
|
||||
id SERIAL PRIMARY KEY,
|
||||
namespace TEXT NOT NULL,
|
||||
key TEXT NOT NULL,
|
||||
"value" TEXT,
|
||||
type TEXT DEFAULT 'str',
|
||||
updated TIMESTAMP NOT NULL,
|
||||
UNIQUE(namespace, key)
|
||||
)
|
||||
|
||||
|
||||
-- name: get-cache-item
|
||||
SELECT * FROM cache
|
||||
WHERE namespace = :namespace and key = :key
|
||||
|
||||
|
||||
-- name: get-cache-keys
|
||||
SELECT key FROM cache
|
||||
WHERE namespace = :namespace
|
||||
|
||||
|
||||
-- name: get-cache-namespaces
|
||||
SELECT DISTINCT namespace FROM cache
|
||||
|
||||
|
||||
-- name: set-cache-item
|
||||
INSERT INTO cache (namespace, key, value, type, updated)
|
||||
VALUES (:namespace, :key, :value, :type, :date)
|
||||
ON CONFLICT (namespace, key) DO
|
||||
UPDATE SET value = :value, type = :type, updated = :date
|
||||
RETURNING *
|
||||
|
||||
|
||||
-- name: del-cache-item
|
||||
DELETE FROM cache
|
||||
WHERE namespace = :namespace and key = :key
|
||||
|
||||
|
||||
-- name: del-cache-namespace
|
||||
DELETE FROM cache
|
||||
WHERE namespace = :namespace
|
||||
|
||||
|
||||
-- name: del-cache-all
|
||||
DELETE FROM cache
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import traceback
|
||||
import typing
|
||||
|
||||
|
@ -7,7 +8,6 @@ from aiohttp import ClientSession, ClientTimeout, TCPConnector
|
|||
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
|
||||
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
|
||||
from aputils.objects import Nodeinfo, WellKnownNodeinfo
|
||||
from cachetools import LRUCache
|
||||
from json.decoder import JSONDecodeError
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
@ -16,7 +16,11 @@ from . import logger as logging
|
|||
from .misc import MIMETYPES, Message, get_app
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from aputils import Signer
|
||||
from tinysql import Row
|
||||
from typing import Any
|
||||
from .application import Application
|
||||
from .cache import Cache
|
||||
|
||||
|
||||
HEADERS = {
|
||||
|
@ -26,8 +30,7 @@ HEADERS = {
|
|||
|
||||
|
||||
class HttpClient:
|
||||
def __init__(self, limit: int = 100, timeout: int = 10, cache_size: int = 1024):
|
||||
self.cache = LRUCache(cache_size)
|
||||
def __init__(self, limit: int = 100, timeout: int = 10):
|
||||
self.limit = limit
|
||||
self.timeout = timeout
|
||||
self._conn = None
|
||||
|
@ -43,6 +46,21 @@ class HttpClient:
|
|||
await self.close()
|
||||
|
||||
|
||||
@property
|
||||
def app(self) -> Application:
|
||||
return get_app()
|
||||
|
||||
|
||||
@property
|
||||
def cache(self) -> Cache:
|
||||
return self.app.cache
|
||||
|
||||
|
||||
@property
|
||||
def signer(self) -> Signer:
|
||||
return self.app.signer
|
||||
|
||||
|
||||
async def open(self) -> None:
|
||||
if self._session:
|
||||
return
|
||||
|
@ -74,8 +92,8 @@ class HttpClient:
|
|||
async def get(self, # pylint: disable=too-many-branches
|
||||
url: str,
|
||||
sign_headers: bool = False,
|
||||
loads: callable | None = None,
|
||||
force: bool = False) -> Message | dict | None:
|
||||
loads: callable = json.loads,
|
||||
force: bool = False) -> dict | None:
|
||||
|
||||
await self.open()
|
||||
|
||||
|
@ -85,13 +103,20 @@ class HttpClient:
|
|||
except ValueError:
|
||||
pass
|
||||
|
||||
if not force and url in self.cache:
|
||||
return self.cache[url]
|
||||
if not force:
|
||||
try:
|
||||
item = self.cache.get('request', url)
|
||||
|
||||
if not item.older_than(48):
|
||||
return loads(item.value)
|
||||
|
||||
except KeyError:
|
||||
logging.verbose('Failed to fetch cached data for url: %s', url)
|
||||
|
||||
headers = {}
|
||||
|
||||
if sign_headers:
|
||||
get_app().signer.sign_headers('GET', url, algorithm = 'original')
|
||||
self.signer.sign_headers('GET', url, algorithm = 'original')
|
||||
|
||||
try:
|
||||
logging.debug('Fetching resource: %s', url)
|
||||
|
@ -101,32 +126,22 @@ class HttpClient:
|
|||
if resp.status == 202:
|
||||
return None
|
||||
|
||||
if resp.status != 200:
|
||||
logging.verbose('Received error when requesting %s: %i', url, resp.status)
|
||||
logging.debug(await resp.read())
|
||||
return None
|
||||
data = await resp.read()
|
||||
|
||||
if loads:
|
||||
message = await resp.json(loads=loads)
|
||||
if resp.status != 200:
|
||||
logging.verbose('Received error when requesting %s: %i', url, resp.status)
|
||||
logging.debug(await resp.read())
|
||||
return None
|
||||
|
||||
elif resp.content_type == MIMETYPES['activity']:
|
||||
message = await resp.json(loads = Message.parse)
|
||||
message = loads(data)
|
||||
self.cache.set('request', url, data.decode('utf-8'), 'str')
|
||||
logging.debug('%s >> resp %s', url, json.dumps(message, indent = 4))
|
||||
|
||||
elif resp.content_type == MIMETYPES['json']:
|
||||
message = await resp.json()
|
||||
|
||||
else:
|
||||
logging.verbose('Invalid Content-Type for "%s": %s', url, resp.content_type)
|
||||
logging.debug('Response: %s', await resp.read())
|
||||
return None
|
||||
|
||||
logging.debug('%s >> resp %s', url, message.to_json(4))
|
||||
|
||||
self.cache[url] = message
|
||||
return message
|
||||
return message
|
||||
|
||||
except JSONDecodeError:
|
||||
logging.verbose('Failed to parse JSON')
|
||||
return None
|
||||
|
||||
except ClientSSLError:
|
||||
logging.verbose('SSL error when connecting to %s', urlparse(url).netloc)
|
||||
|
@ -140,13 +155,9 @@ class HttpClient:
|
|||
return None
|
||||
|
||||
|
||||
async def post(self, url: str, message: Message) -> None:
|
||||
async def post(self, url: str, message: Message, instance: Row | None = None) -> None:
|
||||
await self.open()
|
||||
|
||||
# todo: cache inboxes to avoid opening a db connection
|
||||
with get_app().database.connection() as conn:
|
||||
instance = conn.get_inbox(url)
|
||||
|
||||
## Using the old algo by default is probably a better idea right now
|
||||
# pylint: disable=consider-ternary-expression
|
||||
if instance and instance['software'] in {'mastodon'}:
|
||||
|
|
|
@ -17,6 +17,7 @@ if typing.TYPE_CHECKING:
|
|||
from collections.abc import Coroutine, Generator
|
||||
from typing import Any
|
||||
from .application import Application
|
||||
from .cache import Cache
|
||||
from .config import Config
|
||||
from .database import Database
|
||||
from .http_client import HttpClient
|
||||
|
@ -268,6 +269,11 @@ class View(AbstractView):
|
|||
return self.request.app
|
||||
|
||||
|
||||
@property
|
||||
def cache(self) -> Cache:
|
||||
return self.app.cache
|
||||
|
||||
|
||||
@property
|
||||
def client(self) -> HttpClient:
|
||||
return self.app.client
|
||||
|
|
|
@ -3,8 +3,6 @@ from __future__ import annotations
|
|||
import tinysql
|
||||
import typing
|
||||
|
||||
from cachetools import LRUCache
|
||||
|
||||
from . import logger as logging
|
||||
from .misc import Message
|
||||
|
||||
|
@ -12,9 +10,6 @@ if typing.TYPE_CHECKING:
|
|||
from .views import ActorView
|
||||
|
||||
|
||||
cache = LRUCache(1024)
|
||||
|
||||
|
||||
def person_check(actor: str, software: str) -> bool:
|
||||
# pleroma and akkoma may use Person for the actor type for some reason
|
||||
# akkoma changed this in 3.6.0
|
||||
|
@ -29,31 +24,39 @@ def person_check(actor: str, software: str) -> bool:
|
|||
|
||||
|
||||
async def handle_relay(view: ActorView) -> None:
|
||||
if view.message.object_id in cache:
|
||||
try:
|
||||
view.cache.get('handle-relay', view.message.object_id)
|
||||
logging.verbose('already relayed %s', view.message.object_id)
|
||||
return
|
||||
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
message = Message.new_announce(view.config.domain, view.message.object_id)
|
||||
cache[view.message.object_id] = message.id
|
||||
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
|
||||
logging.debug('>> relay: %s', message)
|
||||
|
||||
with view.database.connection() as conn:
|
||||
for inbox in conn.distill_inboxes(view.message):
|
||||
view.app.push_message(inbox, message)
|
||||
view.app.push_message(inbox, message, view.instance)
|
||||
|
||||
|
||||
async def handle_forward(view: ActorView) -> None:
|
||||
if view.message.id in cache:
|
||||
logging.verbose('already forwarded %s', view.message.id)
|
||||
try:
|
||||
view.cache.get('handle-relay', view.message.object_id)
|
||||
logging.verbose('already forwarded %s', view.message.object_id)
|
||||
return
|
||||
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
message = Message.new_announce(view.config.domain, view.message)
|
||||
cache[view.message.id] = message.id
|
||||
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
|
||||
logging.debug('>> forward: %s', message)
|
||||
|
||||
with view.database.connection() as conn:
|
||||
for inbox in conn.distill_inboxes(view.message):
|
||||
view.app.push_message(inbox, message)
|
||||
view.app.push_message(inbox, message, view.instance)
|
||||
|
||||
|
||||
async def handle_follow(view: ActorView) -> None:
|
||||
|
|
|
@ -19,6 +19,7 @@ if typing.TYPE_CHECKING:
|
|||
from aiohttp.web import Request
|
||||
from aputils.signer import Signer
|
||||
from collections.abc import Callable
|
||||
from tinysql import Row
|
||||
|
||||
|
||||
VIEWS = []
|
||||
|
@ -98,7 +99,7 @@ class ActorView(View):
|
|||
self.signature: Signature = None
|
||||
self.message: Message = None
|
||||
self.actor: Message = None
|
||||
self.instance: dict[str, str] = None
|
||||
self.instance: Row = None
|
||||
self.signer: Signer = None
|
||||
|
||||
|
||||
|
@ -168,7 +169,11 @@ class ActorView(View):
|
|||
logging.verbose('actor not in message')
|
||||
return Response.new_error(400, 'no actor in message', 'json')
|
||||
|
||||
self.actor = await self.client.get(self.signature.keyid, sign_headers = True)
|
||||
self.actor = await self.client.get(
|
||||
self.signature.keyid,
|
||||
sign_headers = True,
|
||||
loads = Message.parse
|
||||
)
|
||||
|
||||
if not self.actor:
|
||||
# ld signatures aren't handled atm, so just ignore it
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
aiohttp>=3.9.1
|
||||
aputils@https://git.barkshark.xyz/barkshark/aputils/archive/0.1.6a.tar.gz
|
||||
cachetools>=5.2.0
|
||||
click>=8.1.2
|
||||
hiredis==2.3.2
|
||||
pyyaml>=6.0
|
||||
redis==5.0.1
|
||||
tinysql[postgres]@https://git.barkshark.xyz/barkshark/tinysql/archive/0.2.3.tar.gz
|
||||
|
||||
importlib_resources==6.1.1;python_version<'3.9'
|
||||
|
|
Loading…
Reference in a new issue