Merge branch 'dev' into 'main'

version 0.3.1

See merge request pleroma/relay!58
This commit is contained in:
Izalia Mae 2024-04-02 18:31:22 +00:00
commit dec7c6a674
57 changed files with 2973 additions and 1651 deletions

4
.gitignore vendored
View file

@ -12,7 +12,7 @@ __pycache__/
env/
build/
develop-eggs/
dist/
dist*/
downloads/
eggs/
.eggs/
@ -98,3 +98,5 @@ ENV/
*.yaml
*.jsonld
*.sqlite3
test*.py

View file

@ -1,28 +1,20 @@
FROM python:3-alpine
# install build deps for pycryptodome and other c-based python modules
RUN apk add alpine-sdk autoconf automake libtool gcc
FROM python:3.12-alpine
# add env var to let the relay know it's in a container
ENV DOCKER_RUNNING=true
# setup various container properties
VOLUME ["/data"]
CMD ["python", "-m", "relay"]
CMD ["python3", "-m", "relay", "run"]
EXPOSE 8080/tcp
WORKDIR /opt/activityrelay
# only copy necessary files
COPY relay ./relay
COPY pyproject.toml ./
# install and update important python modules
RUN pip3 install -U setuptools wheel pip
# only copy necessary files
COPY relay ./relay
COPY LICENSE .
COPY README.md .
COPY requirements.txt .
COPY setup.cfg .
COPY setup.py .
COPY .git ./.git
# install relay deps
RUN pip3 install -r requirements.txt
RUN pip3 install `python3 -c "import tomllib; print(' '.join(dep.replace(' ', '') for dep in tomllib.load(open('pyproject.toml', 'rb'))['project']['dependencies']))"`

View file

@ -1,4 +0,0 @@
flake8 == 7.0.0
pyinstaller == 6.3.0
pylint == 3.0
watchdog == 4.0.0

216
dev.py Executable file
View file

@ -0,0 +1,216 @@
#!/usr/bin/env python3
import click
import platform
import shutil
import subprocess
import sys
import time
from datetime import datetime, timedelta
from pathlib import Path
from relay import __version__, logger as logging
from tempfile import TemporaryDirectory
from typing import Sequence
try:
from watchdog.observers import Observer
from watchdog.events import PatternMatchingEventHandler
except ImportError:
class PatternMatchingEventHandler: # type: ignore
pass
REPO = Path(__file__).parent
IGNORE_EXT = {
'.py',
'.pyc'
}
@click.group('cli')
def cli():
'Useful commands for development'
@cli.command('install')
def cli_install():
cmd = [
sys.executable, '-m', 'pip', 'install',
'-r', 'requirements.txt',
'-r', 'dev-requirements.txt'
]
subprocess.run(cmd, check = False)
@cli.command('lint')
@click.argument('path', required = False, type = Path, default = REPO.joinpath('relay'))
@click.option('--strict', '-s', is_flag = True, help = 'Enable strict mode for mypy')
@click.option('--watch', '-w', is_flag = True,
help = 'Automatically, re-run the linters on source change')
def cli_lint(path: Path, strict: bool, watch: bool) -> None:
path = path.expanduser().resolve()
if watch:
handle_run_watcher([sys.executable, "-m", "relay.dev", "lint", str(path)], wait = True)
return
flake8 = [sys.executable, '-m', 'flake8', "dev.py", str(path)]
mypy = [sys.executable, '-m', 'mypy', "dev.py", str(path)]
if strict:
mypy.append('--strict')
click.echo('----- flake8 -----')
subprocess.run(flake8)
click.echo('\n\n----- mypy -----')
subprocess.run(mypy)
@cli.command('clean')
def cli_clean():
dirs = {
'dist',
'build',
'dist-pypi'
}
for directory in dirs:
shutil.rmtree(directory, ignore_errors = True)
for path in REPO.glob('*.egg-info'):
shutil.rmtree(path)
for path in REPO.glob('*.spec'):
path.unlink()
@cli.command('build')
def cli_build():
with TemporaryDirectory() as tmp:
arch = 'amd64' if sys.maxsize >= 2**32 else 'i386'
cmd = [
sys.executable, '-m', 'PyInstaller',
'--collect-data', 'relay',
'--collect-data', 'aiohttp_swagger',
'--hidden-import', 'pg8000',
'--hidden-import', 'sqlite3',
'--name', f'activityrelay-{__version__}-{platform.system().lower()}-{arch}',
'--workpath', tmp,
'--onefile', 'relay/__main__.py',
]
if platform.system() == 'Windows':
cmd.append('--console')
# putting the spec path on a different drive than the source dir breaks
if str(REPO)[0] == tmp[0]:
cmd.extend(['--specpath', tmp])
else:
cmd.append('--strip')
cmd.extend(['--specpath', tmp])
subprocess.run(cmd, check = False)
@cli.command('run')
@click.option('--dev', '-d', is_flag = True)
def cli_run(dev: bool):
print('Starting process watcher')
cmd = [sys.executable, '-m', 'relay', 'run']
if dev:
cmd.append('-d')
handle_run_watcher(cmd)
def handle_run_watcher(*commands: Sequence[str], wait: bool = False):
handler = WatchHandler(*commands, wait = wait)
handler.run_procs()
watcher = Observer()
watcher.schedule(handler, str(REPO), recursive=True)
watcher.start()
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
pass
handler.kill_procs()
watcher.stop()
watcher.join()
class WatchHandler(PatternMatchingEventHandler):
patterns = ['*.py']
def __init__(self, *commands: Sequence[str], wait: bool = False):
PatternMatchingEventHandler.__init__(self)
self.commands: Sequence[Sequence[str]] = commands
self.wait: bool = wait
self.procs: list[subprocess.Popen] = []
self.last_restart: datetime = datetime.now()
def kill_procs(self):
for proc in self.procs:
if proc.poll() is not None:
continue
logging.info(f'Terminating process {proc.pid}')
proc.terminate()
sec = 0.0
while proc.poll() is None:
time.sleep(0.1)
sec += 0.1
if sec >= 5:
logging.error('Failed to terminate. Killing process...')
proc.kill()
break
logging.info('Process terminated')
def run_procs(self, restart: bool = False):
if restart:
if datetime.now() - timedelta(seconds = 3) < self.last_restart:
return
self.kill_procs()
self.last_restart = datetime.now()
if self.wait:
self.procs = []
for cmd in self.commands:
logging.info('Running command: %s', ' '.join(cmd))
subprocess.run(cmd)
else:
self.procs = list(subprocess.Popen(cmd) for cmd in self.commands)
pids = (str(proc.pid) for proc in self.procs)
logging.info('Started processes with PIDs: %s', ', '.join(pids))
def on_any_event(self, event):
if event.event_type not in ['modified', 'created', 'deleted']:
return
self.run_procs(restart = True)
if __name__ == '__main__':
cli()

View file

