mirror of
https://git.pleroma.social/pleroma/relay.git
synced 2024-11-23 23:17:58 +00:00
Compare commits
No commits in common. "b8e06417334c2172de4b87ee80bd2ef9092f75d9" and "f2baf7f9f927e4e4def2879fc5348d8c305cdb11" have entirely different histories.
b8e0641733
...
f2baf7f9f9
|
@ -35,13 +35,6 @@ SQL database backend to use. Valid values are `sqlite` or `postgres`.
|
||||||
database_type: sqlite
|
database_type: sqlite
|
||||||
|
|
||||||
|
|
||||||
### Cache type
|
|
||||||
|
|
||||||
Cache backend to use. Valid values are `database` or `redis`
|
|
||||||
|
|
||||||
cache_type: database
|
|
||||||
|
|
||||||
|
|
||||||
### Sqlite File Path
|
### Sqlite File Path
|
||||||
|
|
||||||
Path to the sqlite database file. If the path is not absolute, it is relative to the config file.
|
Path to the sqlite database file. If the path is not absolute, it is relative to the config file.
|
||||||
|
@ -54,7 +47,7 @@ directory.
|
||||||
|
|
||||||
In order to use the Postgresql backend, the user and database need to be created first.
|
In order to use the Postgresql backend, the user and database need to be created first.
|
||||||
|
|
||||||
sudo -u postgres psql -c "CREATE USER activityrelay WITH PASSWORD SomeSecurePassword"
|
sudo -u postgres psql -c "CREATE USER activityrelay"
|
||||||
sudo -u postgres psql -c "CREATE DATABASE activityrelay OWNER activityrelay"
|
sudo -u postgres psql -c "CREATE DATABASE activityrelay OWNER activityrelay"
|
||||||
|
|
||||||
|
|
||||||
|
@ -91,47 +84,3 @@ User to use when logging into the server.
|
||||||
Password for the specified user.
|
Password for the specified user.
|
||||||
|
|
||||||
pass: null
|
pass: null
|
||||||
|
|
||||||
|
|
||||||
## Redis
|
|
||||||
|
|
||||||
### Host
|
|
||||||
|
|
||||||
Hostname, IP address, or unix socket the server is hosted on.
|
|
||||||
|
|
||||||
host: /var/run/postgresql
|
|
||||||
|
|
||||||
|
|
||||||
### Port
|
|
||||||
|
|
||||||
Port number the server is listening on.
|
|
||||||
|
|
||||||
port: 5432
|
|
||||||
|
|
||||||
|
|
||||||
### Username
|
|
||||||
|
|
||||||
User to use when logging into the server.
|
|
||||||
|
|
||||||
user: null
|
|
||||||
|
|
||||||
|
|
||||||
### Password
|
|
||||||
|
|
||||||
Password for the specified user.
|
|
||||||
|
|
||||||
pass: null
|
|
||||||
|
|
||||||
|
|
||||||
### Database Number
|
|
||||||
|
|
||||||
Number of the database to use.
|
|
||||||
|
|
||||||
database: 0
|
|
||||||
|
|
||||||
|
|
||||||
### Prefix
|
|
||||||
|
|
||||||
Text to prefix every key with. It cannot contain a `:` character.
|
|
||||||
|
|
||||||
prefix: activityrelay
|
|
||||||
|
|
|
@ -7,15 +7,12 @@ listen: 0.0.0.0
|
||||||
# [integer] Port the relay will listen on
|
# [integer] Port the relay will listen on
|
||||||
port: 8080
|
port: 8080
|
||||||
|
|
||||||
# [integer] Number of push workers to start
|
# [integer] Number of push workers to start (will get removed in a future update)
|
||||||
workers: 8
|
workers: 8
|
||||||
|
|
||||||
# [string] Database backend to use. Valid values: sqlite, postgres
|
# [string] Database backend to use. Valid values: sqlite, postgres
|
||||||
database_type: sqlite
|
database_type: sqlite
|
||||||
|
|
||||||
# [string] Cache backend to use. Valid values: database, redis
|
|
||||||
cache_type: database
|
|
||||||
|
|
||||||
# [string] Path to the sqlite database file if the sqlite backend is in use
|
# [string] Path to the sqlite database file if the sqlite backend is in use
|
||||||
sqlite_path: relay.sqlite3
|
sqlite_path: relay.sqlite3
|
||||||
|
|
||||||
|
@ -36,24 +33,3 @@ postgres:
|
||||||
|
|
||||||
# [string] name of the database to use
|
# [string] name of the database to use
|
||||||
name: activityrelay
|
name: activityrelay
|
||||||
|
|
||||||
# settings for the redis caching backend
|
|
||||||
redis:
|
|
||||||
|
|
||||||
# [string] hostname or unix socket to connect to
|
|
||||||
host: localhost
|
|
||||||
|
|
||||||
# [integer] port of the server
|
|
||||||
port: 6379
|
|
||||||
|
|
||||||
# [string] username to use when logging into the server
|
|
||||||
user: null
|
|
||||||
|
|
||||||
# [string] password for the server
|
|
||||||
pass: null
|
|
||||||
|
|
||||||
# [integer] database number to use
|
|
||||||
database: 0
|
|
||||||
|
|
||||||
# [string] prefix for keys
|
|
||||||
prefix: activityrelay
|
|
||||||
|
|
|
@ -1,20 +1,17 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import queue
|
||||||
import signal
|
import signal
|
||||||
import subprocess
|
import threading
|
||||||
import sys
|
import traceback
|
||||||
import time
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from aputils.signer import Signer
|
from aputils.signer import Signer
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from gunicorn.app.wsgiapp import WSGIApplication
|
|
||||||
|
|
||||||
from . import logger as logging
|
from . import logger as logging
|
||||||
from .cache import get_cache
|
|
||||||
from .config import Config
|
from .config import Config
|
||||||
from .database import get_database
|
from .database import get_database
|
||||||
from .http_client import HttpClient
|
from .http_client import HttpClient
|
||||||
|
@ -22,10 +19,8 @@ from .misc import check_open_port
|
||||||
from .views import VIEWS
|
from .views import VIEWS
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if typing.TYPE_CHECKING:
|
||||||
from collections.abc import Awaitable
|
from tinysql import Database
|
||||||
from tinysql import Database, Row
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from .cache import Cache
|
|
||||||
from .misc import Message
|
from .misc import Message
|
||||||
|
|
||||||
|
|
||||||
|
@ -34,34 +29,25 @@ if typing.TYPE_CHECKING:
|
||||||
class Application(web.Application):
|
class Application(web.Application):
|
||||||
DEFAULT: Application = None
|
DEFAULT: Application = None
|
||||||
|
|
||||||
def __init__(self, cfgpath: str, gunicorn: bool = False):
|
def __init__(self, cfgpath: str):
|
||||||
web.Application.__init__(self)
|
web.Application.__init__(self)
|
||||||
|
|
||||||
Application.DEFAULT = self
|
Application.DEFAULT = self
|
||||||
|
|
||||||
self['proc'] = None
|
|
||||||
self['signer'] = None
|
self['signer'] = None
|
||||||
self['start_time'] = None
|
|
||||||
|
|
||||||
self['config'] = Config(cfgpath, load = True)
|
self['config'] = Config(cfgpath, load = True)
|
||||||
self['database'] = get_database(self.config)
|
self['database'] = get_database(self.config)
|
||||||
self['client'] = HttpClient()
|
self['client'] = HttpClient()
|
||||||
self['cache'] = get_cache(self)
|
|
||||||
|
|
||||||
if not gunicorn:
|
self['workers'] = []
|
||||||
return
|
self['last_worker'] = 0
|
||||||
|
self['start_time'] = None
|
||||||
self.on_response_prepare.append(handle_access_log)
|
self['running'] = False
|
||||||
|
|
||||||
for path, view in VIEWS:
|
for path, view in VIEWS:
|
||||||
self.router.add_view(path, view)
|
self.router.add_view(path, view)
|
||||||
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cache(self) -> Cache:
|
|
||||||
return self['cache']
|
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self) -> HttpClient:
|
def client(self) -> HttpClient:
|
||||||
return self['client']
|
return self['client']
|
||||||
|
@ -101,17 +87,18 @@ class Application(web.Application):
|
||||||
return timedelta(seconds=uptime.seconds)
|
return timedelta(seconds=uptime.seconds)
|
||||||
|
|
||||||
|
|
||||||
def push_message(self, inbox: str, message: Message, instance: Row) -> None:
|
def push_message(self, inbox: str, message: Message) -> None:
|
||||||
asyncio.ensure_future(self.client.post(inbox, message, instance))
|
if self.config.workers <= 0:
|
||||||
|
asyncio.ensure_future(self.client.post(inbox, message))
|
||||||
|
return
|
||||||
|
|
||||||
|
worker = self['workers'][self['last_worker']]
|
||||||
|
worker.queue.put((inbox, message))
|
||||||
|
|
||||||
def run(self, dev: bool = False) -> None:
|
self['last_worker'] += 1
|
||||||
self.start(dev)
|
|
||||||
|
|
||||||
while self['proc'] and self['proc'].poll() is None:
|
if self['last_worker'] >= len(self['workers']):
|
||||||
time.sleep(0.1)
|
self['last_worker'] = 0
|
||||||
|
|
||||||
self.stop()
|
|
||||||
|
|
||||||
|
|
||||||
def set_signal_handler(self, startup: bool) -> None:
|
def set_signal_handler(self, startup: bool) -> None:
|
||||||
|
@ -124,101 +111,91 @@ class Application(web.Application):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
def start(self, dev: bool = False) -> None:
|
|
||||||
if self['proc']:
|
|
||||||
return
|
|
||||||
|
|
||||||
if not check_open_port(self.config.listen, self.config.port):
|
if not check_open_port(self.config.listen, self.config.port):
|
||||||
logging.error('Server already running on %s:%s', self.config.listen, self.config.port)
|
logging.error('A server is already running on port %i', self.config.port)
|
||||||
return
|
return
|
||||||
|
|
||||||
cmd = [
|
for view in VIEWS:
|
||||||
sys.executable, '-m', 'gunicorn',
|
self.router.add_view(*view)
|
||||||
'relay.application:main_gunicorn',
|
|
||||||
'--bind', f'{self.config.listen}:{self.config.port}',
|
|
||||||
'--worker-class', 'aiohttp.GunicornWebWorker',
|
|
||||||
'--workers', str(self.config.workers),
|
|
||||||
'--env', f'CONFIG_FILE={self.config.path}'
|
|
||||||
]
|
|
||||||
|
|
||||||
if dev:
|
logging.info(
|
||||||
cmd.append('--reload')
|
'Starting webserver at %s (%s:%i)',
|
||||||
|
self.config.domain,
|
||||||
|
self.config.listen,
|
||||||
|
self.config.port
|
||||||
|
)
|
||||||
|
|
||||||
|
asyncio.run(self.handle_run())
|
||||||
|
|
||||||
|
|
||||||
|
def stop(self, *_: Any) -> None:
|
||||||
|
self['running'] = False
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_run(self) -> None:
|
||||||
|
self['running'] = True
|
||||||
|
|
||||||
self.set_signal_handler(True)
|
self.set_signal_handler(True)
|
||||||
self['proc'] = subprocess.Popen(cmd) # pylint: disable=consider-using-with
|
|
||||||
|
|
||||||
|
if self.config.workers > 0:
|
||||||
|
for _ in range(self.config.workers):
|
||||||
|
worker = PushWorker(self)
|
||||||
|
worker.start()
|
||||||
|
|
||||||
def stop(self, *_) -> None:
|
self['workers'].append(worker)
|
||||||
if not self['proc']:
|
|
||||||
return
|
|
||||||
|
|
||||||
self['proc'].terminate()
|
runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"')
|
||||||
time_wait = 0.0
|
await runner.setup()
|
||||||
|
|
||||||
while self['proc'].poll() is None:
|
site = web.TCPSite(
|
||||||
time.sleep(0.1)
|
runner,
|
||||||
time_wait += 0.1
|
host = self.config.listen,
|
||||||
|
port = self.config.port,
|
||||||
if time_wait >= 5.0:
|
reuse_address = True
|
||||||
self['proc'].kill()
|
|
||||||
break
|
|
||||||
|
|
||||||
self.set_signal_handler(False)
|
|
||||||
self['proc'] = None
|
|
||||||
|
|
||||||
|
|
||||||
# not used, but keeping just in case
|
|
||||||
class GunicornRunner(WSGIApplication):
|
|
||||||
def __init__(self, app: Application):
|
|
||||||
self.app = app
|
|
||||||
self.app_uri = 'relay.application:main_gunicorn'
|
|
||||||
self.options = {
|
|
||||||
'bind': f'{app.config.listen}:{app.config.port}',
|
|
||||||
'worker_class': 'aiohttp.GunicornWebWorker',
|
|
||||||
'workers': app.config.workers,
|
|
||||||
'raw_env': f'CONFIG_FILE={app.config.path}'
|
|
||||||
}
|
|
||||||
|
|
||||||
WSGIApplication.__init__(self)
|
|
||||||
|
|
||||||
|
|
||||||
def load_config(self):
|
|
||||||
for key, value in self.options.items():
|
|
||||||
self.cfg.set(key, value)
|
|
||||||
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
logging.info('Starting webserver for %s', self.app.config.domain)
|
|
||||||
WSGIApplication.run(self)
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_access_log(request: web.Request, response: web.Response) -> None:
|
|
||||||
address = request.headers.get(
|
|
||||||
'X-Forwarded-For',
|
|
||||||
request.headers.get(
|
|
||||||
'X-Real-Ip',
|
|
||||||
request.remote
|
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
logging.info(
|
await site.start()
|
||||||
'%s "%s %s" %i %i "%s"',
|
self['start_time'] = datetime.now()
|
||||||
address,
|
|
||||||
request.method,
|
while self['running']:
|
||||||
request.path,
|
await asyncio.sleep(0.25)
|
||||||
response.status,
|
|
||||||
len(response.body),
|
await site.stop()
|
||||||
request.headers.get('User-Agent', 'n/a')
|
await self.client.close()
|
||||||
)
|
|
||||||
|
self['start_time'] = None
|
||||||
|
self['running'] = False
|
||||||
|
self['workers'].clear()
|
||||||
|
|
||||||
|
|
||||||
async def main_gunicorn():
|
class PushWorker(threading.Thread):
|
||||||
try:
|
def __init__(self, app: Application):
|
||||||
app = Application(os.environ['CONFIG_FILE'], gunicorn = True)
|
threading.Thread.__init__(self)
|
||||||
|
self.app = app
|
||||||
|
self.queue = queue.Queue()
|
||||||
|
self.client = None
|
||||||
|
|
||||||
except KeyError:
|
|
||||||
logging.error('Failed to set "CONFIG_FILE" environment. Trying to run without gunicorn?')
|
|
||||||
raise
|
|
||||||
|
|
||||||
return app
|
def run(self) -> None:
|
||||||
|
asyncio.run(self.handle_queue())
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_queue(self) -> None:
|
||||||
|
self.client = HttpClient()
|
||||||
|
|
||||||
|
while self.app['running']:
|
||||||
|
try:
|
||||||
|
inbox, message = self.queue.get(block=True, timeout=0.25)
|
||||||
|
self.queue.task_done()
|
||||||
|
logging.verbose('New push from Thread-%i', threading.get_ident())
|
||||||
|
await self.client.post(inbox, message)
|
||||||
|
|
||||||
|
except queue.Empty:
|
||||||
|
pass
|
||||||
|
|
||||||
|
## make sure an exception doesn't bring down the worker
|
||||||
|
except Exception:
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
await self.client.close()
|
||||||
|
|
288
relay/cache.py
288
relay/cache.py
|
@ -1,288 +0,0 @@
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from dataclasses import asdict, dataclass
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from redis import Redis
|
|
||||||
|
|
||||||
from .database import get_database
|
|
||||||
from .misc import Message, boolean
|
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from typing import Any
|
|
||||||
from collections.abc import Callable, Iterator
|
|
||||||
from tinysql import Database
|
|
||||||
from .application import Application
|
|
||||||
|
|
||||||
|
|
||||||
# todo: implement more caching backends
|
|
||||||
|
|
||||||
|
|
||||||
BACKENDS: dict[str, Cache] = {}
|
|
||||||
CONVERTERS: dict[str, tuple[Callable, Callable]] = {
|
|
||||||
'str': (str, str),
|
|
||||||
'int': (str, int),
|
|
||||||
'bool': (str, boolean),
|
|
||||||
'json': (json.dumps, json.loads),
|
|
||||||
'message': (lambda x: x.to_json(), Message.parse)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_cache(app: Application) -> Cache:
|
|
||||||
return BACKENDS[app.config.ca_type](app)
|
|
||||||
|
|
||||||
|
|
||||||
def register_cache(backend: type[Cache]) -> type[Cache]:
|
|
||||||
BACKENDS[backend.name] = backend
|
|
||||||
return backend
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_value(value: Any, value_type: str = 'str') -> str:
|
|
||||||
if isinstance(value, str):
|
|
||||||
return value
|
|
||||||
|
|
||||||
return CONVERTERS[value_type][0](value)
|
|
||||||
|
|
||||||
|
|
||||||
def deserialize_value(value: str, value_type: str = 'str') -> Any:
|
|
||||||
return CONVERTERS[value_type][1](value)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Item:
|
|
||||||
namespace: str
|
|
||||||
key: str
|
|
||||||
value: Any
|
|
||||||
value_type: str
|
|
||||||
updated: datetime
|
|
||||||
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if isinstance(self.updated, str):
|
|
||||||
self.updated = datetime.fromisoformat(self.updated)
|
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_data(cls: type[Item], *args) -> Item:
|
|
||||||
data = cls(*args)
|
|
||||||
data.value = deserialize_value(data.value, data.value_type)
|
|
||||||
|
|
||||||
if not isinstance(data.updated, datetime):
|
|
||||||
data.updated = datetime.fromtimestamp(data.updated, tz = timezone.utc)
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def older_than(self, hours: int) -> bool:
|
|
||||||
delta = datetime.now(tz = timezone.utc) - self.updated
|
|
||||||
return (delta.total_seconds()) > hours * 3600
|
|
||||||
|
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
|
||||||
return asdict(self)
|
|
||||||
|
|
||||||
|
|
||||||
class Cache(ABC):
|
|
||||||
name: str = 'null'
|
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, app: Application):
|
|
||||||
self.app = app
|
|
||||||
self.setup()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get(self, namespace: str, key: str) -> Item:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_keys(self, namespace: str) -> Iterator[str]:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_namespaces(self) -> Iterator[str]:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> Item:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete(self, namespace: str, key: str) -> None:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def setup(self) -> None:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def set_item(self, item: Item) -> Item:
|
|
||||||
return self.set(
|
|
||||||
item.namespace,
|
|
||||||
item.key,
|
|
||||||
item.value,
|
|
||||||
item.type
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def delete_item(self, item: Item) -> None:
|
|
||||||
self.delete(item.namespace, item.key)
|
|
||||||
|
|
||||||
|
|
||||||
@register_cache
|
|
||||||
class SqlCache(Cache):
|
|
||||||
name: str = 'database'
|
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, app: Application):
|
|
||||||
self._db = get_database(app.config)
|
|
||||||
Cache.__init__(self, app)
|
|
||||||
|
|
||||||
|
|
||||||
def get(self, namespace: str, key: str) -> Item:
|
|
||||||
params = {
|
|
||||||
'namespace': namespace,
|
|
||||||
'key': key
|
|
||||||
}
|
|
||||||
|
|
||||||
with self._db.connection() as conn:
|
|
||||||
with conn.exec_statement('get-cache-item', params) as cur:
|
|
||||||
if not (row := cur.one()):
|
|
||||||
raise KeyError(f'{namespace}:{key}')
|
|
||||||
|
|
||||||
row.pop('id', None)
|
|
||||||
return Item.from_data(*tuple(row.values()))
|
|
||||||
|
|
||||||
|
|
||||||
def get_keys(self, namespace: str) -> Iterator[str]:
|
|
||||||
with self._db.connection() as conn:
|
|
||||||
for row in conn.exec_statement('get-cache-keys', {'namespace': namespace}):
|
|
||||||
yield row['key']
|
|
||||||
|
|
||||||
|
|
||||||
def get_namespaces(self) -> Iterator[str]:
|
|
||||||
with self._db.connection() as conn:
|
|
||||||
for row in conn.exec_statement('get-cache-namespaces', None):
|
|
||||||
yield row['namespace']
|
|
||||||
|
|
||||||
|
|
||||||
def set(self, namespace: str, key: str, value: Any, value_type: str = 'str') -> Item:
|
|
||||||
params = {
|
|
||||||
'namespace': namespace,
|
|
||||||
'key': key,
|
|
||||||
'value': serialize_value(value, value_type),
|
|
||||||
'type': value_type,
|
|
||||||
'date': datetime.now(tz = timezone.utc)
|
|
||||||
}
|
|
||||||
|
|
||||||
with self._db.connection() as conn:
|
|
||||||
with conn.exec_statement('set-cache-item', params) as conn:
|
|
||||||
row = conn.one()
|
|
||||||
row.pop('id', None)
|
|
||||||
return Item.from_data(*tuple(row.values()))
|
|
||||||
|
|
||||||
|
|
||||||
def delete(self, namespace: str, key: str) -> None:
|
|
||||||
params = {
|
|
||||||
'namespace': namespace,
|
|
||||||
'key': key
|
|
||||||
}
|
|
||||||
|
|
||||||
with self._db.connection() as conn:
|
|
||||||
with conn.exec_statement('del-cache-item', params):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def setup(self) -> None:
|
|
||||||
with self._db.connection() as conn:
|
|
||||||
with conn.exec_statement(f'create-cache-table-{self._db.type.name.lower()}', None):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@register_cache
|
|
||||||
class RedisCache(Cache):
|
|
||||||
name: str = 'redis'
|
|
||||||
_rd: Redis
|
|
||||||
|
|
||||||
|
|
||||||
@property
|
|
||||||
def prefix(self) -> str:
|
|
||||||
return self.app.config.rd_prefix
|
|
||||||
|
|
||||||
|
|
||||||
def get_key_name(self, namespace: str, key: str) -> str:
|
|
||||||
return f'{self.prefix}:{namespace}:{key}'
|
|
||||||
|
|
||||||
|
|
||||||
def get(self, namespace: str, key: str) -> Item:
|
|
||||||
key_name = self.get_key_name(namespace, key)
|
|
||||||
|
|
||||||
if not (raw_value := self._rd.get(key_name)):
|
|
||||||
raise KeyError(f'{namespace}:{key}')
|
|
||||||
|
|
||||||
value_type, updated, value = raw_value.split(':', 2)
|
|
||||||
return Item.from_data(
|
|
||||||
namespace,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
value_type,
|
|
||||||
datetime.fromtimestamp(float(updated), tz = timezone.utc)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_keys(self, namespace: str) -> Iterator[str]:
|
|
||||||
for key in self._rd.keys(self.get_key_name(namespace, '*')):
|
|
||||||
*_, key_name = key.split(':', 2)
|
|
||||||
yield key_name
|
|
||||||
|
|
||||||
|
|
||||||
def get_namespaces(self) -> Iterator[str]:
|
|
||||||
namespaces = []
|
|
||||||
|
|
||||||
for key in self._rd.keys(f'{self.prefix}:*'):
|
|
||||||
_, namespace, _ = key.split(':', 2)
|
|
||||||
|
|
||||||
if namespace not in namespaces:
|
|
||||||
namespaces.append(namespace)
|
|
||||||
yield namespace
|
|
||||||
|
|
||||||
|
|
||||||
def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> None:
|
|
||||||
date = datetime.now(tz = timezone.utc).timestamp()
|
|
||||||
value = serialize_value(value, value_type)
|
|
||||||
|
|
||||||
self._rd.set(
|
|
||||||
self.get_key_name(namespace, key),
|
|
||||||
f'{value_type}:{date}:{value}'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def delete(self, namespace: str, key: str) -> None:
|
|
||||||
self._rd.delete(self.get_key_name(namespace, key))
|
|
||||||
|
|
||||||
|
|
||||||
def setup(self) -> None:
|
|
||||||
options = {
|
|
||||||
'client_name': f'ActivityRelay_{self.app.config.domain}',
|
|
||||||
'decode_responses': True,
|
|
||||||
'username': self.app.config.rd_user,
|
|
||||||
'password': self.app.config.rd_pass,
|
|
||||||
'db': self.app.config.rd_database
|
|
||||||
}
|
|
||||||
|
|
||||||
if os.path.exists(self.app.config.rd_host):
|
|
||||||
options['unix_socket_path'] = self.app.config.rd_host
|
|
||||||
|
|
||||||
else:
|
|
||||||
options['host'] = self.app.config.rd_host
|
|
||||||
options['port'] = self.app.config.rd_port
|
|
||||||
|
|
||||||
self._rd = Redis(**options)
|
|
|
@ -19,21 +19,12 @@ DEFAULTS: dict[str, Any] = {
|
||||||
'domain': 'relay.example.com',
|
'domain': 'relay.example.com',
|
||||||
'workers': len(os.sched_getaffinity(0)),
|
'workers': len(os.sched_getaffinity(0)),
|
||||||
'db_type': 'sqlite',
|
'db_type': 'sqlite',
|
||||||
'ca_type': 'database',
|
|
||||||
'sq_path': 'relay.sqlite3',
|
'sq_path': 'relay.sqlite3',
|
||||||
|
|
||||||
'pg_host': '/var/run/postgresql',
|
'pg_host': '/var/run/postgresql',
|
||||||
'pg_port': 5432,
|
'pg_port': 5432,
|
||||||
'pg_user': getpass.getuser(),
|
'pg_user': getpass.getuser(),
|
||||||
'pg_pass': None,
|
'pg_pass': None,
|
||||||
'pg_name': 'activityrelay',
|
'pg_name': 'activityrelay'
|
||||||
|
|
||||||
'rd_host': 'localhost',
|
|
||||||
'rd_port': 6379,
|
|
||||||
'rd_user': None,
|
|
||||||
'rd_pass': None,
|
|
||||||
'rd_database': 0,
|
|
||||||
'rd_prefix': 'activityrelay'
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if IS_DOCKER:
|
if IS_DOCKER:
|
||||||
|
@ -49,22 +40,13 @@ class Config:
|
||||||
self.domain = None
|
self.domain = None
|
||||||
self.workers = None
|
self.workers = None
|
||||||
self.db_type = None
|
self.db_type = None
|
||||||
self.ca_type = None
|
|
||||||
self.sq_path = None
|
self.sq_path = None
|
||||||
|
|
||||||
self.pg_host = None
|
self.pg_host = None
|
||||||
self.pg_port = None
|
self.pg_port = None
|
||||||
self.pg_user = None
|
self.pg_user = None
|
||||||
self.pg_pass = None
|
self.pg_pass = None
|
||||||
self.pg_name = None
|
self.pg_name = None
|
||||||
|
|
||||||
self.rd_host = None
|
|
||||||
self.rd_port = None
|
|
||||||
self.rd_user = None
|
|
||||||
self.rd_pass = None
|
|
||||||
self.rd_database = None
|
|
||||||
self.rd_prefix = None
|
|
||||||
|
|
||||||
if load:
|
if load:
|
||||||
try:
|
try:
|
||||||
self.load()
|
self.load()
|
||||||
|
@ -110,7 +92,6 @@ class Config:
|
||||||
with self.path.open('r', encoding = 'UTF-8') as fd:
|
with self.path.open('r', encoding = 'UTF-8') as fd:
|
||||||
config = yaml.load(fd, **options)
|
config = yaml.load(fd, **options)
|
||||||
pgcfg = config.get('postgresql', {})
|
pgcfg = config.get('postgresql', {})
|
||||||
rdcfg = config.get('redis', {})
|
|
||||||
|
|
||||||
if not config:
|
if not config:
|
||||||
raise ValueError('Config is empty')
|
raise ValueError('Config is empty')
|
||||||
|
@ -125,25 +106,18 @@ class Config:
|
||||||
self.set('port', config.get('port', DEFAULTS['port']))
|
self.set('port', config.get('port', DEFAULTS['port']))
|
||||||
self.set('sq_path', config.get('sqlite_path', DEFAULTS['sq_path']))
|
self.set('sq_path', config.get('sqlite_path', DEFAULTS['sq_path']))
|
||||||
|
|
||||||
self.set('workers', config.get('workers', DEFAULTS['workers']))
|
|
||||||
self.set('domain', config.get('domain', DEFAULTS['domain']))
|
self.set('domain', config.get('domain', DEFAULTS['domain']))
|
||||||
self.set('db_type', config.get('database_type', DEFAULTS['db_type']))
|
self.set('db_type', config.get('database_type', DEFAULTS['db_type']))
|
||||||
self.set('ca_type', config.get('cache_type', DEFAULTS['ca_type']))
|
|
||||||
|
|
||||||
for key in DEFAULTS:
|
for key in DEFAULTS:
|
||||||
if key.startswith('pg'):
|
if not key.startswith('pg'):
|
||||||
try:
|
continue
|
||||||
self.set(key, pgcfg[key[3:]])
|
|
||||||
|
|
||||||
except KeyError:
|
try:
|
||||||
continue
|
self.set(key, pgcfg[key[3:]])
|
||||||
|
|
||||||
elif key.startswith('rd'):
|
except KeyError:
|
||||||
try:
|
continue
|
||||||
self.set(key, rdcfg[key[3:]])
|
|
||||||
|
|
||||||
except KeyError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
|
@ -158,9 +132,7 @@ class Config:
|
||||||
'listen': self.listen,
|
'listen': self.listen,
|
||||||
'port': self.port,
|
'port': self.port,
|
||||||
'domain': self.domain,
|
'domain': self.domain,
|
||||||
'workers': self.workers,
|
|
||||||
'database_type': self.db_type,
|
'database_type': self.db_type,
|
||||||
'cache_type': self.ca_type,
|
|
||||||
'sqlite_path': self.sq_path,
|
'sqlite_path': self.sq_path,
|
||||||
'postgres': {
|
'postgres': {
|
||||||
'host': self.pg_host,
|
'host': self.pg_host,
|
||||||
|
@ -168,14 +140,6 @@ class Config:
|
||||||
'user': self.pg_user,
|
'user': self.pg_user,
|
||||||
'pass': self.pg_pass,
|
'pass': self.pg_pass,
|
||||||
'name': self.pg_name
|
'name': self.pg_name
|
||||||
},
|
|
||||||
'redis': {
|
|
||||||
'host': self.rd_host,
|
|
||||||
'port': self.rd_port,
|
|
||||||
'user': self.rd_user,
|
|
||||||
'pass': self.rd_pass,
|
|
||||||
'database': self.rd_database,
|
|
||||||
'refix': self.rd_prefix
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -77,64 +77,3 @@ RETURNING *
|
||||||
-- name: del-domain-whitelist
|
-- name: del-domain-whitelist
|
||||||
DELETE FROM whitelist
|
DELETE FROM whitelist
|
||||||
WHERE domain = :domain
|
WHERE domain = :domain
|
||||||
|
|
||||||
|
|
||||||
-- cache functions --
|
|
||||||
|
|
||||||
-- name: create-cache-table-sqlite
|
|
||||||
CREATE TABLE IF NOT EXISTS cache (
|
|
||||||
id INTEGER PRIMARY KEY UNIQUE,
|
|
||||||
namespace TEXT NOT NULL,
|
|
||||||
key TEXT NOT NULL,
|
|
||||||
"value" TEXT,
|
|
||||||
type TEXT DEFAULT 'str',
|
|
||||||
updated TIMESTAMP NOT NULL,
|
|
||||||
UNIQUE(namespace, key)
|
|
||||||
)
|
|
||||||
|
|
||||||
-- name: create-cache-table-postgres
|
|
||||||
CREATE TABLE IF NOT EXISTS cache (
|
|
||||||
id SERIAL PRIMARY KEY,
|
|
||||||
namespace TEXT NOT NULL,
|
|
||||||
key TEXT NOT NULL,
|
|
||||||
"value" TEXT,
|
|
||||||
type TEXT DEFAULT 'str',
|
|
||||||
updated TIMESTAMP NOT NULL,
|
|
||||||
UNIQUE(namespace, key)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
-- name: get-cache-item
|
|
||||||
SELECT * FROM cache
|
|
||||||
WHERE namespace = :namespace and key = :key
|
|
||||||
|
|
||||||
|
|
||||||
-- name: get-cache-keys
|
|
||||||
SELECT key FROM cache
|
|
||||||
WHERE namespace = :namespace
|
|
||||||
|
|
||||||
|
|
||||||
-- name: get-cache-namespaces
|
|
||||||
SELECT DISTINCT namespace FROM cache
|
|
||||||
|
|
||||||
|
|
||||||
-- name: set-cache-item
|
|
||||||
INSERT INTO cache (namespace, key, value, type, updated)
|
|
||||||
VALUES (:namespace, :key, :value, :type, :date)
|
|
||||||
ON CONFLICT (namespace, key) DO
|
|
||||||
UPDATE SET value = :value, type = :type, updated = :date
|
|
||||||
RETURNING *
|
|
||||||
|
|
||||||
|
|
||||||
-- name: del-cache-item
|
|
||||||
DELETE FROM cache
|
|
||||||
WHERE namespace = :namespace and key = :key
|
|
||||||
|
|
||||||
|
|
||||||
-- name: del-cache-namespace
|
|
||||||
DELETE FROM cache
|
|
||||||
WHERE namespace = :namespace
|
|
||||||
|
|
||||||
|
|
||||||
-- name: del-cache-all
|
|
||||||
DELETE FROM cache
|
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
|
||||||
import traceback
|
import traceback
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
|
@ -8,6 +7,7 @@ from aiohttp import ClientSession, ClientTimeout, TCPConnector
|
||||||
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
|
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
|
||||||
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
|
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
|
||||||
from aputils.objects import Nodeinfo, WellKnownNodeinfo
|
from aputils.objects import Nodeinfo, WellKnownNodeinfo
|
||||||
|
from cachetools import LRUCache
|
||||||
from json.decoder import JSONDecodeError
|
from json.decoder import JSONDecodeError
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
@ -16,11 +16,7 @@ from . import logger as logging
|
||||||
from .misc import MIMETYPES, Message, get_app
|
from .misc import MIMETYPES, Message, get_app
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if typing.TYPE_CHECKING:
|
||||||
from aputils import Signer
|
|
||||||
from tinysql import Row
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from .application import Application
|
|
||||||
from .cache import Cache
|
|
||||||
|
|
||||||
|
|
||||||
HEADERS = {
|
HEADERS = {
|
||||||
|
@ -30,7 +26,8 @@ HEADERS = {
|
||||||
|
|
||||||
|
|
||||||
class HttpClient:
|
class HttpClient:
|
||||||
def __init__(self, limit: int = 100, timeout: int = 10):
|
def __init__(self, limit: int = 100, timeout: int = 10, cache_size: int = 1024):
|
||||||
|
self.cache = LRUCache(cache_size)
|
||||||
self.limit = limit
|
self.limit = limit
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self._conn = None
|
self._conn = None
|
||||||
|
@ -46,21 +43,6 @@ class HttpClient:
|
||||||
await self.close()
|
await self.close()
|
||||||
|
|
||||||
|
|
||||||
@property
|
|
||||||
def app(self) -> Application:
|
|
||||||
return get_app()
|
|
||||||
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cache(self) -> Cache:
|
|
||||||
return self.app.cache
|
|
||||||
|
|
||||||
|
|
||||||
@property
|
|
||||||
def signer(self) -> Signer:
|
|
||||||
return self.app.signer
|
|
||||||
|
|
||||||
|
|
||||||
async def open(self) -> None:
|
async def open(self) -> None:
|
||||||
if self._session:
|
if self._session:
|
||||||
return
|
return
|
||||||
|
@ -92,8 +74,8 @@ class HttpClient:
|
||||||
async def get(self, # pylint: disable=too-many-branches
|
async def get(self, # pylint: disable=too-many-branches
|
||||||
url: str,
|
url: str,
|
||||||
sign_headers: bool = False,
|
sign_headers: bool = False,
|
||||||
loads: callable = json.loads,
|
loads: callable | None = None,
|
||||||
force: bool = False) -> dict | None:
|
force: bool = False) -> Message | dict | None:
|
||||||
|
|
||||||
await self.open()
|
await self.open()
|
||||||
|
|
||||||
|
@ -103,20 +85,13 @@ class HttpClient:
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if not force:
|
if not force and url in self.cache:
|
||||||
try:
|
return self.cache[url]
|
||||||
item = self.cache.get('request', url)
|
|
||||||
|
|
||||||
if not item.older_than(48):
|
|
||||||
return loads(item.value)
|
|
||||||
|
|
||||||
except KeyError:
|
|
||||||
logging.verbose('Failed to fetch cached data for url: %s', url)
|
|
||||||
|
|
||||||
headers = {}
|
headers = {}
|
||||||
|
|
||||||
if sign_headers:
|
if sign_headers:
|
||||||
self.signer.sign_headers('GET', url, algorithm = 'original')
|
get_app().signer.sign_headers('GET', url, algorithm = 'original')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logging.debug('Fetching resource: %s', url)
|
logging.debug('Fetching resource: %s', url)
|
||||||
|
@ -126,22 +101,32 @@ class HttpClient:
|
||||||
if resp.status == 202:
|
if resp.status == 202:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
data = await resp.read()
|
if resp.status != 200:
|
||||||
|
logging.verbose('Received error when requesting %s: %i', url, resp.status)
|
||||||
|
logging.debug(await resp.read())
|
||||||
|
return None
|
||||||
|
|
||||||
if resp.status != 200:
|
if loads:
|
||||||
logging.verbose('Received error when requesting %s: %i', url, resp.status)
|
message = await resp.json(loads=loads)
|
||||||
logging.debug(await resp.read())
|
|
||||||
return None
|
|
||||||
|
|
||||||
message = loads(data)
|
elif resp.content_type == MIMETYPES['activity']:
|
||||||
self.cache.set('request', url, data.decode('utf-8'), 'str')
|
message = await resp.json(loads = Message.parse)
|
||||||
logging.debug('%s >> resp %s', url, json.dumps(message, indent = 4))
|
|
||||||
|
|
||||||
return message
|
elif resp.content_type == MIMETYPES['json']:
|
||||||
|
message = await resp.json()
|
||||||
|
|
||||||
|
else:
|
||||||
|
logging.verbose('Invalid Content-Type for "%s": %s', url, resp.content_type)
|
||||||
|
logging.debug('Response: %s', await resp.read())
|
||||||
|
return None
|
||||||
|
|
||||||
|
logging.debug('%s >> resp %s', url, message.to_json(4))
|
||||||
|
|
||||||
|
self.cache[url] = message
|
||||||
|
return message
|
||||||
|
|
||||||
except JSONDecodeError:
|
except JSONDecodeError:
|
||||||
logging.verbose('Failed to parse JSON')
|
logging.verbose('Failed to parse JSON')
|
||||||
return None
|
|
||||||
|
|
||||||
except ClientSSLError:
|
except ClientSSLError:
|
||||||
logging.verbose('SSL error when connecting to %s', urlparse(url).netloc)
|
logging.verbose('SSL error when connecting to %s', urlparse(url).netloc)
|
||||||
|
@ -155,9 +140,13 @@ class HttpClient:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def post(self, url: str, message: Message, instance: Row | None = None) -> None:
|
async def post(self, url: str, message: Message) -> None:
|
||||||
await self.open()
|
await self.open()
|
||||||
|
|
||||||
|
# todo: cache inboxes to avoid opening a db connection
|
||||||
|
with get_app().database.connection() as conn:
|
||||||
|
instance = conn.get_inbox(url)
|
||||||
|
|
||||||
## Using the old algo by default is probably a better idea right now
|
## Using the old algo by default is probably a better idea right now
|
||||||
# pylint: disable=consider-ternary-expression
|
# pylint: disable=consider-ternary-expression
|
||||||
if instance and instance['software'] in {'mastodon'}:
|
if instance and instance['software'] in {'mastodon'}:
|
||||||
|
|
|
@ -70,6 +70,7 @@ error: Callable = logging.error
|
||||||
critical: Callable = logging.critical
|
critical: Callable = logging.critical
|
||||||
|
|
||||||
|
|
||||||
|
logging.addLevelName(LogLevel['VERBOSE'], 'VERBOSE')
|
||||||
env_log_level = os.environ.get('LOG_LEVEL', 'INFO').upper()
|
env_log_level = os.environ.get('LOG_LEVEL', 'INFO').upper()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -78,15 +79,22 @@ try:
|
||||||
except KeyError:
|
except KeyError:
|
||||||
env_log_file = None
|
env_log_file = None
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
log_level = LogLevel[env_log_level]
|
||||||
|
|
||||||
|
except KeyError:
|
||||||
|
print('Invalid log level:', env_log_level)
|
||||||
|
log_level = LogLevel['INFO']
|
||||||
|
|
||||||
|
|
||||||
handlers = [logging.StreamHandler()]
|
handlers = [logging.StreamHandler()]
|
||||||
|
|
||||||
if env_log_file:
|
if env_log_file:
|
||||||
handlers.append(logging.FileHandler(env_log_file))
|
handlers.append(logging.FileHandler(env_log_file))
|
||||||
|
|
||||||
logging.addLevelName(LogLevel['VERBOSE'], 'VERBOSE')
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level = LogLevel.INFO,
|
level = log_level,
|
||||||
format = '[%(asctime)s] %(levelname)s: %(message)s',
|
format = '[%(asctime)s] %(levelname)s: %(message)s',
|
||||||
datefmt = '%Y-%m-%d %H:%M:%S',
|
|
||||||
handlers = handlers
|
handlers = handlers
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,7 +18,7 @@ from .application import Application
|
||||||
from .compat import RelayConfig, RelayDatabase
|
from .compat import RelayConfig, RelayDatabase
|
||||||
from .database import get_database
|
from .database import get_database
|
||||||
from .database.connection import RELAY_SOFTWARE
|
from .database.connection import RELAY_SOFTWARE
|
||||||
from .misc import IS_DOCKER, Message
|
from .misc import IS_DOCKER, Message, check_open_port
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if typing.TYPE_CHECKING:
|
||||||
from tinysql import Row
|
from tinysql import Row
|
||||||
|
@ -51,13 +51,6 @@ SOFTWARE = (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_alphanumeric(text: str) -> str:
|
|
||||||
if not text.isalnum():
|
|
||||||
raise click.BadParameter('String not alphanumeric')
|
|
||||||
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
@click.group('cli', context_settings={'show_default': True}, invoke_without_command=True)
|
@click.group('cli', context_settings={'show_default': True}, invoke_without_command=True)
|
||||||
@click.option('--config', '-c', default='relay.yaml', help='path to the relay\'s config')
|
@click.option('--config', '-c', default='relay.yaml', help='path to the relay\'s config')
|
||||||
@click.version_option(version=__version__, prog_name='ActivityRelay')
|
@click.version_option(version=__version__, prog_name='ActivityRelay')
|
||||||
|
@ -70,11 +63,6 @@ def cli(ctx: click.Context, config: str) -> None:
|
||||||
cli_setup.callback()
|
cli_setup.callback()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
click.echo(
|
|
||||||
'[DEPRECATED] Running the relay without the "run" command will be removed in the ' +
|
|
||||||
'future.'
|
|
||||||
)
|
|
||||||
|
|
||||||
cli_run.callback()
|
cli_run.callback()
|
||||||
|
|
||||||
|
|
||||||
|
@ -125,9 +113,8 @@ def cli_setup(ctx: click.Context) -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
ctx.obj.config.pg_host = click.prompt(
|
ctx.obj.config.pg_host = click.prompt(
|
||||||
'What IP address, hostname, or unix socket does the server listen on?',
|
'What IP address or hostname does the server listen on?',
|
||||||
default = ctx.obj.config.pg_host,
|
default = ctx.obj.config.pg_host
|
||||||
type = int
|
|
||||||
)
|
)
|
||||||
|
|
||||||
ctx.obj.config.pg_port = click.prompt(
|
ctx.obj.config.pg_port = click.prompt(
|
||||||
|
@ -148,48 +135,6 @@ def cli_setup(ctx: click.Context) -> None:
|
||||||
default = ctx.obj.config.pg_pass or ""
|
default = ctx.obj.config.pg_pass or ""
|
||||||
) or None
|
) or None
|
||||||
|
|
||||||
ctx.obj.config.ca_type = click.prompt(
|
|
||||||
'Which caching backend?',
|
|
||||||
default = ctx.obj.config.ca_type,
|
|
||||||
type = click.Choice(['database', 'redis'], case_sensitive = False)
|
|
||||||
)
|
|
||||||
|
|
||||||
if ctx.obj.config.ca_type == 'redis':
|
|
||||||
ctx.obj.config.rd_host = click.prompt(
|
|
||||||
'What IP address, hostname, or unix socket does the server listen on?',
|
|
||||||
default = ctx.obj.config.rd_host
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx.obj.config.rd_port = click.prompt(
|
|
||||||
'What port does the server listen on?',
|
|
||||||
default = ctx.obj.config.rd_port,
|
|
||||||
type = int
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx.obj.config.rd_user = click.prompt(
|
|
||||||
'Which user will authenticate with the server',
|
|
||||||
default = ctx.obj.config.rd_user
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx.obj.config.rd_pass = click.prompt(
|
|
||||||
'User password',
|
|
||||||
hide_input = True,
|
|
||||||
show_default = False,
|
|
||||||
default = ctx.obj.config.rd_pass or ""
|
|
||||||
) or None
|
|
||||||
|
|
||||||
ctx.obj.config.rd_database = click.prompt(
|
|
||||||
'Which database number to use?',
|
|
||||||
default = ctx.obj.config.rd_database,
|
|
||||||
type = int
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx.obj.config.rd_prefix = click.prompt(
|
|
||||||
'What text should each cache key be prefixed with?',
|
|
||||||
default = ctx.obj.config.rd_database,
|
|
||||||
type = check_alphanumeric
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx.obj.config.save()
|
ctx.obj.config.save()
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
|
@ -205,9 +150,8 @@ def cli_setup(ctx: click.Context) -> None:
|
||||||
|
|
||||||
|
|
||||||
@cli.command('run')
|
@cli.command('run')
|
||||||
@click.option('--dev', '-d', is_flag = True, help = 'Enable worker reloading on code change')
|
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def cli_run(ctx: click.Context, dev: bool = False) -> None:
|
def cli_run(ctx: click.Context) -> None:
|
||||||
'Run the relay'
|
'Run the relay'
|
||||||
|
|
||||||
if ctx.obj.config.domain.endswith('example.com') or not ctx.obj.signer:
|
if ctx.obj.config.domain.endswith('example.com') or not ctx.obj.signer:
|
||||||
|
@ -234,7 +178,11 @@ def cli_run(ctx: click.Context, dev: bool = False) -> None:
|
||||||
click.echo(pip_command)
|
click.echo(pip_command)
|
||||||
return
|
return
|
||||||
|
|
||||||
ctx.obj.run(dev)
|
if not check_open_port(ctx.obj.config.listen, ctx.obj.config.port):
|
||||||
|
click.echo(f'Error: A server is already running on port {ctx.obj.config.port}')
|
||||||
|
return
|
||||||
|
|
||||||
|
ctx.obj.run()
|
||||||
|
|
||||||
|
|
||||||
@cli.command('convert')
|
@cli.command('convert')
|
||||||
|
@ -416,7 +364,7 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
|
||||||
actor = actor
|
actor = actor
|
||||||
)
|
)
|
||||||
|
|
||||||
asyncio.run(http.post(inbox, message, None, inbox_data))
|
asyncio.run(http.post(inbox, message))
|
||||||
click.echo(f'Sent follow message to actor: {actor}')
|
click.echo(f'Sent follow message to actor: {actor}')
|
||||||
|
|
||||||
|
|
||||||
|
@ -457,7 +405,7 @@ def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
asyncio.run(http.post(inbox, message, inbox_data))
|
asyncio.run(http.post(inbox, message))
|
||||||
click.echo(f'Sent unfollow message to: {actor}')
|
click.echo(f'Sent unfollow message to: {actor}')
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -14,11 +14,9 @@ from functools import cached_property
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if typing.TYPE_CHECKING:
|
||||||
from collections.abc import Awaitable, Coroutine, Generator
|
from collections.abc import Coroutine, Generator
|
||||||
from tinysql import Connection
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from .application import Application
|
from .application import Application
|
||||||
from .cache import Cache
|
|
||||||
from .config import Config
|
from .config import Config
|
||||||
from .database import Database
|
from .database import Database
|
||||||
from .http_client import HttpClient
|
from .http_client import HttpClient
|
||||||
|
@ -236,21 +234,13 @@ class Response(AiohttpResponse):
|
||||||
|
|
||||||
class View(AbstractView):
|
class View(AbstractView):
|
||||||
def __await__(self) -> Generator[Response]:
|
def __await__(self) -> Generator[Response]:
|
||||||
if self.request.method not in METHODS:
|
if (self.request.method) not in METHODS:
|
||||||
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
|
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
|
||||||
|
|
||||||
if not (handler := self.handlers.get(self.request.method)):
|
if not (handler := self.handlers.get(self.request.method)):
|
||||||
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) from None
|
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) from None
|
||||||
|
|
||||||
return self._run_handler(handler).__await__()
|
return handler(self.request, **self.request.match_info).__await__()
|
||||||
|
|
||||||
|
|
||||||
async def _run_handler(self, handler: Awaitable) -> Response:
|
|
||||||
with self.database.config.connection_class(self.database) as conn:
|
|
||||||
# todo: remove on next tinysql release
|
|
||||||
conn.open()
|
|
||||||
|
|
||||||
return await handler(self.request, conn, **self.request.match_info)
|
|
||||||
|
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
|
@ -278,11 +268,6 @@ class View(AbstractView):
|
||||||
return self.request.app
|
return self.request.app
|
||||||
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cache(self) -> Cache:
|
|
||||||
return self.app.cache
|
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self) -> HttpClient:
|
def client(self) -> HttpClient:
|
||||||
return self.app.client
|
return self.app.client
|
||||||
|
|
|
@ -1,15 +1,20 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import tinysql
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
|
from cachetools import LRUCache
|
||||||
|
|
||||||
from . import logger as logging
|
from . import logger as logging
|
||||||
from .database.connection import Connection
|
|
||||||
from .misc import Message
|
from .misc import Message
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if typing.TYPE_CHECKING:
|
||||||
from .views import ActorView
|
from .views import ActorView
|
||||||
|
|
||||||
|
|
||||||
|
cache = LRUCache(1024)
|
||||||
|
|
||||||
|
|
||||||
def person_check(actor: str, software: str) -> bool:
|
def person_check(actor: str, software: str) -> bool:
|
||||||
# pleroma and akkoma may use Person for the actor type for some reason
|
# pleroma and akkoma may use Person for the actor type for some reason
|
||||||
# akkoma changed this in 3.6.0
|
# akkoma changed this in 3.6.0
|
||||||
|
@ -23,87 +28,83 @@ def person_check(actor: str, software: str) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def handle_relay(view: ActorView, conn: Connection) -> None:
|
async def handle_relay(view: ActorView) -> None:
|
||||||
try:
|
if view.message.object_id in cache:
|
||||||
view.cache.get('handle-relay', view.message.object_id)
|
|
||||||
logging.verbose('already relayed %s', view.message.object_id)
|
logging.verbose('already relayed %s', view.message.object_id)
|
||||||
return
|
return
|
||||||
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
message = Message.new_announce(view.config.domain, view.message.object_id)
|
message = Message.new_announce(view.config.domain, view.message.object_id)
|
||||||
|
cache[view.message.object_id] = message.id
|
||||||
logging.debug('>> relay: %s', message)
|
logging.debug('>> relay: %s', message)
|
||||||
|
|
||||||
for inbox in conn.distill_inboxes(view.message):
|
with view.database.connection() as conn:
|
||||||
view.app.push_message(inbox, message, view.instance)
|
for inbox in conn.distill_inboxes(view.message):
|
||||||
|
view.app.push_message(inbox, message)
|
||||||
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_forward(view: ActorView, conn: Connection) -> None:
|
async def handle_forward(view: ActorView) -> None:
|
||||||
try:
|
if view.message.id in cache:
|
||||||
view.cache.get('handle-relay', view.message.object_id)
|
logging.verbose('already forwarded %s', view.message.id)
|
||||||
logging.verbose('already forwarded %s', view.message.object_id)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
message = Message.new_announce(view.config.domain, view.message)
|
message = Message.new_announce(view.config.domain, view.message)
|
||||||
|
cache[view.message.id] = message.id
|
||||||
logging.debug('>> forward: %s', message)
|
logging.debug('>> forward: %s', message)
|
||||||
|
|
||||||
for inbox in conn.distill_inboxes(view.message):
|
with view.database.connection() as conn:
|
||||||
view.app.push_message(inbox, message, view.instance)
|
for inbox in conn.distill_inboxes(view.message):
|
||||||
|
view.app.push_message(inbox, message)
|
||||||
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_follow(view: ActorView, conn: Connection) -> None:
|
async def handle_follow(view: ActorView) -> None:
|
||||||
nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain)
|
nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain)
|
||||||
software = nodeinfo.sw_name if nodeinfo else None
|
software = nodeinfo.sw_name if nodeinfo else None
|
||||||
|
|
||||||
# reject if software used by actor is banned
|
with view.database.connection() as conn:
|
||||||
if conn.get_software_ban(software):
|
# reject if software used by actor is banned
|
||||||
view.app.push_message(
|
if conn.get_software_ban(software):
|
||||||
view.actor.shared_inbox,
|
view.app.push_message(
|
||||||
Message.new_response(
|
view.actor.shared_inbox,
|
||||||
host = view.config.domain,
|
Message.new_response(
|
||||||
actor = view.actor.id,
|
host = view.config.domain,
|
||||||
followid = view.message.id,
|
actor = view.actor.id,
|
||||||
accept = False
|
followid = view.message.id,
|
||||||
|
accept = False
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
logging.verbose(
|
logging.verbose(
|
||||||
'Rejected follow from actor for using specific software: actor=%s, software=%s',
|
'Rejected follow from actor for using specific software: actor=%s, software=%s',
|
||||||
view.actor.id,
|
view.actor.id,
|
||||||
software
|
software
|
||||||
)
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
## reject if the actor is not an instance actor
|
|
||||||
if person_check(view.actor, software):
|
|
||||||
view.app.push_message(
|
|
||||||
view.actor.shared_inbox,
|
|
||||||
Message.new_response(
|
|
||||||
host = view.config.domain,
|
|
||||||
actor = view.actor.id,
|
|
||||||
followid = view.message.id,
|
|
||||||
accept = False
|
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
logging.verbose('Non-application actor tried to follow: %s', view.actor.id)
|
return
|
||||||
return
|
|
||||||
|
|
||||||
if conn.get_inbox(view.actor.shared_inbox):
|
## reject if the actor is not an instance actor
|
||||||
view.instance = conn.update_inbox(view.actor.shared_inbox, followid = view.message.id)
|
if person_check(view.actor, software):
|
||||||
|
view.app.push_message(
|
||||||
|
view.actor.shared_inbox,
|
||||||
|
Message.new_response(
|
||||||
|
host = view.config.domain,
|
||||||
|
actor = view.actor.id,
|
||||||
|
followid = view.message.id,
|
||||||
|
accept = False
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
logging.verbose('Non-application actor tried to follow: %s', view.actor.id)
|
||||||
with conn.transaction():
|
return
|
||||||
view.instance = conn.put_inbox(
|
|
||||||
|
if conn.get_inbox(view.actor.shared_inbox):
|
||||||
|
data = {'followid': view.message.id}
|
||||||
|
statement = tinysql.Update('inboxes', data, inbox = view.actor.shared_inbox)
|
||||||
|
|
||||||
|
with conn.query(statement):
|
||||||
|
pass
|
||||||
|
|
||||||
|
else:
|
||||||
|
conn.put_inbox(
|
||||||
view.actor.domain,
|
view.actor.domain,
|
||||||
view.actor.shared_inbox,
|
view.actor.shared_inbox,
|
||||||
view.actor.id,
|
view.actor.id,
|
||||||
|
@ -111,37 +112,35 @@ async def handle_follow(view: ActorView, conn: Connection) -> None:
|
||||||
software
|
software
|
||||||
)
|
)
|
||||||
|
|
||||||
view.app.push_message(
|
|
||||||
view.actor.shared_inbox,
|
|
||||||
Message.new_response(
|
|
||||||
host = view.config.domain,
|
|
||||||
actor = view.actor.id,
|
|
||||||
followid = view.message.id,
|
|
||||||
accept = True
|
|
||||||
),
|
|
||||||
view.instance
|
|
||||||
)
|
|
||||||
|
|
||||||
# Are Akkoma and Pleroma the only two that expect a follow back?
|
|
||||||
# Ignoring only Mastodon for now
|
|
||||||
if software != 'mastodon':
|
|
||||||
view.app.push_message(
|
view.app.push_message(
|
||||||
view.actor.shared_inbox,
|
view.actor.shared_inbox,
|
||||||
Message.new_follow(
|
Message.new_response(
|
||||||
host = view.config.domain,
|
host = view.config.domain,
|
||||||
actor = view.actor.id
|
actor = view.actor.id,
|
||||||
),
|
followid = view.message.id,
|
||||||
view.instance
|
accept = True
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Are Akkoma and Pleroma the only two that expect a follow back?
|
||||||
|
# Ignoring only Mastodon for now
|
||||||
|
if software != 'mastodon':
|
||||||
|
view.app.push_message(
|
||||||
|
view.actor.shared_inbox,
|
||||||
|
Message.new_follow(
|
||||||
|
host = view.config.domain,
|
||||||
|
actor = view.actor.id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
async def handle_undo(view: ActorView, conn: Connection) -> None:
|
|
||||||
|
async def handle_undo(view: ActorView) -> None:
|
||||||
## If the object is not a Follow, forward it
|
## If the object is not a Follow, forward it
|
||||||
if view.message.object['type'] != 'Follow':
|
if view.message.object['type'] != 'Follow':
|
||||||
await handle_forward(view, conn)
|
await handle_forward(view)
|
||||||
return
|
return
|
||||||
|
|
||||||
with conn.transaction():
|
with view.database.connection() as conn:
|
||||||
if not conn.del_inbox(view.actor.id):
|
if not conn.del_inbox(view.actor.id):
|
||||||
logging.verbose(
|
logging.verbose(
|
||||||
'Failed to delete "%s" with follow ID "%s"',
|
'Failed to delete "%s" with follow ID "%s"',
|
||||||
|
@ -155,8 +154,7 @@ async def handle_undo(view: ActorView, conn: Connection) -> None:
|
||||||
host = view.config.domain,
|
host = view.config.domain,
|
||||||
actor = view.actor.id,
|
actor = view.actor.id,
|
||||||
follow = view.message
|
follow = view.message
|
||||||
),
|
)
|
||||||
view.instance
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -170,7 +168,7 @@ processors = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def run_processor(view: ActorView, conn: Connection) -> None:
|
async def run_processor(view: ActorView) -> None:
|
||||||
if view.message.type not in processors:
|
if view.message.type not in processors:
|
||||||
logging.verbose(
|
logging.verbose(
|
||||||
'Message type "%s" from actor cannot be handled: %s',
|
'Message type "%s" from actor cannot be handled: %s',
|
||||||
|
@ -181,19 +179,20 @@ async def run_processor(view: ActorView, conn: Connection) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if view.instance:
|
if view.instance:
|
||||||
with conn.transaction():
|
if not view.instance['software']:
|
||||||
if not view.instance['software']:
|
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])):
|
||||||
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])):
|
with view.database.connection() as conn:
|
||||||
view.instance = conn.update_inbox(
|
view.instance = conn.update_inbox(
|
||||||
view.instance['inbox'],
|
view.instance['inbox'],
|
||||||
software = nodeinfo.sw_name
|
software = nodeinfo.sw_name
|
||||||
)
|
)
|
||||||
|
|
||||||
if not view.instance['actor']:
|
if not view.instance['actor']:
|
||||||
|
with view.database.connection() as conn:
|
||||||
view.instance = conn.update_inbox(
|
view.instance = conn.update_inbox(
|
||||||
view.instance['inbox'],
|
view.instance['inbox'],
|
||||||
actor = view.actor.id
|
actor = view.actor.id
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id)
|
logging.verbose('New "%s" from actor: %s', view.message.type, view.actor.id)
|
||||||
await processors[view.message.type](view, conn)
|
await processors[view.message.type](view)
|
||||||
|
|
101
relay/views.py
101
relay/views.py
|
@ -1,5 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import subprocess
|
import subprocess
|
||||||
import traceback
|
import traceback
|
||||||
import typing
|
import typing
|
||||||
|
@ -11,7 +12,6 @@ from pathlib import Path
|
||||||
|
|
||||||
from . import __version__
|
from . import __version__
|
||||||
from . import logger as logging
|
from . import logger as logging
|
||||||
from .database.connection import Connection
|
|
||||||
from .misc import Message, Response, View
|
from .misc import Message, Response, View
|
||||||
from .processors import run_processor
|
from .processors import run_processor
|
||||||
|
|
||||||
|
@ -19,7 +19,6 @@ if typing.TYPE_CHECKING:
|
||||||
from aiohttp.web import Request
|
from aiohttp.web import Request
|
||||||
from aputils.signer import Signer
|
from aputils.signer import Signer
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from tinysql import Row
|
|
||||||
|
|
||||||
|
|
||||||
VIEWS = []
|
VIEWS = []
|
||||||
|
@ -75,16 +74,17 @@ def register_route(*paths: str) -> Callable:
|
||||||
|
|
||||||
@register_route('/')
|
@register_route('/')
|
||||||
class HomeView(View):
|
class HomeView(View):
|
||||||
async def get(self, request: Request, conn: Connection) -> Response:
|
async def get(self, request: Request) -> Response:
|
||||||
config = conn.get_config_all()
|
with self.database.connection() as conn:
|
||||||
inboxes = conn.execute('SELECT * FROM inboxes').all()
|
config = conn.get_config_all()
|
||||||
|
inboxes = conn.execute('SELECT * FROM inboxes').all()
|
||||||
|
|
||||||
text = HOME_TEMPLATE.format(
|
text = HOME_TEMPLATE.format(
|
||||||
host = self.config.domain,
|
host = self.config.domain,
|
||||||
note = config['note'],
|
note = config['note'],
|
||||||
count = len(inboxes),
|
count = len(inboxes),
|
||||||
targets = '<br>'.join(inbox['domain'] for inbox in inboxes)
|
targets = '<br>'.join(inbox['domain'] for inbox in inboxes)
|
||||||
)
|
)
|
||||||
|
|
||||||
return Response.new(text, ctype='html')
|
return Response.new(text, ctype='html')
|
||||||
|
|
||||||
|
@ -98,11 +98,11 @@ class ActorView(View):
|
||||||
self.signature: Signature = None
|
self.signature: Signature = None
|
||||||
self.message: Message = None
|
self.message: Message = None
|
||||||
self.actor: Message = None
|
self.actor: Message = None
|
||||||
self.instance: Row = None
|
self.instance: dict[str, str] = None
|
||||||
self.signer: Signer = None
|
self.signer: Signer = None
|
||||||
|
|
||||||
|
|
||||||
async def get(self, request: Request, conn: Connection) -> Response:
|
async def get(self, request: Request) -> Response:
|
||||||
data = Message.new_actor(
|
data = Message.new_actor(
|
||||||
host = self.config.domain,
|
host = self.config.domain,
|
||||||
pubkey = self.app.signer.pubkey
|
pubkey = self.app.signer.pubkey
|
||||||
|
@ -111,36 +111,37 @@ class ActorView(View):
|
||||||
return Response.new(data, ctype='activity')
|
return Response.new(data, ctype='activity')
|
||||||
|
|
||||||
|
|
||||||
async def post(self, request: Request, conn: Connection) -> Response:
|
async def post(self, request: Request) -> Response:
|
||||||
if response := await self.get_post_data():
|
if response := await self.get_post_data():
|
||||||
return response
|
return response
|
||||||
|
|
||||||
self.instance = conn.get_inbox(self.actor.shared_inbox)
|
with self.database.connection() as conn:
|
||||||
config = conn.get_config_all()
|
self.instance = conn.get_inbox(self.actor.shared_inbox)
|
||||||
|
config = conn.get_config_all()
|
||||||
|
|
||||||
## reject if the actor isn't whitelisted while the whiltelist is enabled
|
## reject if the actor isn't whitelisted while the whiltelist is enabled
|
||||||
if config['whitelist-enabled'] and not conn.get_domain_whitelist(self.actor.domain):
|
if config['whitelist-enabled'] and not conn.get_domain_whitelist(self.actor.domain):
|
||||||
logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id)
|
logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id)
|
||||||
return Response.new_error(403, 'access denied', 'json')
|
return Response.new_error(403, 'access denied', 'json')
|
||||||
|
|
||||||
## reject if actor is banned
|
## reject if actor is banned
|
||||||
if conn.get_domain_ban(self.actor.domain):
|
if conn.get_domain_ban(self.actor.domain):
|
||||||
logging.verbose('Ignored request from banned actor: %s', self.actor.id)
|
logging.verbose('Ignored request from banned actor: %s', self.actor.id)
|
||||||
return Response.new_error(403, 'access denied', 'json')
|
return Response.new_error(403, 'access denied', 'json')
|
||||||
|
|
||||||
## reject if activity type isn't 'Follow' and the actor isn't following
|
## reject if activity type isn't 'Follow' and the actor isn't following
|
||||||
if self.message.type != 'Follow' and not self.instance:
|
if self.message.type != 'Follow' and not self.instance:
|
||||||
logging.verbose(
|
logging.verbose(
|
||||||
'Rejected actor for trying to post while not following: %s',
|
'Rejected actor for trying to post while not following: %s',
|
||||||
self.actor.id
|
self.actor.id
|
||||||
)
|
)
|
||||||
|
|
||||||
return Response.new_error(401, 'access denied', 'json')
|
return Response.new_error(401, 'access denied', 'json')
|
||||||
|
|
||||||
logging.debug('>> payload %s', self.message.to_json(4))
|
logging.debug('>> payload %s', self.message.to_json(4))
|
||||||
|
|
||||||
await run_processor(self, conn)
|
asyncio.ensure_future(run_processor(self))
|
||||||
return Response.new(status = 202)
|
return Response.new(status = 202)
|
||||||
|
|
||||||
|
|
||||||
async def get_post_data(self) -> Response | None:
|
async def get_post_data(self) -> Response | None:
|
||||||
|
@ -167,11 +168,7 @@ class ActorView(View):
|
||||||
logging.verbose('actor not in message')
|
logging.verbose('actor not in message')
|
||||||
return Response.new_error(400, 'no actor in message', 'json')
|
return Response.new_error(400, 'no actor in message', 'json')
|
||||||
|
|
||||||
self.actor = await self.client.get(
|
self.actor = await self.client.get(self.signature.keyid, sign_headers = True)
|
||||||
self.signature.keyid,
|
|
||||||
sign_headers = True,
|
|
||||||
loads = Message.parse
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.actor:
|
if not self.actor:
|
||||||
# ld signatures aren't handled atm, so just ignore it
|
# ld signatures aren't handled atm, so just ignore it
|
||||||
|
@ -230,7 +227,7 @@ class ActorView(View):
|
||||||
|
|
||||||
@register_route('/.well-known/webfinger')
|
@register_route('/.well-known/webfinger')
|
||||||
class WebfingerView(View):
|
class WebfingerView(View):
|
||||||
async def get(self, request: Request, conn: Connection) -> Response:
|
async def get(self, request: Request) -> Response:
|
||||||
try:
|
try:
|
||||||
subject = request.query['resource']
|
subject = request.query['resource']
|
||||||
|
|
||||||
|
@ -251,18 +248,18 @@ class WebfingerView(View):
|
||||||
|
|
||||||
@register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}')
|
@register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}')
|
||||||
class NodeinfoView(View):
|
class NodeinfoView(View):
|
||||||
# pylint: disable=no-self-use
|
async def get(self, request: Request, niversion: str) -> Response:
|
||||||
async def get(self, request: Request, conn: Connection, niversion: str) -> Response:
|
with self.database.connection() as conn:
|
||||||
inboxes = conn.execute('SELECT * FROM inboxes').all()
|
inboxes = conn.execute('SELECT * FROM inboxes').all()
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
'name': 'activityrelay',
|
'name': 'activityrelay',
|
||||||
'version': VERSION,
|
'version': VERSION,
|
||||||
'protocols': ['activitypub'],
|
'protocols': ['activitypub'],
|
||||||
'open_regs': not conn.get_config('whitelist-enabled'),
|
'open_regs': not conn.get_config('whitelist-enabled'),
|
||||||
'users': 1,
|
'users': 1,
|
||||||
'metadata': {'peers': [inbox['domain'] for inbox in inboxes]}
|
'metadata': {'peers': [inbox['domain'] for inbox in inboxes]}
|
||||||
}
|
}
|
||||||
|
|
||||||
if niversion == '2.1':
|
if niversion == '2.1':
|
||||||
data['repo'] = 'https://git.pleroma.social/pleroma/relay'
|
data['repo'] = 'https://git.pleroma.social/pleroma/relay'
|
||||||
|
@ -272,6 +269,6 @@ class NodeinfoView(View):
|
||||||
|
|
||||||
@register_route('/.well-known/nodeinfo')
|
@register_route('/.well-known/nodeinfo')
|
||||||
class WellknownNodeinfoView(View):
|
class WellknownNodeinfoView(View):
|
||||||
async def get(self, request: Request, conn: Connection) -> Response:
|
async def get(self, request: Request) -> Response:
|
||||||
data = WellKnownNodeinfo.new_template(self.config.domain)
|
data = WellKnownNodeinfo.new_template(self.config.domain)
|
||||||
return Response.new(data, ctype = 'json')
|
return Response.new(data, ctype = 'json')
|
||||||
|
|
|
@ -1,10 +1,8 @@
|
||||||
aiohttp>=3.9.1
|
aiohttp>=3.9.1
|
||||||
aputils@https://git.barkshark.xyz/barkshark/aputils/archive/0.1.6a.tar.gz
|
aputils@https://git.barkshark.xyz/barkshark/aputils/archive/0.1.6a.tar.gz
|
||||||
|
cachetools>=5.2.0
|
||||||
click>=8.1.2
|
click>=8.1.2
|
||||||
gunicorn==21.1.0
|
|
||||||
hiredis==2.3.2
|
|
||||||
pyyaml>=6.0
|
pyyaml>=6.0
|
||||||
redis==5.0.1
|
tinysql[postgres]@https://git.barkshark.xyz/barkshark/tinysql/archive/0.2.3.tar.gz
|
||||||
tinysql[postgres]@https://git.barkshark.xyz/barkshark/tinysql/archive/0.2.4.tar.gz
|
|
||||||
|
|
||||||
importlib_resources==6.1.1;python_version<'3.9'
|
importlib_resources==6.1.1;python_version<'3.9'
|
||||||
|
|
Loading…
Reference in a new issue