diff --git a/relay/database.py b/relay/database.py index bc08414..c5993d4 100644 --- a/relay/database.py +++ b/relay/database.py @@ -138,7 +138,13 @@ class RelayDatabase(dict): def del_inbox(self, domain, followid=None, fail=False): - data = self.get_inbox(domain, fail=True) + data = self.get_inbox(domain, fail=False) + + if not data: + if fail: + raise KeyError(domain) + + return False if not data['followid'] or not followid or data['followid'] == followid: del self['relay-list'][data['domain']] diff --git a/relay/manage.py b/relay/manage.py index 838e2f9..f1c8fb4 100644 --- a/relay/manage.py +++ b/relay/manage.py @@ -4,6 +4,8 @@ import click import logging import platform +from urllib.parse import urlparse + from . import misc, __version__ from .application import Application from .config import relay_software_names @@ -54,20 +56,27 @@ def cli_inbox_follow(actor): return click.echo(f'Error: Refusing to follow banned actor: {actor}') if not actor.startswith('http'): + domain = actor actor = f'https://{actor}/actor' - if app.database.get_inbox(actor): - return click.echo(f'Error: Already following actor: {actor}') + else: + domain = urlparse(actor).hostname - actor_data = asyncio.run(misc.request(actor, sign_headers=True)) - if not actor_data: - return click.echo(f'Error: Failed to fetch actor: {actor}') + try: + inbox_data = app.database['relay-list'][domain] + inbox = inbox_data['inbox'] - app.database.add_inbox(actor_data.shared_inbox) - app.database.save() + except KeyError: + actor_data = asyncio.run(misc.request(actor)) + inbox = actor_data.shared_inbox - asyncio.run(misc.follow_remote_actor(actor)) + message = misc.Message.new_follow( + host = app.config.host, + actor = actor.id + ) + + asyncio.run(misc.request(inbox, message)) click.echo(f'Sent follow message to actor: {actor}') @@ -77,14 +86,36 @@ def cli_inbox_unfollow(actor): 'Unfollow an actor (Relay must be running)' if not actor.startswith('http'): + domain = actor actor = f'https://{actor}/actor' - if app.database.del_inbox(actor): - app.database.save() - asyncio.run(misc.unfollow_remote_actor(actor)) - return click.echo(f'Sent unfollow message to: {actor}') + else: + domain = urlparse(actor).hostname - return click.echo(f'Error: Not following actor: {actor}') + try: + inbox_data = app.database['relay-list'][domain] + inbox = inbox_data['inbox'] + message = misc.Message.new_unfollow( + host = app.config.host, + actor = actor, + follow = inbox_data['followid'] + ) + + except KeyError: + actor_data = asyncio.run(misc.request(actor)) + inbox = actor_data.shared_inbox + message = misc.Message.new_unfollow( + host = app.config.host, + actor = actor, + follow = { + 'type': 'Follow', + 'object': actor, + 'actor': f'https://{app.config.host}/actor' + } + ) + + asyncio.run(misc.request(inbox, message)) + click.echo(f'Sent unfollow message to: {actor}') @cli_inbox.command('add') diff --git a/relay/misc.py b/relay/misc.py index 68f36dc..5d2f849 100644 --- a/relay/misc.py +++ b/relay/misc.py @@ -59,10 +59,10 @@ def create_signature_header(headers): sigstring = build_signing_string(headers, used_headers) sig = { - 'keyId': app['config'].keyid, + 'keyId': app.config.keyid, 'algorithm': 'rsa-sha256', 'headers': ' '.join(used_headers), - 'signature': sign_signing_string(sigstring, app['database'].PRIVKEY) + 'signature': sign_signing_string(sigstring, app.database.PRIVKEY) } chunks = ['{}="{}"'.format(k, v) for k, v in sig.items()] @@ -70,22 +70,20 @@ def create_signature_header(headers): def distill_inboxes(actor, object_id): - database = app['database'] - - for inbox in database.inboxes: + for inbox in app.database.inboxes: if inbox != actor.shared_inbox and urlparse(inbox).hostname != urlparse(object_id).hostname: yield inbox def generate_body_digest(body): - bodyhash = app['cache'].digests.get(body) + bodyhash = app.cache.digests.get(body) if bodyhash: return bodyhash h = SHA256.new(body.encode('utf-8')) bodyhash = base64.b64encode(h.digest()).decode('utf-8') - app['cache'].digests[body] = bodyhash + app.cache.digests[body] = bodyhash return bodyhash @@ -155,61 +153,11 @@ async def fetch_nodeinfo(domain): return False -## todo: remove follow_remote_actor and unfollow_remote_actor -async def follow_remote_actor(actor_uri): - config = app['config'] - - actor = await request(actor_uri) - - if not actor: - logging.error(f'failed to fetch actor at: {actor_uri}') - return - - message = { - "@context": "https://www.w3.org/ns/activitystreams", - "type": "Follow", - "to": [actor['id']], - "object": actor['id'], - "id": f"https://{config.host}/activities/{uuid4()}", - "actor": f"https://{config.host}/actor" - } - - logging.verbose(f'sending follow request: {actor_uri}') - await request(actor.shared_inbox, message) - - -async def unfollow_remote_actor(actor_uri): - config = app['config'] - - actor = await request(actor_uri) - - if not actor: - logging.error(f'failed to fetch actor: {actor_uri}') - return - - message = { - "@context": "https://www.w3.org/ns/activitystreams", - "type": "Undo", - "to": [actor_uri], - "object": { - "type": "Follow", - "object": actor_uri, - "actor": actor_uri, - "id": f"https://{config.host}/activities/{uuid4()}" - }, - "id": f"https://{config.host}/activities/{uuid4()}", - "actor": f"https://{config.host}/actor" - } - - logging.verbose(f'sending unfollow request to inbox: {actor.shared_inbox}') - await request(actor.shared_inbox, message) - - async def request(uri, data=None, force=False, sign_headers=True, activity=True): ## If a get request and not force, try to use the cache first if not data and not force: try: - return app['cache'].json[uri] + return app.cache.json[uri] except KeyError: pass @@ -255,7 +203,7 @@ async def request(uri, data=None, force=False, sign_headers=True, activity=True) else: logging.verbose(f'Sending GET request to url: {uri}') - async with ClientSession(trace_configs=http_debug()) as session, app['semaphore']: + async with ClientSession(trace_configs=http_debug()) as session, app.semaphore: async with session.request(method, uri, headers=headers, data=data) as resp: ## aiohttp has been known to leak if the response hasn't been read, ## so we're just gonna read the request no matter what @@ -283,7 +231,7 @@ async def request(uri, data=None, force=False, sign_headers=True, activity=True) logging.debug(f'{uri} >> resp {resp_data}') - app['cache'].json[uri] = resp_data + app.cache.json[uri] = resp_data return resp_data except JSONDecodeError: