From 2bc8633d5466cf24ab0c744217b663625face9ad Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Wed, 19 Feb 2025 08:53:36 -0500 Subject: [PATCH] use async for cli commands that make http requests --- relay/cli/__init__.py | 5 ++ relay/cli/inbox.py | 125 +++++++++++++++++++------------------- relay/cli/request.py | 26 ++++---- relay/cli/software_ban.py | 16 +++-- relay/http_client.py | 17 +----- 5 files changed, 93 insertions(+), 96 deletions(-) diff --git a/relay/cli/__init__.py b/relay/cli/__init__.py index 1340b01..b3bcab7 100644 --- a/relay/cli/__init__.py +++ b/relay/cli/__init__.py @@ -1,6 +1,8 @@ from __future__ import annotations +import asyncio import click +import inspect import json import multiprocessing @@ -42,6 +44,9 @@ def cli(config: File | None) -> None: def pass_state(func: Callable[Concatenate[State, P], R]) -> Callable[P, R]: def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + if inspect.iscoroutinefunction(func): + return asyncio.run(func(State.default(), *args, **kwargs)) # type: ignore[no-any-return] + return func(State.default(), *args, **kwargs) return update_wrapper(wrapper, func) diff --git a/relay/cli/inbox.py b/relay/cli/inbox.py index 6398534..d5517f2 100644 --- a/relay/cli/inbox.py +++ b/relay/cli/inbox.py @@ -1,11 +1,9 @@ -import asyncio import click from urllib.parse import urlparse from . import cli, pass_state -from .. import http_client as http from ..database.schema import Instance from ..misc import ACTOR_FORMATS, Message from ..state import State @@ -31,82 +29,86 @@ def cli_inbox_list(state: State) -> None: @cli_inbox.command("follow") @click.argument("actor") @pass_state -def cli_inbox_follow(state: State, actor: str) -> None: +async def cli_inbox_follow(state: State, actor: str) -> None: "Follow an actor (Relay must be running)" instance: Instance | None = None - with state.database.session() as conn: - if conn.get_domain_ban(actor): - click.echo(f"Error: Refusing to follow banned actor: {actor}") - return - - if (instance := conn.get_inbox(actor)) is not None: - inbox = instance.inbox - - else: - if not actor.startswith("http"): - actor = f"https://{actor}/actor" - - if (actor_data := asyncio.run(http.get(state, actor, sign_headers = True))) is None: - click.echo(f"Failed to fetch actor: {actor}") + async with state.client: + with state.database.session() as conn: + if conn.get_domain_ban(actor): + click.echo(f"Error: Refusing to follow banned actor: {actor}") return - inbox = actor_data.shared_inbox + if (instance := conn.get_inbox(actor)) is not None: + inbox = instance.inbox - message = Message.new_follow( - host = state.config.domain, - actor = actor - ) + else: + if not actor.startswith("http"): + actor = f"https://{actor}/actor" - asyncio.run(http.post(state, inbox, message, instance)) - click.echo(f"Sent follow message to actor: {actor}") + actor_data = await state.client.get(actor, cls = Message, sign_headers = True) + + if not actor_data: + click.echo(f"Failed to fetch actor: {actor}") + return + + inbox = actor_data.shared_inbox + + message = Message.new_follow( + host = state.config.domain, + actor = actor + ) + + await state.client.post(inbox, message, instance) + click.echo(f"Sent follow message to actor: {actor}") @cli_inbox.command("unfollow") @click.argument("actor") @pass_state -def cli_inbox_unfollow(state: State, actor: str) -> None: +async def cli_inbox_unfollow(state: State, actor: str) -> None: "Unfollow an actor (Relay must be running)" instance: Instance | None = None - with state.database.session() as conn: - if conn.get_domain_ban(actor): - click.echo(f"Error: Refusing to follow banned actor: {actor}") - return - - if (instance := conn.get_inbox(actor)): - inbox = instance.inbox - message = Message.new_unfollow( - host = state.config.domain, - actor = actor, - follow = instance.followid - ) - - else: - if not actor.startswith("http"): - actor = f"https://{actor}/actor" - - actor_data = asyncio.run(http.get(state, actor, sign_headers = True)) - - if not actor_data: - click.echo("Failed to fetch actor") + async with state.client: + with state.database.session() as conn: + if conn.get_domain_ban(actor): + click.echo(f"Error: Refusing to follow banned actor: {actor}") return - inbox = actor_data.shared_inbox - message = Message.new_unfollow( - host = state.config.domain, - actor = actor, - follow = { - "type": "Follow", - "object": actor, - "actor": f"https://{state.config.domain}/actor" - } - ) + if (instance := conn.get_inbox(actor)): + inbox = instance.inbox + message = Message.new_unfollow( + host = state.config.domain, + actor = actor, + follow = instance.followid + ) - asyncio.run(http.post(state, inbox, message, instance)) - click.echo(f"Sent unfollow message to: {actor}") + else: + if not actor.startswith("http"): + actor = f"https://{actor}/actor" + + actor_data = await state.client.get(actor, cls = Message, sign_headers = True) + + if not actor_data: + click.echo("Failed to fetch actor") + return + + inbox = actor_data.shared_inbox + message = Message.new_unfollow( + host = state.config.domain, + actor = actor, + follow = { + "type": "Follow", + "object": actor, + "actor": f"https://{state.config.domain}/actor" + } + ) + + await state.client.post(inbox, message, instance) + click.echo(f"Sent unfollow message to: {actor}") @cli_inbox.command("add") @@ -115,7 +117,7 @@ def cli_inbox_unfollow(state: State, actor: str) -> None: @click.option("--followid", "-f", help = "Url for the follow activity") @click.option("--software", "-s", help = "Nodeinfo software name of the instance") @pass_state -def cli_inbox_add( +async def cli_inbox_add( state: State, inbox: str, actor: str | None = None, @@ -131,8 +133,9 @@ def cli_inbox_add( domain = urlparse(inbox).netloc if not software: - if (nodeinfo := asyncio.run(http.fetch_nodeinfo(state, domain))): - software = nodeinfo.sw_name + async with state.client: + if (nodeinfo := await state.client.fetch_nodeinfo(domain)): + software = nodeinfo.sw_name if not actor and software: try: diff --git a/relay/cli/request.py b/relay/cli/request.py index ed1e6d5..54ab879 100644 --- a/relay/cli/request.py +++ b/relay/cli/request.py @@ -1,9 +1,7 @@ -import asyncio import click from . import cli, pass_state -from .. import http_client as http from ..misc import Message from ..state import State @@ -29,7 +27,7 @@ def cli_request_list(state: State) -> None: @cli_request.command("accept") @click.argument("domain") @pass_state -def cli_request_accept(state: State, domain: str) -> None: +async def cli_request_accept(state: State, domain: str) -> None: "Accept a follow request" try: @@ -40,28 +38,29 @@ def cli_request_accept(state: State, domain: str) -> None: click.echo("Request not found") return - message = Message.new_response( + response = Message.new_response( host = state.config.domain, actor = instance.actor, followid = instance.followid, accept = True ) - asyncio.run(http.post(state, instance.inbox, message, instance)) + async with state.client: + await state.client.post(instance.inbox, response, instance) - if instance.software != "mastodon": - message = Message.new_follow( - host = state.config.domain, - actor = instance.actor - ) + if instance.software != "mastodon": + follow = Message.new_follow( + host = state.config.domain, + actor = instance.actor + ) - asyncio.run(http.post(state, instance.inbox, message, instance)) + await state.client.post(instance.inbox, follow, instance) @cli_request.command("deny") @click.argument("domain") @pass_state -def cli_request_deny(state: State, domain: str) -> None: +async def cli_request_deny(state: State, domain: str) -> None: "Accept a follow request" try: @@ -79,4 +78,5 @@ def cli_request_deny(state: State, domain: str) -> None: accept = False ) - asyncio.run(http.post(state, instance.inbox, response, instance)) + async with state.client: + await state.client.post(instance.inbox, response, instance) diff --git a/relay/cli/software_ban.py b/relay/cli/software_ban.py index 62d1a0f..8fe3f86 100644 --- a/relay/cli/software_ban.py +++ b/relay/cli/software_ban.py @@ -1,9 +1,7 @@ -import asyncio import click from . import cli, pass_state -from .. import http_client as http from ..misc import RELAY_SOFTWARE from ..state import State @@ -39,7 +37,7 @@ def cli_software_list(state: State) -> None: help = "Treat NAME like a domain and try to fetch the software name from nodeinfo" ) @pass_state -def cli_software_ban(state: State, +async def cli_software_ban(state: State, name: str, reason: str, note: str, @@ -59,7 +57,10 @@ def cli_software_ban(state: State, return if fetch_nodeinfo: - if not (nodeinfo := asyncio.run(http.fetch_nodeinfo(state, name))): + async with state.client: + nodeinfo = await state.client.fetch_nodeinfo(name) + + if not nodeinfo: click.echo(f"Failed to fetch software name from domain: {name}") return @@ -86,7 +87,7 @@ def cli_software_ban(state: State, help = "Treat NAME like a domain and try to fetch the software name from nodeinfo" ) @pass_state -def cli_software_unban(state: State, name: str, fetch_nodeinfo: bool) -> None: +async def cli_software_unban(state: State, name: str, fetch_nodeinfo: bool) -> None: "Ban software. Use RELAYS for NAME to unban relays" with state.database.session() as conn: @@ -99,7 +100,10 @@ def cli_software_unban(state: State, name: str, fetch_nodeinfo: bool) -> None: return if fetch_nodeinfo: - if not (nodeinfo := asyncio.run(http.fetch_nodeinfo(state, name))): + async with state.client: + nodeinfo = await state.client.fetch_nodeinfo(name) + + if not nodeinfo: click.echo(f"Failed to fetch software name from domain: {name}") return diff --git a/relay/http_client.py b/relay/http_client.py index b658cfb..37e0bd8 100644 --- a/relay/http_client.py +++ b/relay/http_client.py @@ -167,7 +167,7 @@ class HttpClient: if cls is not None: if data is None: - # this shouldn"t actually get raised, but keeping just in case + # this shouldn't actually get raised, but keeping just in case raise EmptyBodyError(f"GET {url}") return cls.parse(data) @@ -237,18 +237,3 @@ class HttpClient: raise ValueError(f"Failed to fetch nodeinfo url for {domain}") return await self.get(nodeinfo_url, False, Nodeinfo, force) - - -async def get(state: State, *args: Any, **kwargs: Any) -> Any: - async with HttpClient(state) as client: - return await client.get(*args, **kwargs) - - -async def post(state: State, *args: Any, **kwargs: Any) -> None: - async with HttpClient(state) as client: - return await client.post(*args, **kwargs) - - -async def fetch_nodeinfo(state: State, *args: Any, **kwargs: Any) -> Nodeinfo | None: - async with HttpClient(state) as client: - return await client.fetch_nodeinfo(*args, **kwargs)