@ -2,9 +2,21 @@
case $1 in
install)
if [[ -z ${2#$} ]]; then
host=127.0.0.1
else
host=$2
fi
if [[ -z ${3#$} ]]; then
port=8080
else
port=$3
fi
docker build -f Dockerfile -t activityrelay . && \
docker volume create activityrelay-data && \
docker run -it -p 8080:8080 -v activityrelay-data:/data --name activityrelay activityrelay
docker run -it -p target=8080,published=${host}:${port} -v activityrelay-data:/data --name activityrelay activityrelay
;;
uninstall)
@ -22,6 +34,10 @@ case $1 in
docker stop activityrelay
;;
restart)
docker restart activityrelay
;;
manage)
shift
docker exec -it activityrelay python3 -m relay "$@"
@ -54,13 +70,14 @@ case $1 in
COLS="%-22s %s\n"
echo "Valid commands:"
printf "$COLS" "- start" "Run the relay in the background"
printf "$COLS" "- stop" "Stop the relay"
printf "$COLS" "- manage <cmd> [args]" "Run a relay management command"
printf "$COLS" "- edit" "Edit the relay's config in \$EDITOR"
printf "$COLS" "- shell" "Drop into a bash shell on the running container"
printf "$COLS" "- rescue" "Drop into a bash shell on a temp container with the data volume mounted"
printf "$COLS" "- install" "Build the image, create a new container and volume, and run relay setup"
printf "$COLS" "- install [address] [port]" "Build the image, create a new container and volume, and run relay setup"
printf "$COLS" "- uninstall" "Delete the relay image, container, and volume"
;;
esac

View file

@ -1,15 +1,19 @@
# Configuration
## General
## Config File
### Domain
These options are stored in the configuration file (usually relay.yaml)
### General
#### Domain
Hostname the relay will be hosted on.
domain: relay.example.com
### Listener
#### Listener
The address and port the relay will listen on. If the reverse proxy (nginx, apache, caddy, etc)
is running on the same host, it is recommended to change `listen` to `localhost` if the reverse
@ -19,7 +23,7 @@ proxy is on the same host.
port: 8080
### Push Workers
#### Push Workers
The number of processes to spawn for pushing messages to subscribed instances. Leave it at 0 to
automatically detect how many processes should be spawned.
@ -27,21 +31,21 @@ automatically detect how many processes should be spawned.
workers: 0
### Database type
#### Database type
SQL database backend to use. Valid values are `sqlite` or `postgres`.
database_type: sqlite
### Cache type
#### Cache type
Cache backend to use. Valid values are `database` or `redis`
cache_type: database
### Sqlite File Path
#### Sqlite File Path
Path to the sqlite database file. If the path is not absolute, it is relative to the config file.
directory.
@ -49,7 +53,7 @@ directory.
sqlite_path: relay.jsonld
## Postgresql
### Postgresql
In order to use the Postgresql backend, the user and database need to be created first.
@ -57,80 +61,132 @@ In order to use the Postgresql backend, the user and database need to be created
sudo -u postgres psql -c "CREATE DATABASE activityrelay OWNER activityrelay"
### Database Name
#### Database Name
Name of the database to use.
name: activityrelay
### Host
#### Host
Hostname, IP address, or unix socket the server is hosted on.
host: /var/run/postgresql
### Port
#### Port
Port number the server is listening on.
port: 5432
### Username
#### Username
User to use when logging into the server.
user: null
### Password
#### Password
Password for the specified user.
pass: null
## Redis
### Redis
### Host
#### Host
Hostname, IP address, or unix socket the server is hosted on.
host: /var/run/postgresql
### Port
#### Port
Port number the server is listening on.
port: 5432
### Username
#### Username
User to use when logging into the server.
user: null
### Password
#### Password
Password for the specified user.
pass: null
### Database Number
#### Database Number
Number of the database to use.
database: 0
### Prefix
#### Prefix
Text to prefix every key with. It cannot contain a `:` character.
prefix: activityrelay
## Database Config
These options are stored in the database and can be changed via CLI, API, or the web interface.
### Approval Required
When enabled, instances that try to follow the relay will have to be manually approved by an admin.
approval-required: false
### Log Level
Maximum level of messages to log.
Note: Changing this setting via CLI does not actually take effect until restart.
Valid values: `DEBUG`, `VERBOSE`, `INFO`, `WARNING`, `ERROR`, `CRITICAL`
log-level: INFO
### Name
Name of your relay's instance. It will be displayed at the top of web pages and in API endpoints.
name: ActivityRelay
### Note
Short blurb that will be displayed on the relay's home and in API endpoints if set. Can be in
markdown format.
note: null
### Theme
Color theme to use for the web pages.
Valid values: `Default`, `Pink`, `Blue`
theme: Default
### Whitelist Enabled
When enabled, only instances on the whitelist can join. Any instances currently subscribed and not
in the whitelist when this is enabled can still post.
whitelist-enabled: False

View file

@ -13,9 +13,9 @@ the [official pipx docs](https://pypa.github.io/pipx/installation/) for more in-
python3 -m pip install pipx
Now simply install ActivityRelay directly from git
Now simply install ActivityRelay from pypi
pipx install git+https://git.pleroma.social/pleroma/relay@0.3.0
pipx install activityrelay
Or from a cloned git repo.
@ -36,10 +36,9 @@ be installed via [pyenv](https://github.com/pyenv/pyenv).
## Pip
The instructions for installation via pip are very similar to pipx. Installation can be done from
git
The instructions for installation via pip are very similar to pipx
python3 -m pip install git+https://git.pleroma.social/pleroma/relay@0.3.0
python3 -m pip install activityrelay
or a cloned git repo.
@ -58,10 +57,11 @@ And start the relay when finished
Installation and management via Docker can be handled with the `docker.sh` script. To install
ActivityRelay, run the install command. Once the image is built and the container is created,
your will be asked to fill out some config options for your relay.
you will be asked to fill out some config options for your relay. An address and port can be
specified to change what the relay listens on.
./docker.sh install
./docker.sh install 0.0.0.0 6942
Finally start it up. It will be listening on TCP port 8080.
Finally start it up. It will be listening on TCP localhost:8080 by default.
./docker.sh start

View file

@ -1,6 +1,3 @@
relay.example.org {
gzip
proxy / 127.0.0.1:8080 {
transparent
}
relay.example.com {
reverse_proxy / http://localhost:8080
}

View file

@ -4,6 +4,7 @@ Description=ActivityPub Relay
[Service]
WorkingDirectory=/home/relay/relay
ExecStart=/usr/bin/python3 -m relay run
Environment="IS_SYSTEMD=1"
[Install]
WantedBy=multi-user.target

View file

@ -1,56 +1,90 @@
[build-system]
requires = ["setuptools","wheel"]
build-backend = 'setuptools.build_meta'
requires = ["setuptools>=61.2"]
build-backend = "setuptools.build_meta"
[project]
name = "ActivityRelay"
description = "Generic LitePub relay (works with all LitePub consumers and Mastodon)"
license = {text = "AGPLv3"}
classifiers = [
"Environment :: Console",
"License :: OSI Approved :: GNU Affero General Public License v3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]
dependencies = [
"activitypub-utils == 0.2.1",
"aiohttp >= 3.9.1",
"aiohttp-swagger[performance] == 1.0.16",
"argon2-cffi == 23.1.0",
"barkshark-sql == 0.1.2",
"click >= 8.1.2",
"hiredis == 2.3.2",
"jinja2-haml == 0.3.5",
"markdown == 3.5.2",
"platformdirs == 4.2.0",
"pyyaml >= 6.0",
"redis == 5.0.1",
"importlib_resources == 6.1.1; python_version < '3.9'"
]
requires-python = ">=3.8"
dynamic = ["version"]
[tool.pylint.main]
jobs = 0
persistent = true
load-plugins = [
"pylint.extensions.code_style",
"pylint.extensions.comparison_placement",
"pylint.extensions.confusing_elif",
"pylint.extensions.for_any_all",
"pylint.extensions.consider_ternary_expression",
"pylint.extensions.bad_builtin",
"pylint.extensions.dict_init_mutate",
"pylint.extensions.check_elif",
"pylint.extensions.empty_comment",
"pylint.extensions.private_import",
"pylint.extensions.redefined_variable_type",
"pylint.extensions.no_self_use",
"pylint.extensions.overlapping_exceptions",
"pylint.extensions.set_membership",
"pylint.extensions.typing"
[project.readme]
file = "README.md"
content-type = "text/markdown; charset=UTF-8"
[project.urls]
Documentation = "https://git.pleroma.social/pleroma/relay/-/blob/main/docs/index.md"
Source = "https://git.pleroma.social/pleroma/relay"
Tracker = "https://git.pleroma.social/pleroma/relay/-/issues"
[project.scripts]
activityrelay = "relay.manage:main"
[project.optional-dependencies]
dev = [
"flake8 == 7.0.0",
"mypy == 1.9.0",
"pyinstaller == 6.3.0",
"watchdog == 4.0.0",
"typing_extensions >= 4.10.0; python_version < '3.11.0'"
]
[tool.pylint.design]
max-args = 10
max-attributes = 100
[tool.pylint.format]
indent-str = "\t"
indent-after-paren = 1
max-line-length = 100
single-line-if-stmt = true
[tool.pylint.messages_control]
disable = [
"fixme",
"broad-exception-caught",
"cyclic-import",
"global-statement",
"invalid-name",
"missing-module-docstring",
"too-few-public-methods",
"too-many-public-methods",
"too-many-return-statements",
"wrong-import-order",
"missing-function-docstring",
"missing-class-docstring",
"consider-using-namedtuple-or-dataclass",
"confusing-consecutive-elif"
[tool.setuptools]
zip-safe = false
packages = [
"relay",
"relay.database",
"relay.views",
]
include-package-data = true
license-files = ["LICENSE"]
[tool.setuptools.package-data]
relay = [
"data/*",
"frontend/*",
"frontend/page/*",
"frontend/static/*"
]
[tool.setuptools.dynamic]
version = {attr = "relay.__version__"}
[tool.setuptools.dynamic.optional-dependencies]
dev = {file = ["dev-requirements.txt"]}
[tool.mypy]
show_traceback = true
install_types = true
pretty = true
disallow_untyped_decorators = true
warn_redundant_casts = true
warn_unreachable = true
warn_unused_ignores = true
ignore_missing_imports = true
follow_imports = "silent"

View file

@ -1 +1 @@
__version__ = '0.3.0'
__version__ = '0.3.1'

View file

@ -8,9 +8,12 @@ import traceback
import typing
from aiohttp import web
from aiohttp.web import StaticResource
from aiohttp_swagger import setup_swagger
from aputils.signer import Signer
from datetime import datetime, timedelta
from mimetypes import guess_type
from pathlib import Path
from queue import Empty
from threading import Event, Thread
@ -26,18 +29,33 @@ from .views.api import handle_api_path
from .views.frontend import handle_frontend_path
if typing.TYPE_CHECKING:
from collections.abc import Coroutine
from tinysql import Database, Row
from collections.abc import Callable
from bsql import Database, Row
from .cache import Cache
from .misc import Message, Response
# pylint: disable=unsubscriptable-object
def get_csp(request: web.Request) -> str:
data = [
"default-src 'none'",
f"script-src 'nonce-{request['hash']}'",
f"style-src 'self' 'nonce-{request['hash']}'",
"form-action 'self'",
"connect-src 'self'",
"img-src 'self'",
"object-src 'none'",
"frame-ancestors 'none'",
f"manifest-src 'self' https://{request.app['config'].domain}"
]
return '; '.join(data) + ';'
class Application(web.Application):
DEFAULT: Application = None
DEFAULT: Application | None = None
def __init__(self, cfgpath: str | None, dev: bool = False):
def __init__(self, cfgpath: Path | None, dev: bool = False):
web.Application.__init__(self,
middlewares = [
handle_api_path,
@ -48,7 +66,7 @@ class Application(web.Application):
Application.DEFAULT = self
self['running'] = None
self['running'] = False
self['signer'] = None
self['start_time'] = None
self['cleanup_thread'] = None
@ -64,14 +82,13 @@ class Application(web.Application):
self['workers'] = []
self.cache.setup()
# self.on_response_prepare.append(handle_access_log)
self.on_cleanup.append(handle_cleanup)
self.on_cleanup.append(handle_cleanup) # type: ignore
for path, view in VIEWS:
self.router.add_view(path, view)
setup_swagger(self,
setup_swagger(
self,
ui_version = 3,
swagger_from_file = get_resource('data/swagger.yaml')
)
@ -111,6 +128,11 @@ class Application(web.Application):
self['signer'] = Signer(value, self.config.keyid)
@property
def template(self) -> Template:
return self['template']
@property
def uptime(self) -> timedelta:
if not self['start_time']:
@ -121,10 +143,20 @@ class Application(web.Application):
return timedelta(seconds=uptime.seconds)
def push_message(self, inbox: str, message: Message, instance: Row) -> None:
def push_message(self, inbox: str, message: Message | bytes, instance: Row) -> None:
self['push_queue'].put((inbox, message, instance))
def register_static_routes(self) -> None:
if self['dev']:
static = StaticResource('/static', get_resource('frontend/static'))
else:
static = CachedStaticResource('/static', get_resource('frontend/static'))
self.router.register_resource(static)
def run(self) -> None:
if self["running"]:
return
@ -137,6 +169,8 @@ class Application(web.Application):
logging.error(f'A server is already running on {host}:{port}')
return
self.register_static_routes()
logging.info(f'Starting webserver at {domain} ({host}:{port})')
asyncio.run(self.handle_run())
@ -160,6 +194,7 @@ class Application(web.Application):
self.set_signal_handler(True)
self['client'].open()
self['database'].connect()
self['cache'].setup()
self['cleanup_thread'] = CacheCleanupThread(self)
@ -174,7 +209,8 @@ class Application(web.Application):
runner = web.AppRunner(self, access_log_format='%{X-Forwarded-For}i "%r" %s %b "%{User-Agent}i"')
await runner.setup()
site = web.TCPSite(runner,
site = web.TCPSite(
runner,
host = self.config.listen,
port = self.config.port,
reuse_address = True
@ -188,7 +224,7 @@ class Application(web.Application):
await site.stop()
for worker in self['workers']: # pylint: disable=not-an-iterable
for worker in self['workers']:
worker.stop()
self.set_signal_handler(False)
@ -201,6 +237,39 @@ class Application(web.Application):
self['cache'].close()
class CachedStaticResource(StaticResource):
def __init__(self, prefix: str, path: Path):
StaticResource.__init__(self, prefix, path)
self.cache: dict[str, bytes] = {}
for filename in path.rglob('*'):
if filename.is_dir():
continue
rel_path = str(filename.relative_to(path))
with filename.open('rb') as fd:
logging.debug('Loading static resource "%s"', rel_path)
self.cache[rel_path] = fd.read()
async def _handle(self, request: web.Request) -> web.StreamResponse:
rel_url = request.match_info['filename']
if Path(rel_url).anchor:
raise web.HTTPForbidden()
try:
return web.Response(
body = self.cache[rel_url],
content_type = guess_type(rel_url)[0]
)
except KeyError:
raise web.HTTPNotFound()
class CacheCleanupThread(Thread):
def __init__(self, app: Application):
Thread.__init__(self)
@ -242,16 +311,17 @@ class PushWorker(multiprocessing.Process):
async def handle_queue(self) -> None:
client = HttpClient()
client.open()
while not self.shutdown.is_set():
try:
inbox, message, instance = self.queue.get(block=True, timeout=0.25)
await client.post(inbox, message, instance)
inbox, message, instance = self.queue.get(block=True, timeout=0.1)
asyncio.create_task(client.post(inbox, message, instance))
except Empty:
pass
await asyncio.sleep(0)
## make sure an exception doesn't bring down the worker
# make sure an exception doesn't bring down the worker
except Exception:
traceback.print_exc()
@ -259,10 +329,14 @@ class PushWorker(multiprocessing.Process):
@web.middleware
async def handle_response_headers(request: web.Request, handler: Coroutine) -> Response:
async def handle_response_headers(request: web.Request, handler: Callable) -> Response:
resp = await handler(request)
resp.headers['Server'] = 'ActivityRelay'
# Still have to figure out how csp headers work
if resp.content_type == 'text/html' and not request.path.startswith("/api"):
resp.headers['Content-Security-Policy'] = get_csp(request)
if not request.app['dev'] and request.path.endswith(('.css', '.js')):
# cache for 2 weeks
resp.headers['Cache-Control'] = 'public,max-age=1209600,immutable'

View file

@ -13,15 +13,16 @@ from .database import get_database
from .misc import Message, boolean
if typing.TYPE_CHECKING:
from typing import Any
from blib import Database
from collections.abc import Callable, Iterator
from typing import Any
from .application import Application
# todo: implement more caching backends
BACKENDS: dict[str, Cache] = {}
BACKENDS: dict[str, type[Cache]] = {}
CONVERTERS: dict[str, tuple[Callable, Callable]] = {
'str': (str, str),
'int': (str, int),
@ -71,7 +72,7 @@ class Item:
data.value = deserialize_value(data.value, data.value_type)
if not isinstance(data.updated, datetime):
data.updated = datetime.fromtimestamp(data.updated, tz = timezone.utc)
data.updated = datetime.fromtimestamp(data.updated, tz = timezone.utc) # type: ignore
return data
@ -143,7 +144,7 @@ class Cache(ABC):
item.namespace,
item.key,
item.value,
item.type
item.value_type
)
@ -158,7 +159,7 @@ class SqlCache(Cache):
def __init__(self, app: Application):
Cache.__init__(self, app)
self._db = None
self._db: Database = None
def get(self, namespace: str, key: str) -> Item:
@ -257,7 +258,7 @@ class RedisCache(Cache):
def __init__(self, app: Application):
Cache.__init__(self, app)
self._rd = None
self._rd: Redis = None # type: ignore
@property
@ -275,7 +276,7 @@ class RedisCache(Cache):
if not (raw_value := self._rd.get(key_name)):
raise KeyError(f'{namespace}:{key}')
value_type, updated, value = raw_value.split(':', 2)
value_type, updated, value = raw_value.split(':', 2) # type: ignore
return Item.from_data(
namespace,
key,
@ -302,7 +303,7 @@ class RedisCache(Cache):
yield namespace
def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> None:
def set(self, namespace: str, key: str, value: Any, value_type: str = 'key') -> Item:
date = datetime.now(tz = timezone.utc).timestamp()
value = serialize_value(value, value_type)
@ -311,6 +312,8 @@ class RedisCache(Cache):
f'{value_type}:{date}:{value}'
)
return self.get(namespace, key)
def delete(self, namespace: str, key: str) -> None:
self._rd.delete(self.get_key_name(namespace, key))
@ -350,7 +353,7 @@ class RedisCache(Cache):
options['host'] = self.app.config.rd_host
options['port'] = self.app.config.rd_port
self._rd = Redis(**options)
self._rd = Redis(**options) # type: ignore
def close(self) -> None:
@ -358,4 +361,4 @@ class RedisCache(Cache):
return
self._rd.close()
self._rd = None
self._rd = None # type: ignore

View file

@ -9,16 +9,12 @@ from functools import cached_property
from pathlib import Path
from urllib.parse import urlparse
from . import logger as logging
from .misc import Message, boolean
from .misc import boolean
if typing.TYPE_CHECKING:
from collections.abc import Iterator
from typing import Any
# pylint: disable=duplicate-code
class RelayConfig(dict):
def __init__(self, path: str):
dict.__init__(self, {})
@ -46,7 +42,7 @@ class RelayConfig(dict):
@property
def db(self) -> RelayDatabase:
def db(self) -> Path:
return Path(self['db']).expanduser().resolve()
@ -184,121 +180,3 @@ class RelayDatabase(dict):
except json.decoder.JSONDecodeError as e:
if self.config.db.stat().st_size > 0:
raise e from None
def save(self) -> None:
with self.config.db.open('w', encoding = 'UTF-8') as fd:
json.dump(self, fd, indent=4)
def get_inbox(self, domain: str, fail: bool = False) -> dict[str, str] | None:
if domain.startswith('http'):
domain = urlparse(domain).hostname
if (inbox := self['relay-list'].get(domain)):
return inbox
if fail:
raise KeyError(domain)
return None
def add_inbox(self,
inbox: str,
followid: str | None = None,
software: str | None = None) -> dict[str, str]:
assert inbox.startswith('https'), 'Inbox must be a url'
domain = urlparse(inbox).hostname
if (instance := self.get_inbox(domain)):
if followid:
instance['followid'] = followid
if software:
instance['software'] = software
return instance
self['relay-list'][domain] = {
'domain': domain,
'inbox': inbox,
'followid': followid,
'software': software
}
logging.verbose('Added inbox to database: %s', inbox)
return self['relay-list'][domain]
def del_inbox(self,
domain: str,
followid: str = None,
fail: bool = False) -> bool:
if not (data := self.get_inbox(domain, fail=False)):
if fail:
raise KeyError(domain)
return False
if not data['followid'] or not followid or data['followid'] == followid:
del self['relay-list'][data['domain']]
logging.verbose('Removed inbox from database: %s', data['inbox'])
return True
if fail:
raise ValueError('Follow IDs do not match')
logging.debug('Follow ID does not match: db = %s, object = %s', data['followid'], followid)
return False
def get_request(self, domain: str, fail: bool = True) -> dict[str, str] | None:
if domain.startswith('http'):
domain = urlparse(domain).hostname
try:
return self['follow-requests'][domain]
except KeyError as e:
if fail:
raise e
return None
def add_request(self, actor: str, inbox: str, followid: str) -> None:
domain = urlparse(inbox).hostname
try:
request = self.get_request(domain)
request['followid'] = followid
except KeyError:
pass
self['follow-requests'][domain] = {
'actor': actor,
'inbox': inbox,
'followid': followid
}
def del_request(self, domain: str) -> None:
if domain.startswith('http'):
domain = urlparse(domain).hostname
del self['follow-requests'][domain]
def distill_inboxes(self, message: Message) -> Iterator[str]:
src_domains = {
message.domain,
urlparse(message.object_id).netloc
}
for domain, instance in self['relay-list'].items():
if domain not in src_domains:
yield instance['inbox']

View file

@ -6,6 +6,7 @@ import platform
import typing
import yaml
from dataclasses import asdict, dataclass, fields
from pathlib import Path
from platformdirs import user_config_dir
@ -14,6 +15,12 @@ from .misc import IS_DOCKER
if typing.TYPE_CHECKING:
from typing import Any
try:
from typing import Self
except ImportError:
from typing_extensions import Self
if platform.system() == 'Windows':
import multiprocessing
@ -23,61 +30,44 @@ else:
CORE_COUNT = len(os.sched_getaffinity(0))
DEFAULTS: dict[str, Any] = {
DOCKER_VALUES = {
'listen': '0.0.0.0',
'port': 8080,
'domain': 'relay.example.com',
'workers': CORE_COUNT,
'db_type': 'sqlite',
'ca_type': 'database',
'sq_path': 'relay.sqlite3',
'pg_host': '/var/run/postgresql',
'pg_port': 5432,
'pg_user': getpass.getuser(),
'pg_pass': None,
'pg_name': 'activityrelay',
'rd_host': 'localhost',
'rd_port': 6379,
'rd_user': None,
'rd_pass': None,
'rd_database': 0,
'rd_prefix': 'activityrelay'
'sq_path': '/data/relay.sqlite3'
}
if IS_DOCKER:
DEFAULTS['sq_path'] = '/data/relay.jsonld'
class NOVALUE:
pass
@dataclass(init = False)
class Config:
def __init__(self, path: str, load: bool = False):
if path:
self.path = Path(path).expanduser().resolve()
listen: str = '0.0.0.0'
port: int = 8080
domain: str = 'relay.example.com'
workers: int = CORE_COUNT
db_type: str = 'sqlite'
ca_type: str = 'database'
sq_path: str = 'relay.sqlite3'
else:
self.path = Config.get_config_dir()
pg_host: str = '/var/run/postgresql'
pg_port: int = 5432
pg_user: str = getpass.getuser()
pg_pass: str | None = None
pg_name: str = 'activityrelay'
self.listen = None
self.port = None
self.domain = None
self.workers = None
self.db_type = None
self.ca_type = None
self.sq_path = None
rd_host: str = 'localhost'
rd_port: int = 6470
rd_user: str | None = None
rd_pass: str | None = None
rd_database: int = 0
rd_prefix: str = 'activityrelay'
self.pg_host = None
self.pg_port = None
self.pg_user = None
self.pg_pass = None
self.pg_name = None
self.rd_host = None
self.rd_port = None
self.rd_user = None
self.rd_pass = None
self.rd_database = None
self.rd_prefix = None
def __init__(self, path: Path | None = None, load: bool = False):
self.path = Config.get_config_dir(path)
self.reset()
if load:
try:
@ -87,22 +77,39 @@ class Config:
self.save()
@staticmethod
def get_config_dir(path: str | None = None) -> Path:
if path:
return Path(path).expanduser().resolve()
@classmethod
def KEYS(cls: type[Self]) -> list[str]:
return list(cls.__dataclass_fields__)
dirs = (
@classmethod
def DEFAULT(cls: type[Self], key: str) -> str | int | None:
for field in fields(cls):
if field.name == key:
return field.default # type: ignore
raise KeyError(key)
@staticmethod
def get_config_dir(path: Path | str | None = None) -> Path:
if isinstance(path, str):
path = Path(path)
if path is not None:
return path.expanduser().resolve()
paths = (
Path("relay.yaml").resolve(),
Path(user_config_dir("activityrelay"), "relay.yaml"),
Path("/etc/activityrelay/relay.yaml")
)
for directory in dirs:
if directory.exists():
return directory
for cfgfile in paths:
if cfgfile.exists():
return cfgfile
return dirs[0]
return paths[0]
@property
@ -130,7 +137,6 @@ class Config:
def load(self) -> None:
self.reset()
options = {}
try:
@ -141,95 +147,85 @@ class Config:
with self.path.open('r', encoding = 'UTF-8') as fd:
config = yaml.load(fd, **options)
pgcfg = config.get('postgresql', {})
rdcfg = config.get('redis', {})
if not config:
raise ValueError('Config is empty')
if IS_DOCKER:
self.listen = '0.0.0.0'
self.port = 8080
self.sq_path = '/data/relay.jsonld'
pgcfg = config.get('postgresql', {})
rdcfg = config.get('redis', {})
else:
self.set('listen', config.get('listen', DEFAULTS['listen']))
self.set('port', config.get('port', DEFAULTS['port']))
self.set('sq_path', config.get('sqlite_path', DEFAULTS['sq_path']))
for key in type(self).KEYS():
if IS_DOCKER and key in {'listen', 'port', 'sq_path'}:
self.set(key, DOCKER_VALUES[key])
continue
self.set('workers', config.get('workers', DEFAULTS['workers']))
self.set('domain', config.get('domain', DEFAULTS['domain']))
self.set('db_type', config.get('database_type', DEFAULTS['db_type']))
self.set('ca_type', config.get('cache_type', DEFAULTS['ca_type']))
for key in DEFAULTS:
if key.startswith('pg'):
try:
self.set(key, pgcfg[key[3:]])
except KeyError:
self.set(key, pgcfg.get(key[3:], NOVALUE))
continue
elif key.startswith('rd'):
try:
self.set(key, rdcfg[key[3:]])
except KeyError:
self.set(key, rdcfg.get(key[3:], NOVALUE))
continue
cfgkey = key
if key == 'db_type':
cfgkey = 'database_type'
elif key == 'ca_type':
cfgkey = 'cache_type'
elif key == 'sq_path':
cfgkey = 'sqlite_path'
self.set(key, config.get(cfgkey, NOVALUE))
def reset(self) -> None:
for key, value in DEFAULTS.items():
setattr(self, key, value)
for field in fields(self):
setattr(self, field.name, field.default)
def save(self) -> None:
self.path.parent.mkdir(exist_ok = True, parents = True)
data: dict[str, Any] = {}
for key, value in asdict(self).items():
if key.startswith('pg_'):
if 'postgres' not in data:
data['postgres'] = {}
data['postgres'][key[3:]] = value
continue
if key.startswith('rd_'):
if 'redis' not in data:
data['redis'] = {}
data['redis'][key[3:]] = value
continue
if key == 'db_type':
key = 'database_type'
elif key == 'ca_type':
key = 'cache_type'
elif key == 'sq_path':
key = 'sqlite_path'
data[key] = value
with self.path.open('w', encoding = 'utf-8') as fd:
yaml.dump(self.to_dict(), fd, sort_keys = False)
yaml.dump(data, fd, sort_keys = False)
def set(self, key: str, value: Any) -> None:
if key not in DEFAULTS:
if key not in type(self).KEYS():
raise KeyError(key)
if key in {'port', 'pg_port', 'workers'} and not isinstance(value, int):
if (value := int(value)) < 1:
if key == 'port':
value = 8080
elif key == 'pg_port':
value = 5432
elif key == 'workers':
value = len(os.sched_getaffinity(0))
if value is NOVALUE:
return
setattr(self, key, value)
def to_dict(self) -> dict[str, Any]:
return {
'listen': self.listen,
'port': self.port,
'domain': self.domain,
'workers': self.workers,
'database_type': self.db_type,
'cache_type': self.ca_type,
'sqlite_path': self.sq_path,
'postgres': {
'host': self.pg_host,
'port': self.pg_port,
'user': self.pg_user,
'pass': self.pg_pass,
'name': self.pg_name
},
'redis': {
'host': self.rd_host,
'port': self.rd_port,
'user': self.rd_user,
'pass': self.rd_pass,
'database': self.rd_database,
'refix': self.rd_prefix
}
}

View file

@ -23,17 +23,26 @@ SELECT * FROM inboxes WHERE domain = :value or inbox = :value or actor = :value;
-- name: put-inbox
INSERT INTO inboxes (domain, actor, inbox, followid, software, created)
VALUES (:domain, :actor, :inbox, :followid, :software, :created)
ON CONFLICT (domain) DO UPDATE SET followid = :followid
INSERT INTO inboxes (domain, actor, inbox, followid, software, accepted, created)
VALUES (:domain, :actor, :inbox, :followid, :software, :accepted, :created)
ON CONFLICT (domain) DO
UPDATE SET followid = :followid, inbox = :inbox, software = :software, created = :created
RETURNING *;
-- name: put-inbox-accept
UPDATE inboxes SET accepted = :accepted WHERE domain = :domain RETURNING *;
-- name: del-inbox
DELETE FROM inboxes
WHERE domain = :value or inbox = :value or actor = :value;
-- name: get-request
SELECT * FROM inboxes WHERE accepted = 0 and domain = :domain;
-- name: get-user
SELECT * FROM users
WHERE username = :value or handle = :value;

View file

@ -13,6 +13,10 @@ schemes:
- https
securityDefinitions:
Cookie:
type: apiKey
in: cookie
name: user-token
Bearer:
type: apiKey
name: Authorization
@ -285,6 +289,50 @@ paths:
schema:
$ref: "#/definitions/Error"
/v1/request:
get:
tags:
- Follow Request
description: Get the list of follow requests
produces:
- application/json
responses:
"200":
description: List of instances
schema:
type: array
items:
$ref: "#/definitions/Instance"
post:
tags:
- Follow Request
description: Approve or deny a follow request
parameters:
- in: formData
name: domain
required: true
type: string
- in: formData
name: accept
required: true
type: boolean
consumes:
- application/json
- multipart/form-data
- application/x-www-form-urlencoded
produces:
- application/json
responses:
"200":
description: Follow request successfully accepted or denied
schema:
$ref: "#/definitions/Message"
"500":
description: Follow request does not exist
schema:
$ref: "#/definitions/Error"
/v1/domain_ban:
get:
tags:
@ -505,6 +553,104 @@ paths:
schema:
$ref: "#/definitions/Error"
/v1/user:
get:
tags:
- User
description: Get a list of all local users
produces:
- application/json
responses:
"200":
description: List of users
schema:
type: array
items:
$ref: "#/definitions/User"
post:
tags:
- User
description: Create a new user
parameters:
- in: formData
name: username
required: true
type: string
- in: formData
name: password
required: true
type: string
format: password
- in: formData
name: handle
required: false
type: string
format: email
produces:
- application/json
responses:
"200":
description: Newly created user
schema:
$ref: "#/definitions/User"
"404":
description: User already exists
schema:
$ref: "#/definitions/Error"
patch:
tags:
- User
description: Update a user's password or handle
parameters:
- in: formData
name: username
required: true
type: string
- in: formData
name: password
required: false
type: string
format: password
- in: formData
name: handle
required: false
type: string
format: email
produces:
- application/json
responses:
"200":
description: Updated user data
schema:
$ref: "#/definitions/User"
"404":
description: User does not exist
schema:
$ref: "#/definitions/Error"
delete:
tags:
- User
description: Delete a user
parameters:
- in: formData
name: username
required: true
type: string
produces:
- application/json
responses:
"202":
description: Successfully deleted user
schema:
$ref: "#/definitions/Message"
"404":
description: User not found
schema:
$ref: "#/definitions/Error"
/v1/whitelist:
get:
tags:
@ -672,6 +818,9 @@ definitions:
software:
description: Nodeinfo-formatted name of the instance's software
type: string
accepted:
description: Whether or not the follow request has been accepted
type: boolean
created:
description: Date the instance joined or was added
type: string
@ -701,6 +850,21 @@ definitions:
description: Character string used for authenticating with the api
type: string
User:
type: object
properties:
username:
description: Username of the account
type: string
handle:
description: Fediverse handle associated with the account
type: string
format: email
created:
description: Date the account was created
type: string
format: date-time
Whitelist:
type: object
properties:

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import bsql
import typing
from .config import CONFIG_DEFAULTS, THEMES, get_default_value
from .config import THEMES, ConfigData
from .connection import RELAY_SOFTWARE, Connection
from .schema import TABLES, VERSIONS, migrate_0
@ -11,7 +11,7 @@ from .. import logger as logging
from ..misc import get_resource
if typing.TYPE_CHECKING:
from .config import Config
from ..config import Config
def get_database(config: Config, migrate: bool = True) -> bsql.Database:
@ -46,13 +46,14 @@ def get_database(config: Config, migrate: bool = True) -> bsql.Database:
migrate_0(conn)
return db
if (schema_ver := conn.get_config('schema-version')) < get_default_value('schema-version'):
if (schema_ver := conn.get_config('schema-version')) < ConfigData.DEFAULT('schema-version'):
logging.info("Migrating database from version '%i'", schema_ver)
for ver, func in VERSIONS.items():
if schema_ver < ver:
func(conn)
conn.put_config('schema-version', ver)
logging.info("Updated database to %i", ver)
if (privkey := conn.get_config('private-key')):
conn.app.signer = privkey

View file

@ -1,15 +1,23 @@
from __future__ import annotations
import json
import typing
from dataclasses import Field, asdict, dataclass, fields
from .. import logger as logging
from ..misc import boolean
if typing.TYPE_CHECKING:
from collections.abc import Callable
from bsql import Row
from collections.abc import Callable, Sequence
from typing import Any
try:
from typing import Self
except ImportError:
from typing_extensions import Self
THEMES = {
'default': {
@ -59,39 +67,101 @@ THEMES = {
}
}
CONFIG_DEFAULTS: dict[str, tuple[str, Any]] = {
'schema-version': ('int', 20240206),
'private-key': ('str', None),
'log-level': ('loglevel', logging.LogLevel.INFO),
'name': ('str', 'ActivityRelay'),
'note': ('str', 'Make a note about your instance here.'),
'theme': ('str', 'default'),
'whitelist-enabled': ('bool', False)
}
# serializer | deserializer
CONFIG_CONVERT: dict[str, tuple[Callable, Callable]] = {
CONFIG_CONVERT: dict[str, tuple[Callable[[Any], str], Callable[[str], Any]]] = {
'str': (str, str),
'int': (str, int),
'bool': (str, boolean),
'json': (json.dumps, json.loads),
'loglevel': (lambda x: x.name, logging.LogLevel.parse)
'logging.LogLevel': (lambda x: x.name, logging.LogLevel.parse)
}
def get_default_value(key: str) -> Any:
return CONFIG_DEFAULTS[key][1]
@dataclass()
class ConfigData:
schema_version: int = 20240310
private_key: str = ''
approval_required: bool = False
log_level: logging.LogLevel = logging.LogLevel.INFO
name: str = 'ActivityRelay'
note: str = ''
theme: str = 'default'
whitelist_enabled: bool = False
def get_default_type(key: str) -> str:
return CONFIG_DEFAULTS[key][0]
def __getitem__(self, key: str) -> Any:
if (value := getattr(self, key.replace('-', '_'), None)) is None:
raise KeyError(key)
return value
def serialize(key: str, value: Any) -> str:
type_name = get_default_type(key)
return CONFIG_CONVERT[type_name][0](value)
def __setitem__(self, key: str, value: Any) -> None:
self.set(key, value)
def deserialize(key: str, value: str) -> Any:
type_name = get_default_type(key)
return CONFIG_CONVERT[type_name][1](value)
@classmethod
def KEYS(cls: type[Self]) -> Sequence[str]:
return list(cls.__dataclass_fields__)
@staticmethod
def SYSTEM_KEYS() -> Sequence[str]:
return ('schema-version', 'schema_version', 'private-key', 'private_key')
@classmethod
def USER_KEYS(cls: type[Self]) -> Sequence[str]:
return tuple(key for key in cls.KEYS() if key not in cls.SYSTEM_KEYS())
@classmethod
def DEFAULT(cls: type[Self], key: str) -> str | int | bool:
return cls.FIELD(key.replace('-', '_')).default # type: ignore
@classmethod
def FIELD(cls: type[Self], key: str) -> Field:
for field in fields(cls):
if field.name == key.replace('-', '_'):
return field
raise KeyError(key)
@classmethod
def from_rows(cls: type[Self], rows: Sequence[Row]) -> Self:
data = cls()
set_schema_version = False
for row in rows:
data.set(row['key'], row['value'])
if row['key'] == 'schema-version':
set_schema_version = True
if not set_schema_version:
data.schema_version = 0
return data
def get(self, key: str, default: Any = None, serialize: bool = False) -> Any:
field = type(self).FIELD(key)
value = getattr(self, field.name, None)
if not serialize:
return value
converter = CONFIG_CONVERT[str(field.type)][0]
return converter(value)
def set(self, key: str, value: Any) -> None:
field = type(self).FIELD(key)
converter = CONFIG_CONVERT[str(field.type)][1]
setattr(self, field.name, converter(value))
def to_dict(self) -> dict[str, Any]:
return {key.replace('_', '-'): value for key, value in asdict(self).items()}

View file

@ -9,22 +9,18 @@ from urllib.parse import urlparse
from uuid import uuid4
from .config import (
CONFIG_DEFAULTS,
THEMES,
get_default_type,
get_default_value,
serialize,
deserialize
ConfigData
)
from .. import logger as logging
from ..misc import boolean, get_app
if typing.TYPE_CHECKING:
from collections.abc import Iterator
from collections.abc import Iterator, Sequence
from bsql import Row
from typing import Any
from .application import Application
from ..application import Application
from ..misc import Message
@ -46,54 +42,37 @@ class Connection(SqlConnection):
return get_app()
def distill_inboxes(self, message: Message) -> Iterator[str]:
def distill_inboxes(self, message: Message) -> Iterator[Row]:
src_domains = {
message.domain,
urlparse(message.object_id).netloc
}
for inbox in self.execute('SELECT * FROM inboxes'):
if inbox['domain'] not in src_domains:
yield inbox['inbox']
for instance in self.get_inboxes():
if instance['domain'] not in src_domains:
yield instance
def get_config(self, key: str) -> Any:
if key not in CONFIG_DEFAULTS:
raise KeyError(key)
key = key.replace('_', '-')
with self.run('get-config', {'key': key}) as cur:
if not (row := cur.one()):
return get_default_value(key)
return ConfigData.DEFAULT(key)
if row['value']:
return deserialize(row['key'], row['value'])
return None
data = ConfigData()
data.set(row['key'], row['value'])
return data.get(key)
def get_config_all(self) -> dict[str, Any]:
def get_config_all(self) -> ConfigData:
with self.run('get-config-all', None) as cur:
db_config = {row['key']: row['value'] for row in cur}
config = {}
for key, data in CONFIG_DEFAULTS.items():
try:
config[key] = deserialize(key, db_config[key])
except KeyError:
if key == 'schema-version':
config[key] = 0
else:
config[key] = data[1]
return config
return ConfigData.from_rows(tuple(cur.all()))
def put_config(self, key: str, value: Any) -> Any:
if key not in CONFIG_DEFAULTS:
raise KeyError(key)
field = ConfigData.FIELD(key)
key = field.name.replace('_', '-')
if key == 'private-key':
self.app.signer = value
@ -102,73 +81,70 @@ class Connection(SqlConnection):
value = logging.LogLevel.parse(value)
logging.set_level(value)
elif key == 'whitelist-enabled':
elif key in {'approval-required', 'whitelist-enabled'}:
value = boolean(value)
elif key == 'theme':
if value not in THEMES:
raise ValueError(f'"{value}" is not a valid theme')
data = ConfigData()
data.set(key, value)
params = {
'key': key,
'value': serialize(key, value) if value is not None else None,
'type': get_default_type(key)
'value': data.get(key, serialize = True),
'type': 'LogLevel' if field.type == 'logging.LogLevel' else field.type
}
with self.run('put-config', params):
return value
pass
return data.get(key)
def get_inbox(self, value: str) -> Row:
with self.run('get-inbox', {'value': value}) as cur:
return cur.one()
return cur.one() # type: ignore
def get_inboxes(self) -> Sequence[Row]:
with self.execute("SELECT * FROM inboxes WHERE accepted = 1") as cur:
return tuple(cur.all())
def put_inbox(self,
domain: str,
inbox: str,
inbox: str | None = None,
actor: str | None = None,
followid: str | None = None,
software: str | None = None) -> Row:
software: str | None = None,
accepted: bool = True) -> Row:
params = {
'domain': domain,
params: dict[str, Any] = {
'inbox': inbox,
'actor': actor,
'followid': followid,
'software': software,
'created': datetime.now(tz = timezone.utc)
'accepted': accepted
}
if not self.get_inbox(domain):
if not inbox:
raise ValueError("Missing inbox")
params['domain'] = domain
params['created'] = datetime.now(tz = timezone.utc)
with self.run('put-inbox', params) as cur:
return cur.one()
return cur.one() # type: ignore
for key, value in tuple(params.items()):
if value is None:
del params[key]
def update_inbox(self,
inbox: str,
actor: str | None = None,
followid: str | None = None,
software: str | None = None) -> Row:
if not (actor or followid or software):
raise ValueError('Missing "actor", "followid", and/or "software"')
data = {}
if actor:
data['actor'] = actor
if followid:
data['followid'] = followid
if software:
data['software'] = software
statement = Update('inboxes', data)
statement.set_where("inbox", inbox)
with self.query(statement):
return self.get_inbox(inbox)
with self.update('inboxes', params, domain = domain) as cur:
return cur.one() # type: ignore
def del_inbox(self, value: str) -> bool:
@ -179,17 +155,64 @@ class Connection(SqlConnection):
return cur.row_count == 1
def get_request(self, domain: str) -> Row:
with self.run('get-request', {'domain': domain}) as cur:
if not (row := cur.one()):
raise KeyError(domain)
return row
def get_requests(self) -> Sequence[Row]:
with self.execute('SELECT * FROM inboxes WHERE accepted = 0') as cur:
return tuple(cur.all())
def put_request_response(self, domain: str, accepted: bool) -> Row:
instance = self.get_request(domain)
if not accepted:
self.del_inbox(domain)
return instance
params = {
'domain': domain,
'accepted': accepted
}
with self.run('put-inbox-accept', params) as cur:
return cur.one() # type: ignore
def get_user(self, value: str) -> Row:
with self.run('get-user', {'value': value}) as cur:
return cur.one()
return cur.one() # type: ignore
def get_user_by_token(self, code: str) -> Row:
with self.run('get-user-by-token', {'code': code}) as cur:
return cur.one()
return cur.one() # type: ignore
def put_user(self, username: str, password: str, handle: str | None = None) -> Row:
def put_user(self, username: str, password: str | None, handle: str | None = None) -> Row:
if self.get_user(username):
data: dict[str, str | datetime | None] = {}
if password:
data['hash'] = self.hasher.hash(password)
if handle:
data['handle'] = handle
stmt = Update("users", data)
stmt.set_where("username", username)
with self.query(stmt) as cur:
return cur.one() # type: ignore
if password is None:
raise ValueError('Password cannot be empty')
data = {
'username': username,
'hash': self.hasher.hash(password),
@ -198,7 +221,7 @@ class Connection(SqlConnection):
}
with self.run('put-user', data) as cur:
return cur.one()
return cur.one() # type: ignore
def del_user(self, username: str) -> None:
@ -213,7 +236,7 @@ class Connection(SqlConnection):
def get_token(self, code: str) -> Row:
with self.run('get-token', {'code': code}) as cur:
return cur.one()
return cur.one() # type: ignore
def put_token(self, username: str) -> Row:
@ -224,7 +247,7 @@ class Connection(SqlConnection):
}
with self.run('put-token', data) as cur:
return cur.one()
return cur.one() # type: ignore
def del_token(self, code: str) -> None:
@ -237,7 +260,7 @@ class Connection(SqlConnection):
domain = urlparse(domain).netloc
with self.run('get-domain-ban', {'domain': domain}) as cur:
return cur.one()
return cur.one() # type: ignore
def put_domain_ban(self,
@ -253,7 +276,7 @@ class Connection(SqlConnection):
}
with self.run('put-domain-ban', params) as cur:
return cur.one()
return cur.one() # type: ignore
def update_domain_ban(self,
@ -292,7 +315,7 @@ class Connection(SqlConnection):
def get_software_ban(self, name: str) -> Row:
with self.run('get-software-ban', {'name': name}) as cur:
return cur.one()
return cur.one() # type: ignore
def put_software_ban(self,
@ -308,7 +331,7 @@ class Connection(SqlConnection):
}
with self.run('put-software-ban', params) as cur:
return cur.one()
return cur.one() # type: ignore
def update_software_ban(self,
@ -347,7 +370,7 @@ class Connection(SqlConnection):
def get_domain_whitelist(self, domain: str) -> Row:
with self.run('get-domain-whitelist', {'domain': domain}) as cur:
return cur.one()
return cur.one() # type: ignore
def put_domain_whitelist(self, domain: str) -> Row:
@ -357,7 +380,7 @@ class Connection(SqlConnection):
}
with self.run('put-domain-whitelist', params) as cur:
return cur.one()
return cur.one() # type: ignore
def del_domain_whitelist(self, domain: str) -> bool:

View file

@ -2,12 +2,13 @@ from __future__ import annotations
import typing
from bsql import Column, Connection, Table, Tables
from bsql import Column, Table, Tables
from .config import get_default_value
from .config import ConfigData
if typing.TYPE_CHECKING:
from collections.abc import Callable
from .connection import Connection
VERSIONS: dict[int, Callable] = {}
@ -25,6 +26,7 @@ TABLES: Tables = Tables(
Column('inbox', 'text', unique = True, nullable = False),
Column('followid', 'text'),
Column('software', 'text'),
Column('accepted', 'boolean'),
Column('created', 'timestamp', nullable = False)
),
Table(
@ -70,9 +72,15 @@ def migration(func: Callable) -> Callable:
def migrate_0(conn: Connection) -> None:
conn.create_tables()
conn.put_config('schema-version', get_default_value('schema-version'))
conn.put_config('schema-version', ConfigData.DEFAULT('schema-version'))
@migration
def migrate_20240206(conn: Connection) -> None:
conn.create_tables()
@migration
def migrate_20240310(conn: Connection) -> None:
conn.execute("ALTER TABLE inboxes ADD COLUMN accepted BOOLEAN")
conn.execute("UPDATE inboxes SET accepted = 1")

View file

@ -1,164 +0,0 @@
import click
import platform
import subprocess
import sys
import time
from datetime import datetime
from pathlib import Path
from tempfile import TemporaryDirectory
from . import __version__
try:
from watchdog.observers import Observer
from watchdog.events import PatternMatchingEventHandler
except ImportError:
class PatternMatchingEventHandler:
pass
SCRIPT = Path(__file__).parent
REPO = SCRIPT.parent
IGNORE_EXT = {
'.py',
'.pyc'
}
@click.group('cli')
def cli():
'Useful commands for development'
@cli.command('install')
def cli_install():
cmd = [
sys.executable, '-m', 'pip', 'install',
'-r', 'requirements.txt',
'-r', 'dev-requirements.txt'
]
subprocess.run(cmd, check = False)
@cli.command('lint')
@click.argument('path', required = False, default = 'relay')
def cli_lint(path):
subprocess.run([sys.executable, '-m', 'flake8', path], check = False)
subprocess.run([sys.executable, '-m', 'pylint', path], check = False)
@cli.command('build')
def cli_build():
with TemporaryDirectory() as tmp:
arch = 'amd64' if sys.maxsize >= 2**32 else 'i386'
cmd = [
sys.executable, '-m', 'PyInstaller',
'--collect-data', 'relay',
'--collect-data', 'aiohttp_swagger',
'--hidden-import', 'pg8000',
'--hidden-import', 'sqlite3',
'--name', f'activityrelay-{__version__}-{platform.system().lower()}-{arch}',
'--workpath', tmp,
'--onefile', 'relay/__main__.py',
]
if platform.system() == 'Windows':
cmd.append('--console')
# putting the spec path on a different drive than the source dir breaks
if str(SCRIPT)[0] == tmp[0]:
cmd.extend(['--specpath', tmp])
else:
cmd.append('--strip')
cmd.extend(['--specpath', tmp])
subprocess.run(cmd, check = False)
@cli.command('run')
def cli_run():
print('Starting process watcher')
handler = WatchHandler()
handler.run_proc()
watcher = Observer()
watcher.schedule(handler, str(SCRIPT), recursive=True)
watcher.start()
try:
while True:
handler.proc.stdin.write(sys.stdin.read().encode('UTF-8'))
handler.proc.stdin.flush()
except KeyboardInterrupt:
pass
handler.kill_proc()
watcher.stop()
watcher.join()
class WatchHandler(PatternMatchingEventHandler):
patterns = ['*.py']
cmd = [sys.executable, '-m', 'relay', 'run', '-d']
def __init__(self):
PatternMatchingEventHandler.__init__(self)
self.proc = None
self.last_restart = None
def kill_proc(self):
if self.proc.poll() is not None:
return
print(f'Terminating process {self.proc.pid}')
self.proc.terminate()
sec = 0.0
while self.proc.poll() is None:
time.sleep(0.1)
sec += 0.1
if sec >= 5:
print('Failed to terminate. Killing process...')
self.proc.kill()
break
print('Process terminated')
def run_proc(self, restart=False):
timestamp = datetime.timestamp(datetime.now())
self.last_restart = timestamp if not self.last_restart else 0
if restart and self.proc.pid != '':
if timestamp - 3 < self.last_restart:
return
self.kill_proc()
# pylint: disable=consider-using-with
self.proc = subprocess.Popen(self.cmd, stdin = subprocess.PIPE)
self.last_restart = timestamp
print(f'Started process with PID {self.proc.pid}')
def on_any_event(self, event):
if event.event_type not in ['modified', 'created', 'deleted']:
return
self.run_proc(restart = True)
if __name__ == '__main__':
cli()

View file

@ -11,8 +11,10 @@
%title << {{config.name}}: {{page}}
%meta(charset="UTF-8")
%meta(name="viewport" content="width=device-width, initial-scale=1")
%link(rel="stylesheet" type="text/css" href="/theme/{{theme_name}}.css")
%link(rel="stylesheet" type="text/css" href="/style.css")
%link(rel="stylesheet" type="text/css" href="/theme/{{config.theme}}.css" nonce="{{view.request['hash']}}" class="theme")
%link(rel="stylesheet" type="text/css" href="/static/style.css" nonce="{{view.request['hash']}}")
%link(rel="manifest" href="/manifest.json")
%script(type="application/javascript" src="/static/api.js" nonce="{{view.request['hash']}}" defer)
-block head
%body
@ -35,22 +37,23 @@
-else
{{menu_item("Login", "/login")}}
%ul#notifications
#container
#header.section
%span#menu-open << &#8286;
%span.title-container
%a.title(href="/") -> =config.name
-if view.request.path not in ["/", "/login"]
.page -> =page
.empty
-if error
.error.section -> =error
%fieldset.error.section
%legend << Error
=error
-if message
.message.section -> =message
%fieldset.message.section
%legend << Message
=message
#content(class="page-{{page.lower().replace(' ', '_')}}")
-block content
@ -69,26 +72,3 @@
.version
%a(href="https://git.pleroma.social/pleroma/relay")
ActivityRelay/{{version}}
%script(type="application/javascript")
const body = document.getElementById("container")
const menu = document.getElementById("menu");
const menu_open = document.getElementById("menu-open");
const menu_close = document.getElementById("menu-close");
menu_open.addEventListener("click", (event) => {
var new_value = menu.attributes.visible.nodeValue === "true" ? "false" : "true";
menu.attributes.visible.nodeValue = new_value;
});
menu_close.addEventListener("click", (event) => {
menu.attributes.visible.nodeValue = "false"
});
body.addEventListener("click", (event) => {
if (event.target === menu_open) {
return;
}
menu.attributes.visible.nodeValue = "false";
});

View file

@ -0,0 +1,16 @@
-macro new_checkbox(name, checked)
-if checked
%input(id="{{name}}" type="checkbox" checked)
-else
%input(id="{{name}}" type="checkbox")
-macro new_select(name, selected, items)
%select(id="{{name}}")
-for item in items
-if item == selected
%option(value="{{item}}" selected) -> =item.title()
-else
%option(value="{{item}}") -> =item.title()

View file

@ -1,37 +1,29 @@
-extends "base.haml"
-set page="Config"
-block head
%script(type="application/javascript" src="/static/config.js" nonce="{{view.request['hash']}}" defer)
-import "functions.haml" as func
-block content
%form.section(action="/admin/config" method="POST")
%fieldset.section
%legend << Config
.grid-2col
%label(for="name") << Name
%input(id = "name" name="name" placeholder="Relay Name" value="{{config.name or ''}}")
%input(id = "name" placeholder="Relay Name" value="{{config.name or ''}}")
%label(for="description") << Description
%textarea(id="description" name="note" value="{{config.note}}") << {{config.note}}
%label(for="note") << Description
%textarea(id="note" value="{{config.note or ''}}") << {{config.note}}
%label(for="theme") << Color Theme
%select(id="theme" name="theme")
-for theme in themes
-if theme == config.theme
%option(value="{{theme}}" selected) -> =theme.title()
-else
%option(value="{{theme}}") -> =theme.title()
=func.new_select("theme", config.theme, themes)
%label(for="log-level") << Log Level
%select(id="log-level" name="log-level")
-for level in LogLevel
-if level == config["log-level"]
%option(value="{{level.name}}" selected) -> =level.name.title()
-else
%option(value="{{level.name}}") -> =level.name.title()
=func.new_select("log-level", config.log_level.name, levels)
%label(for="whitelist-enabled") << Whitelist
-if config["whitelist-enabled"]
%input(id="whitelist-enabled" name="whitelist-enabled" type="checkbox" checked)
=func.new_checkbox("whitelist-enabled", config.whitelist_enabled)
-else
%input(id="whitelist-enabled" name="whitelist-enabled" type="checkbox")
%input(type="submit" value="Save")
%label(for="approval-required") << Approval Required
=func.new_checkbox("approval-required", config.approval_required)

View file

@ -1,48 +1,53 @@
-extends "base.haml"
-set page="Domain Bans"
-block head
%script(type="application/javascript" src="/static/domain_ban.js" nonce="{{view.request['hash']}}" defer)
-block content
%details.section
%summary << Ban Domain
%form(action="/admin/domain_bans" method="POST")
#add-item
%label(for="domain") << Domain
%input(type="domain" id="domain" name="domain" placeholder="Domain")
%label(for="new-domain") << Domain
%input(type="domain" id="new-domain" placeholder="Domain")
%label(for="reason") << Ban Reason
%textarea(id="reason" name="reason") << {{""}}
%label(for="new-reason") << Ban Reason
%textarea(id="new-reason") << {{""}}
%label(for="note") << Admin Note
%textarea(id="note" name="note") << {{""}}
%label(for="new-note") << Admin Note
%textarea(id="new-note") << {{""}}
%input(type="submit" value="Ban Domain")
%input#new-ban(type="button" value="Ban Domain")
#data-table.section
%fieldset.section
%legend << Domain Bans
.data-table
%table
%thead
%tr
%td.domain << Instance
%td.domain << Domain
%td << Date
%td.remove
%tbody
-for ban in bans
%tr
%tr(id="{{ban.domain}}")
%td.domain
%details
%summary -> =ban.domain
%form(action="/admin/domain_bans" method="POST")
.grid-2col
.reason << Reason
%textarea.reason(id="reason" name="reason") << {{ban.reason or ""}}
%label.reason(for="{{ban.domain}}-reason") << Reason
%textarea.reason(id="{{ban.domain}}-reason") << {{ban.reason or ""}}
.note << Note
%textarea.note(id="note" name="note") << {{ban.note or ""}}
%label.note(for="{{ban.domain}}-note") << Note
%textarea.note(id="{{ban.domain}}-note") << {{ban.note or ""}}
%input(type="hidden" name="domain" value="{{ban.domain}}")
%input(type="submit" value="Update")
%input.update-ban(type="button" value="Update")
%td.date
=ban.created.strftime("%Y-%m-%d")
%td.remove
%a(href="/admin/domain_bans/delete/{{ban.domain}}" title="Unban domain") << &#10006;
%a(href="#" title="Unban domain") << &#10006;

View file

@ -1,26 +1,63 @@
-extends "base.haml"
-set page="Instances"
-block head
%script(type="application/javascript" src="/static/instance.js" nonce="{{view.request['hash']}}" defer)
-block content
%details.section
%summary << Add Instance
%form(action="/admin/instances" method="POST")
#add-item
%label(for="domain") << Domain
%input(type="domain" id="domain" name="domain" placeholder="Domain")
%label(for="new-actor") << Actor
%input(type="url" id="new-actor" placeholder="Actor URL")
%label(for="actor") << Actor URL
%input(type="url" id="actor" name="actor" placeholder="Actor URL")
%label(for="new-inbox") << Inbox
%input(type="url" id="new-inbox" placeholder="Inbox URL")
%label(for="inbox") << Inbox URL
%input(type="url" id="inbox" name="inbox" placeholder="Inbox URL")
%label(for="new-followid") << Follow ID
%input(type="url" id="new-followid" placeholder="Follow ID URL")
%label(for="software") << Software
%input(name="software" id="software" placeholder="software")
%label(for="new-software") << Software
%input(id="new-software" placeholder="software")
%input(type="submit" value="Add Instance")
%input#add-instance(type="button" value="Add Instance")
#data-table.section
%table
-if requests
%fieldset.section.requests
%legend << Follow Requests
.data-table
%table#requests
%thead
%tr
%td.instance << Instance
%td.software << Software
%td.date << Joined
%td.approve
%td.deny
%tbody
-for request in requests
%tr(id="{{request.domain}}")
%td.instance
%a(href="https://{{request.domain}}" target="_new") -> =request.domain
%td.software
=request.software or "n/a"
%td.date
=request.created.strftime("%Y-%m-%d")
%td.approve
%a(href="#" title="Approve Request") << &check;
%td.deny
%a(href="#" title="Deny Request") << &#10006;
%fieldset.section.instances
%legend << Instances
.data-table
%table#instances
%thead
%tr
%td.instance << Instance
@ -30,7 +67,7 @@
%tbody
-for instance in instances
%tr
%tr(id="{{instance.domain}}")
%td.instance
%a(href="https://{{instance.domain}}/" target="_new") -> =instance.domain
@ -41,4 +78,4 @@
=instance.created.strftime("%Y-%m-%d")
%td.remove
%a(href="/admin/instances/delete/{{instance.domain}}" title="Remove Instance") << &#10006;
%a(href="#" title="Remove Instance") << &#10006;

View file

@ -1,48 +1,53 @@
-extends "base.haml"
-set page="Software Bans"
-block head
%script(type="application/javascript" src="/static/software_ban.js" nonce="{{view.request['hash']}}" defer)
-block content
%details.section
%summary << Ban Software
%form(action="/admin/software_bans" method="POST")
#add-item
%label(for="name") << Name
%input(id="name" name="name" placeholder="Name")
%label(for="new-name") << Domain
%input(type="name" id="new-name" placeholder="Domain")
%label(for="reason") << Ban Reason
%textarea(id="reason" name="reason") << {{""}}
%label(for="new-reason") << Ban Reason
%textarea(id="new-reason") << {{""}}
%label(for="note") << Admin Note
%textarea(id="note" name="note") << {{""}}
%label(for="new-note") << Admin Note
%textarea(id="new-note") << {{""}}
%input(type="submit" value="Ban Software")
%input#new-ban(type="button" value="Ban Software")
#data-table.section
%table
%fieldset.section
%legend << Software Bans
.data-table
%table#bans
%thead
%tr
%td.name << Instance
%td.name << Name
%td << Date
%td.remove
%tbody
-for ban in bans
%tr
%tr(id="{{ban.name}}")
%td.name
%details
%summary -> =ban.name
%form(action="/admin/software_bans" method="POST")
.grid-2col
.reason << Reason
%textarea.reason(id="reason" name="reason") << {{ban.reason or ""}}
%label.reason(for="{{ban.name}}-reason") << Reason
%textarea.reason(id="{{ban.name}}-reason") << {{ban.reason or ""}}
.note << Note
%textarea.note(id="note" name="note") << {{ban.note or ""}}
%label.note(for="{{ban.name}}-note") << Note
%textarea.note(id="{{ban.name}}-note") << {{ban.note or ""}}
%input(type="hidden" name="name" value="{{ban.name}}")
%input(type="submit" value="Update")
%input.update-ban(type="button" value="Update")
%td.date
=ban.created.strftime("%Y-%m-%d")
%td.remove
%a(href="/admin/software_bans/delete/{{ban.name}}" title="Unban software") << &#10006;
%a(href="#" title="Unban name") << &#10006;

View file

@ -1,26 +1,32 @@
-extends "base.haml"
-set page="Users"
-block head
%script(type="application/javascript" src="/static/user.js" nonce="{{view.request['hash']}}" defer)
-block content
%details.section
%summary << Add User
%form(action="/admin/users", method="POST")
#add-item
%label(for="username") << Username
%input(id="username" name="username" placeholder="Username")
%label(for="new-username") << Username
%input(id="new-username" name="username" placeholder="Username" autocomplete="off")
%label(for="password") << Password
%input(type="password" id="password" name="password" placeholder="Password")
%label(for="new-password") << Password
%input(id="new-password" type="password" placeholder="Password" autocomplete="off")
%label(for="password2") << Password Again
%input(type="password" id="password2" name="password2" placeholder="Password Again")
%label(for="new-password2") << Password Again
%input(id="new-password2" type="password" placeholder="Password Again" autocomplete="off")
%label(for="handle") << Handle
%input(type="email" name="handle" id="handle" placeholder="handle")
%label(for="new-handle") << Handle
%input(id="new-handle" type="email" placeholder="handle" autocomplete="off")
%input(type="submit" value="Add User")
%input#new-user(type="button" value="Add User")
#data-table.section
%table
%fieldset.section
%legend << Users
.data-table
%table#users
%thead
%tr
%td.username << Username
@ -30,7 +36,7 @@
%tbody
-for user in users
%tr
%tr(id="{{user.username}}")
%td.username
=user.username
@ -41,4 +47,4 @@
=user.created.strftime("%Y-%m-%d")
%td.remove
%a(href="/admin/users/delete/{{user.username}}" title="Remove User") << &#10006;
%a(href="#" title="Remove User") << &#10006;

View file

@ -1,17 +1,22 @@
-extends "base.haml"
-set page="Whitelist"
-block head
%script(type="application/javascript" src="/static/whitelist.js" nonce="{{view.request['hash']}}" defer)
-block content
%details.section
%summary << Add Domain
%form(action="/admin/whitelist" method="POST")
#add-item
%label(for="domain") << Domain
%input(type="domain" id="domain" name="domain" placeholder="Domain")
%label(for="new-domain") << Domain
%input(type="domain" id="new-domain" placeholder="Domain")
%input(type="submit" value="Add Domain")
%input#new-item(type="button" value="Add Domain")
#data-table.section
%table
%fieldset.data-table.section
%legend << Whitelist
%table#whitelist
%thead
%tr
%td.domain << Domain
@ -20,7 +25,7 @@
%tbody
-for item in whitelist
%tr
%tr(id="{{item.domain}}")
%td.domain
=item.domain
@ -28,4 +33,4 @@
=item.created.strftime("%Y-%m-%d")
%td.remove
%a(href="/admin/whitelist/delete/{{item.domain}}" title="Remove whitlisted domain") << &#10006;
%a(href="#" title="Remove whitlisted domain") << &#10006;

View file

@ -1,10 +1,9 @@
-extends "base.haml"
-set page = "Home"
-block content
-if config.note
.section
-for line in config.note.splitlines()
-if line
%p -> =line
-markdown -> =config.note
.section
%p
@ -14,12 +13,24 @@
You may subscribe to this relay with the address:
%a(href="https://{{domain}}/actor") << https://{{domain}}/actor</a>
-if config["whitelist-enabled"]
%p.section.message
Note: The whitelist is enabled on this instance. Ask the admin to add your instance
before joining.
-if config.approval_required
%fieldset.section.message
%legend << Require Approval
#data-table.section
Follow requests require approval. You will need to wait for an admin to accept or deny
your request.
-elif config.whitelist_enabled
%fieldset.section.message
%legend << Whitelist Enabled
The whitelist is enabled on this instance. Ask the admin to add your instance before
joining.
%fieldset.section
%legend << Instances
.data-table
%table
%thead
%tr

View file

@ -1,7 +1,13 @@
-extends "base.haml"
-set page="Login"
-block head
%script(type="application/javascript" src="/static/login.js" nonce="{{view.request['hash']}}" defer)
-block content
%form.section(action="/login" method="POST")
%fieldset.section
%legend << Login
.grid-2col
%label(for="username") << Username
%input(id="username" name="username" placeholder="Username" value="{{username or ''}}")
@ -9,4 +15,4 @@
%label(for="password") << Password
%input(id="password" name="password" placeholder="Password" type="password")
%input(type="submit" value="Login")
%input.submit(type="button" value="Login")

View file

@ -0,0 +1,135 @@
// toast notifications
const notifications = document.querySelector("#notifications")
function remove_toast(toast) {
toast.classList.add("hide");
if (toast.timeoutId) {
clearTimeout(toast.timeoutId);
}
setTimeout(() => toast.remove(), 300);
}
function toast(text, type="error", timeout=5) {
const toast = document.createElement("li");
toast.className = `section ${type}`
toast.innerHTML = `<span class=".text">${text}</span><a href="#">&#10006;</span>`
toast.querySelector("a").addEventListener("click", async (event) => {
event.preventDefault();
await remove_toast(toast);
});
notifications.appendChild(toast);
toast.timeoutId = setTimeout(() => remove_toast(toast), timeout * 1000);
}
// menu
const body = document.getElementById("container")
const menu = document.getElementById("menu");
const menu_open = document.getElementById("menu-open");
const menu_close = document.getElementById("menu-close");
menu_open.addEventListener("click", (event) => {
var new_value = menu.attributes.visible.nodeValue === "true" ? "false" : "true";
menu.attributes.visible.nodeValue = new_value;
});
menu_close.addEventListener("click", (event) => {
menu.attributes.visible.nodeValue = "false"
});
body.addEventListener("click", (event) => {
if (event.target === menu_open) {
return;
}
menu.attributes.visible.nodeValue = "false";
});
// misc
function get_date_string(date) {
var year = date.getFullYear().toString();
var month = date.getMonth().toString();
var day = date.getDay().toString();
if (month.length === 1) {
month = "0" + month;
}
if (day.length === 1) {
day = "0" + day
}
return `${year}-${month}-${day}`;
}
function append_table_row(table, row_name, row) {
var table_row = table.insertRow(-1);
table_row.id = row_name;
index = 0;
for (var prop in row) {
if (Object.prototype.hasOwnProperty.call(row, prop)) {
var cell = table_row.insertCell(index);
cell.className = prop;
cell.innerHTML = row[prop];
index += 1;
}
}
return table_row;
}
async function request(method, path, body = null) {
var headers = {
"Accept": "application/json"
}
if (body !== null) {
headers["Content-Type"] = "application/json"
body = JSON.stringify(body)
}
const response = await fetch("/api/" + path, {
method: method,
mode: "cors",
cache: "no-store",
redirect: "follow",
body: body,
headers: headers
});
const message = await response.json();
if (Object.hasOwn(message, "error")) {
throw new Error(message.error);
}
if (Array.isArray(message)) {
message.forEach((msg) => {
if (Object.hasOwn(msg, "created")) {
msg.created = new Date(msg.created);
}
});
} else {
if (Object.hasOwn(message, "created")) {
message.created = new Date(message.created);
}
}
return message;
}

View file

@ -0,0 +1,40 @@
const elems = [
document.querySelector("#name"),
document.querySelector("#note"),
document.querySelector("#theme"),
document.querySelector("#log-level"),
document.querySelector("#whitelist-enabled"),
document.querySelector("#approval-required")
]
async function handle_config_change(event) {
params = {
key: event.target.id,
value: event.target.type === "checkbox" ? event.target.checked : event.target.value
}
try {
await request("POST", "v1/config", params);
} catch (error) {
toast(error);
return;
}
if (params.key === "name") {
document.querySelector("#header .title").innerHTML = params.value;
document.querySelector("title").innerHTML = params.value;
}
if (params.key === "theme") {
document.querySelector("link.theme").href = `/theme/${params.value}.css`;
}
toast("Updated config", "message");
}
for (const elem of elems) {
elem.addEventListener("change", handle_config_change);
}

View file

@ -0,0 +1,123 @@
function create_ban_object(domain, reason, note) {
var text = '<details>\n';
text += `<summary>${domain}</summary>\n`;
text += '<div class="grid-2col">\n';
text += `<label for="${domain}-reason" class="reason">Reason</label>\n`;
text += `<textarea id="${domain}-reason" class="reason">${reason}</textarea>\n`;
text += `<label for="${domain}-note" class="note">Note</label>\n`;
text += `<textarea id="${domain}-note" class="note">${note}</textarea>\n`;
text += `<input class="update-ban" type="button" value="Update">`;
text += '</details>';
return text;
}
function add_row_listeners(row) {
row.querySelector(".update-ban").addEventListener("click", async (event) => {
await update_ban(row.id);
});
row.querySelector(".remove a").addEventListener("click", async (event) => {
event.preventDefault();
await unban(row.id);
});
}
async function ban() {
var table = document.querySelector("table");
var elems = {
domain: document.getElementById("new-domain"),
reason: document.getElementById("new-reason"),
note: document.getElementById("new-note")
}
var values = {
domain: elems.domain.value.trim(),
reason: elems.reason.value.trim(),
note: elems.note.value.trim()
}
if (values.domain === "") {
toast("Domain is required");
return;
}
try {
var ban = await request("POST", "v1/domain_ban", values);
} catch (err) {
toast(err);
return
}
var row = append_table_row(document.querySelector("table"), ban.domain, {
domain: create_ban_object(ban.domain, ban.reason, ban.note),
date: get_date_string(ban.created),
remove: `<a href="#" title="Unban domain">&#10006;</a>`
});
add_row_listeners(row);
elems.domain.value = null;
elems.reason.value = null;
elems.note.value = null;
document.querySelector("details.section").open = false;
toast("Banned domain", "message");
}
async function update_ban(domain) {
var row = document.getElementById(domain);
var elems = {
"reason": row.querySelector("textarea.reason"),
"note": row.querySelector("textarea.note")
}
var values = {
"domain": domain,
"reason": elems.reason.value,
"note": elems.note.value
}
try {
await request("PATCH", "v1/domain_ban", values)
} catch (error) {
toast(error);
return;
}
row.querySelector("details").open = false;
toast("Updated baned domain", "message");
}
async function unban(domain) {
try {
await request("DELETE", "v1/domain_ban", {"domain": domain});
} catch (error) {
toast(error);
return;
}
document.getElementById(domain).remove();
toast("Unbanned domain", "message");
}
document.querySelector("#new-ban").addEventListener("click", async (event) => {
await ban();
});
for (var row of document.querySelector("fieldset.section table").rows) {
if (!row.querySelector(".update-ban")) {
continue;
}
add_row_listeners(row);
}

View file

@ -0,0 +1,145 @@
function add_instance_listeners(row) {
row.querySelector(".remove a").addEventListener("click", async (event) => {
event.preventDefault();
await del_instance(row.id);
});
}
function add_request_listeners(row) {
row.querySelector(".approve a").addEventListener("click", async (event) => {
event.preventDefault();
await req_response(row.id, true);
});
row.querySelector(".deny a").addEventListener("click", async (event) => {
event.preventDefault();
await req_response(row.id, false);
});
}
async function add_instance() {
var elems = {
actor: document.getElementById("new-actor"),
inbox: document.getElementById("new-inbox"),
followid: document.getElementById("new-followid"),
software: document.getElementById("new-software")
}
var values = {
actor: elems.actor.value.trim(),
inbox: elems.inbox.value.trim(),
followid: elems.followid.value.trim(),
software: elems.software.value.trim()
}
if (values.actor === "") {
toast("Actor is required");
return;
}
try {
var instance = await request("POST", "v1/instance", values);
} catch (err) {
toast(err);
return
}
row = append_table_row(document.getElementById("instances"), instance.domain, {
domain: `<a href="https://${instance.domain}/" target="_new">${instance.domain}</a>`,
software: instance.software,
date: get_date_string(instance.created),
remove: `<a href="#" title="Remove Instance">&#10006;</a>`
});
add_instance_listeners(row);
elems.actor.value = null;
elems.inbox.value = null;
elems.followid.value = null;
elems.software.value = null;
document.querySelector("details.section").open = false;
toast("Added instance", "message");
}
async function del_instance(domain) {
try {
await request("DELETE", "v1/instance", {"domain": domain});
} catch (error) {
toast(error);
return;
}
document.getElementById(domain).remove();
}
async function req_response(domain, accept) {
params = {
"domain": domain,
"accept": accept
}
try {
await request("POST", "v1/request", params);
} catch (error) {
toast(error);
return;
}
document.getElementById(domain).remove();
if (document.getElementById("requests").rows.length < 2) {
document.querySelector("fieldset.requests").remove()
}
if (!accept) {
toast("Denied instance request", "message");
return;
}
instances = await request("GET", `v1/instance`, null);
instances.forEach((instance) => {
if (instance.domain === domain) {
row = append_table_row(document.getElementById("instances"), instance.domain, {
domain: `<a href="https://${instance.domain}/" target="_new">${instance.domain}</a>`,
software: instance.software,
date: get_date_string(instance.created),
remove: `<a href="#" title="Remove Instance">&#10006;</a>`
});
add_instance_listeners(row);
}
});
toast("Accepted instance request", "message");
}
document.querySelector("#add-instance").addEventListener("click", async (event) => {
await add_instance();
})
for (var row of document.querySelector("#instances").rows) {
if (!row.querySelector(".remove a")) {
continue;
}
add_instance_listeners(row);
}
if (document.querySelector("#requests")) {
for (var row of document.querySelector("#requests").rows) {
if (!row.querySelector(".approve a")) {
continue;
}
add_request_listeners(row);
}
}

View file

@ -0,0 +1,29 @@
async function login(event) {
fields = {
username: document.querySelector("#username"),
password: document.querySelector("#password")
}
values = {
username: fields.username.value.trim(),
password: fields.password.value.trim()
}
if (values.username === "" | values.password === "") {
toast("Username and/or password field is blank");
return;
}
try {
await request("POST", "v1/token", values);
} catch (error) {
toast(error);
return;
}
document.location = "/";
}
document.querySelector(".submit").addEventListener("click", login);

View file

@ -0,0 +1,122 @@
function create_ban_object(name, reason, note) {
var text = '<details>\n';
text += `<summary>${name}</summary>\n`;
text += '<div class="grid-2col">\n';
text += `<label for="${name}-reason" class="reason">Reason</label>\n`;
text += `<textarea id="${name}-reason" class="reason">${reason}</textarea>\n`;
text += `<label for="${name}-note" class="note">Note</label>\n`;
text += `<textarea id="${name}-note" class="note">${note}</textarea>\n`;
text += `<input class="update-ban" type="button" value="Update">`;
text += '</details>';
return text;
}
function add_row_listeners(row) {
row.querySelector(".update-ban").addEventListener("click", async (event) => {
await update_ban(row.id);
});
row.querySelector(".remove a").addEventListener("click", async (event) => {
event.preventDefault();
await unban(row.id);
});
}
async function ban() {
var elems = {
name: document.getElementById("new-name"),
reason: document.getElementById("new-reason"),
note: document.getElementById("new-note")
}
var values = {
name: elems.name.value.trim(),
reason: elems.reason.value,
note: elems.note.value
}
if (values.name === "") {
toast("Domain is required");
return;
}
try {
var ban = await request("POST", "v1/software_ban", values);
} catch (err) {
toast(err);
return
}
var row = append_table_row(document.getElementById("bans"), ban.name, {
name: create_ban_object(ban.name, ban.reason, ban.note),
date: get_date_string(ban.created),
remove: `<a href="#" title="Unban software">&#10006;</a>`
});
add_row_listeners(row);
elems.name.value = null;
elems.reason.value = null;
elems.note.value = null;
document.querySelector("details.section").open = false;
toast("Banned software", "message");
}
async function update_ban(name) {
var row = document.getElementById(name);
var elems = {
"reason": row.querySelector("textarea.reason"),
"note": row.querySelector("textarea.note")
}
var values = {
"name": name,
"reason": elems.reason.value,
"note": elems.note.value
}
try {
await request("PATCH", "v1/software_ban", values)
} catch (error) {
toast(error);
return;
}
row.querySelector("details").open = false;
toast("Updated software ban", "message");
}
async function unban(name) {
try {
await request("DELETE", "v1/software_ban", {"name": name});
} catch (error) {
toast(error);
return;
}
document.getElementById(name).remove();
toast("Unbanned software", "message");
}
document.querySelector("#new-ban").addEventListener("click", async (event) => {
await ban();
});
for (var row of document.querySelector("#bans").rows) {
if (!row.querySelector(".update-ban")) {
continue;
}
add_row_listeners(row);
}

View file

@ -23,11 +23,29 @@ details summary {
cursor: pointer;
}
fieldset {
margin-left: 0px;
margin-right: 0px;
}
fieldset > *:nth-child(2) {
margin-top: 0px !important;
}
form input[type="submit"] {
display: block;
margin: 0 auto;
}
legend {
background-color: var(--table-background);
padding: 5px;
border: 1px solid var(--border);
border-radius: 5px;
font-size: 10pt;
font-weight: bold;
}
p {
line-height: 1em;
margin: 0px;
@ -91,6 +109,17 @@ textarea {
margin: 0px auto;
}
#content .title {
font-size: 24px;
text-align: center;
font-weight: bold;
margin-bottom: 10px;
}
#content .title:not(:first-child) {
margin-top: 10px;
}
#header {
display: grid;
grid-template-columns: 50px auto 50px;
@ -175,6 +204,37 @@ textarea {
text-align: center;
}
#notifications {
position: fixed;
top: 40px;
left: 50%;
transform: translateX(-50%);
}
#notifications li {
position: relative;
overflow: hidden;
list-style: none;
border-radius: 5px;
padding: 5px;;
margin-bottom: var(--spacing);
animation: show_toast 0.3s ease forwards;
display: grid;
grid-template-columns: auto max-content;
grid-gap: 5px;
align-items: center;
}
#notifications a {
font-size: 1.5em;
line-height: 1em;
text-decoration: none;
}
#notifications li.hide {
animation: hide_toast 0.3s ease forwards;
}
#footer {
display: grid;
grid-template-columns: auto auto;
@ -193,15 +253,6 @@ textarea {
align-items: center;
}
#data-table td:first-child {
width: 100%;
}
#data-table .date {
width: max-content;
text-align: right;
}
.button {
background-color: var(--primary);
border: 1px solid var(--primary);
@ -220,6 +271,15 @@ textarea {
grid-template-columns: max-content auto;
}
.data-table td:first-child {
width: 100%;
}
.data-table .date {
width: max-content;
text-align: right;
}
.error, .message {
text-align: center;
}
@ -267,6 +327,44 @@ textarea {
}
@keyframes show_toast {
0% {
transform: translateX(100%);
}
40% {
transform: translateX(-5%);
}
80% {
transform: translateX(0%);
}
100% {
transform: translateX(-10px);
}
}
@keyframes hide_toast {
0% {
transform: translateX(-10px);
}
40% {
transform: translateX(0%);
}
80% {
transform: translateX(-5%);
}
100% {
transform: translateX(calc(100% + 20px));
}
}
@media (max-width: 1026px) {
body {
margin: 0px;

View file

@ -0,0 +1,85 @@
function add_row_listeners(row) {
row.querySelector(".remove a").addEventListener("click", async (event) => {
event.preventDefault();
await del_user(row.id);
});
}
async function add_user() {
var elems = {
username: document.getElementById("new-username"),
password: document.getElementById("new-password"),
password2: document.getElementById("new-password2"),
handle: document.getElementById("new-handle")
}
var values = {
username: elems.username.value.trim(),
password: elems.password.value.trim(),
password2: elems.password2.value.trim(),
handle: elems.handle.value.trim()
}
if (values.username === "" | values.password === "" | values.password2 === "") {
toast("Username, password, and password2 are required");
return;
}
if (values.password !== values.password2) {
toast("Passwords do not match");
return;
}
try {
var user = await request("POST", "v1/user", values);
} catch (err) {
toast(err);
return
}
var row = append_table_row(document.querySelector("fieldset.section table"), user.username, {
domain: user.username,
handle: user.handle ? self.handle : "n/a",
date: get_date_string(user.created),
remove: `<a href="#" title="Delete User">&#10006;</a>`
});
add_row_listeners(row);
elems.username.value = null;
elems.password.value = null;
elems.password2.value = null;
elems.handle.value = null;
document.querySelector("details.section").open = false;
toast("Created user", "message");
}
async function del_user(username) {
try {
await request("DELETE", "v1/user", {"username": username});
} catch (error) {
toast(error);
return;
}
document.getElementById(username).remove();
toast("Deleted user", "message");
}
document.querySelector("#new-user").addEventListener("click", async (event) => {
await add_user();
});
for (var row of document.querySelector("#users").rows) {
if (!row.querySelector(".remove a")) {
continue;
}
add_row_listeners(row);
}

View file

@ -0,0 +1,64 @@
function add_row_listeners(row) {
row.querySelector(".remove a").addEventListener("click", async (event) => {
event.preventDefault();
await del_whitelist(row.id);
});
}
async function add_whitelist() {
var domain_elem = document.getElementById("new-domain");
var domain = domain_elem.value.trim();
if (domain === "") {
toast("Domain is required");
return;
}
try {
var item = await request("POST", "v1/whitelist", {"domain": domain});
} catch (err) {
toast(err);
return;
}
var row = append_table_row(document.getElementById("whitelist"), item.domain, {
domain: item.domain,
date: get_date_string(item.created),
remove: `<a href="#" title="Remove whitelisted domain">&#10006;</a>`
});
add_row_listeners(row);
domain_elem.value = null;
document.querySelector("details.section").open = false;
toast("Added domain", "message");
}
async function del_whitelist(domain) {
try {
await request("DELETE", "v1/whitelist", {"domain": domain});
} catch (error) {
toast(error);
return;
}
document.getElementById(domain).remove();
toast("Removed domain", "message");
}
document.querySelector("#new-item").addEventListener("click", async (event) => {
await add_whitelist();
});
for (var row of document.querySelector("fieldset.section table").rows) {
if (!row.querySelector(".remove a")) {
continue;
}
add_row_listeners(row);
}

View file

@ -7,7 +7,7 @@ import typing
from aiohttp import ClientSession, ClientTimeout, TCPConnector
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
from aputils.objects import Nodeinfo, WellKnownNodeinfo
from aputils import AlgorithmType, JsonBase, Nodeinfo, ObjectType, WellKnownNodeinfo
from json.decoder import JSONDecodeError
from urllib.parse import urlparse
@ -17,12 +17,13 @@ from .misc import MIMETYPES, Message, get_app
if typing.TYPE_CHECKING:
from aputils import Signer
from tinysql import Row
from bsql import Row
from typing import Any
from .application import Application
from .cache import Cache
T = typing.TypeVar('T', bound = JsonBase)
HEADERS = {
'Accept': f'{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9',
'User-Agent': f'ActivityRelay/{__version__}'
@ -33,12 +34,12 @@ class HttpClient:
def __init__(self, limit: int = 100, timeout: int = 10):
self.limit = limit
self.timeout = timeout
self._conn = None
self._session = None
self._conn: TCPConnector | None = None
self._session: ClientSession | None = None
async def __aenter__(self) -> HttpClient:
await self.open()
self.open()
return self
@ -61,7 +62,7 @@ class HttpClient:
return self.app.signer
async def open(self) -> None:
def open(self) -> None:
if self._session:
return
@ -79,23 +80,19 @@ class HttpClient:
async def close(self) -> None:
if not self._session:
return
if self._session:
await self._session.close()
if self._conn:
await self._conn.close()
self._conn = None
self._session = None
async def get(self, # pylint: disable=too-many-branches
url: str,
sign_headers: bool = False,
loads: callable = json.loads,
force: bool = False) -> dict | None:
await self.open()
async def _get(self, url: str, sign_headers: bool, force: bool) -> dict[str, Any] | None:
if not self._session:
raise RuntimeError('Client not open')
try:
url, _ = url.split('#', 1)
@ -105,10 +102,8 @@ class HttpClient:
if not force:
try:
item = self.cache.get('request', url)
if not item.older_than(48):
return loads(item.value)
if not (item := self.cache.get('request', url)).older_than(48):
return json.loads(item.value)
except KeyError:
logging.verbose('No cached data for url: %s', url)
@ -116,38 +111,39 @@ class HttpClient:
headers = {}
if sign_headers:
self.signer.sign_headers('GET', url, algorithm = 'original')
headers = self.signer.sign_headers('GET', url, algorithm = AlgorithmType.HS2019)
try:
logging.debug('Fetching resource: %s', url)
async with self._session.get(url, headers = headers) as resp:
## Not expecting a response with 202s, so just return
# Not expecting a response with 202s, so just return
if resp.status == 202:
return None
data = await resp.read()
data = await resp.text()
if resp.status != 200:
logging.verbose('Received error when requesting %s: %i', url, resp.status)
logging.debug(await resp.read())
logging.debug(data)
return None
message = loads(data)
self.cache.set('request', url, data.decode('utf-8'), 'str')
logging.debug('%s >> resp %s', url, json.dumps(message, indent = 4))
self.cache.set('request', url, data, 'str')
logging.debug('%s >> resp %s', url, json.dumps(json.loads(data), indent = 4))
return message
return json.loads(data)
except JSONDecodeError:
logging.verbose('Failed to parse JSON')
return None
except ClientSSLError:
except ClientSSLError as e:
logging.verbose('SSL error when connecting to %s', urlparse(url).netloc)
logging.warning(str(e))
except (AsyncTimeoutError, ClientConnectionError):
except (AsyncTimeoutError, ClientConnectionError) as e:
logging.verbose('Failed to connect to %s', urlparse(url).netloc)
logging.warning(str(e))
except Exception:
traceback.print_exc()
@ -155,39 +151,74 @@ class HttpClient:
return None
async def post(self, url: str, message: Message, instance: Row | None = None) -> None:
await self.open()
async def get(self,
url: str,
sign_headers: bool,
cls: type[T],
force: bool = False) -> T | None:
## Using the old algo by default is probably a better idea right now
# pylint: disable=consider-ternary-expression
if not issubclass(cls, JsonBase):
raise TypeError('cls must be a sub-class of "aputils.JsonBase"')
if (data := (await self._get(url, sign_headers, force))) is None:
return None
return cls.parse(data)
async def post(self, url: str, data: Message | bytes, instance: Row | None = None) -> None:
if not self._session:
raise RuntimeError('Client not open')
# akkoma and pleroma do not support HS2019 and other software still needs to be tested
if instance and instance['software'] in {'mastodon'}:
algorithm = 'hs2019'
algorithm = AlgorithmType.HS2019
else:
algorithm = 'original'
# pylint: enable=consider-ternary-expression
algorithm = AlgorithmType.RSASHA256
headers = {'Content-Type': 'application/activity+json'}
headers.update(get_app().signer.sign_headers('POST', url, message, algorithm=algorithm))
body: bytes
message: Message
if isinstance(data, bytes):
body = data
message = Message.parse(data)
else:
body = data.to_json().encode("utf-8")
message = data
mtype = message.type.value if isinstance(message.type, ObjectType) else message.type
headers = self.signer.sign_headers(
'POST',
url,
body,
headers = {'Content-Type': 'application/activity+json'},
algorithm = algorithm
)
try:
logging.verbose('Sending "%s" to %s', message.type, url)
logging.verbose('Sending "%s" to %s', mtype, url)
async with self._session.post(url, headers=headers, data=message.to_json()) as resp:
async with self._session.post(url, headers = headers, data = body) as resp:
# Not expecting a response, so just return
if resp.status in {200, 202}:
logging.verbose('Successfully sent "%s" to %s', message.type, url)
logging.verbose('Successfully sent "%s" to %s', mtype, url)
return
logging.verbose('Received error when pushing to %s: %i', url, resp.status)
logging.debug(await resp.read())
logging.debug("message: %s", body.decode("utf-8"))
logging.debug("headers: %s", json.dumps(headers, indent = 4))
return
except ClientSSLError:
except ClientSSLError as e:
logging.warning('SSL error when pushing to %s', urlparse(url).netloc)
logging.warning(str(e))
except (AsyncTimeoutError, ClientConnectionError):
except (AsyncTimeoutError, ClientConnectionError) as e:
logging.warning('Failed to connect to %s for message push', urlparse(url).netloc)
logging.warning(str(e))
# prevent workers from being brought down
except Exception:
@ -198,10 +229,11 @@ class HttpClient:
nodeinfo_url = None
wk_nodeinfo = await self.get(
f'https://{domain}/.well-known/nodeinfo',
loads = WellKnownNodeinfo.parse
False,
WellKnownNodeinfo
)
if not wk_nodeinfo:
if wk_nodeinfo is None:
logging.verbose('Failed to fetch well-known nodeinfo url for %s', domain)
return None
@ -212,14 +244,14 @@ class HttpClient:
except KeyError:
pass
if not nodeinfo_url:
if nodeinfo_url is None:
logging.verbose('Failed to fetch nodeinfo url for %s', domain)
return None
return await self.get(nodeinfo_url, loads = Nodeinfo.parse) or None
return await self.get(nodeinfo_url, False, Nodeinfo)
async def get(*args: Any, **kwargs: Any) -> Message | dict | None:
async def get(*args: Any, **kwargs: Any) -> Any:
async with HttpClient() as client:
return await client.get(*args, **kwargs)

View file

@ -11,6 +11,12 @@ if typing.TYPE_CHECKING:
from collections.abc import Callable
from typing import Any
try:
from typing import Self
except ImportError:
from typing_extensions import Self
class LogLevel(IntEnum):
DEBUG = logging.DEBUG
@ -26,7 +32,13 @@ class LogLevel(IntEnum):
@classmethod
def parse(cls: type[IntEnum], data: object) -> IntEnum:
def parse(cls: type[Self], data: Any) -> Self:
try:
data = int(data)
except ValueError:
pass
if isinstance(data, cls):
return data
@ -57,10 +69,10 @@ def set_level(level: LogLevel | str) -> None:
def verbose(message: str, *args: Any, **kwargs: Any) -> None:
if not logging.root.isEnabledFor(LogLevel['VERBOSE']):
if not logging.root.isEnabledFor(LogLevel.VERBOSE):
return
logging.log(LogLevel['VERBOSE'], message, *args, **kwargs)
logging.log(LogLevel.VERBOSE, message, *args, **kwargs)
debug: Callable = logging.debug
@ -70,23 +82,27 @@ error: Callable = logging.error
critical: Callable = logging.critical
env_log_level = os.environ.get('LOG_LEVEL', 'INFO').upper()
try:
env_log_file = Path(os.environ['LOG_FILE']).expanduser().resolve()
env_log_file: Path | None = Path(os.environ['LOG_FILE']).expanduser().resolve()
except KeyError:
env_log_file = None
handlers = [logging.StreamHandler()]
handlers: list[Any] = [logging.StreamHandler()]
if env_log_file:
handlers.append(logging.FileHandler(env_log_file))
logging.addLevelName(LogLevel['VERBOSE'], 'VERBOSE')
if os.environ.get('IS_SYSTEMD'):
logging_format = '%(levelname)s: %(message)s'
else:
logging_format = '[%(asctime)s] %(levelname)s: %(message)s'
logging.addLevelName(LogLevel.VERBOSE, 'VERBOSE')
logging.basicConfig(
level = LogLevel.INFO,
format = '[%(asctime)s] %(levelname)s: %(message)s',
format = logging_format,
datefmt = '%Y-%m-%d %H:%M:%S',
handlers = handlers
)

View file

@ -1,11 +1,10 @@
from __future__ import annotations
import Crypto
import aputils
import asyncio
import click
import json
import os
import platform
import typing
from pathlib import Path
@ -21,19 +20,10 @@ from .database import RELAY_SOFTWARE, get_database
from .misc import ACTOR_FORMATS, SOFTWARE, IS_DOCKER, Message
if typing.TYPE_CHECKING:
from tinysql import Row
from bsql import Row
from typing import Any
# pylint: disable=unsubscriptable-object,unsupported-assignment-operation
CONFIG_IGNORE = (
'schema-version',
'private-key'
)
def check_alphanumeric(text: str) -> str:
if not text.isalnum():
raise click.BadParameter('String not alphanumeric')
@ -41,31 +31,44 @@ def check_alphanumeric(text: str) -> str:
return text
@click.group('cli', context_settings={'show_default': True}, invoke_without_command=True)
@click.option('--config', '-c', help='path to the relay\'s config')
@click.group('cli', context_settings = {'show_default': True})
@click.option('--config', '-c', type = Path, help = 'path to the relay\'s config')
@click.version_option(version = __version__, prog_name = 'ActivityRelay')
@click.pass_context
def cli(ctx: click.Context, config: str | None) -> None:
def cli(ctx: click.Context, config: Path | None) -> None:
if IS_DOCKER:
config = Path("/data/relay.yaml")
# The database was named "relay.jsonld" even though it's an sqlite file. Fix it.
db = Path('/data/relay.sqlite3')
wrongdb = Path('/data/relay.jsonld')
if wrongdb.exists() and not db.exists():
try:
with wrongdb.open('rb') as fd:
json.load(fd)
except json.JSONDecodeError:
wrongdb.rename(db)
ctx.obj = Application(config)
if not ctx.invoked_subcommand:
if ctx.obj.config.domain.endswith('example.com'):
cli_setup.callback()
else:
click.echo(
'[DEPRECATED] Running the relay without the "run" command will be removed in the ' +
'future.'
)
cli_run.callback()
@cli.command('setup')
@click.option('--skip-questions', '-s', is_flag = True, help = 'Just setup the database')
@click.pass_context
def cli_setup(ctx: click.Context) -> None:
def cli_setup(ctx: click.Context, skip_questions: bool) -> None:
'Generate a new config and create the database'
if ctx.obj.signer is not None:
if not click.prompt('The database is already setup. Are you sure you want to continue?'):
return
if skip_questions and ctx.obj.config.domain.endswith('example.com'):
click.echo('You cannot skip the questions if the relay is not configured yet')
return
if not skip_questions:
while True:
ctx.obj.config.domain = click.prompt(
'What domain will the relay be hosted on?',
@ -95,7 +98,7 @@ def cli_setup(ctx: click.Context) -> None:
type = click.Choice(['postgres', 'sqlite'], case_sensitive = False)
)
if ctx.obj.config.db_type == 'sqlite':
if ctx.obj.config.db_type == 'sqlite' and not IS_DOCKER:
ctx.obj.config.sq_path = click.prompt(
'Where should the database be stored?',
default = ctx.obj.config.sq_path
@ -183,8 +186,12 @@ def cli_setup(ctx: click.Context) -> None:
for key, value in config.items():
conn.put_config(key, value)
if not IS_DOCKER and click.confirm('Relay all setup! Would you like to run it now?'):
cli_run.callback()
if IS_DOCKER:
click.echo("Relay all setup! Start the container to run the relay.")
return
if click.confirm('Relay all setup! Would you like to run it now?'):
cli_run.callback() # type: ignore
@cli.command('run')
@ -193,28 +200,13 @@ def cli_setup(ctx: click.Context) -> None:
def cli_run(ctx: click.Context, dev: bool = False) -> None:
'Run the relay'
if ctx.obj.config.domain.endswith('example.com') or not ctx.obj.signer:
click.echo(
'Relay is not set up. Please edit your relay config or run "activityrelay setup".'
)
if ctx.obj.config.domain.endswith('example.com') or ctx.obj.signer is None:
if not IS_DOCKER:
click.echo('Relay is not set up. Please run "activityrelay setup".')
return
vers_split = platform.python_version().split('.')
pip_command = 'pip3 uninstall pycrypto && pip3 install pycryptodome'
if Crypto.__version__ == '2.6.1':
if int(vers_split[1]) > 7:
click.echo(
'Error: PyCrypto is broken on Python 3.8+. Please replace it with pycryptodome ' +
'before running again. Exiting...'
)
click.echo(pip_command)
return
click.echo('Warning: PyCrypto is old and should be replaced with pycryptodome')
click.echo(pip_command)
cli_setup.callback() # type: ignore
return
ctx.obj['dev'] = dev
@ -257,7 +249,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
conn.put_config('note', config['note'])
conn.put_config('whitelist-enabled', config['whitelist_enabled'])
with click.progressbar(
with click.progressbar( # type: ignore
database['relay-list'].values(),
label = 'Inboxes'.ljust(15),
width = 0
@ -281,7 +273,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
software = inbox['software']
)
with click.progressbar(
with click.progressbar( # type: ignore
config['blocked_software'],
label = 'Banned software'.ljust(15),
width = 0
@ -293,7 +285,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
reason = 'relay' if software in RELAY_SOFTWARE else None
)
with click.progressbar(
with click.progressbar( # type: ignore
config['blocked_instances'],
label = 'Banned domains'.ljust(15),
width = 0
@ -302,7 +294,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
for domain in banned_software:
conn.put_domain_ban(domain)
with click.progressbar(
with click.progressbar( # type: ignore
config['whitelist'],
label = 'Whitelist'.ljust(15),
width = 0
@ -339,10 +331,17 @@ def cli_config_list(ctx: click.Context) -> None:
click.echo('Relay Config:')
with ctx.obj.database.session() as conn:
for key, value in conn.get_config_all().items():
if key not in CONFIG_IGNORE:
key = f'{key}:'.ljust(20)
click.echo(f'- {key} {value}')
config = conn.get_config_all()
for key, value in config.to_dict().items():
if key in type(config).SYSTEM_KEYS():
continue
if key == 'log-level':
value = value.name
key_str = f'{key}:'.ljust(20)
click.echo(f'- {key_str} {repr(value)}')
@cli_config.command('set')
@ -477,7 +476,7 @@ def cli_inbox_list(ctx: click.Context) -> None:
click.echo('Connected to the following instances or relays:')
with ctx.obj.database.session() as conn:
for inbox in conn.execute('SELECT * FROM inboxes'):
for inbox in conn.get_inboxes():
click.echo(f'- {inbox["inbox"]}')
@ -520,7 +519,7 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
'Unfollow an actor (Relay must be running)'
inbox_data: Row = None
inbox_data: Row | None = None
with ctx.obj.database.session() as conn:
if conn.get_domain_ban(actor):
@ -540,6 +539,11 @@ def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
actor = f'https://{actor}/actor'
actor_data = asyncio.run(http.get(actor, sign_headers = True))
if not actor_data:
click.echo("Failed to fetch actor")
return
inbox = actor_data.shared_inbox
message = Message.new_unfollow(
host = ctx.obj.config.domain,
@ -618,6 +622,80 @@ def cli_inbox_remove(ctx: click.Context, inbox: str) -> None:
click.echo(f'Removed inbox from the database: {inbox}')
@cli.group('request')
def cli_request() -> None:
'Manage follow requests'
@cli_request.command('list')
@click.pass_context
def cli_request_list(ctx: click.Context) -> None:
'List all current follow requests'
click.echo('Follow requests:')
with ctx.obj.database.session() as conn:
for instance in conn.get_requests():
date = instance['created'].strftime('%Y-%m-%d')
click.echo(f'- [{date}] {instance["domain"]}')
@cli_request.command('accept')
@click.argument('domain')
@click.pass_context
def cli_request_accept(ctx: click.Context, domain: str) -> None:
'Accept a follow request'
try:
with ctx.obj.database.session() as conn:
instance = conn.put_request_response(domain, True)
except KeyError:
click.echo('Request not found')
return
message = Message.new_response(
host = ctx.obj.config.domain,
actor = instance['actor'],
followid = instance['followid'],
accept = True
)
asyncio.run(http.post(instance['inbox'], message, instance))
if instance['software'] != 'mastodon':
message = Message.new_follow(
host = ctx.obj.config.domain,
actor = instance['actor']
)
asyncio.run(http.post(instance['inbox'], message, instance))
@cli_request.command('deny')
@click.argument('domain')
@click.pass_context
def cli_request_deny(ctx: click.Context, domain: str) -> None:
'Accept a follow request'
try:
with ctx.obj.database.session() as conn:
instance = conn.put_request_response(domain, False)
except KeyError:
click.echo('Request not found')
return
response = Message.new_response(
host = ctx.obj.config.domain,
actor = instance['actor'],
followid = instance['followid'],
accept = False
)
asyncio.run(http.post(instance['inbox'], response, instance))
@cli.group('instance')
def cli_instance() -> None:
'Manage instance bans'
@ -893,7 +971,6 @@ def cli_whitelist_import(ctx: click.Context) -> None:
def main() -> None:
# pylint: disable=no-value-for-parameter
cli(prog_name='relay')

View file

@ -8,27 +8,44 @@ import typing
from aiohttp.web import Response as AiohttpResponse
from datetime import datetime
from pathlib import Path
from uuid import uuid4
try:
from importlib.resources import files as pkgfiles
except ImportError:
from importlib_resources import files as pkgfiles
from importlib_resources import files as pkgfiles # type: ignore
if typing.TYPE_CHECKING:
from pathlib import Path
from typing import Any
from .application import Application
try:
from typing import Self
except ImportError:
from typing_extensions import Self
T = typing.TypeVar('T')
ResponseType = typing.TypedDict('ResponseType', {
'status': int,
'headers': dict[str, typing.Any] | None,
'content_type': str,
'body': bytes | None,
'text': str | None
})
IS_DOCKER = bool(os.environ.get('DOCKER_RUNNING'))
MIMETYPES = {
'activity': 'application/activity+json',
'css': 'text/css',
'html': 'text/html',
'json': 'application/json',
'text': 'text/plain'
'text': 'text/plain',
'webmanifest': 'application/manifest+json'
}
NODEINFO_NS = {
@ -92,7 +109,7 @@ def check_open_port(host: str, port: int) -> bool:
def get_app() -> Application:
from .application import Application # pylint: disable=import-outside-toplevel
from .application import Application
if not Application.DEFAULT:
raise ValueError('No default application set')
@ -101,7 +118,7 @@ def get_app() -> Application:
def get_resource(path: str) -> Path:
return pkgfiles('relay').joinpath(path)
return Path(str(pkgfiles('relay'))).joinpath(path)
class JsonEncoder(json.JSONEncoder):
@ -114,18 +131,18 @@ class JsonEncoder(json.JSONEncoder):
class Message(aputils.Message):
@classmethod
def new_actor(cls: type[Message], # pylint: disable=arguments-differ
def new_actor(cls: type[Self], # type: ignore
host: str,
pubkey: str,
description: str | None = None) -> Message:
description: str | None = None,
approves: bool = False) -> Self:
return cls({
'@context': 'https://www.w3.org/ns/activitystreams',
return cls.new(aputils.ObjectType.APPLICATION, {
'id': f'https://{host}/actor',
'type': 'Application',
'preferredUsername': 'relay',
'name': 'ActivityRelay',
'summary': description or 'ActivityRelay bot',
'manuallyApprovesFollowers': approves,
'followers': f'https://{host}/followers',
'following': f'https://{host}/following',
'inbox': f'https://{host}/inbox',
@ -142,11 +159,9 @@ class Message(aputils.Message):
@classmethod
def new_announce(cls: type[Message], host: str, obj: str) -> Message:
return cls({
'@context': 'https://www.w3.org/ns/activitystreams',
def new_announce(cls: type[Self], host: str, obj: str | dict[str, Any]) -> Self:
return cls.new(aputils.ObjectType.ANNOUNCE, {
'id': f'https://{host}/activities/{uuid4()}',
'type': 'Announce',
'to': [f'https://{host}/followers'],
'actor': f'https://{host}/actor',
'object': obj
@ -154,23 +169,19 @@ class Message(aputils.Message):
@classmethod
def new_follow(cls: type[Message], host: str, actor: str) -> Message:
return cls({
'@context': 'https://www.w3.org/ns/activitystreams',
'type': 'Follow',
def new_follow(cls: type[Self], host: str, actor: str) -> Self:
return cls.new(aputils.ObjectType.FOLLOW, {
'id': f'https://{host}/activities/{uuid4()}',
'to': [actor],
'object': actor,
'id': f'https://{host}/activities/{uuid4()}',
'actor': f'https://{host}/actor'
})
@classmethod
def new_unfollow(cls: type[Message], host: str, actor: str, follow: str) -> Message:
return cls({
'@context': 'https://www.w3.org/ns/activitystreams',
def new_unfollow(cls: type[Self], host: str, actor: str, follow: dict[str, str]) -> Self:
return cls.new(aputils.ObjectType.UNDO, {
'id': f'https://{host}/activities/{uuid4()}',
'type': 'Undo',
'to': [actor],
'actor': f'https://{host}/actor',
'object': follow
@ -178,16 +189,9 @@ class Message(aputils.Message):
@classmethod
def new_response(cls: type[Message],
host: str,
actor: str,
followid: str,
accept: bool) -> Message:
return cls({
'@context': 'https://www.w3.org/ns/activitystreams',
def new_response(cls: type[Self], host: str, actor: str, followid: str, accept: bool) -> Self:
return cls.new(aputils.ObjectType.ACCEPT if accept else aputils.ObjectType.REJECT, {
'id': f'https://{host}/activities/{uuid4()}',
'type': 'Accept' if accept else 'Reject',
'to': [actor],
'actor': f'https://{host}/actor',
'object': {
@ -206,16 +210,18 @@ class Response(AiohttpResponse):
@classmethod
def new(cls: type[Response],
body: str | bytes | dict = '',
def new(cls: type[Self],
body: str | bytes | dict | tuple | list | set = '',
status: int = 200,
headers: dict[str, str] | None = None,
ctype: str = 'text') -> Response:
ctype: str = 'text') -> Self:
kwargs = {
kwargs: ResponseType = {
'status': status,
'headers': headers,
'content_type': MIMETYPES[ctype]
'content_type': MIMETYPES[ctype],
'body': None,
'text': None
}
if isinstance(body, bytes):
@ -231,10 +237,10 @@ class Response(AiohttpResponse):
@classmethod
def new_error(cls: type[Response],
def new_error(cls: type[Self],
status: int,
body: str | bytes | dict,
ctype: str = 'text') -> Response:
ctype: str = 'text') -> Self:
if ctype == 'json':
body = {'error': body}
@ -243,14 +249,14 @@ class Response(AiohttpResponse):
@classmethod
def new_redir(cls: type[Response], path: str) -> Response:
def new_redir(cls: type[Self], path: str) -> Self:
body = f'Redirect to <a href="{path}">{path}</a>'
return cls.new(body, 302, {'Location': path})
@property
def location(self) -> str:
return self.headers.get('Location')
return self.headers.get('Location', '')
@location.setter

View file

@ -7,10 +7,10 @@ from .database import Connection
from .misc import Message
if typing.TYPE_CHECKING:
from .views import ActorView
from .views.activitypub import ActorView
def person_check(actor: str, software: str) -> bool:
def person_check(actor: Message, software: str | None) -> bool:
# pleroma and akkoma may use Person for the actor type for some reason
# akkoma changed this in 3.6.0
if software in {'akkoma', 'pleroma'} and actor.id == f'https://{actor.domain}/relay':
@ -35,8 +35,8 @@ async def handle_relay(view: ActorView, conn: Connection) -> None:
message = Message.new_announce(view.config.domain, view.message.object_id)
logging.debug('>> relay: %s', message)
for inbox in conn.distill_inboxes(view.message):
view.app.push_message(inbox, message, view.instance)
for instance in conn.distill_inboxes(view.message):
view.app.push_message(instance["inbox"], message, instance)
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
@ -53,8 +53,8 @@ async def handle_forward(view: ActorView, conn: Connection) -> None:
message = Message.new_announce(view.config.domain, view.message)
logging.debug('>> forward: %s', message)
for inbox in conn.distill_inboxes(view.message):
view.app.push_message(inbox, message, view.instance)
for instance in conn.distill_inboxes(view.message):
view.app.push_message(instance["inbox"], await view.request.read(), instance)
view.cache.set('handle-relay', view.message.id, message.id, 'str')
@ -62,9 +62,12 @@ async def handle_forward(view: ActorView, conn: Connection) -> None:
async def handle_follow(view: ActorView, conn: Connection) -> None:
nodeinfo = await view.client.fetch_nodeinfo(view.actor.domain)
software = nodeinfo.sw_name if nodeinfo else None
config = conn.get_config_all()
# reject if software used by actor is banned
if conn.get_software_ban(software):
if software and conn.get_software_ban(software):
logging.verbose('Rejected banned actor: %s', view.actor.id)
view.app.push_message(
view.actor.shared_inbox,
Message.new_response(
@ -72,7 +75,8 @@ async def handle_follow(view: ActorView, conn: Connection) -> None:
actor = view.actor.id,
followid = view.message.id,
accept = False
)
),
view.instance
)
logging.verbose(
@ -83,8 +87,10 @@ async def handle_follow(view: ActorView, conn: Connection) -> None:
return
## reject if the actor is not an instance actor
# reject if the actor is not an instance actor
if person_check(view.actor, software):
logging.verbose('Non-application actor tried to follow: %s', view.actor.id)
view.app.push_message(
view.actor.shared_inbox,
Message.new_response(
@ -92,23 +98,54 @@ async def handle_follow(view: ActorView, conn: Connection) -> None:
actor = view.actor.id,
followid = view.message.id,
accept = False
)
),
view.instance
)
return
if not conn.get_domain_whitelist(view.actor.domain):
# add request if approval-required is enabled
if config.approval_required:
logging.verbose('New follow request fromm actor: %s', view.actor.id)
with conn.transaction():
view.instance = conn.put_inbox(
domain = view.actor.domain,
inbox = view.actor.shared_inbox,
actor = view.actor.id,
followid = view.message.id,
software = software,
accepted = False
)
return
# reject if the actor isn't whitelisted while the whiltelist is enabled
if config.whitelist_enabled:
logging.verbose('Rejected actor for not being in the whitelist: %s', view.actor.id)
view.app.push_message(
view.actor.shared_inbox,
Message.new_response(
host = view.config.domain,
actor = view.actor.id,
followid = view.message.id,
accept = False
),
view.instance
)
logging.verbose('Non-application actor tried to follow: %s', view.actor.id)
return
with conn.transaction():
if conn.get_inbox(view.actor.shared_inbox):
view.instance = conn.update_inbox(view.actor.shared_inbox, followid = view.message.id)
else:
view.instance = conn.put_inbox(
view.actor.domain,
view.actor.shared_inbox,
view.actor.id,
view.message.id,
software
domain = view.actor.domain,
inbox = view.actor.shared_inbox,
actor = view.actor.id,
followid = view.message.id,
software = software,
accepted = True
)
view.app.push_message(
@ -136,7 +173,7 @@ async def handle_follow(view: ActorView, conn: Connection) -> None:
async def handle_undo(view: ActorView, conn: Connection) -> None:
## If the object is not a Follow, forward it
# If the object is not a Follow, forward it
if view.message.object['type'] != 'Follow':
await handle_forward(view, conn)
return
@ -150,7 +187,7 @@ async def handle_undo(view: ActorView, conn: Connection) -> None:
logging.verbose(
'Failed to delete "%s" with follow ID "%s"',
view.actor.id,
view.message.object['id']
view.message.object_id
)
view.app.push_message(
@ -189,15 +226,15 @@ async def run_processor(view: ActorView) -> None:
if not view.instance['software']:
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])):
with conn.transaction():
view.instance = conn.update_inbox(
view.instance['inbox'],
view.instance = conn.put_inbox(
domain = view.instance['domain'],
software = nodeinfo.sw_name
)
if not view.instance['actor']:
with conn.transaction():
view.instance = conn.update_inbox(
view.instance['inbox'],
view.instance = conn.put_inbox(
domain = view.instance['domain'],
actor = view.actor.id
)

View file

@ -1,15 +1,22 @@
from __future__ import annotations
import textwrap
import typing
from hamlish_jinja.extension import HamlishExtension
from collections.abc import Callable
from hamlish_jinja import HamlishExtension
from jinja2 import Environment, FileSystemLoader
from jinja2.ext import Extension
from jinja2.nodes import CallBlock
from markdown import Markdown
from . import __version__
from .database.config import THEMES
from .misc import get_resource
if typing.TYPE_CHECKING:
from jinja2.nodes import Node
from jinja2.parser import Parser
from typing import Any
from .application import Application
from .views.base import View
@ -22,7 +29,8 @@ class Template(Environment):
trim_blocks = True,
lstrip_blocks = True,
extensions = [
HamlishExtension
HamlishExtension,
MarkdownExtension
],
loader = FileSystemLoader([
get_resource('frontend'),
@ -36,16 +44,52 @@ class Template(Environment):
def render(self, path: str, view: View | None = None, **context: Any) -> str:
with self.app.database.session(False) as s:
config = s.get_config_all()
with self.app.database.session(False) as conn:
config = conn.get_config_all()
new_context = {
'view': view,
'domain': self.app.config.domain,
'version': __version__,
'config': config,
'theme_name': config['theme'] or 'Default',
**(context or {})
}
return self.get_template(path).render(new_context)
def render_markdown(self, text: str) -> str:
return self._render_markdown(text) # type: ignore
class MarkdownExtension(Extension):
tags = {'markdown'}
extensions = (
'attr_list',
'smarty',
'tables'
)
def __init__(self, environment: Environment):
Extension.__init__(self, environment)
self._markdown = Markdown(extensions = MarkdownExtension.extensions)
environment.extend(
_render_markdown = self._render_markdown
)
def parse(self, parser: Parser) -> Node | list[Node]:
lineno = next(parser.stream).lineno
body = parser.parse_statements(
('name:endmarkdown',),
drop_needle = True
)
output = CallBlock(self.call_method('_render_markdown'), [], [], body)
return output.set_lineno(lineno)
def _render_markdown(self, caller: Callable[[], str] | str) -> str:
text = caller if isinstance(caller, str) else caller()
return self._markdown.convert(textwrap.dedent(text.strip('\n')))

View file

@ -1,4 +1,4 @@
from __future__ import annotations
from . import activitypub, api, frontend, misc
from .base import VIEWS
from .base import VIEWS, View

View file

@ -12,27 +12,31 @@ from ..processors import run_processor
if typing.TYPE_CHECKING:
from aiohttp.web import Request
from tinysql import Row
from bsql import Row
# pylint: disable=unused-argument
@register_route('/actor', '/inbox')
class ActorView(View):
signature: aputils.Signature
message: Message
actor: Message
instancce: Row
signer: aputils.Signer
def __init__(self, request: Request):
View.__init__(self, request)
self.signature: aputils.Signature = None
self.message: Message = None
self.actor: Message = None
self.instance: Row = None
self.signer: aputils.Signer = None
async def get(self, request: Request) -> Response:
with self.database.session(False) as conn:
config = conn.get_config_all()
data = Message.new_actor(
host = self.config.domain,
pubkey = self.app.signer.pubkey
pubkey = self.app.signer.pubkey,
description = self.app.template.render_markdown(config.note),
approves = config.approval_required
)
return Response.new(data, ctype='activity')
@ -44,19 +48,13 @@ class ActorView(View):
with self.database.session() as conn:
self.instance = conn.get_inbox(self.actor.shared_inbox)
config = conn.get_config_all()
## reject if the actor isn't whitelisted while the whiltelist is enabled
if config['whitelist-enabled'] and not conn.get_domain_whitelist(self.actor.domain):
logging.verbose('Rejected actor for not being in the whitelist: %s', self.actor.id)
return Response.new_error(403, 'access denied', 'json')
## reject if actor is banned
# reject if actor is banned
if conn.get_domain_ban(self.actor.domain):
logging.verbose('Ignored request from banned actor: %s', self.actor.id)
return Response.new_error(403, 'access denied', 'json')
## reject if activity type isn't 'Follow' and the actor isn't following
# reject if activity type isn't 'Follow' and the actor isn't following
if self.message.type != 'Follow' and not self.instance:
logging.verbose(
'Rejected actor for trying to post while not following: %s',
@ -73,35 +71,33 @@ class ActorView(View):
async def get_post_data(self) -> Response | None:
try:
self.signature = aputils.Signature.new_from_signature(self.request.headers['signature'])
self.signature = aputils.Signature.parse(self.request.headers['signature'])
except KeyError:
logging.verbose('Missing signature header')
return Response.new_error(400, 'missing signature header', 'json')
try:
self.message = await self.request.json(loads = Message.parse)
message: Message | None = await self.request.json(loads = Message.parse)
except Exception:
traceback.print_exc()
logging.verbose('Failed to parse inbox message')
return Response.new_error(400, 'failed to parse message', 'json')
if self.message is None:
if message is None:
logging.verbose('empty message')
return Response.new_error(400, 'missing message', 'json')
self.message = message
if 'actor' not in self.message:
logging.verbose('actor not in message')
return Response.new_error(400, 'no actor in message', 'json')
self.actor = await self.client.get(
self.signature.keyid,
sign_headers = True,
loads = Message.parse
)
actor: Message | None = await self.client.get(self.signature.keyid, True, Message)
if not self.actor:
if actor is None:
# ld signatures aren't handled atm, so just ignore it
if self.message.type == 'Delete':
logging.verbose('Instance sent a delete which cannot be handled')
@ -110,6 +106,8 @@ class ActorView(View):
logging.verbose(f'Failed to fetch actor: {self.signature.keyid}')
return Response.new_error(400, 'failed to fetch actor', 'json')
self.actor = actor
try:
self.signer = self.actor.signer
@ -118,42 +116,13 @@ class ActorView(View):
return Response.new_error(400, 'actor missing public key', 'json')
try:
self.validate_signature(await self.request.read())
await self.signer.validate_request_async(self.request)
except aputils.SignatureFailureError as e:
logging.verbose('signature validation failed for "%s": %s', self.actor.id, e)
return Response.new_error(401, str(e), 'json')
def validate_signature(self, body: bytes) -> None:
headers = {key.lower(): value for key, value in self.request.headers.items()}
headers["(request-target)"] = " ".join([self.request.method.lower(), self.request.path])
if (digest := aputils.Digest.new_from_digest(headers.get("digest"))):
if not body:
raise aputils.SignatureFailureError("Missing body for digest verification")
if not digest.validate(body):
raise aputils.SignatureFailureError("Body digest does not match")
if self.signature.algorithm_type == "hs2019":
if "(created)" not in self.signature.headers:
raise aputils.SignatureFailureError("'(created)' header not used")
current_timestamp = aputils.HttpDate.new_utc().timestamp()
if self.signature.created > current_timestamp:
raise aputils.SignatureFailureError("Creation date after current date")
if current_timestamp > self.signature.expires:
raise aputils.SignatureFailureError("Expiration date before current date")
headers["(created)"] = self.signature.created
headers["(expires)"] = self.signature.expires
# pylint: disable=protected-access
if not self.signer._validate_signature(headers, self.signature):
raise aputils.SignatureFailureError("Signature does not match")
return None
@register_route('/.well-known/webfinger')

View file

@ -9,23 +9,22 @@ from urllib.parse import urlparse
from .base import View, register_route
from .. import __version__
from .. import logger as logging
from ..database.config import CONFIG_DEFAULTS
from ..misc import Message, Response
from ..database import ConfigData
from ..misc import Message, Response, boolean, get_app
if typing.TYPE_CHECKING:
from aiohttp.web import Request
from collections.abc import Coroutine
from collections.abc import Callable, Sequence
from typing import Any
CONFIG_IGNORE = (
'schema-version',
'private-key'
)
ALLOWED_HEADERS = {
'accept',
'authorization',
'content-type'
}
CONFIG_VALID = {key for key in CONFIG_DEFAULTS if key not in CONFIG_IGNORE}
PUBLIC_API_PATHS: tuple[tuple[str, str]] = (
PUBLIC_API_PATHS: Sequence[tuple[str, str]] = (
('GET', '/api/v1/relay'),
('GET', '/api/v1/instance'),
('POST', '/api/v1/token')
@ -40,28 +39,36 @@ def check_api_path(method: str, path: str) -> bool:
@web.middleware
async def handle_api_path(request: web.Request, handler: Coroutine) -> web.Response:
async def handle_api_path(request: Request, handler: Callable) -> Response:
try:
if (token := request.cookies.get('user-token')):
request['token'] = token
else:
request['token'] = request.headers['Authorization'].replace('Bearer', '').strip()
with request.app.database.session() as conn:
with get_app().database.session() as conn:
request['user'] = conn.get_user_by_token(request['token'])
except (KeyError, ValueError):
request['token'] = None
request['user'] = None
if check_api_path(request.method, request.path):
if request.method != "OPTIONS" and check_api_path(request.method, request.path):
if not request['token']:
return Response.new_error(401, 'Missing token', 'json')
if not request['user']:
return Response.new_error(401, 'Invalid token', 'json')
return await handler(request)
response = await handler(request)
if request.path.startswith('/api'):
response.headers['Access-Control-Allow-Origin'] = '*'
response.headers['Access-Control-Allow-Headers'] = ', '.join(ALLOWED_HEADERS)
return response
# pylint: disable=no-self-use,unused-argument
@register_route('/api/v1/token')
class Login(View):
@ -87,7 +94,19 @@ class Login(View):
token = conn.put_token(data['username'])
return Response.new({'token': token['code']}, ctype = 'json')
resp = Response.new({'token': token['code']}, ctype = 'json')
resp.set_cookie(
'user-token',
token['code'],
max_age = 60 * 60 * 24 * 365,
domain = self.config.domain,
path = '/',
secure = True,
httponly = False,
samesite = 'lax'
)
return resp
async def delete(self, request: Request) -> Response:
@ -102,14 +121,14 @@ class RelayInfo(View):
async def get(self, request: Request) -> Response:
with self.database.session() as conn:
config = conn.get_config_all()
inboxes = [row['domain'] for row in conn.execute('SELECT * FROM inboxes')]
inboxes = [row['domain'] for row in conn.get_inboxes()]
data = {
'domain': self.config.domain,
'name': config['name'],
'description': config['note'],
'name': config.name,
'description': config.note,
'version': __version__,
'whitelist_enabled': config['whitelist-enabled'],
'whitelist_enabled': config.whitelist_enabled,
'email': None,
'admin': None,
'icon': None,
@ -122,12 +141,17 @@ class RelayInfo(View):
@register_route('/api/v1/config')
class Config(View):
async def get(self, request: Request) -> Response:
with self.database.session() as conn:
data = conn.get_config_all()
data['log-level'] = data['log-level'].name
data = {}
for key in CONFIG_IGNORE:
del data[key]
with self.database.session() as conn:
for key, value in conn.get_config_all().to_dict().items():
if key in ConfigData.SYSTEM_KEYS():
continue
if key == 'log-level':
value = value.name
data[key] = value
return Response.new(data, ctype = 'json')
@ -138,7 +162,9 @@ class Config(View):
if isinstance(data, Response):
return data
if data['key'] not in CONFIG_VALID:
data['key'] = data['key'].replace('-', '_')
if data['key'] not in ConfigData.USER_KEYS():
return Response.new_error(400, 'Invalid key', 'json')
with self.database.session() as conn:
@ -153,11 +179,11 @@ class Config(View):
if isinstance(data, Response):
return data
if data['key'] not in CONFIG_VALID:
if data['key'] not in ConfigData.USER_KEYS():
return Response.new_error(400, 'Invalid key', 'json')
with self.database.session() as conn:
conn.put_config(data['key'], CONFIG_DEFAULTS[data['key']][1])
conn.put_config(data['key'], ConfigData.DEFAULT(data['key']))
return Response.new({'message': 'Updated config'}, ctype = 'json')
@ -166,7 +192,7 @@ class Config(View):
class Inbox(View):
async def get(self, request: Request) -> Response:
with self.database.session() as conn:
data = tuple(conn.execute('SELECT * FROM inboxes').all())
data = conn.get_inboxes()
return Response.new(data, ctype = 'json')
@ -184,18 +210,18 @@ class Inbox(View):
return Response.new_error(404, 'Instance already in database', 'json')
if not data.get('inbox'):
try:
actor_data = await self.client.get(
data['actor'],
sign_headers = True,
loads = Message.parse
)
actor_data: Message | None = await self.client.get(data['actor'], True, Message)
if actor_data is None:
return Response.new_error(500, 'Failed to fetch actor', 'json')
data['inbox'] = actor_data.shared_inbox
except Exception as e:
logging.error('Failed to fetch actor: %s', str(e))
return Response.new_error(500, 'Failed to fetch actor', 'json')
if not data.get('software'):
nodeinfo = await self.client.fetch_nodeinfo(data['domain'])
if nodeinfo is not None:
data['software'] = nodeinfo.sw_name
row = conn.put_inbox(**data)
@ -212,12 +238,12 @@ class Inbox(View):
if not (instance := conn.get_inbox(data['domain'])):
return Response.new_error(404, 'Instance with domain not found', 'json')
instance = conn.update_inbox(instance['inbox'], **data)
instance = conn.put_inbox(instance['domain'], **data)
return Response.new(instance, ctype = 'json')
async def delete(self, request: Request, domain: str) -> Response:
async def delete(self, request: Request) -> Response:
with self.database.session() as conn:
data = await self.get_api_data(['domain'], [])
@ -232,6 +258,47 @@ class Inbox(View):
return Response.new({'message': 'Deleted instance'}, ctype = 'json')
@register_route('/api/v1/request')
class RequestView(View):
async def get(self, request: Request) -> Response:
with self.database.session() as conn:
instances = conn.get_requests()
return Response.new(instances, ctype = 'json')
async def post(self, request: Request) -> Response:
data: dict[str, Any] | Response = await self.get_api_data(['domain', 'accept'], [])
data['accept'] = boolean(data['accept'])
try:
with self.database.session(True) as conn:
instance = conn.put_request_response(data['domain'], data['accept'])
except KeyError:
return Response.new_error(404, 'Request not found', 'json')
message = Message.new_response(
host = self.config.domain,
actor = instance['actor'],
followid = instance['followid'],
accept = data['accept']
)
self.app.push_message(instance['inbox'], message, instance)
if data['accept'] and instance['software'] != 'mastodon':
message = Message.new_follow(
host = self.config.domain,
actor = instance['actor']
)
self.app.push_message(instance['inbox'], message, instance)
resp_message = {'message': 'Request accepted' if data['accept'] else 'Request denied'}
return Response.new(resp_message, ctype = 'json')
@register_route('/api/v1/domain_ban')
class DomainBan(View):
async def get(self, request: Request) -> Response:
@ -269,7 +336,7 @@ class DomainBan(View):
if not any([data.get('note'), data.get('reason')]):
return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
ban = conn.update_domain_ban(data['domain'], **data)
ban = conn.update_domain_ban(**data)
return Response.new(ban, ctype = 'json')
@ -326,7 +393,7 @@ class SoftwareBan(View):
if not any([data.get('note'), data.get('reason')]):
return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
ban = conn.update_software_ban(data['name'], **data)
ban = conn.update_software_ban(**data)
return Response.new(ban, ctype = 'json')
@ -346,6 +413,63 @@ class SoftwareBan(View):
return Response.new({'message': 'Unbanned software'}, ctype = 'json')
@register_route('/api/v1/user')
class User(View):
async def get(self, request: Request) -> Response:
with self.database.session() as conn:
items = []
for row in conn.execute('SELECT * FROM users'):
del row['hash']
items.append(row)
return Response.new(items, ctype = 'json')
async def post(self, request: Request) -> Response:
data = await self.get_api_data(['username', 'password'], ['handle'])
if isinstance(data, Response):
return data
with self.database.session() as conn:
if conn.get_user(data['username']):
return Response.new_error(404, 'User already exists', 'json')
user = conn.put_user(**data)
del user['hash']
return Response.new(user, ctype = 'json')
async def patch(self, request: Request) -> Response:
data = await self.get_api_data(['username'], ['password', 'handle'])
if isinstance(data, Response):
return data
with self.database.session(True) as conn:
user = conn.put_user(**data)
del user['hash']
return Response.new(user, ctype = 'json')
async def delete(self, request: Request) -> Response:
data = await self.get_api_data(['username'], [])
if isinstance(data, Response):
return data
with self.database.session(True) as conn:
if not conn.get_user(data['username']):
return Response.new_error(404, 'User does not exist', 'json')
conn.del_user(data['username'])
return Response.new({'message': 'Deleted user'}, ctype = 'json')
@register_route('/api/v1/whitelist')
class Whitelist(View):
async def get(self, request: Request) -> Response:

View file

@ -2,40 +2,52 @@ from __future__ import annotations
import typing
from Crypto.Random import get_random_bytes
from aiohttp.abc import AbstractView
from aiohttp.hdrs import METH_ALL as METHODS
from aiohttp.web import HTTPMethodNotAllowed
from base64 import b64encode
from functools import cached_property
from json.decoder import JSONDecodeError
from ..misc import Response
from ..misc import Response, get_app
if typing.TYPE_CHECKING:
from aiohttp.web import Request
from collections.abc import Callable, Coroutine, Generator
from collections.abc import Callable, Generator, Sequence, Mapping
from bsql import Database
from typing import Any, Self
from typing import Any
from ..application import Application
from ..cache import Cache
from ..config import Config
from ..http_client import HttpClient
from ..template import Template
try:
from typing import Self
VIEWS = []
except ImportError:
from typing_extensions import Self
VIEWS: list[tuple[str, type[View]]] = []
def convert_data(data: Mapping[str, Any]) -> dict[str, str]:
return {key: str(value) for key, value in data.items()}
def register_route(*paths: str) -> Callable:
def wrapper(view: View) -> View:
def wrapper(view: type[View]) -> type[View]:
for path in paths:
VIEWS.append([path, view])
VIEWS.append((path, view))
return view
return wrapper
class View(AbstractView):
def __await__(self) -> Generator[Response]:
def __await__(self) -> Generator[Any, None, Response]:
if self.request.method not in METHODS:
raise HTTPMethodNotAllowed(self.request.method, self.allowed_methods)
@ -46,22 +58,27 @@ class View(AbstractView):
@classmethod
async def run(cls: type[Self], method: str, request: Request, **kwargs: Any) -> Self:
async def run(cls: type[Self], method: str, request: Request, **kwargs: Any) -> Response:
view = cls(request)
return await view.handlers[method](request, **kwargs)
async def _run_handler(self, handler: Coroutine, **kwargs: Any) -> Response:
async def _run_handler(self, handler: Callable[..., Any], **kwargs: Any) -> Response:
self.request['hash'] = b64encode(get_random_bytes(16)).decode('ascii')
return await handler(self.request, **self.request.match_info, **kwargs)
async def options(self, request: Request) -> Response:
return Response.new()
@cached_property
def allowed_methods(self) -> tuple[str]:
def allowed_methods(self) -> Sequence[str]:
return tuple(self.handlers.keys())
@cached_property
def handlers(self) -> dict[str, Coroutine]:
def handlers(self) -> dict[str, Callable[..., Any]]:
data = {}
for method in METHODS:
@ -74,10 +91,9 @@ class View(AbstractView):
return data
# app components
@property
def app(self) -> Application:
return self.request.app
return get_app()
@property
@ -110,17 +126,17 @@ class View(AbstractView):
optional: list[str]) -> dict[str, str] | Response:
if self.request.content_type in {'x-www-form-urlencoded', 'multipart/form-data'}:
post_data = await self.request.post()
post_data = convert_data(await self.request.post())
elif self.request.content_type == 'application/json':
try:
post_data = await self.request.json()
post_data = convert_data(await self.request.json())
except JSONDecodeError:
return Response.new_error(400, 'Invalid JSON data', 'json')
else:
post_data = self.request.query
post_data = convert_data(self.request.query)
data = {}
@ -132,6 +148,6 @@ class View(AbstractView):
return Response.new_error(400, f'Missing {str(e)} pararmeter', 'json')
for key in optional:
data[key] = post_data.get(key)
data[key] = post_data.get(key, '')
return data

View file

@ -3,60 +3,59 @@ from __future__ import annotations
import typing
from aiohttp import web
from argon2.exceptions import VerifyMismatchError
from urllib.parse import urlparse
from .base import View, register_route
from ..database import CONFIG_DEFAULTS, THEMES
from ..database import THEMES
from ..logger import LogLevel
from ..misc import ACTOR_FORMATS, Message, Response
from ..misc import Response, get_app
if typing.TYPE_CHECKING:
from aiohttp.web import Request
from collections.abc import Coroutine
from collections.abc import Callable
from typing import Any
# pylint: disable=no-self-use
UNAUTH_ROUTES = {
'/',
'/login'
}
CONFIG_IGNORE = (
'schema-version',
'private-key'
)
@web.middleware
async def handle_frontend_path(request: web.Request, handler: Coroutine) -> Response:
async def handle_frontend_path(request: web.Request, handler: Callable) -> Response:
app = get_app()
if request.path in UNAUTH_ROUTES or request.path.startswith('/admin'):
request['token'] = request.cookies.get('user-token')
request['user'] = None
if request['token']:
with request.app.database.session(False) as conn:
with app.database.session(False) as conn:
request['user'] = conn.get_user_by_token(request['token'])
if request['user'] and request.path == '/login':
return Response.new('', 302, {'Location': '/'})
if not request['user'] and request.path.startswith('/admin'):
return Response.new('', 302, {'Location': f'/login?redir={request.path}'})
response = Response.new('', 302, {'Location': f'/login?redir={request.path}'})
response.del_cookie('user-token')
return response
return await handler(request)
response = await handler(request)
if not request.path.startswith('/api') and not request['user'] and request['token']:
response.del_cookie('user-token')
return response
# pylint: disable=unused-argument
@register_route('/')
class HomeView(View):
async def get(self, request: Request) -> Response:
with self.database.session() as conn:
context = {
'instances': tuple(conn.execute('SELECT * FROM inboxes').all())
context: dict[str, Any] = {
'instances': tuple(conn.get_inboxes())
}
data = self.template.render('page/home.haml', self, **context)
@ -70,47 +69,6 @@ class Login(View):
return Response.new(data, ctype = 'html')
async def post(self, request: Request) -> Response:
form = await request.post()
params = {}
with self.database.session(True) as conn:
if not (user := conn.get_user(form['username'])):
params = {
'username': form['username'],
'error': 'User not found'
}
else:
try:
conn.hasher.verify(user['hash'], form['password'])
except VerifyMismatchError:
params = {
'username': form['username'],
'error': 'Invalid password'
}
if params:
data = self.template.render('page/login.haml', self, **params)
return Response.new(data, ctype = 'html')
token = conn.put_token(user['username'])
resp = Response.new_redir(request.query.getone('redir', '/'))
resp.set_cookie(
'user-token',
token['code'],
max_age = 60 * 60 * 24 * 365,
domain = self.config.domain,
path = '/',
secure = True,
httponly = True,
samesite = 'Strict'
)
return resp
@register_route('/logout')
class Logout(View):
async def get(self, request: Request) -> Response:
@ -136,8 +94,9 @@ class AdminInstances(View):
message: str | None = None) -> Response:
with self.database.session() as conn:
context = {
'instances': tuple(conn.execute('SELECT * FROM inboxes').all())
context: dict[str, Any] = {
'instances': tuple(conn.get_inboxes()),
'requests': tuple(conn.get_requests())
}
if error:
@ -150,44 +109,6 @@ class AdminInstances(View):
return Response.new(data, ctype = 'html')
async def post(self, request: Request) -> Response:
data = await request.post()
if not data.get('actor') and not data.get('domain'):
return await self.get(request, error = 'Missing actor and/or domain')
if not data.get('domain'):
data['domain'] = urlparse(data['actor']).netloc
if not data.get('software'):
nodeinfo = await self.client.fetch_nodeinfo(data['domain'])
data['software'] = nodeinfo.sw_name
if not data.get('actor') and data['software'] in ACTOR_FORMATS:
data['actor'] = ACTOR_FORMATS[data['software']].format(domain = data['domain'])
if not data.get('inbox') and data['actor']:
actor = await self.client.get(data['actor'], sign_headers = True, loads = Message.parse)
data['inbox'] = actor.shared_inbox
with self.database.session(True) as conn:
conn.put_inbox(**data)
return await self.get(request, message = "Added new inbox")
@register_route('/admin/instances/delete/{domain}')
class AdminInstancesDelete(View):
async def get(self, request: Request, domain: str) -> Response:
with self.database.session() as conn:
if not conn.get_inbox(domain):
return await AdminInstances(request).get(request, message = 'Instance not found')
conn.del_inbox(domain)
return await AdminInstances(request).get(request, message = 'Removed instance')
@register_route('/admin/whitelist')
class AdminWhitelist(View):
async def get(self,
@ -196,8 +117,8 @@ class AdminWhitelist(View):
message: str | None = None) -> Response:
with self.database.session() as conn:
context = {
'whitelist': tuple(conn.execute('SELECT * FROM whitelist').all())
context: dict[str, Any] = {
'whitelist': tuple(conn.execute('SELECT * FROM whitelist ORDER BY domain ASC'))
}
if error:
@ -210,34 +131,6 @@ class AdminWhitelist(View):
return Response.new(data, ctype = 'html')
async def post(self, request: Request) -> Response:
data = await request.post()
if not data['domain']:
return await self.get(request, error = 'Missing domain')
with self.database.session(True) as conn:
if conn.get_domain_whitelist(data['domain']):
return await self.get(request, message = "Domain already in whitelist")
conn.put_domain_whitelist(data['domain'])
return await self.get(request, message = "Added/updated domain ban")
@register_route('/admin/whitelist/delete/{domain}')
class AdminWhitlistDelete(View):
async def get(self, request: Request, domain: str) -> Response:
with self.database.session() as conn:
if not conn.get_domain_whitelist(domain):
msg = 'Whitelisted domain not found'
return await AdminWhitelist.run("GET", request, message = msg)
conn.del_domain_whitelist(domain)
return await AdminWhitelist.run("GET", request, message = 'Removed domain from whitelist')
@register_route('/admin/domain_bans')
class AdminDomainBans(View):
async def get(self,
@ -246,8 +139,8 @@ class AdminDomainBans(View):
message: str | None = None) -> Response:
with self.database.session() as conn:
context = {
'bans': tuple(conn.execute('SELECT * FROM domain_bans ORDER BY domain ASC').all())
context: dict[str, Any] = {
'bans': tuple(conn.execute('SELECT * FROM domain_bans ORDER BY domain ASC'))
}
if error:
@ -260,42 +153,6 @@ class AdminDomainBans(View):
return Response.new(data, ctype = 'html')
async def post(self, request: Request) -> Response:
data = await request.post()
if not data['domain']:
return await self.get(request, error = 'Missing domain')
with self.database.session(True) as conn:
if conn.get_domain_ban(data['domain']):
conn.update_domain_ban(
data['domain'],
data.get('reason'),
data.get('note')
)
else:
conn.put_domain_ban(
data['domain'],
data.get('reason'),
data.get('note')
)
return await self.get(request, message = "Added/updated domain ban")
@register_route('/admin/domain_bans/delete/{domain}')
class AdminDomainBansDelete(View):
async def get(self, request: Request, domain: str) -> Response:
with self.database.session() as conn:
if not conn.get_domain_ban(domain):
return await AdminDomainBans.run("GET", request, message = 'Domain ban not found')
conn.del_domain_ban(domain)
return await AdminDomainBans.run("GET", request, message = 'Unbanned domain')
@register_route('/admin/software_bans')
class AdminSoftwareBans(View):
async def get(self,
@ -304,8 +161,8 @@ class AdminSoftwareBans(View):
message: str | None = None) -> Response:
with self.database.session() as conn:
context = {
'bans': tuple(conn.execute('SELECT * FROM software_bans ORDER BY name ASC').all())
context: dict[str, Any] = {
'bans': tuple(conn.execute('SELECT * FROM software_bans ORDER BY name ASC'))
}
if error:
@ -318,42 +175,6 @@ class AdminSoftwareBans(View):
return Response.new(data, ctype = 'html')
async def post(self, request: Request) -> Response:
data = await request.post()
if not data['name']:
return await self.get(request, error = 'Missing name')
with self.database.session(True) as conn:
if conn.get_software_ban(data['name']):
conn.update_software_ban(
data['name'],
data.get('reason'),
data.get('note')
)
else:
conn.put_software_ban(
data['name'],
data.get('reason'),
data.get('note')
)
return await self.get(request, message = "Added/updated software ban")
@register_route('/admin/software_bans/delete/{name}')
class AdminSoftwareBansDelete(View):
async def get(self, request: Request, name: str) -> Response:
with self.database.session() as conn:
if not conn.get_software_ban(name):
return await AdminSoftwareBans.run("GET", request, message = 'Software ban not found')
conn.del_software_ban(name)
return await AdminSoftwareBans.run("GET", request, message = 'Unbanned software')
@register_route('/admin/users')
class AdminUsers(View):
async def get(self,
@ -362,8 +183,8 @@ class AdminUsers(View):
message: str | None = None) -> Response:
with self.database.session() as conn:
context = {
'users': tuple(conn.execute('SELECT * FROM users').all())
context: dict[str, Any] = {
'users': tuple(conn.execute('SELECT * FROM users ORDER BY username ASC'))
}
if error:
@ -376,82 +197,47 @@ class AdminUsers(View):
return Response.new(data, ctype = 'html')
async def post(self, request: Request) -> Response:
data = await request.post()
required_fields = {'username', 'password', 'password2'}
if not all(data.get(field) for field in required_fields):
return await self.get(request, error = 'Missing username and/or password')
if data['password'] != data['password2']:
return await self.get(request, error = 'Passwords do not match')
with self.database.session(True) as conn:
if conn.get_user(data['username']):
return await self.get(request, message = "User already exists")
conn.put_user(data['username'], data['password'], data['handle'])
return await self.get(request, message = "Added user")
@register_route('/admin/users/delete/{name}')
class AdminUsersDelete(View):
async def get(self, request: Request, name: str) -> Response:
with self.database.session() as conn:
if not conn.get_user(name):
return await AdminUsers.run("GET", request, message = 'User not found')
conn.del_user(name)
return await AdminUsers.run("GET", request, message = 'User deleted')
@register_route('/admin/config')
class AdminConfig(View):
async def get(self, request: Request, message: str | None = None) -> Response:
context = {
context: dict[str, Any] = {
'themes': tuple(THEMES.keys()),
'LogLevel': LogLevel,
'levels': tuple(level.name for level in LogLevel),
'message': message
}
data = self.template.render('page/admin-config.haml', self, **context)
return Response.new(data, ctype = 'html')
async def post(self, request: Request) -> Response:
form = dict(await request.post())
with self.database.session(True) as conn:
for key in CONFIG_DEFAULTS:
value = form.get(key)
if key == 'whitelist-enabled':
value = bool(value)
elif key.lower() in CONFIG_IGNORE:
continue
if value is None:
continue
conn.put_config(key, value)
return await self.get(request, message = 'Updated config')
@register_route('/style.css')
class StyleCss(View):
@register_route('/manifest.json')
class ManifestJson(View):
async def get(self, request: Request) -> Response:
data = self.template.render('style.css', self)
return Response.new(data, ctype = 'css')
with self.database.session(False) as conn:
config = conn.get_config_all()
theme = THEMES[config.theme]
data = {
'background_color': theme['background'],
'categories': ['activitypub'],
'description': 'Message relay for the ActivityPub network',
'display': 'standalone',
'name': config['name'],
'orientation': 'portrait',
'scope': f"https://{self.config.domain}/",
'short_name': 'ActivityRelay',
'start_url': f"https://{self.config.domain}/",
'theme_color': theme['primary']
}
return Response.new(data, ctype = 'webmanifest')
@register_route('/theme/{theme}.css')
class ThemeCss(View):
async def get(self, request: Request, theme: str) -> Response:
try:
context = {
context: dict[str, Any] = {
'theme': THEMES[theme]
}

View file

@ -27,28 +27,26 @@ if Path(__file__).parent.parent.joinpath('.git').exists():
pass
# pylint: disable=unused-argument
@register_route('/nodeinfo/{niversion:\\d.\\d}.json', '/nodeinfo/{niversion:\\d.\\d}')
class NodeinfoView(View):
# pylint: disable=no-self-use
async def get(self, request: Request, niversion: str) -> Response:
with self.database.session() as conn:
inboxes = conn.execute('SELECT * FROM inboxes').all()
inboxes = conn.get_inboxes()
data = {
'name': 'activityrelay',
'version': VERSION,
'protocols': ['activitypub'],
'open_regs': not conn.get_config('whitelist-enabled'),
'users': 1,
'metadata': {'peers': [inbox['domain'] for inbox in inboxes]}
nodeinfo = aputils.Nodeinfo.new(
name = 'activityrelay',
version = VERSION,
protocols = ['activitypub'],
open_regs = not conn.get_config('whitelist-enabled'),
users = 1,
repo = 'https://git.pleroma.social/pleroma/relay' if niversion == '2.1' else None,
metadata = {
'approval_required': conn.get_config('approval-required'),
'peers': [inbox['domain'] for inbox in inboxes]
}
)
if niversion == '2.1':
data['repo'] = 'https://git.pleroma.social/pleroma/relay'
return Response.new(aputils.Nodeinfo.new(**data), ctype = 'json')
return Response.new(nodeinfo, ctype = 'json')
@register_route('/.well-known/nodeinfo')

View file

@ -1,13 +0,0 @@
aiohttp>=3.9.1
aiohttp-swagger[performance]==1.0.16
aputils@https://git.barkshark.xyz/barkshark/aputils/archive/0.1.7.tar.gz
argon2-cffi==23.1.0
barkshark-sql@https://git.barkshark.xyz/barkshark/bsql/archive/0.1.2.tar.gz
click>=8.1.2
hamlish-jinja@https://git.barkshark.xyz/barkshark/hamlish-jinja/archive/0.3.5.tar.gz
hiredis==2.3.2
platformdirs==4.2.0
pyyaml>=6.0
redis==5.0.1
importlib_resources==6.1.1;python_version<'3.9'

View file

@ -1,49 +0,0 @@
[metadata]
name = relay
version = attr: relay.__version__
description = Generic LitePub relay (works with all LitePub consumers and Mastodon)
long_description = file: README.md
long_description_content_type = text/markdown; charset=UTF-8
url = https://git.pleroma.social/pleroma/relay
license = AGPLv3
license_file = LICENSE
classifiers =
Environment :: Console
License :: OSI Approved :: AGPLv3 License
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
Programming Language :: Python :: 3.12
project_urls =
Source = https://git.pleroma.social/pleroma/relay
Tracker = https://git.pleroma.social/pleroma/relay/-/issues
[options]
zip_safe = False
packages =
relay
relay.database
relay.views
include_package_data = true
install_requires = file: requirements.txt
python_requires = >=3.8
[options.extras_require]
dev = file: dev-requirements.txt
[options.package_data]
relay =
data/*
frontend/*
frontend/page/*
[options.entry_points]
console_scripts =
activityrelay = relay.manage:main
[flake8]
select = F401
per-file-ignores =
__init__.py: F401

View file

@ -1,4 +0,0 @@
import setuptools
if __name__ == "__main__":
setuptools.setup()

6
tox.ini Normal file
View file

@ -0,0 +1,6 @@
[flake8]
extend-ignore = E128,E251,E261,E303,W191
max-line-length = 100
indent-size = 4
per-file-ignores =
__init__.py: F401