Compare commits

..

No commits in common. "0709d8deb9145f65540c530d884c17b6114ade9b" and "6112734b2f62c86bd022ca5a49a740b6d4b0c62e" have entirely different histories.

8 changed files with 55 additions and 46 deletions

View file

@ -4,7 +4,6 @@ Description=ActivityPub Relay
[Service] [Service]
WorkingDirectory=/home/relay/relay WorkingDirectory=/home/relay/relay
ExecStart=/usr/bin/python3 -m relay run ExecStart=/usr/bin/python3 -m relay run
Environment="IS_SYSTEMD=1"
[Install] [Install]
WantedBy=multi-user.target WantedBy=multi-user.target

View file

@ -65,7 +65,7 @@ class Application(web.Application):
Application.DEFAULT = self Application.DEFAULT = self
self['running'] = False self['running'] = None
self['signer'] = None self['signer'] = None
self['start_time'] = None self['start_time'] = None
self['cleanup_thread'] = None self['cleanup_thread'] = None
@ -142,7 +142,7 @@ class Application(web.Application):
return timedelta(seconds=uptime.seconds) return timedelta(seconds=uptime.seconds)
def push_message(self, inbox: str, message: Message | bytes, instance: Row) -> None: def push_message(self, inbox: str, message: Message, instance: Row) -> None:
self['push_queue'].put((inbox, message, instance)) self['push_queue'].put((inbox, message, instance))

View file

@ -42,15 +42,15 @@ class Connection(SqlConnection):
return get_app() return get_app()
def distill_inboxes(self, message: Message) -> Iterator[Row]: def distill_inboxes(self, message: Message) -> Iterator[str]:
src_domains = { src_domains = {
message.domain, message.domain,
urlparse(message.object_id).netloc urlparse(message.object_id).netloc
} }
for instance in self.get_inboxes(): for inbox in self.get_inboxes():
if instance['domain'] not in src_domains: if inbox['domain'] not in src_domains:
yield instance yield inbox['inbox']
def get_config(self, key: str) -> Any: def get_config(self, key: str) -> Any:

View file

@ -7,7 +7,7 @@ import typing
from aiohttp import ClientSession, ClientTimeout, TCPConnector from aiohttp import ClientSession, ClientTimeout, TCPConnector
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
from asyncio.exceptions import TimeoutError as AsyncTimeoutError from asyncio.exceptions import TimeoutError as AsyncTimeoutError
from aputils import AlgorithmType, JsonBase, Nodeinfo, ObjectType, WellKnownNodeinfo from aputils import JsonBase, Nodeinfo, WellKnownNodeinfo
from json.decoder import JSONDecodeError from json.decoder import JSONDecodeError
from urllib.parse import urlparse from urllib.parse import urlparse
@ -111,7 +111,7 @@ class HttpClient:
headers = {} headers = {}
if sign_headers: if sign_headers:
headers = self.signer.sign_headers('GET', url, algorithm = AlgorithmType.HS2019) headers = self.signer.sign_headers('GET', url, algorithm = 'original')
try: try:
logging.debug('Fetching resource: %s', url) logging.debug('Fetching resource: %s', url)
@ -164,50 +164,31 @@ class HttpClient:
return cls.parse(data) return cls.parse(data)
async def post(self, url: str, data: Message | bytes, instance: Row | None = None) -> None: async def post(self, url: str, message: Message, instance: Row | None = None) -> None:
if not self._session: if not self._session:
raise RuntimeError('Client not open') raise RuntimeError('Client not open')
# akkoma and pleroma do not support HS2019 and other software still needs to be tested # Using the old algo by default is probably a better idea right now
if instance and instance['software'] in {'mastodon'}: if instance and instance['software'] in {'mastodon'}:
algorithm = AlgorithmType.HS2019 algorithm = 'hs2019'
else: else:
algorithm = AlgorithmType.RSASHA256 algorithm = 'original'
body: bytes headers = {'Content-Type': 'application/activity+json'}
message: Message headers.update(get_app().signer.sign_headers('POST', url, message, algorithm=algorithm))
if isinstance(data, bytes):
body = data
message = Message.parse(data)
else:
body = data.to_json().encode("utf-8")
message = data
mtype = message.type.value if isinstance(message.type, ObjectType) else message.type
headers = self.signer.sign_headers(
'POST',
url,
body,
headers = {'Content-Type': 'application/activity+json'},
algorithm = algorithm
)
try: try:
logging.verbose('Sending "%s" to %s', mtype, url) logging.verbose('Sending "%s" to %s', message.type.value, url)
async with self._session.post(url, headers = headers, data = body) as resp: async with self._session.post(url, headers = headers, data = message.to_json()) as resp:
# Not expecting a response, so just return # Not expecting a response, so just return
if resp.status in {200, 202}: if resp.status in {200, 202}:
logging.verbose('Successfully sent "%s" to %s', mtype, url) logging.verbose('Successfully sent "%s" to %s', message.type.value, url)
return return
logging.verbose('Received error when pushing to %s: %i', url, resp.status) logging.verbose('Received error when pushing to %s: %i', url, resp.status)
logging.debug(await resp.read()) logging.debug(await resp.read())
logging.debug("message: %s", body.decode("utf-8"))
logging.debug("headers: %s", json.dumps(headers, indent = 4))
return return
except ClientSSLError: except ClientSSLError:

View file

@ -87,7 +87,7 @@ handlers: list[Any] = [logging.StreamHandler()]
if env_log_file: if env_log_file:
handlers.append(logging.FileHandler(env_log_file)) handlers.append(logging.FileHandler(env_log_file))
if os.environ.get('IS_SYSTEMD'): if os.environ.get('INVOCATION_ID'):
logging_format = '%(levelname)s: %(message)s' logging_format = '%(levelname)s: %(message)s'
else: else:

View file

@ -35,8 +35,8 @@ async def handle_relay(view: ActorView, conn: Connection) -> None:
message = Message.new_announce(view.config.domain, view.message.object_id) message = Message.new_announce(view.config.domain, view.message.object_id)
logging.debug('>> relay: %s', message) logging.debug('>> relay: %s', message)
for instance in conn.distill_inboxes(view.message): for inbox in conn.distill_inboxes(view.message):
view.app.push_message(instance["inbox"], message, instance) view.app.push_message(inbox, message, view.instance)
view.cache.set('handle-relay', view.message.object_id, message.id, 'str') view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
@ -53,8 +53,8 @@ async def handle_forward(view: ActorView, conn: Connection) -> None:
message = Message.new_announce(view.config.domain, view.message) message = Message.new_announce(view.config.domain, view.message)
logging.debug('>> forward: %s', message) logging.debug('>> forward: %s', message)
for instance in conn.distill_inboxes(view.message): for inbox in conn.distill_inboxes(view.message):
view.app.push_message(instance["inbox"], await view.request.read(), instance) view.app.push_message(inbox, message, view.instance)
view.cache.set('handle-relay', view.message.id, message.id, 'str') view.cache.set('handle-relay', view.message.id, message.id, 'str')

View file

@ -3,7 +3,6 @@ from __future__ import annotations
import aputils import aputils
import traceback import traceback
import typing import typing
import json
from .base import View, register_route from .base import View, register_route
@ -72,7 +71,7 @@ class ActorView(View):
async def get_post_data(self) -> Response | None: async def get_post_data(self) -> Response | None:
try: try:
self.signature = aputils.Signature.parse(self.request.headers['signature']) self.signature = aputils.Signature.new_from_signature(self.request.headers['signature'])
except KeyError: except KeyError:
logging.verbose('Missing signature header') logging.verbose('Missing signature header')
@ -117,7 +116,7 @@ class ActorView(View):
return Response.new_error(400, 'actor missing public key', 'json') return Response.new_error(400, 'actor missing public key', 'json')
try: try:
await self.signer.validate_aiohttp_request(self.request) self.validate_signature(await self.request.read())
except aputils.SignatureFailureError as e: except aputils.SignatureFailureError as e:
logging.verbose('signature validation failed for "%s": %s', self.actor.id, e) logging.verbose('signature validation failed for "%s": %s', self.actor.id, e)
@ -126,6 +125,36 @@ class ActorView(View):
return None return None
def validate_signature(self, body: bytes) -> None:
headers = {key.lower(): value for key, value in self.request.headers.items()}
headers["(request-target)"] = " ".join([self.request.method.lower(), self.request.path])
if (digest := aputils.Digest.new_from_digest(headers.get("digest"))):
if not body:
raise aputils.SignatureFailureError("Missing body for digest verification")
if not digest.validate(body):
raise aputils.SignatureFailureError("Body digest does not match")
if self.signature.algorithm_type == aputils.AlgorithmType.HS2019:
if self.signature.created is None or self.signature.expires is None:
raise aputils.SignatureFailureError("Missing 'created' or 'expireds' parameter")
current_timestamp = aputils.HttpDate.new_utc().timestamp()
if self.signature.created > current_timestamp:
raise aputils.SignatureFailureError("Creation date after current date")
if self.signature.expires < current_timestamp:
raise aputils.SignatureFailureError("Signature has expired")
headers["(created)"] = str(self.signature.created)
headers["(expires)"] = str(self.signature.expires)
if not self.signer._validate_signature(headers, self.signature):
raise aputils.SignatureFailureError("Signature does not match")
@register_route('/.well-known/webfinger') @register_route('/.well-known/webfinger')
class WebfingerView(View): class WebfingerView(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:

View file

@ -1,4 +1,4 @@
activitypub-utils == 0.2.0 activitypub-utils == 0.1.9
aiohttp >= 3.9.1 aiohttp >= 3.9.1
aiohttp-swagger[performance] == 1.0.16 aiohttp-swagger[performance] == 1.0.16
argon2-cffi == 23.1.0 argon2-cffi == 23.1.0