Compare commits

..

No commits in common. "49917fcc4e40d2ce34b21c4aeaea5ebe5f261198" and "189ac887a98b6eec39314338bd9d4bcb5733a39b" have entirely different histories.

31 changed files with 687 additions and 640 deletions

2
.gitignore vendored
View file

@ -98,5 +98,3 @@ ENV/
*.yaml *.yaml
*.jsonld *.jsonld
*.sqlite3 *.sqlite3
test*.py

View file

@ -1,4 +1,4 @@
flake8 == 7.0.0 flake8 == 7.0.0
mypy == 1.9.0
pyinstaller == 6.3.0 pyinstaller == 6.3.0
pylint == 3.0
watchdog == 4.0.0 watchdog == 4.0.0

View file

@ -1,19 +1,15 @@
# Configuration # Configuration
## Config File ## General
These options are stored in the configuration file (usually relay.yaml) ### Domain
### General
#### Domain
Hostname the relay will be hosted on. Hostname the relay will be hosted on.
domain: relay.example.com domain: relay.example.com
#### Listener ### Listener
The address and port the relay will listen on. If the reverse proxy (nginx, apache, caddy, etc) The address and port the relay will listen on. If the reverse proxy (nginx, apache, caddy, etc)
is running on the same host, it is recommended to change `listen` to `localhost` if the reverse is running on the same host, it is recommended to change `listen` to `localhost` if the reverse
@ -23,7 +19,7 @@ proxy is on the same host.
port: 8080 port: 8080
#### Push Workers ### Push Workers
The number of processes to spawn for pushing messages to subscribed instances. Leave it at 0 to The number of processes to spawn for pushing messages to subscribed instances. Leave it at 0 to
automatically detect how many processes should be spawned. automatically detect how many processes should be spawned.
@ -31,21 +27,21 @@ automatically detect how many processes should be spawned.
workers: 0 workers: 0
#### Database type ### Database type
SQL database backend to use. Valid values are `sqlite` or `postgres`. SQL database backend to use. Valid values are `sqlite` or `postgres`.
database_type: sqlite database_type: sqlite
#### Cache type ### Cache type
Cache backend to use. Valid values are `database` or `redis` Cache backend to use. Valid values are `database` or `redis`
cache_type: database cache_type: database
#### Sqlite File Path ### Sqlite File Path
Path to the sqlite database file. If the path is not absolute, it is relative to the config file. Path to the sqlite database file. If the path is not absolute, it is relative to the config file.
directory. directory.
@ -53,7 +49,7 @@ directory.
sqlite_path: relay.jsonld sqlite_path: relay.jsonld
### Postgresql ## Postgresql
In order to use the Postgresql backend, the user and database need to be created first. In order to use the Postgresql backend, the user and database need to be created first.
@ -61,130 +57,80 @@ In order to use the Postgresql backend, the user and database need to be created
sudo -u postgres psql -c "CREATE DATABASE activityrelay OWNER activityrelay" sudo -u postgres psql -c "CREATE DATABASE activityrelay OWNER activityrelay"
#### Database Name ### Database Name
Name of the database to use. Name of the database to use.
name: activityrelay name: activityrelay
#### Host ### Host
Hostname, IP address, or unix socket the server is hosted on. Hostname, IP address, or unix socket the server is hosted on.
host: /var/run/postgresql host: /var/run/postgresql
#### Port ### Port
Port number the server is listening on. Port number the server is listening on.
port: 5432 port: 5432
#### Username ### Username
User to use when logging into the server. User to use when logging into the server.
user: null user: null
#### Password ### Password
Password for the specified user. Password for the specified user.
pass: null pass: null
### Redis ## Redis
#### Host ### Host
Hostname, IP address, or unix socket the server is hosted on. Hostname, IP address, or unix socket the server is hosted on.
host: /var/run/postgresql host: /var/run/postgresql
#### Port ### Port
Port number the server is listening on. Port number the server is listening on.
port: 5432 port: 5432
#### Username ### Username
User to use when logging into the server. User to use when logging into the server.
user: null user: null
#### Password ### Password
Password for the specified user. Password for the specified user.
pass: null pass: null
#### Database Number ### Database Number
Number of the database to use. Number of the database to use.
database: 0 database: 0
#### Prefix ### Prefix
Text to prefix every key with. It cannot contain a `:` character. Text to prefix every key with. It cannot contain a `:` character.
prefix: activityrelay prefix: activityrelay
## Database Config
These options are stored in the database and can be changed via CLI, API, or the web interface.
### Approval Required
When enabled, instances that try to follow the relay will have to be manually approved by an admin.
approval-required: false
### Log Level
Maximum level of messages to log.
Valid values: `DEBUG`, `VERBOSE`, `INFO`, `WARNING`, `ERROR`, `CRITICAL`
log-level: INFO
### Name
Name of your relay's instance. It will be displayed at the top of web pages and in API endpoints.
name: ActivityRelay
### Note
Short blurb that will be displayed on the relay's home and in API endpoints if set. Can be in
markdown format.
note: null
### Theme
Color theme to use for the web pages.
Valid values: `Default`, `Pink`, `Blue`
theme: Default
### Whitelist Enabled
When enabled, only instances on the whitelist can join. Any instances currently subscribed and not
in the whitelist when this is enabled can still post.
whitelist-enabled: False

View file

@ -3,13 +3,54 @@ requires = ["setuptools","wheel"]
build-backend = 'setuptools.build_meta' build-backend = 'setuptools.build_meta'
[tool.mypy] [tool.pylint.main]
show_traceback = true jobs = 0
install_types = true persistent = true
pretty = true load-plugins = [
disallow_untyped_decorators = true "pylint.extensions.code_style",
warn_redundant_casts = true "pylint.extensions.comparison_placement",
warn_unreachable = true "pylint.extensions.confusing_elif",
warn_unused_ignores = true "pylint.extensions.for_any_all",
ignore_missing_imports = true "pylint.extensions.consider_ternary_expression",
follow_imports = "silent" "pylint.extensions.bad_builtin",
"pylint.extensions.dict_init_mutate",
"pylint.extensions.check_elif",
"pylint.extensions.empty_comment",
"pylint.extensions.private_import",
"pylint.extensions.redefined_variable_type",
"pylint.extensions.no_self_use",
"pylint.extensions.overlapping_exceptions",
"pylint.extensions.set_membership",
"pylint.extensions.typing"
]
[tool.pylint.design]
max-args = 10
max-attributes = 100
[tool.pylint.format]
indent-str = "\t"
indent-after-paren = 1
max-line-length = 100
single-line-if-stmt = true
[tool.pylint.messages_control]
disable = [
"fixme",
"broad-exception-caught",
"cyclic-import",
"global-statement",
"invalid-name",
"missing-module-docstring",
"too-few-public-methods",
"too-many-public-methods",
"too-many-return-statements",
"wrong-import-order",
"missing-function-docstring",
"missing-class-docstring",
"consider-using-namedtuple-or-dataclass",
"confusing-consecutive-elif"
]

View file

@ -26,15 +26,16 @@ from .views.api import handle_api_path
from .views.frontend import handle_frontend_path from .views.frontend import handle_frontend_path
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from collections.abc import Callable from collections.abc import Coroutine
from bsql import Database, Row from tinysql import Database, Row
from .cache import Cache from .cache import Cache
from .misc import Message, Response from .misc import Message, Response
class Application(web.Application): # pylint: disable=unsubscriptable-object
DEFAULT: Application | None = None
class Application(web.Application):
DEFAULT: Application = None
def __init__(self, cfgpath: str | None, dev: bool = False): def __init__(self, cfgpath: str | None, dev: bool = False):
web.Application.__init__(self, web.Application.__init__(self,
@ -63,13 +64,14 @@ class Application(web.Application):
self['workers'] = [] self['workers'] = []
self.cache.setup() self.cache.setup()
self.on_cleanup.append(handle_cleanup) # type: ignore
# self.on_response_prepare.append(handle_access_log)
self.on_cleanup.append(handle_cleanup)
for path, view in VIEWS: for path, view in VIEWS:
self.router.add_view(path, view) self.router.add_view(path, view)
setup_swagger( setup_swagger(self,
self,
ui_version = 3, ui_version = 3,
swagger_from_file = get_resource('data/swagger.yaml') swagger_from_file = get_resource('data/swagger.yaml')
) )
@ -163,7 +165,6 @@ class Application(web.Application):
self.set_signal_handler(True) self.set_signal_handler(True)
self['client'].open()
self['database'].connect() self['database'].connect()
self['cache'].setup() self['cache'].setup()
self['cleanup_thread'] = CacheCleanupThread(self) self['cleanup_thread'] = CacheCleanupThread(self)
@ -178,8 +179,7 @@ class Application(web.Application):
runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"') runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"')
await runner.setup() await runner.setup()
site = web.TCPSite( site = web.TCPSite(runner,
runner,
host = self.config.listen, host = self.config.listen,
port = self.config.port, port = self.config.port,
reuse_address = True reuse_address = True
@ -193,7 +193,7 @@ class Application(web.Application):
await site.stop() await site.stop()
for worker in self['workers']: for worker in self['workers']: # pylint: disable=not-an-iterable
worker.stop() worker.stop()
self.set_signal_handler(False) self.set_signal_handler(False)
@ -247,7 +247,6 @@ class PushWorker(multiprocessing.Process):
async def handle_queue(self) -> None: async def handle_queue(self) -> None:
client = HttpClient() client = HttpClient()
client.open()
while not self.shutdown.is_set(): while not self.shutdown.is_set():
try: try:
@ -257,7 +256,7 @@ class PushWorker(multiprocessing.Process):
except Empty: except Empty:
pass pass
# make sure an exception doesn't bring down the worker ## make sure an exception doesn't bring down the worker
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
@ -265,7 +264,7 @@ class PushWorker(multiprocessing.Process):
@web.middleware @web.middleware
async def handle_response_headers(request: web.Request, handler: Callable) -> Response: async def handle_response_headers(request: web.Request, handler: Coroutine) -> Response:
resp = await handler(request) resp = await handler(request)
resp.headers['Server'] = 'ActivityRelay' resp.headers['Server'] = 'ActivityRelay'

View file

@ -13,16 +13,15 @@ from .database import get_database
from .misc import Message, boolean from .misc import Message, boolean
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from blib import Database
from collections.abc import Callable, Iterator
from typing import Any from typing import Any
from collections.abc import Callable, Iterator
from .application import Application from .application import Application
# todo: implement more caching backends # todo: implement more caching backends
BACKENDS: dict[str, type[Cache]] = {} BACKENDS: dict[str, Cache] = {}
CONVERTERS: dict[str, tuple[Callable, Callable]] = { CONVERTERS: dict[str, tuple[Callable, Callable]] = {
'str': (str, str), 'str': (str, str),
'int': (str, int), 'int': (str, int),
@ -72,7 +71,7 @@ class Item:
data.value = deserialize_value(data.value, data.value_type) data.value = deserialize_value(data.value, data.value_type)
if not isinstance(data.updated, datetime): if not isinstance(data.updated, datetime):
data.updated = datetime.fromtimestamp(data.updated, tz = timezone.utc) # type: ignore data.updated = datetime.fromtimestamp(data.updated, tz = timezone.utc)
return data return data
@ -144,7 +143,7 @@ class Cache(ABC):
item.namespace, item.namespace,
item.key, item.key,
item.value, item.value,
item.value_type item.type
) )
@ -159,7 +158,7 @@ class SqlCache(Cache):
def __init__(self, app: Application): def __init__(self, app: Application):
Cache.__init__(self, app) Cache.__init__(self, app)
self._db: Database = None self._db = None
def get(self, namespace: str, key: str) -> Item: def get(self, namespace: str, key: str) -> Item:
@ -258,7 +257,7 @@ class RedisCache(Cache):
def __init__(self, app: Application): def __init__(self, app: Application):
Cache.__init__(self, app) Cache.__init__(self, app)
self._rd: Redis = None # type: ignore self._rd = None
@property @property
@ -276,7 +275,7 @@ class RedisCache(Cache):
if not (raw_value := self._rd.get(key_name)): if not (raw_value := self._rd.get(key_name)):
raise KeyError(f'{namespace}:{key}') raise KeyError(f'{namespace}:{key}')
value_type, updated, value = raw_value.split(':', 2) # type: ignore value_type, updated, value = raw_value.split(':', 2)
return Item.from_data( return Item.from_data(
namespace, namespace,
key, key,
@ -303,7 +302,7 @@ class RedisCache(Cache):
yield namespace yield namespace
def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> Item: def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> None:
date = datetime.now(tz = timezone.utc).timestamp() date = datetime.now(tz = timezone.utc).timestamp()
value = serialize_value(value, value_type) value = serialize_value(value, value_type)
@ -312,8 +311,6 @@ class RedisCache(Cache):
f'{value_type}:{date}:{value}' f'{value_type}:{date}:{value}'
) )
return self.get(namespace, key)
def delete(self, namespace: str, key: str) -> None: def delete(self, namespace: str, key: str) -> None:
self._rd.delete(self.get_key_name(namespace, key)) self._rd.delete(self.get_key_name(namespace, key))
@ -353,7 +350,7 @@ class RedisCache(Cache):
options['host'] = self.app.config.rd_host options['host'] = self.app.config.rd_host
options['port'] = self.app.config.rd_port options['port'] = self.app.config.rd_port
self._rd = Redis(**options) # type: ignore self._rd = Redis(**options)
def close(self) -> None: def close(self) -> None:
@ -361,4 +358,4 @@ class RedisCache(Cache):
return return
self._rd.close() self._rd.close()
self._rd = None # type: ignore self._rd = None

View file

@ -9,12 +9,16 @@ from functools import cached_property
from pathlib import Path from pathlib import Path
from urllib.parse import urlparse from urllib.parse import urlparse
from .misc import boolean from . import logger as logging
from .misc import Message, boolean
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from collections.abc import Iterator
from typing import Any from typing import Any
# pylint: disable=duplicate-code
class RelayConfig(dict): class RelayConfig(dict):
def __init__(self, path: str): def __init__(self, path: str):
dict.__init__(self, {}) dict.__init__(self, {})
@ -42,7 +46,7 @@ class RelayConfig(dict):
@property @property
def db(self) -> Path: def db(self) -> RelayDatabase:
return Path(self['db']).expanduser().resolve() return Path(self['db']).expanduser().resolve()
@ -180,3 +184,121 @@ class RelayDatabase(dict):
except json.decoder.JSONDecodeError as e: except json.decoder.JSONDecodeError as e:
if self.config.db.stat().st_size > 0: if self.config.db.stat().st_size > 0:
raise e from None raise e from None
def save(self) -> None:
with self.config.db.open('w', encoding = 'UTF-8') as fd:
json.dump(self, fd, indent=4)
def get_inbox(self, domain: str, fail: bool = False) -> dict[str, str] | None:
if domain.startswith('http'):
domain = urlparse(domain).hostname
if (inbox := self['relay-list'].get(domain)):
return inbox
if fail:
raise KeyError(domain)
return None
def add_inbox(self,
inbox: str,
followid: str | None = None,
software: str | None = None) -> dict[str, str]:
assert inbox.startswith('https'), 'Inbox must be a url'
domain = urlparse(inbox).hostname
if (instance := self.get_inbox(domain)):
if followid:
instance['followid'] = followid
if software:
instance['software'] = software
return instance
self['relay-list'][domain] = {
'domain': domain,
'inbox': inbox,
'followid': followid,
'software': software
}
logging.verbose('Added inbox to database: %s', inbox)
return self['relay-list'][domain]
def del_inbox(self,
domain: str,
followid: str = None,
fail: bool = False) -> bool:
if not (data := self.get_inbox(domain, fail=False)):
if fail:
raise KeyError(domain)
return False
if not data['followid'] or not followid or data['followid'] == followid:
del self['relay-list'][data['domain']]
logging.verbose('Removed inbox from database: %s', data['inbox'])
return True
if fail:
raise ValueError('Follow IDs do not match')
logging.debug('Follow ID does not match: db = %s, object = %s', data['followid'], followid)
return False
def get_request(self, domain: str, fail: bool = True) -> dict[str, str] | None:
if domain.startswith('http'):
domain = urlparse(domain).hostname
try:
return self['follow-requests'][domain]
except KeyError as e:
if fail:
raise e
return None
def add_request(self, actor: str, inbox: str, followid: str) -> None:
domain = urlparse(inbox).hostname
try:
request = self.get_request(domain)
request['followid'] = followid
except KeyError:
pass
self['follow-requests'][domain] = {
'actor': actor,
'inbox': inbox,
'followid': followid
}
def del_request(self, domain: str) -> None:
if domain.startswith('http'):
domain = urlparse(domain).hostname
del self['follow-requests'][domain]
def distill_inboxes(self, message: Message) -> Iterator[str]:
src_domains = {
message.domain,
urlparse(message.object_id).netloc
}
for domain, instance in self['relay-list'].items():
if domain not in src_domains:
yield instance['inbox']

View file

@ -6,14 +6,13 @@ import platform
import typing import typing
import yaml import yaml
from dataclasses import asdict, dataclass, fields
from pathlib import Path from pathlib import Path
from platformdirs import user_config_dir from platformdirs import user_config_dir
from .misc import IS_DOCKER from .misc import IS_DOCKER
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from typing import Any, Self from typing import Any
if platform.system() == 'Windows': if platform.system() == 'Windows':
@ -24,44 +23,61 @@ else:
CORE_COUNT = len(os.sched_getaffinity(0)) CORE_COUNT = len(os.sched_getaffinity(0))
DOCKER_VALUES = { DEFAULTS: dict[str, Any] = {
'listen': '0.0.0.0', 'listen': '0.0.0.0',
'port': 8080, 'port': 8080,
'sq_path': '/data/relay.jsonld' 'domain': 'relay.example.com',
'workers': CORE_COUNT,
'db_type': 'sqlite',
'ca_type': 'database',
'sq_path': 'relay.sqlite3',
'pg_host': '/var/run/postgresql',
'pg_port': 5432,
'pg_user': getpass.getuser(),
'pg_pass': None,
'pg_name': 'activityrelay',
'rd_host': 'localhost',
'rd_port': 6379,
'rd_user': None,
'rd_pass': None,
'rd_database': 0,
'rd_prefix': 'activityrelay'
} }
if IS_DOCKER:
class NOVALUE: DEFAULTS['sq_path'] = '/data/relay.jsonld'
pass
@dataclass(init = False)
class Config: class Config:
listen: str = '0.0.0.0' def __init__(self, path: str, load: bool = False):
port: int = 8080 if path:
domain: str = 'relay.example.com' self.path = Path(path).expanduser().resolve()
workers: int = CORE_COUNT
db_type: str = 'sqlite'
ca_type: str = 'database'
sq_path: str = 'relay.sqlite3'
pg_host: str = '/var/run/postgresql' else:
pg_port: int = 5432 self.path = Config.get_config_dir()
pg_user: str = getpass.getuser()
pg_pass: str | None = None
pg_name: str = 'activityrelay'
rd_host: str = 'localhost' self.listen = None
rd_port: int = 6470 self.port = None
rd_user: str | None = None self.domain = None
rd_pass: str | None = None self.workers = None
rd_database: int = 0 self.db_type = None
rd_prefix: str = 'activityrelay' self.ca_type = None
self.sq_path = None
self.pg_host = None
self.pg_port = None
self.pg_user = None
self.pg_pass = None
self.pg_name = None
def __init__(self, path: str | None = None, load: bool = False): self.rd_host = None
self.path = Config.get_config_dir(path) self.rd_port = None
self.reset() self.rd_user = None
self.rd_pass = None
self.rd_database = None
self.rd_prefix = None
if load: if load:
try: try:
@ -71,36 +87,22 @@ class Config:
self.save() self.save()
@classmethod
def KEYS(cls: type[Self]) -> list[str]:
return list(cls.__dataclass_fields__)
@classmethod
def DEFAULT(cls: type[Self], key: str) -> str | int | None:
for field in fields(cls):
if field.name == key:
return field.default # type: ignore
raise KeyError(key)
@staticmethod @staticmethod
def get_config_dir(path: str | None = None) -> Path: def get_config_dir(path: str | None = None) -> Path:
if path: if path:
return Path(path).expanduser().resolve() return Path(path).expanduser().resolve()
paths = ( dirs = (
Path("relay.yaml").resolve(), Path("relay.yaml").resolve(),
Path(user_config_dir("activityrelay"), "relay.yaml"), Path(user_config_dir("activityrelay"), "relay.yaml"),
Path("/etc/activityrelay/relay.yaml") Path("/etc/activityrelay/relay.yaml")
) )
for cfgfile in paths: for directory in dirs:
if cfgfile.exists(): if directory.exists():
return cfgfile return directory
return paths[0] return dirs[0]
@property @property
@ -128,6 +130,7 @@ class Config:
def load(self) -> None: def load(self) -> None:
self.reset() self.reset()
options = {} options = {}
try: try:
@ -138,85 +141,95 @@ class Config:
with self.path.open('r', encoding = 'UTF-8') as fd: with self.path.open('r', encoding = 'UTF-8') as fd:
config = yaml.load(fd, **options) config = yaml.load(fd, **options)
pgcfg = config.get('postgresql', {})
rdcfg = config.get('redis', {})
if not config: if not config:
raise ValueError('Config is empty') raise ValueError('Config is empty')
pgcfg = config.get('postgresql', {}) if IS_DOCKER:
rdcfg = config.get('redis', {}) self.listen = '0.0.0.0'
self.port = 8080
self.sq_path = '/data/relay.jsonld'
for key in type(self).KEYS(): else:
if IS_DOCKER and key in {'listen', 'port', 'sq_path'}: self.set('listen', config.get('listen', DEFAULTS['listen']))
self.set(key, DOCKER_VALUES[key]) self.set('port', config.get('port', DEFAULTS['port']))
continue self.set('sq_path', config.get('sqlite_path', DEFAULTS['sq_path']))
self.set('workers', config.get('workers', DEFAULTS['workers']))
self.set('domain', config.get('domain', DEFAULTS['domain']))
self.set('db_type', config.get('database_type', DEFAULTS['db_type']))
self.set('ca_type', config.get('cache_type', DEFAULTS['ca_type']))
for key in DEFAULTS:
if key.startswith('pg'): if key.startswith('pg'):
self.set(key, pgcfg.get(key[3:], NOVALUE)) try:
continue self.set(key, pgcfg[key[3:]])
except KeyError:
continue
elif key.startswith('rd'): elif key.startswith('rd'):
self.set(key, rdcfg.get(key[3:], NOVALUE)) try:
continue self.set(key, rdcfg[key[3:]])
cfgkey = key except KeyError:
continue
if key == 'db_type':
cfgkey = 'database_type'
elif key == 'ca_type':
cfgkey = 'cache_type'
elif key == 'sq_path':
cfgkey = 'sqlite_path'
self.set(key, config.get(cfgkey, NOVALUE))
def reset(self) -> None: def reset(self) -> None:
for field in fields(self): for key, value in DEFAULTS.items():
setattr(self, field.name, field.default) setattr(self, key, value)
def save(self) -> None: def save(self) -> None:
self.path.parent.mkdir(exist_ok = True, parents = True) self.path.parent.mkdir(exist_ok = True, parents = True)
data: dict[str, Any] = {}
for key, value in asdict(self).items():
if key.startswith('pg_'):
if 'postgres' not in data:
data['postgres'] = {}
data['postgres'][key[3:]] = value
continue
if key.startswith('rd_'):
if 'redis' not in data:
data['redis'] = {}
data['redis'][key[3:]] = value
continue
if key == 'db_type':
key = 'database_type'
elif key == 'ca_type':
key = 'cache_type'
elif key == 'sq_path':
key = 'sqlite_path'
data[key] = value
with self.path.open('w', encoding = 'utf-8') as fd: with self.path.open('w', encoding = 'utf-8') as fd:
yaml.dump(data, fd, sort_keys = False) yaml.dump(self.to_dict(), fd, sort_keys = False)
def set(self, key: str, value: Any) -> None: def set(self, key: str, value: Any) -> None:
if key not in type(self).KEYS(): if key not in DEFAULTS:
raise KeyError(key) raise KeyError(key)
if value is NOVALUE: if key in {'port', 'pg_port', 'workers'} and not isinstance(value, int):
return if (value := int(value)) < 1:
if key == 'port':
value = 8080
elif key == 'pg_port':
value = 5432
elif key == 'workers':
value = len(os.sched_getaffinity(0))
setattr(self, key, value) setattr(self, key, value)
def to_dict(self) -> dict[str, Any]:
return {
'listen': self.listen,
'port': self.port,
'domain': self.domain,
'workers': self.workers,
'database_type': self.db_type,
'cache_type': self.ca_type,
'sqlite_path': self.sq_path,
'postgres': {
'host': self.pg_host,
'port': self.pg_port,
'user': self.pg_user,
'pass': self.pg_pass,
'name': self.pg_name
},
'redis': {
'host': self.rd_host,
'port': self.rd_port,
'user': self.rd_user,
'pass': self.rd_pass,
'database': self.rd_database,
'refix': self.rd_prefix
}
}

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import bsql import bsql
import typing import typing
from .config import THEMES, ConfigData from .config import CONFIG_DEFAULTS, THEMES, get_default_value
from .connection import RELAY_SOFTWARE, Connection from .connection import RELAY_SOFTWARE, Connection
from .schema import TABLES, VERSIONS, migrate_0 from .schema import TABLES, VERSIONS, migrate_0
@ -11,7 +11,7 @@ from .. import logger as logging
from ..misc import get_resource from ..misc import get_resource
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from ..config import Config from .config import Config
def get_database(config: Config, migrate: bool = True) -> bsql.Database: def get_database(config: Config, migrate: bool = True) -> bsql.Database:
@ -46,7 +46,7 @@ def get_database(config: Config, migrate: bool = True) -> bsql.Database:
migrate_0(conn) migrate_0(conn)
return db return db
if (schema_ver := conn.get_config('schema-version')) < ConfigData.DEFAULT('schema-version'): if (schema_ver := conn.get_config('schema-version')) < get_default_value('schema-version'):
logging.info("Migrating database from version '%i'", schema_ver) logging.info("Migrating database from version '%i'", schema_ver)
for ver, func in VERSIONS.items(): for ver, func in VERSIONS.items():

View file

@ -1,16 +1,14 @@
from __future__ import annotations from __future__ import annotations
import json
import typing import typing
from dataclasses import Field, asdict, dataclass, fields
from .. import logger as logging from .. import logger as logging
from ..misc import boolean from ..misc import boolean
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from bsql import Row from collections.abc import Callable
from collections.abc import Callable, Sequence from typing import Any
from typing import Any, Self
THEMES = { THEMES = {
@ -61,101 +59,40 @@ THEMES = {
} }
} }
CONFIG_DEFAULTS: dict[str, tuple[str, Any]] = {
'schema-version': ('int', 20240310),
'private-key': ('str', None),
'approval-required': ('bool', False),
'log-level': ('loglevel', logging.LogLevel.INFO),
'name': ('str', 'ActivityRelay'),
'note': ('str', 'Make a note about your instance here.'),
'theme': ('str', 'default'),
'whitelist-enabled': ('bool', False)
}
# serializer | deserializer # serializer | deserializer
CONFIG_CONVERT: dict[str, tuple[Callable[[Any], str], Callable[[str], Any]]] = { CONFIG_CONVERT: dict[str, tuple[Callable, Callable]] = {
'str': (str, str), 'str': (str, str),
'int': (str, int), 'int': (str, int),
'bool': (str, boolean), 'bool': (str, boolean),
'logging.LogLevel': (lambda x: x.name, logging.LogLevel.parse) 'json': (json.dumps, json.loads),
'loglevel': (lambda x: x.name, logging.LogLevel.parse)
} }
@dataclass() def get_default_value(key: str) -> Any:
class ConfigData: return CONFIG_DEFAULTS[key][1]
schema_version: int = 20240310
private_key: str = ''
approval_required: bool = False
log_level: logging.LogLevel = logging.LogLevel.INFO
name: str = 'ActivityRelay'
note: str = ''
theme: str = 'default'
whitelist_enabled: bool = False
def __getitem__(self, key: str) -> Any: def get_default_type(key: str) -> str:
if (value := getattr(self, key.replace('-', '_'), None)) is None: return CONFIG_DEFAULTS[key][0]
raise KeyError(key)
return value
def __setitem__(self, key: str, value: Any) -> None: def serialize(key: str, value: Any) -> str:
self.set(key, value) type_name = get_default_type(key)
return CONFIG_CONVERT[type_name][0](value)
@classmethod def deserialize(key: str, value: str) -> Any:
def KEYS(cls: type[Self]) -> Sequence[str]: type_name = get_default_type(key)
return list(cls.__dataclass_fields__) return CONFIG_CONVERT[type_name][1](value)
@staticmethod
def SYSTEM_KEYS() -> Sequence[str]:
return ('schema-version', 'schema_version', 'private-key', 'private_key')
@classmethod
def USER_KEYS(cls: type[Self]) -> Sequence[str]:
return tuple(key for key in cls.KEYS() if key not in cls.SYSTEM_KEYS())
@classmethod
def DEFAULT(cls: type[Self], key: str) -> str | int | bool:
return cls.FIELD(key.replace('-', '_')).default # type: ignore
@classmethod
def FIELD(cls: type[Self], key: str) -> Field:
for field in fields(cls):
if field.name == key.replace('-', '_'):
return field
raise KeyError(key)
@classmethod
def from_rows(cls: type[Self], rows: Sequence[Row]) -> Self:
data = cls()
set_schema_version = False
for row in rows:
data.set(row['key'], row['value'])
if row['key'] == 'schema-version':
set_schema_version = True
if not set_schema_version:
data.schema_version = 0
return data
def get(self, key: str, default: Any = None, serialize: bool = False) -> Any:
field = type(self).FIELD(key)
value = getattr(self, field.name, None)
if not serialize:
return value
converter = CONFIG_CONVERT[str(field.type)][0]
return converter(value)
def set(self, key: str, value: Any) -> None:
field = type(self).FIELD(key)
converter = CONFIG_CONVERT[str(field.type)][1]
setattr(self, field.name, converter(value))
def to_dict(self) -> dict[str, Any]:
return {key.replace('_', '-'): value for key, value in asdict(self).items()}

View file

@ -9,18 +9,22 @@ from urllib.parse import urlparse
from uuid import uuid4 from uuid import uuid4
from .config import ( from .config import (
CONFIG_DEFAULTS,
THEMES, THEMES,
ConfigData get_default_type,
get_default_value,
serialize,
deserialize
) )
from .. import logger as logging from .. import logger as logging
from ..misc import boolean, get_app from ..misc import boolean, get_app
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from collections.abc import Iterator, Sequence from collections.abc import Iterator
from bsql import Row from bsql import Row
from typing import Any from typing import Any
from ..application import Application from .application import Application
from ..misc import Message from ..misc import Message
@ -54,57 +58,73 @@ class Connection(SqlConnection):
def get_config(self, key: str) -> Any: def get_config(self, key: str) -> Any:
if key not in CONFIG_DEFAULTS:
raise KeyError(key)
with self.run('get-config', {'key': key}) as cur: with self.run('get-config', {'key': key}) as cur:
if not (row := cur.one()): if not (row := cur.one()):
return ConfigData.DEFAULT(key) return get_default_value(key)
data = ConfigData() if row['value']:
data.set(row['key'], row['value']) return deserialize(row['key'], row['value'])
return data.get(key)
return None
def get_config_all(self) -> ConfigData: def get_config_all(self) -> dict[str, Any]:
with self.run('get-config-all', None) as cur: with self.run('get-config-all', None) as cur:
return ConfigData.from_rows(tuple(cur.all())) db_config = {row['key']: row['value'] for row in cur}
config = {}
for key, data in CONFIG_DEFAULTS.items():
try:
config[key] = deserialize(key, db_config[key])
except KeyError:
if key == 'schema-version':
config[key] = 0
else:
config[key] = data[1]
return config
def put_config(self, key: str, value: Any) -> Any: def put_config(self, key: str, value: Any) -> Any:
field = ConfigData.FIELD(key) if key not in CONFIG_DEFAULTS:
key = field.name.replace('_', '-') raise KeyError(key)
if key == 'private_key': if key == 'private-key':
self.app.signer = value self.app.signer = value
elif key == 'log_level': elif key == 'log-level':
value = logging.LogLevel.parse(value) value = logging.LogLevel.parse(value)
logging.set_level(value) logging.set_level(value)
elif key in {'approval-required', 'whitelist-enabled'}: elif key == 'whitelist-enabled':
value = boolean(value) value = boolean(value)
elif key == 'theme': elif key == 'theme':
if value not in THEMES: if value not in THEMES:
raise ValueError(f'"{value}" is not a valid theme') raise ValueError(f'"{value}" is not a valid theme')
data = ConfigData()
data.set(key, value)
params = { params = {
'key': key, 'key': key,
'value': data.get(key, serialize = True), 'value': serialize(key, value) if value is not None else None,
'type': 'LogLevel' if field.type == 'logging.LogLevel' else field.type 'type': get_default_type(key)
} }
with self.run('put-config', params): with self.run('put-config', params):
pass return value
def get_inbox(self, value: str) -> Row: def get_inbox(self, value: str) -> Row:
with self.run('get-inbox', {'value': value}) as cur: with self.run('get-inbox', {'value': value}) as cur:
return cur.one() # type: ignore return cur.one()
def get_inboxes(self) -> Sequence[Row]: def get_inboxes(self) -> tuple[Row]:
with self.execute("SELECT * FROM inboxes WHERE accepted = 1") as cur: with self.execute("SELECT * FROM inboxes WHERE accepted = 1") as cur:
return tuple(cur.all()) return tuple(cur.all())
@ -117,7 +137,7 @@ class Connection(SqlConnection):
software: str | None = None, software: str | None = None,
accepted: bool = True) -> Row: accepted: bool = True) -> Row:
params: dict[str, Any] = { params = {
'inbox': inbox, 'inbox': inbox,
'actor': actor, 'actor': actor,
'followid': followid, 'followid': followid,
@ -133,14 +153,14 @@ class Connection(SqlConnection):
params['created'] = datetime.now(tz = timezone.utc) params['created'] = datetime.now(tz = timezone.utc)
with self.run('put-inbox', params) as cur: with self.run('put-inbox', params) as cur:
return cur.one() # type: ignore return cur.one()
for key, value in tuple(params.items()): for key, value in tuple(params.items()):
if value is None: if value is None:
del params[key] del params[key]
with self.update('inboxes', params, domain = domain) as cur: with self.update('inboxes', params, domain = domain) as cur:
return cur.one() # type: ignore return cur.one()
def del_inbox(self, value: str) -> bool: def del_inbox(self, value: str) -> bool:
@ -159,7 +179,7 @@ class Connection(SqlConnection):
return row return row
def get_requests(self) -> Sequence[Row]: def get_requests(self) -> tuple[Row]:
with self.execute('SELECT * FROM inboxes WHERE accepted = 0') as cur: with self.execute('SELECT * FROM inboxes WHERE accepted = 0') as cur:
return tuple(cur.all()) return tuple(cur.all())
@ -177,17 +197,17 @@ class Connection(SqlConnection):
} }
with self.run('put-inbox-accept', params) as cur: with self.run('put-inbox-accept', params) as cur:
return cur.one() # type: ignore return cur.one()
def get_user(self, value: str) -> Row: def get_user(self, value: str) -> Row:
with self.run('get-user', {'value': value}) as cur: with self.run('get-user', {'value': value}) as cur:
return cur.one() # type: ignore return cur.one()
def get_user_by_token(self, code: str) -> Row: def get_user_by_token(self, code: str) -> Row:
with self.run('get-user-by-token', {'code': code}) as cur: with self.run('get-user-by-token', {'code': code}) as cur:
return cur.one() # type: ignore return cur.one()
def put_user(self, username: str, password: str, handle: str | None = None) -> Row: def put_user(self, username: str, password: str, handle: str | None = None) -> Row:
@ -199,7 +219,7 @@ class Connection(SqlConnection):
} }
with self.run('put-user', data) as cur: with self.run('put-user', data) as cur:
return cur.one() # type: ignore return cur.one()
def del_user(self, username: str) -> None: def del_user(self, username: str) -> None:
@ -214,7 +234,7 @@ class Connection(SqlConnection):
def get_token(self, code: str) -> Row: def get_token(self, code: str) -> Row:
with self.run('get-token', {'code': code}) as cur: with self.run('get-token', {'code': code}) as cur:
return cur.one() # type: ignore return cur.one()
def put_token(self, username: str) -> Row: def put_token(self, username: str) -> Row:
@ -225,7 +245,7 @@ class Connection(SqlConnection):
} }
with self.run('put-token', data) as cur: with self.run('put-token', data) as cur:
return cur.one() # type: ignore return cur.one()
def del_token(self, code: str) -> None: def del_token(self, code: str) -> None:
@ -238,7 +258,7 @@ class Connection(SqlConnection):
domain = urlparse(domain).netloc domain = urlparse(domain).netloc
with self.run('get-domain-ban', {'domain': domain}) as cur: with self.run('get-domain-ban', {'domain': domain}) as cur:
return cur.one() # type: ignore return cur.one()
def put_domain_ban(self, def put_domain_ban(self,
@ -254,7 +274,7 @@ class Connection(SqlConnection):
} }
with self.run('put-domain-ban', params) as cur: with self.run('put-domain-ban', params) as cur:
return cur.one() # type: ignore return cur.one()
def update_domain_ban(self, def update_domain_ban(self,
@ -293,7 +313,7 @@ class Connection(SqlConnection):
def get_software_ban(self, name: str) -> Row: def get_software_ban(self, name: str) -> Row:
with self.run('get-software-ban', {'name': name}) as cur: with self.run('get-software-ban', {'name': name}) as cur:
return cur.one() # type: ignore return cur.one()
def put_software_ban(self, def put_software_ban(self,
@ -309,7 +329,7 @@ class Connection(SqlConnection):
} }
with self.run('put-software-ban', params) as cur: with self.run('put-software-ban', params) as cur:
return cur.one() # type: ignore return cur.one()
def update_software_ban(self, def update_software_ban(self,
@ -348,7 +368,7 @@ class Connection(SqlConnection):
def get_domain_whitelist(self, domain: str) -> Row: def get_domain_whitelist(self, domain: str) -> Row:
with self.run('get-domain-whitelist', {'domain': domain}) as cur: with self.run('get-domain-whitelist', {'domain': domain}) as cur:
return cur.one() # type: ignore return cur.one()
def put_domain_whitelist(self, domain: str) -> Row: def put_domain_whitelist(self, domain: str) -> Row:
@ -358,7 +378,7 @@ class Connection(SqlConnection):
} }
with self.run('put-domain-whitelist', params) as cur: with self.run('put-domain-whitelist', params) as cur:
return cur.one() # type: ignore return cur.one()
def del_domain_whitelist(self, domain: str) -> bool: def del_domain_whitelist(self, domain: str) -> bool:

View file

@ -2,13 +2,12 @@ from __future__ import annotations
import typing import typing
from bsql import Column, Table, Tables from bsql import Column, Connection, Table, Tables
from .config import ConfigData from .config import get_default_value
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from collections.abc import Callable from collections.abc import Callable
from .connection import Connection
VERSIONS: dict[int, Callable] = {} VERSIONS: dict[int, Callable] = {}
@ -72,7 +71,7 @@ def migration(func: Callable) -> Callable:
def migrate_0(conn: Connection) -> None: def migrate_0(conn: Connection) -> None:
conn.create_tables() conn.create_tables()
conn.put_config('schema-version', ConfigData.DEFAULT('schema-version')) conn.put_config('schema-version', get_default_value('schema-version'))
@migration @migration

View file

@ -15,7 +15,7 @@ try:
from watchdog.events import PatternMatchingEventHandler from watchdog.events import PatternMatchingEventHandler
except ImportError: except ImportError:
class PatternMatchingEventHandler: # type: ignore class PatternMatchingEventHandler:
pass pass
@ -45,15 +45,9 @@ def cli_install():
@cli.command('lint') @cli.command('lint')
@click.argument('path', required = False, default = 'relay') @click.argument('path', required = False, default = 'relay')
@click.option('--strict', '-s', is_flag = True, help = 'Enable strict mode for mypy') def cli_lint(path):
def cli_lint(path: str, strict: bool) -> None: subprocess.run([sys.executable, '-m', 'flake8', path], check = False)
cmd: list[str] = [sys.executable, '-m', 'mypy'] subprocess.run([sys.executable, '-m', 'pylint', path], check = False)
if strict:
cmd.append('--strict')
subprocess.run([*cmd, path], check = False)
subprocess.run([sys.executable, '-m', 'flake8', path])
@cli.command('build') @cli.command('build')
@ -152,6 +146,7 @@ class WatchHandler(PatternMatchingEventHandler):
self.kill_proc() self.kill_proc()
# pylint: disable=consider-using-with
self.proc = subprocess.Popen(self.cmd, stdin = subprocess.PIPE) self.proc = subprocess.Popen(self.cmd, stdin = subprocess.PIPE)
self.last_restart = timestamp self.last_restart = timestamp

View file

@ -11,7 +11,7 @@
%title << {{config.name}}: {{page}} %title << {{config.name}}: {{page}}
%meta(charset="UTF-8") %meta(charset="UTF-8")
%meta(name="viewport" content="width=device-width, initial-scale=1") %meta(name="viewport" content="width=device-width, initial-scale=1")
%link(rel="stylesheet" type="text/css" href="/theme/{{config.theme}}.css") %link(rel="stylesheet" type="text/css" href="/theme/{{theme_name}}.css")
%link(rel="stylesheet" type="text/css" href="/style.css") %link(rel="stylesheet" type="text/css" href="/style.css")
-block head -block head

View file

@ -8,18 +8,18 @@
%input(id = "name" name="name" placeholder="Relay Name" value="{{config.name or ''}}") %input(id = "name" name="name" placeholder="Relay Name" value="{{config.name or ''}}")
%label(for="description") << Description %label(for="description") << Description
%textarea(id="description" name="note" value="{{config.note or ''}}") << {{config.note}} %textarea(id="description" name="note" value="{{config.note}}") << {{config.note}}
%label(for="theme") << Color Theme %label(for="theme") << Color Theme
=func.new_select("theme", config.theme, themes) =func.new_select("theme", config.theme, themes)
%label(for="log-level") << Log Level %label(for="log-level") << Log Level
=func.new_select("log-level", config.log_level.name, levels) =func.new_select("log-level", config["log-level"].name, levels)
%label(for="whitelist-enabled") << Whitelist %label(for="whitelist-enabled") << Whitelist
=func.new_checkbox("whitelist-enabled", config.whitelist_enabled) =func.new_checkbox("whitelist-enabled", config["whitelist-enabled"])
%label(for="approval-required") << Approval Required %label(for="approval-required") << Approval Required
=func.new_checkbox("approval-required", config.approval_required) =func.new_checkbox("approval-required", config["approval-required"])
%input(type="submit" value="Save") %input(type="submit" value="Save")

View file

@ -1,9 +1,8 @@
-extends "base.haml" -extends "base.haml"
-set page = "Home" -set page = "Home"
-block content -block content
-if config.note .section
.section -markdown -> =config.note
-markdown -> =config.note
.section .section
%p %p
@ -13,12 +12,12 @@
You may subscribe to this relay with the address: You may subscribe to this relay with the address:
%a(href="https://{{domain}}/actor") << https://{{domain}}/actor</a> %a(href="https://{{domain}}/actor") << https://{{domain}}/actor</a>
-if config.approval_required -if config["approval-required"]
%p.section.message %p.section.message
Follow requests require approval. You will need to wait for an admin to accept or deny Follow requests require approval. You will need to wait for an admin to accept or deny
your request. your request.
-elif config.whitelist_enabled -elif config["whitelist-enabled"]
%p.section.message %p.section.message
The whitelist is enabled on this instance. Ask the admin to add your instance before The whitelist is enabled on this instance. Ask the admin to add your instance before
joining. joining.

View file

@ -28,14 +28,6 @@ form input[type="submit"] {
margin: 0 auto; margin: 0 auto;
} }
legend {
background-color: var(--section-background);
padding: 5px;
border: 1px solid var(--border);
border-radius: 5px;
font-size: 10pt;
}
p { p {
line-height: 1em; line-height: 1em;
margin: 0px; margin: 0px;

View file

@ -7,7 +7,7 @@ import typing
from aiohttp import ClientSession, ClientTimeout, TCPConnector from aiohttp import ClientSession, ClientTimeout, TCPConnector
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
from asyncio.exceptions import TimeoutError as AsyncTimeoutError from asyncio.exceptions import TimeoutError as AsyncTimeoutError
from aputils import JsonBase, Nodeinfo, WellKnownNodeinfo from aputils.objects import Nodeinfo, WellKnownNodeinfo
from json.decoder import JSONDecodeError from json.decoder import JSONDecodeError
from urllib.parse import urlparse from urllib.parse import urlparse
@ -17,13 +17,12 @@ from .misc import MIMETYPES, Message, get_app
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from aputils import Signer from aputils import Signer
from bsql import Row from tinysql import Row
from typing import Any from typing import Any
from .application import Application from .application import Application
from .cache import Cache from .cache import Cache
T = typing.TypeVar('T', bound = JsonBase)
HEADERS = { HEADERS = {
'Accept': f'{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9', 'Accept': f'{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9',
'User-Agent': f'ActivityRelay/{__version__}' 'User-Agent': f'ActivityRelay/{__version__}'
@ -34,12 +33,12 @@ class HttpClient:
def __init__(self, limit: int = 100, timeout: int = 10): def __init__(self, limit: int = 100, timeout: int = 10):
self.limit = limit self.limit = limit
self.timeout = timeout self.timeout = timeout
self._conn: TCPConnector | None = None self._conn = None
self._session: ClientSession | None = None self._session = None
async def __aenter__(self) -> HttpClient: async def __aenter__(self) -> HttpClient:
self.open() await self.open()
return self return self
@ -62,7 +61,7 @@ class HttpClient:
return self.app.signer return self.app.signer
def open(self) -> None: async def open(self) -> None:
if self._session: if self._session:
return return
@ -80,19 +79,23 @@ class HttpClient:
async def close(self) -> None: async def close(self) -> None:
if self._session: if not self._session:
await self._session.close() return
if self._conn: await self._session.close()
await self._conn.close() await self._conn.close()
self._conn = None self._conn = None
self._session = None self._session = None
async def _get(self, url: str, sign_headers: bool, force: bool) -> dict[str, Any] | None: async def get(self, # pylint: disable=too-many-branches
if not self._session: url: str,
raise RuntimeError('Client not open') sign_headers: bool = False,
loads: callable = json.loads,
force: bool = False) -> dict | None:
await self.open()
try: try:
url, _ = url.split('#', 1) url, _ = url.split('#', 1)
@ -102,8 +105,10 @@ class HttpClient:
if not force: if not force:
try: try:
if not (item := self.cache.get('request', url)).older_than(48): item = self.cache.get('request', url)
return json.loads(item.value)
if not item.older_than(48):
return loads(item.value)
except KeyError: except KeyError:
logging.verbose('No cached data for url: %s', url) logging.verbose('No cached data for url: %s', url)
@ -116,22 +121,23 @@ class HttpClient:
try: try:
logging.debug('Fetching resource: %s', url) logging.debug('Fetching resource: %s', url)
async with self._session.get(url, headers = headers) as resp: async with self._session.get(url, headers=headers) as resp:
# Not expecting a response with 202s, so just return ## Not expecting a response with 202s, so just return
if resp.status == 202: if resp.status == 202:
return None return None
data = await resp.text() data = await resp.read()
if resp.status != 200: if resp.status != 200:
logging.verbose('Received error when requesting %s: %i', url, resp.status) logging.verbose('Received error when requesting %s: %i', url, resp.status)
logging.debug(data) logging.debug(data)
return None return None
self.cache.set('request', url, data, 'str') message = loads(data)
logging.debug('%s >> resp %s', url, json.dumps(json.loads(data), indent = 4)) self.cache.set('request', url, data.decode('utf-8'), 'str')
logging.debug('%s >> resp %s', url, json.dumps(message, indent = 4))
return json.loads(data) return message
except JSONDecodeError: except JSONDecodeError:
logging.verbose('Failed to parse JSON') logging.verbose('Failed to parse JSON')
@ -149,26 +155,17 @@ class HttpClient:
return None return None
async def get(self, url: str, sign_headers: bool, cls: type[T], force: bool = False) -> T | None:
if not issubclass(cls, JsonBase):
raise TypeError('cls must be a sub-class of "aputils.JsonBase"')
if (data := (await self._get(url, sign_headers, force))) is None:
return None
return cls.parse(data)
async def post(self, url: str, message: Message, instance: Row | None = None) -> None: async def post(self, url: str, message: Message, instance: Row | None = None) -> None:
if not self._session: await self.open()
raise RuntimeError('Client not open')
# Using the old algo by default is probably a better idea right now ## Using the old algo by default is probably a better idea right now
# pylint: disable=consider-ternary-expression
if instance and instance['software'] in {'mastodon'}: if instance and instance['software'] in {'mastodon'}:
algorithm = 'hs2019' algorithm = 'hs2019'
else: else:
algorithm = 'original' algorithm = 'original'
# pylint: enable=consider-ternary-expression
headers = {'Content-Type': 'application/activity+json'} headers = {'Content-Type': 'application/activity+json'}
headers.update(get_app().signer.sign_headers('POST', url, message, algorithm=algorithm)) headers.update(get_app().signer.sign_headers('POST', url, message, algorithm=algorithm))
@ -176,7 +173,7 @@ class HttpClient:
try: try:
logging.verbose('Sending "%s" to %s', message.type, url) logging.verbose('Sending "%s" to %s', message.type, url)
async with self._session.post(url, headers = headers, data = message.to_json()) as resp: async with self._session.post(url, headers=headers, data=message.to_json()) as resp:
# Not expecting a response, so just return # Not expecting a response, so just return
if resp.status in {200, 202}: if resp.status in {200, 202}:
logging.verbose('Successfully sent "%s" to %s', message.type, url) logging.verbose('Successfully sent "%s" to %s', message.type, url)
@ -201,11 +198,10 @@ class HttpClient:
nodeinfo_url = None nodeinfo_url = None
wk_nodeinfo = await self.get( wk_nodeinfo = await self.get(
f'https://{domain}/.well-known/nodeinfo', f'https://{domain}/.well-known/nodeinfo',
False, loads = WellKnownNodeinfo.parse
WellKnownNodeinfo
) )
if wk_nodeinfo is None: if not wk_nodeinfo:
logging.verbose('Failed to fetch well-known nodeinfo url for %s', domain) logging.verbose('Failed to fetch well-known nodeinfo url for %s', domain)
return None return None
@ -216,14 +212,14 @@ class HttpClient:
except KeyError: except KeyError:
pass pass
if nodeinfo_url is None: if not nodeinfo_url:
logging.verbose('Failed to fetch nodeinfo url for %s', domain) logging.verbose('Failed to fetch nodeinfo url for %s', domain)
return None return None
return await self.get(nodeinfo_url, False, Nodeinfo) return await self.get(nodeinfo_url, loads = Nodeinfo.parse) or None
async def get(*args: Any, **kwargs: Any) -> Any: async def get(*args: Any, **kwargs: Any) -> Message | dict | None:
async with HttpClient() as client: async with HttpClient() as client:
return await client.get(*args, **kwargs) return await client.get(*args, **kwargs)

View file

@ -9,7 +9,7 @@ from pathlib import Path
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from collections.abc import Callable from collections.abc import Callable
from typing import Any, Self from typing import Any
class LogLevel(IntEnum): class LogLevel(IntEnum):
@ -26,13 +26,7 @@ class LogLevel(IntEnum):
@classmethod @classmethod
def parse(cls: type[Self], data: Any) -> Self: def parse(cls: type[IntEnum], data: object) -> IntEnum:
try:
data = int(data)
except ValueError:
pass
if isinstance(data, cls): if isinstance(data, cls):
return data return data
@ -76,15 +70,15 @@ error: Callable = logging.error
critical: Callable = logging.critical critical: Callable = logging.critical
env_log_level: Path | str | None = os.environ.get('LOG_LEVEL', 'INFO').upper() env_log_level = os.environ.get('LOG_LEVEL', 'INFO').upper()
try: try:
env_log_file: Path | None = Path(os.environ['LOG_FILE']).expanduser().resolve() env_log_file = Path(os.environ['LOG_FILE']).expanduser().resolve()
except KeyError: except KeyError:
env_log_file = None env_log_file = None
handlers: list[Any] = [logging.StreamHandler()] handlers = [logging.StreamHandler()]
if env_log_file: if env_log_file:
handlers.append(logging.FileHandler(env_log_file)) handlers.append(logging.FileHandler(env_log_file))

View file

@ -21,10 +21,19 @@ from .database import RELAY_SOFTWARE, get_database
from .misc import ACTOR_FORMATS, SOFTWARE, IS_DOCKER, Message from .misc import ACTOR_FORMATS, SOFTWARE, IS_DOCKER, Message
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from bsql import Row from tinysql import Row
from typing import Any from typing import Any
# pylint: disable=unsubscriptable-object,unsupported-assignment-operation
CONFIG_IGNORE = (
'schema-version',
'private-key'
)
def check_alphanumeric(text: str) -> str: def check_alphanumeric(text: str) -> str:
if not text.isalnum(): if not text.isalnum():
raise click.BadParameter('String not alphanumeric') raise click.BadParameter('String not alphanumeric')
@ -41,7 +50,7 @@ def cli(ctx: click.Context, config: str | None) -> None:
if not ctx.invoked_subcommand: if not ctx.invoked_subcommand:
if ctx.obj.config.domain.endswith('example.com'): if ctx.obj.config.domain.endswith('example.com'):
cli_setup.callback() # type: ignore cli_setup.callback()
else: else:
click.echo( click.echo(
@ -49,7 +58,7 @@ def cli(ctx: click.Context, config: str | None) -> None:
'future.' 'future.'
) )
cli_run.callback() # type: ignore cli_run.callback()
@cli.command('setup') @cli.command('setup')
@ -175,7 +184,7 @@ def cli_setup(ctx: click.Context) -> None:
conn.put_config(key, value) conn.put_config(key, value)
if not IS_DOCKER and click.confirm('Relay all setup! Would you like to run it now?'): if not IS_DOCKER and click.confirm('Relay all setup! Would you like to run it now?'):
cli_run.callback() # type: ignore cli_run.callback()
@cli.command('run') @cli.command('run')
@ -248,7 +257,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
conn.put_config('note', config['note']) conn.put_config('note', config['note'])
conn.put_config('whitelist-enabled', config['whitelist_enabled']) conn.put_config('whitelist-enabled', config['whitelist_enabled'])
with click.progressbar( # type: ignore with click.progressbar(
database['relay-list'].values(), database['relay-list'].values(),
label = 'Inboxes'.ljust(15), label = 'Inboxes'.ljust(15),
width = 0 width = 0
@ -272,7 +281,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
software = inbox['software'] software = inbox['software']
) )
with click.progressbar( # type: ignore with click.progressbar(
config['blocked_software'], config['blocked_software'],
label = 'Banned software'.ljust(15), label = 'Banned software'.ljust(15),
width = 0 width = 0
@ -284,7 +293,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
reason = 'relay' if software in RELAY_SOFTWARE else None reason = 'relay' if software in RELAY_SOFTWARE else None
) )
with click.progressbar( # type: ignore with click.progressbar(
config['blocked_instances'], config['blocked_instances'],
label = 'Banned domains'.ljust(15), label = 'Banned domains'.ljust(15),
width = 0 width = 0
@ -293,7 +302,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
for domain in banned_software: for domain in banned_software:
conn.put_domain_ban(domain) conn.put_domain_ban(domain)
with click.progressbar( # type: ignore with click.progressbar(
config['whitelist'], config['whitelist'],
label = 'Whitelist'.ljust(15), label = 'Whitelist'.ljust(15),
width = 0 width = 0
@ -330,17 +339,10 @@ def cli_config_list(ctx: click.Context) -> None:
click.echo('Relay Config:') click.echo('Relay Config:')
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
config = conn.get_config_all() for key, value in conn.get_config_all().items():
if key not in CONFIG_IGNORE:
for key, value in config.to_dict().items(): key = f'{key}:'.ljust(20)
if key in type(config).SYSTEM_KEYS(): click.echo(f'- {key} {repr(value)}')
continue
if key == 'log-level':
value = value.name
key_str = f'{key}:'.ljust(20)
click.echo(f'- {key_str} {repr(value)}')
@cli_config.command('set') @cli_config.command('set')
@ -518,7 +520,7 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None: def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
'Unfollow an actor (Relay must be running)' 'Unfollow an actor (Relay must be running)'
inbox_data: Row | None = None inbox_data: Row = None
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
if conn.get_domain_ban(actor): if conn.get_domain_ban(actor):
@ -538,11 +540,6 @@ def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
actor = f'https://{actor}/actor' actor = f'https://{actor}/actor'
actor_data = asyncio.run(http.get(actor, sign_headers = True)) actor_data = asyncio.run(http.get(actor, sign_headers = True))
if not actor_data:
click.echo("Failed to fetch actor")
return
inbox = actor_data.shared_inbox inbox = actor_data.shared_inbox
message = Message.new_unfollow( message = Message.new_unfollow(
host = ctx.obj.config.domain, host = ctx.obj.config.domain,
@ -970,6 +967,7 @@ def cli_whitelist_import(ctx: click.Context) -> None:
def main() -> None: def main() -> None:
# pylint: disable=no-value-for-parameter
cli(prog_name='relay') cli(prog_name='relay')

View file

@ -8,31 +8,21 @@ import typing
from aiohttp.web import Response as AiohttpResponse from aiohttp.web import Response as AiohttpResponse
from datetime import datetime from datetime import datetime
from pathlib import Path
from uuid import uuid4 from uuid import uuid4
try: try:
from importlib.resources import files as pkgfiles from importlib.resources import files as pkgfiles
except ImportError: except ImportError:
from importlib_resources import files as pkgfiles # type: ignore from importlib_resources import files as pkgfiles
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from typing import Any, Self from pathlib import Path
from typing import Any
from .application import Application from .application import Application
T = typing.TypeVar('T')
ResponseType = typing.TypedDict('ResponseType', {
'status': int,
'headers': dict[str, typing.Any] | None,
'content_type': str,
'body': bytes | None,
'text': str | None
})
IS_DOCKER = bool(os.environ.get('DOCKER_RUNNING')) IS_DOCKER = bool(os.environ.get('DOCKER_RUNNING'))
MIMETYPES = { MIMETYPES = {
'activity': 'application/activity+json', 'activity': 'application/activity+json',
'css': 'text/css', 'css': 'text/css',
@ -102,7 +92,7 @@ def check_open_port(host: str, port: int) -> bool:
def get_app() -> Application: def get_app() -> Application:
from .application import Application from .application import Application # pylint: disable=import-outside-toplevel
if not Application.DEFAULT: if not Application.DEFAULT:
raise ValueError('No default application set') raise ValueError('No default application set')
@ -111,7 +101,7 @@ def get_app() -> Application:
def get_resource(path: str) -> Path: def get_resource(path: str) -> Path:
return Path(str(pkgfiles('relay'))).joinpath(path) return pkgfiles('relay').joinpath(path)
class JsonEncoder(json.JSONEncoder): class JsonEncoder(json.JSONEncoder):
@ -124,11 +114,11 @@ class JsonEncoder(json.JSONEncoder):
class Message(aputils.Message): class Message(aputils.Message):
@classmethod @classmethod
def new_actor(cls: type[Self], # type: ignore def new_actor(cls: type[Message], # pylint: disable=arguments-differ
host: str, host: str,
pubkey: str, pubkey: str,
description: str | None = None, description: str | None = None,
approves: bool = False) -> Self: approves: bool = False) -> Message:
return cls({ return cls({
'@context': 'https://www.w3.org/ns/activitystreams', '@context': 'https://www.w3.org/ns/activitystreams',
@ -154,7 +144,7 @@ class Message(aputils.Message):
@classmethod @classmethod
def new_announce(cls: type[Self], host: str, obj: str | dict[str, Any]) -> Self: def new_announce(cls: type[Message], host: str, obj: str) -> Message:
return cls({ return cls({
'@context': 'https://www.w3.org/ns/activitystreams', '@context': 'https://www.w3.org/ns/activitystreams',
'id': f'https://{host}/activities/{uuid4()}', 'id': f'https://{host}/activities/{uuid4()}',
@ -166,7 +156,7 @@ class Message(aputils.Message):
@classmethod @classmethod
def new_follow(cls: type[Self], host: str, actor: str) -> Self: def new_follow(cls: type[Message], host: str, actor: str) -> Message:
return cls({ return cls({
'@context': 'https://www.w3.org/ns/activitystreams', '@context': 'https://www.w3.org/ns/activitystreams',
'type': 'Follow', 'type': 'Follow',
@ -178,7 +168,7 @@ class Message(aputils.Message):
@classmethod @classmethod
def new_unfollow(cls: type[Self], host: str, actor: str, follow: dict[str, str]) -> Self: def new_unfollow(cls: type[Message], host: str, actor: str, follow: str) -> Message:
return cls({ return cls({
'@context': 'https://www.w3.org/ns/activitystreams', '@context': 'https://www.w3.org/ns/activitystreams',
'id': f'https://{host}/activities/{uuid4()}', 'id': f'https://{host}/activities/{uuid4()}',
@ -190,7 +180,12 @@ class Message(aputils.Message):
@classmethod @classmethod
def new_response(cls: type[Self], host: str, actor: str, followid: str, accept: bool) -> Self: def new_response(cls: type[Message],
host: str,
actor: str,
followid: str,
accept: bool) -> Message:
return cls({ return cls({
'@context': 'https://www.w3.org/ns/activitystreams', '@context': 'https://www.w3.org/ns/activitystreams',
'id': f'https://{host}/activities/{uuid4()}', 'id': f'https://{host}/activities/{uuid4()}',
@ -213,18 +208,16 @@ class Response(AiohttpResponse):
@classmethod @classmethod
def new(cls: type[Self], def new(cls: type[Response],
body: str | bytes | dict | tuple | list | set = '', body: str | bytes | dict = '',
status: int = 200, status: int = 200,
headers: dict[str, str] | None = None, headers: dict[str, str] | None = None,
ctype: str = 'text') -> Self: ctype: str = 'text') -> Response:
kwargs: ResponseType = { kwargs = {
'status': status, 'status': status,
'headers': headers, 'headers': headers,
'content_type': MIMETYPES[ctype], 'content_type': MIMETYPES[ctype]
'body': None,
'text': None
} }
if isinstance(body, bytes): if isinstance(body, bytes):
@ -240,10 +233,10 @@ class Response(AiohttpResponse):
@classmethod @classmethod
def new_error(cls: type[Self], def new_error(cls: type[Response],
status: int, status: int,
body: str | bytes | dict, body: str | bytes | dict,
ctype: str = 'text') -> Self: ctype: str = 'text') -> Response:
if ctype == 'json': if ctype == 'json':
body = {'error': body} body = {'error': body}
@ -252,14 +245,14 @@ class Response(AiohttpResponse):
@classmethod @classmethod
def new_redir(cls: type[Self], path: str) -> Self: def new_redir(cls: type[Response], path: str) -> Response:
body = f'Redirect to <a href="{path}">{path}</a>' body = f'Redirect to <a href="{path}">{path}</a>'
return cls.new(body, 302, {'Location': path}) return cls.new(body, 302, {'Location': path})
@property @property
def location(self) -> str: def location(self) -> str:
return self.headers.get('Location', '') return self.headers.get('Location')
@location.setter @location.setter

View file

@ -7,10 +7,10 @@ from .database import Connection
from .misc import Message from .misc import Message
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from .views.activitypub import ActorView from .views import ActorView
def person_check(actor: Message, software: str | None) -> bool: def person_check(actor: str, software: str) -> bool:
# pleroma and akkoma may use Person for the actor type for some reason # pleroma and akkoma may use Person for the actor type for some reason
# akkoma changed this in 3.6.0 # akkoma changed this in 3.6.0
if software in {'akkoma', 'pleroma'} and actor.id == f'https://{actor.domain}/relay': if software in {'akkoma', 'pleroma'} and actor.id == f'https://{actor.domain}/relay':
@ -65,7 +65,7 @@ async def handle_follow(view: ActorView, conn: Connection) -> None:
config = conn.get_config_all() config = conn.get_config_all()
# reject if software used by actor is banned # reject if software used by actor is banned
if software and conn.get_software_ban(software): if conn.get_software_ban(software):
logging.verbose('Rejected banned actor: %s', view.actor.id) logging.verbose('Rejected banned actor: %s', view.actor.id)
view.app.push_message( view.app.push_message(
@ -75,8 +75,7 @@ async def handle_follow(view: ActorView, conn: Connection) -> None:
actor = view.actor.id, actor = view.actor.id,
followid = view.message.id, followid = view.message.id,
accept = False accept = False
), )
view.instance
) )
logging.verbose( logging.verbose(
@ -87,7 +86,7 @@ async def handle_follow(view: ActorView, conn: Connection) -> None:
return return
# reject if the actor is not an instance actor ## reject if the actor is not an instance actor
if person_check(view.actor, software): if person_check(view.actor, software):
logging.verbose('Non-application actor tried to follow: %s', view.actor.id) logging.verbose('Non-application actor tried to follow: %s', view.actor.id)
@ -106,7 +105,7 @@ async def handle_follow(view: ActorView, conn: Connection) -> None:
if not conn.get_domain_whitelist(view.actor.domain): if not conn.get_domain_whitelist(view.actor.domain):
# add request if approval-required is enabled # add request if approval-required is enabled
if config.approval_required: if config['approval-required']:
logging.verbose('New follow request fromm actor: %s', view.actor.id) logging.verbose('New follow request fromm actor: %s', view.actor.id)
with conn.transaction(): with conn.transaction():
@ -122,7 +121,7 @@ async def handle_follow(view: ActorView, conn: Connection) -> None:
return return
# reject if the actor isn't whitelisted while the whiltelist is enabled # reject if the actor isn't whitelisted while the whiltelist is enabled
if config.whitelist_enabled: if config['whitelist-enabled']:
logging.verbose('Rejected actor for not being in the whitelist: %s', view.actor.id) logging.verbose('Rejected actor for not being in the whitelist: %s', view.actor.id)
view.app.push_message( view.app.push_message(
@ -132,8 +131,7 @@ async def handle_follow(view: ActorView, conn: Connection) -> None:
actor = view.actor.id, actor = view.actor.id,
followid = view.message.id, followid = view.message.id,
accept = False accept = False
), )
view.instance
) )
return return
@ -173,7 +171,7 @@ async def handle_follow(view: ActorView, conn: Connection) -> None:
async def handle_undo(view: ActorView, conn: Connection) -> None: async def handle_undo(view: ActorView, conn: Connection) -> None:
# If the object is not a Follow, forward it ## If the object is not a Follow, forward it
if view.message.object['type'] != 'Follow': if view.message.object['type'] != 'Follow':
await handle_forward(view, conn) await handle_forward(view, conn)
return return
@ -187,7 +185,7 @@ async def handle_undo(view: ActorView, conn: Connection) -> None:
logging.verbose( logging.verbose(
'Failed to delete "%s" with follow ID "%s"', 'Failed to delete "%s" with follow ID "%s"',
view.actor.id, view.actor.id,
view.message.object_id view.message.object['id']
) )
view.app.push_message( view.app.push_message(

View file

@ -52,37 +52,34 @@ class Template(Environment):
'domain': self.app.config.domain, 'domain': self.app.config.domain,
'version': __version__, 'version': __version__,
'config': config, 'config': config,
'theme_name': config['theme'] or 'Default',
**(context or {}) **(context or {})
} }
return self.get_template(path).render(new_context) return self.get_template(path).render(new_context)
def render_markdown(self, text: str) -> str:
return self._render_markdown(text) # type: ignore
class MarkdownExtension(Extension): class MarkdownExtension(Extension):
tags = {'markdown'} tags = {'markdown'}
extensions = ( extensions = {
'attr_list', 'attr_list',
'smarty', 'smarty',
'tables' 'tables'
) }
def __init__(self, environment: Environment): def __init__(self, environment: Environment):
Extension.__init__(self, environment) Extension.__init__(self, environment)
self._markdown = Markdown(extensions = MarkdownExtension.extensions) self._markdown = Markdown(extensions = MarkdownExtension.extensions)
environment.extend( environment.extend(
_render_markdown = self._render_markdown render_markdown = self._render_markdown
) )
def parse(self, parser: Parser) -> Node | list[Node]: def parse(self, parser: Parser) -> Node | list[Node]:
lineno = next(parser.stream).lineno lineno = next(parser.stream).lineno
body = parser.parse_statements( body = parser.parse_statements(
('name:endmarkdown',), ['name:endmarkdown'],
drop_needle = True drop_needle = True
) )
@ -91,5 +88,5 @@ class MarkdownExtension(Extension):
def _render_markdown(self, caller: Callable[[], str] | str) -> str: def _render_markdown(self, caller: Callable[[], str] | str) -> str:
text = caller if isinstance(caller, str) else caller() text = caller() if isinstance(caller, Callable) else caller
return self._markdown.convert(textwrap.dedent(text.strip('\n'))) return self._markdown.convert(textwrap.dedent(text.strip('\n')))

View file

@ -1,4 +1,4 @@
from __future__ import annotations from __future__ import annotations
from . import activitypub, api, frontend, misc from . import activitypub, api, frontend, misc
from .base import VIEWS, View from .base import VIEWS

View file

@ -12,21 +12,22 @@ from ..processors import run_processor
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from aiohttp.web import Request from aiohttp.web import Request
from bsql import Row from tinysql import Row
# pylint: disable=unused-argument
@register_route('/actor', '/inbox') @register_route('/actor', '/inbox')
class ActorView(View): class ActorView(View):
signature: aputils.Signature
message: Message
actor: Message
instancce: Row
signer: aputils.Signer
def __init__(self, request: Request): def __init__(self, request: Request):
View.__init__(self, request) View.__init__(self, request)
self.signature: aputils.Signature = None
self.message: Message = None
self.actor: Message = None
self.instance: Row = None
self.signer: aputils.Signer = None
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.session(False) as conn: with self.database.session(False) as conn:
@ -35,8 +36,8 @@ class ActorView(View):
data = Message.new_actor( data = Message.new_actor(
host = self.config.domain, host = self.config.domain,
pubkey = self.app.signer.pubkey, pubkey = self.app.signer.pubkey,
description = self.app.template.render_markdown(config.note), description = self.app.template.render_markdown(config['note']),
approves = config.approval_required approves = config['approval-required']
) )
return Response.new(data, ctype='activity') return Response.new(data, ctype='activity')
@ -49,12 +50,12 @@ class ActorView(View):
with self.database.session() as conn: with self.database.session() as conn:
self.instance = conn.get_inbox(self.actor.shared_inbox) self.instance = conn.get_inbox(self.actor.shared_inbox)
# reject if actor is banned ## reject if actor is banned
if conn.get_domain_ban(self.actor.domain): if conn.get_domain_ban(self.actor.domain):
logging.verbose('Ignored request from banned actor: %s', self.actor.id) logging.verbose('Ignored request from banned actor: %s', self.actor.id)
return Response.new_error(403, 'access denied', 'json') return Response.new_error(403, 'access denied', 'json')
# reject if activity type isn't 'Follow' and the actor isn't following ## reject if activity type isn't 'Follow' and the actor isn't following
if self.message.type != 'Follow' and not self.instance: if self.message.type != 'Follow' and not self.instance:
logging.verbose( logging.verbose(
'Rejected actor for trying to post while not following: %s', 'Rejected actor for trying to post while not following: %s',
@ -78,26 +79,28 @@ class ActorView(View):
return Response.new_error(400, 'missing signature header', 'json') return Response.new_error(400, 'missing signature header', 'json')
try: try:
message: Message | None = await self.request.json(loads = Message.parse) self.message = await self.request.json(loads = Message.parse)
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
logging.verbose('Failed to parse inbox message') logging.verbose('Failed to parse inbox message')
return Response.new_error(400, 'failed to parse message', 'json') return Response.new_error(400, 'failed to parse message', 'json')
if message is None: if self.message is None:
logging.verbose('empty message') logging.verbose('empty message')
return Response.new_error(400, 'missing message', 'json') return Response.new_error(400, 'missing message', 'json')
self.message = message
if 'actor' not in self.message: if 'actor' not in self.message:
logging.verbose('actor not in message') logging.verbose('actor not in message')
return Response.new_error(400, 'no actor in message', 'json') return Response.new_error(400, 'no actor in message', 'json')
actor: Message | None = await self.client.get(self.signature.keyid, True, Message) self.actor = await self.client.get(
self.signature.keyid,
sign_headers = True,
loads = Message.parse
)
if actor is None: if not self.actor:
# ld signatures aren't handled atm, so just ignore it # ld signatures aren't handled atm, so just ignore it
if self.message.type == 'Delete': if self.message.type == 'Delete':
logging.verbose('Instance sent a delete which cannot be handled') logging.verbose('Instance sent a delete which cannot be handled')
@ -106,8 +109,6 @@ class ActorView(View):
logging.verbose(f'Failed to fetch actor: {self.signature.keyid}') logging.verbose(f'Failed to fetch actor: {self.signature.keyid}')
return Response.new_error(400, 'failed to fetch actor', 'json') return Response.new_error(400, 'failed to fetch actor', 'json')
self.actor = actor
try: try:
self.signer = self.actor.signer self.signer = self.actor.signer
@ -122,8 +123,6 @@ class ActorView(View):
logging.verbose('signature validation failed for "%s": %s', self.actor.id, e) logging.verbose('signature validation failed for "%s": %s', self.actor.id, e)
return Response.new_error(401, str(e), 'json') return Response.new_error(401, str(e), 'json')
return None
def validate_signature(self, body: bytes) -> None: def validate_signature(self, body: bytes) -> None:
headers = {key.lower(): value for key, value in self.request.headers.items()} headers = {key.lower(): value for key, value in self.request.headers.items()}
@ -151,6 +150,7 @@ class ActorView(View):
headers["(created)"] = self.signature.created headers["(created)"] = self.signature.created
headers["(expires)"] = self.signature.expires headers["(expires)"] = self.signature.expires
# pylint: disable=protected-access
if not self.signer._validate_signature(headers, self.signature): if not self.signer._validate_signature(headers, self.signature):
raise aputils.SignatureFailureError("Signature does not match") raise aputils.SignatureFailureError("Signature does not match")

View file

@ -9,15 +9,23 @@ from urllib.parse import urlparse
from .base import View, register_route from .base import View, register_route
from .. import __version__ from .. import __version__
from ..database import ConfigData from .. import logger as logging
from ..misc import Message, Response, get_app from ..database.config import CONFIG_DEFAULTS
from ..misc import Message, Response
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from aiohttp.web import Request from aiohttp.web import Request
from collections.abc import Callable, Sequence from collections.abc import Coroutine
PUBLIC_API_PATHS: Sequence[tuple[str, str]] = ( CONFIG_IGNORE = (
'schema-version',
'private-key'
)
CONFIG_VALID = {key for key in CONFIG_DEFAULTS if key not in CONFIG_IGNORE}
PUBLIC_API_PATHS: tuple[tuple[str, str]] = (
('GET', '/api/v1/relay'), ('GET', '/api/v1/relay'),
('GET', '/api/v1/instance'), ('GET', '/api/v1/instance'),
('POST', '/api/v1/token') ('POST', '/api/v1/token')
@ -32,11 +40,11 @@ def check_api_path(method: str, path: str) -> bool:
@web.middleware @web.middleware
async def handle_api_path(request: Request, handler: Callable) -> Response: async def handle_api_path(request: web.Request, handler: Coroutine) -> web.Response:
try: try:
request['token'] = request.headers['Authorization'].replace('Bearer', '').strip() request['token'] = request.headers['Authorization'].replace('Bearer', '').strip()
with get_app().database.session() as conn: with request.app.database.session() as conn:
request['user'] = conn.get_user_by_token(request['token']) request['user'] = conn.get_user_by_token(request['token'])
except (KeyError, ValueError): except (KeyError, ValueError):
@ -53,6 +61,8 @@ async def handle_api_path(request: Request, handler: Callable) -> Response:
return await handler(request) return await handler(request)
# pylint: disable=no-self-use,unused-argument
@register_route('/api/v1/token') @register_route('/api/v1/token')
class Login(View): class Login(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
@ -92,14 +102,14 @@ class RelayInfo(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
config = conn.get_config_all() config = conn.get_config_all()
inboxes = [row['domain'] for row in conn.get_inboxes()] inboxes = [row['domain'] for row in conn.execute('SELECT * FROM inboxes')]
data = { data = {
'domain': self.config.domain, 'domain': self.config.domain,
'name': config.name, 'name': config['name'],
'description': config.note, 'description': config['note'],
'version': __version__, 'version': __version__,
'whitelist_enabled': config.whitelist_enabled, 'whitelist_enabled': config['whitelist-enabled'],
'email': None, 'email': None,
'admin': None, 'admin': None,
'icon': None, 'icon': None,
@ -112,17 +122,12 @@ class RelayInfo(View):
@register_route('/api/v1/config') @register_route('/api/v1/config')
class Config(View): class Config(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
data = {}
with self.database.session() as conn: with self.database.session() as conn:
for key, value in conn.get_config_all().to_dict().items(): data = conn.get_config_all()
if key in ConfigData.SYSTEM_KEYS(): data['log-level'] = data['log-level'].name
continue
if key == 'log-level': for key in CONFIG_IGNORE:
value = value.name del data[key]
data[key] = value
return Response.new(data, ctype = 'json') return Response.new(data, ctype = 'json')
@ -133,7 +138,7 @@ class Config(View):
if isinstance(data, Response): if isinstance(data, Response):
return data return data
if data['key'] not in ConfigData.USER_KEYS(): if data['key'] not in CONFIG_VALID:
return Response.new_error(400, 'Invalid key', 'json') return Response.new_error(400, 'Invalid key', 'json')
with self.database.session() as conn: with self.database.session() as conn:
@ -148,11 +153,11 @@ class Config(View):
if isinstance(data, Response): if isinstance(data, Response):
return data return data
if data['key'] not in ConfigData.USER_KEYS(): if data['key'] not in CONFIG_VALID:
return Response.new_error(400, 'Invalid key', 'json') return Response.new_error(400, 'Invalid key', 'json')
with self.database.session() as conn: with self.database.session() as conn:
conn.put_config(data['key'], ConfigData.DEFAULT(data['key'])) conn.put_config(data['key'], CONFIG_DEFAULTS[data['key']][1])
return Response.new({'message': 'Updated config'}, ctype = 'json') return Response.new({'message': 'Updated config'}, ctype = 'json')
@ -179,13 +184,19 @@ class Inbox(View):
return Response.new_error(404, 'Instance already in database', 'json') return Response.new_error(404, 'Instance already in database', 'json')
if not data.get('inbox'): if not data.get('inbox'):
actor_data: Message | None = await self.client.get(data['actor'], True, Message) try:
actor_data = await self.client.get(
data['actor'],
sign_headers = True,
loads = Message.parse
)
if actor_data is None: data['inbox'] = actor_data.shared_inbox
except Exception as e:
logging.error('Failed to fetch actor: %s', str(e))
return Response.new_error(500, 'Failed to fetch actor', 'json') return Response.new_error(500, 'Failed to fetch actor', 'json')
data['inbox'] = actor_data.shared_inbox
row = conn.put_inbox(**data) row = conn.put_inbox(**data)
return Response.new(row, ctype = 'json') return Response.new(row, ctype = 'json')

View file

@ -8,11 +8,11 @@ from aiohttp.web import HTTPMethodNotAllowed
from functools import cached_property from functools import cached_property
from json.decoder import JSONDecodeError from json.decoder import JSONDecodeError
from ..misc import Response, get_app from ..misc import Response
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from aiohttp.web import Request from aiohttp.web import Request
from collections.abc import Callable, Generator, Sequence, Mapping from collections.abc import Callable, Coroutine, Generator
from bsql import Database from bsql import Database
from typing import Any, Self from typing import Any, Self
from ..application import Application from ..application import Application
@ -22,24 +22,20 @@ if typing.TYPE_CHECKING:
from ..template import Template from ..template import Template
VIEWS: list[tuple[str, type[View]]] = [] VIEWS = []
def convert_data(data: Mapping[str, Any]) -> dict[str, str]:
return {key: str(value) for key, value in data.items()}
def register_route(*paths: str) -> Callable: def register_route(*paths: str) -> Callable:
def wrapper(view: type[View]) -> type[View]: def wrapper(view: View) -> View:
for path in paths: for path in paths:
VIEWS.append((path, view)) VIEWS.append([path, view])
return view return view
return wrapper return wrapper
class View(AbstractView): class View(AbstractView):
def __await__(self) -> Generator[Any, None, Response]: def __await__(self) -> Generator[Response]:
if self.request.method not in METHODS: if self.request.method not in METHODS:
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods) raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
@ -50,22 +46,22 @@ class View(AbstractView):
@classmethod @classmethod
async def run(cls: type[Self], method: str, request: Request, **kwargs: Any) -> Response: async def run(cls: type[Self], method: str, request: Request, **kwargs: Any) -> Self:
view = cls(request) view = cls(request)
return await view.handlers[method](request, **kwargs) return await view.handlers[method](request, **kwargs)
async def _run_handler(self, handler: Callable[..., Any], **kwargs: Any) -> Response: async def _run_handler(self, handler: Coroutine, **kwargs: Any) -> Response:
return await handler(self.request, **self.request.match_info, **kwargs) return await handler(self.request, **self.request.match_info, **kwargs)
@cached_property @cached_property
def allowed_methods(self) -> Sequence[str]: def allowed_methods(self) -> tuple[str]:
return tuple(self.handlers.keys()) return tuple(self.handlers.keys())
@cached_property @cached_property
def handlers(self) -> dict[str, Callable[..., Any]]: def handlers(self) -> dict[str, Coroutine]:
data = {} data = {}
for method in METHODS: for method in METHODS:
@ -78,9 +74,10 @@ class View(AbstractView):
return data return data
# app components
@property @property
def app(self) -> Application: def app(self) -> Application:
return get_app() return self.request.app
@property @property
@ -113,17 +110,17 @@ class View(AbstractView):
optional: list[str]) -> dict[str, str] | Response: optional: list[str]) -> dict[str, str] | Response:
if self.request.content_type in {'x-www-form-urlencoded', 'multipart/form-data'}: if self.request.content_type in {'x-www-form-urlencoded', 'multipart/form-data'}:
post_data = convert_data(await self.request.post()) post_data = await self.request.post()
elif self.request.content_type == 'application/json': elif self.request.content_type == 'application/json':
try: try:
post_data = convert_data(await self.request.json()) post_data = await self.request.json()
except JSONDecodeError: except JSONDecodeError:
return Response.new_error(400, 'Invalid JSON data', 'json') return Response.new_error(400, 'Invalid JSON data', 'json')
else: else:
post_data = convert_data(await self.request.query) # type: ignore post_data = self.request.query
data = {} data = {}
@ -135,6 +132,6 @@ class View(AbstractView):
return Response.new_error(400, f'Missing {str(e)} pararmeter', 'json') return Response.new_error(400, f'Missing {str(e)} pararmeter', 'json')
for key in optional: for key in optional:
data[key] = post_data.get(key, '') data[key] = post_data.get(key)
return data return data

View file

@ -8,30 +8,36 @@ from urllib.parse import urlparse
from .base import View, register_route from .base import View, register_route
from ..database import THEMES, ConfigData from ..database import CONFIG_DEFAULTS, THEMES
from ..logger import LogLevel from ..logger import LogLevel
from ..misc import ACTOR_FORMATS, Message, Response, get_app from ..misc import ACTOR_FORMATS, Message, Response
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from aiohttp.web import Request from aiohttp.web import Request
from collections.abc import Callable from collections.abc import Coroutine
from typing import Any
# pylint: disable=no-self-use
UNAUTH_ROUTES = { UNAUTH_ROUTES = {
'/', '/',
'/login' '/login'
} }
CONFIG_IGNORE = (
'schema-version',
'private-key'
)
@web.middleware @web.middleware
async def handle_frontend_path(request: web.Request, handler: Callable) -> Response: async def handle_frontend_path(request: web.Request, handler: Coroutine) -> Response:
if request.path in UNAUTH_ROUTES or request.path.startswith('/admin'): if request.path in UNAUTH_ROUTES or request.path.startswith('/admin'):
request['token'] = request.cookies.get('user-token') request['token'] = request.cookies.get('user-token')
request['user'] = None request['user'] = None
if request['token']: if request['token']:
with get_app().database.session(False) as conn: with request.app.database.session(False) as conn:
request['user'] = conn.get_user_by_token(request['token']) request['user'] = conn.get_user_by_token(request['token'])
if request['user'] and request.path == '/login': if request['user'] and request.path == '/login':
@ -43,11 +49,13 @@ async def handle_frontend_path(request: web.Request, handler: Callable) -> Respo
return await handler(request) return await handler(request)
# pylint: disable=unused-argument
@register_route('/') @register_route('/')
class HomeView(View): class HomeView(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
context: dict[str, Any] = { context = {
'instances': tuple(conn.get_inboxes()) 'instances': tuple(conn.get_inboxes())
} }
@ -128,7 +136,7 @@ class AdminInstances(View):
message: str | None = None) -> Response: message: str | None = None) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
context: dict[str, Any] = { context = {
'instances': tuple(conn.get_inboxes()), 'instances': tuple(conn.get_inboxes()),
'requests': tuple(conn.get_requests()) 'requests': tuple(conn.get_requests())
} }
@ -144,8 +152,7 @@ class AdminInstances(View):
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
post = await request.post() data = await request.post()
data: dict[str, str] = {key: value for key, value in post.items()} # type: ignore
if not data.get('actor') and not data.get('domain'): if not data.get('actor') and not data.get('domain'):
return await self.get(request, error = 'Missing actor and/or domain') return await self.get(request, error = 'Missing actor and/or domain')
@ -155,21 +162,13 @@ class AdminInstances(View):
if not data.get('software'): if not data.get('software'):
nodeinfo = await self.client.fetch_nodeinfo(data['domain']) nodeinfo = await self.client.fetch_nodeinfo(data['domain'])
if nodeinfo is None:
return await self.get(request, error = 'Failed to fetch nodeinfo')
data['software'] = nodeinfo.sw_name data['software'] = nodeinfo.sw_name
if not data.get('actor') and data['software'] in ACTOR_FORMATS: if not data.get('actor') and data['software'] in ACTOR_FORMATS:
data['actor'] = ACTOR_FORMATS[data['software']].format(domain = data['domain']) data['actor'] = ACTOR_FORMATS[data['software']].format(domain = data['domain'])
if not data.get('inbox') and data['actor']: if not data.get('inbox') and data['actor']:
actor: Message | None = await self.client.get(data['actor'], True, Message) actor = await self.client.get(data['actor'], sign_headers = True, loads = Message.parse)
if actor is None:
return await self.get(request, error = 'Failed to fetch actor')
data['inbox'] = actor.shared_inbox data['inbox'] = actor.shared_inbox
with self.database.session(True) as conn: with self.database.session(True) as conn:
@ -249,7 +248,7 @@ class AdminWhitelist(View):
message: str | None = None) -> Response: message: str | None = None) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
context: dict[str, Any] = { context = {
'whitelist': tuple(conn.execute('SELECT * FROM whitelist ORDER BY domain ASC')) 'whitelist': tuple(conn.execute('SELECT * FROM whitelist ORDER BY domain ASC'))
} }
@ -299,7 +298,7 @@ class AdminDomainBans(View):
message: str | None = None) -> Response: message: str | None = None) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
context: dict[str, Any] = { context = {
'bans': tuple(conn.execute('SELECT * FROM domain_bans ORDER BY domain ASC')) 'bans': tuple(conn.execute('SELECT * FROM domain_bans ORDER BY domain ASC'))
} }
@ -357,7 +356,7 @@ class AdminSoftwareBans(View):
message: str | None = None) -> Response: message: str | None = None) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
context: dict[str, Any] = { context = {
'bans': tuple(conn.execute('SELECT * FROM software_bans ORDER BY name ASC')) 'bans': tuple(conn.execute('SELECT * FROM software_bans ORDER BY name ASC'))
} }
@ -415,7 +414,7 @@ class AdminUsers(View):
message: str | None = None) -> Response: message: str | None = None) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
context: dict[str, Any] = { context = {
'users': tuple(conn.execute('SELECT * FROM users ORDER BY username ASC')) 'users': tuple(conn.execute('SELECT * FROM users ORDER BY username ASC'))
} }
@ -463,26 +462,29 @@ class AdminUsersDelete(View):
@register_route('/admin/config') @register_route('/admin/config')
class AdminConfig(View): class AdminConfig(View):
async def get(self, request: Request, message: str | None = None) -> Response: async def get(self, request: Request, message: str | None = None) -> Response:
context: dict[str, Any] = { context = {
'themes': tuple(THEMES.keys()), 'themes': tuple(THEMES.keys()),
'levels': tuple(level.name for level in LogLevel), 'levels': tuple(level.name for level in LogLevel),
'message': message 'message': message
} }
data = self.template.render('page/admin-config.haml', self, **context) data = self.template.render('page/admin-config.haml', self, **context)
return Response.new(data, ctype = 'html') return Response.new(data, ctype = 'html')
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
form = dict(await request.post()) form = dict(await request.post())
data = ConfigData()
for key in ConfigData.USER_KEYS():
data.set(key, form.get(key.replace('_', '-')))
with self.database.session(True) as conn: with self.database.session(True) as conn:
for key, value in data.to_dict().items(): for key in CONFIG_DEFAULTS:
if key in ConfigData.SYSTEM_KEYS(): value = form.get(key)
if key == 'whitelist-enabled':
value = bool(value)
elif key.lower() in CONFIG_IGNORE:
continue
if value is None:
continue continue
conn.put_config(key, value) conn.put_config(key, value)
@ -501,7 +503,7 @@ class StyleCss(View):
class ThemeCss(View): class ThemeCss(View):
async def get(self, request: Request, theme: str) -> Response: async def get(self, request: Request, theme: str) -> Response:
try: try:
context: dict[str, Any] = { context = {
'theme': THEMES[theme] 'theme': THEMES[theme]
} }

View file

@ -27,26 +27,31 @@ if Path(__file__).parent.parent.joinpath('.git').exists():
pass pass
# pylint: disable=unused-argument
@register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}') @register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}')
class NodeinfoView(View): class NodeinfoView(View):
# pylint: disable=no-self-use
async def get(self, request: Request, niversion: str) -> Response: async def get(self, request: Request, niversion: str) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
inboxes = conn.get_inboxes() inboxes = conn.get_inboxes()
nodeinfo = aputils.Nodeinfo.new( data = {
name = 'activityrelay', 'name': 'activityrelay',
version = VERSION, 'version': VERSION,
protocols = ['activitypub'], 'protocols': ['activitypub'],
open_regs = not conn.get_config('whitelist-enabled'), 'open_regs': not conn.get_config('whitelist-enabled'),
users = 1, 'users': 1,
repo = 'https://git.pleroma.social/pleroma/relay' if niversion == '2.1' else None, 'metadata': {
metadata = {
'approval_required': conn.get_config('approval-required'), 'approval_required': conn.get_config('approval-required'),
'peers': [inbox['domain'] for inbox in inboxes] 'peers': [inbox['domain'] for inbox in inboxes]
} }
) }
return Response.new(nodeinfo, ctype = 'json') if niversion == '2.1':
data['repo'] = 'https://git.pleroma.social/pleroma/relay'
return Response.new(aputils.Nodeinfo.new(**data), ctype = 'json')
@register_route('/.well-known/nodeinfo') @register_route('/.well-known/nodeinfo')

View file

@ -1,14 +1,14 @@
aiohttp >= 3.9.1 aiohttp>=3.9.1
aiohttp-swagger[performance] == 1.0.16 aiohttp-swagger[performance]==1.0.16
aputils @ https://git.barkshark.xyz/barkshark/aputils/archive/0.1.7.tar.gz aputils@https://git.barkshark.xyz/barkshark/aputils/archive/0.1.7.tar.gz
argon2-cffi == 23.1.0 argon2-cffi==23.1.0
barkshark-sql @ https://git.barkshark.xyz/barkshark/bsql/archive/0.1.2.tar.gz barkshark-sql@https://git.barkshark.xyz/barkshark/bsql/archive/0.1.2.tar.gz
click >= 8.1.2 click>=8.1.2
hamlish-jinja @ https://git.barkshark.xyz/barkshark/hamlish-jinja/archive/0.3.5.tar.gz hamlish-jinja@https://git.barkshark.xyz/barkshark/hamlish-jinja/archive/0.3.5.tar.gz
hiredis == 2.3.2 hiredis==2.3.2
markdown == 3.5.2 markdown==3.5.2
platformdirs == 4.2.0 platformdirs==4.2.0
pyyaml >= 6.0 pyyaml>=6.0
redis == 5.0.1 redis==5.0.1
importlib_resources == 6.1.1; python_version < '3.9' importlib_resources==6.1.1;python_version<'3.9'

View file

@ -44,8 +44,6 @@ console_scripts =
[flake8] [flake8]
extend-ignore = E128,E251,E261,E303,W191 select = F401
max-line-length = 100
indent-size = 4
per-file-ignores = per-file-ignores =
__init__.py: F401 __init__.py: F401