use async for cli commands that make http requests

This commit is contained in:
Izalia Mae 2025-02-19 08:53:36 -05:00
parent c445b54a91
commit 2bc8633d54
5 changed files with 93 additions and 96 deletions

View file

@ -1,6 +1,8 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import click import click
import inspect
import json import json
import multiprocessing import multiprocessing
@ -42,6 +44,9 @@ def cli(config: File | None) -> None:
def pass_state(func: Callable[Concatenate[State, P], R]) -> Callable[P, R]: def pass_state(func: Callable[Concatenate[State, P], R]) -> Callable[P, R]:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
if inspect.iscoroutinefunction(func):
return asyncio.run(func(State.default(), *args, **kwargs)) # type: ignore[no-any-return]
return func(State.default(), *args, **kwargs) return func(State.default(), *args, **kwargs)
return update_wrapper(wrapper, func) return update_wrapper(wrapper, func)

View file

@ -1,11 +1,9 @@
import asyncio
import click import click
from urllib.parse import urlparse from urllib.parse import urlparse
from . import cli, pass_state from . import cli, pass_state
from .. import http_client as http
from ..database.schema import Instance from ..database.schema import Instance
from ..misc import ACTOR_FORMATS, Message from ..misc import ACTOR_FORMATS, Message
from ..state import State from ..state import State
@ -31,11 +29,12 @@ def cli_inbox_list(state: State) -> None:
@cli_inbox.command("follow") @cli_inbox.command("follow")
@click.argument("actor") @click.argument("actor")
@pass_state @pass_state
def cli_inbox_follow(state: State, actor: str) -> None: async def cli_inbox_follow(state: State, actor: str) -> None:
"Follow an actor (Relay must be running)" "Follow an actor (Relay must be running)"
instance: Instance | None = None instance: Instance | None = None
async with state.client:
with state.database.session() as conn: with state.database.session() as conn:
if conn.get_domain_ban(actor): if conn.get_domain_ban(actor):
click.echo(f"Error: Refusing to follow banned actor: {actor}") click.echo(f"Error: Refusing to follow banned actor: {actor}")
@ -48,7 +47,9 @@ def cli_inbox_follow(state: State, actor: str) -> None:
if not actor.startswith("http"): if not actor.startswith("http"):
actor = f"https://{actor}/actor" actor = f"https://{actor}/actor"
if (actor_data := asyncio.run(http.get(state, actor, sign_headers = True))) is None: actor_data = await state.client.get(actor, cls = Message, sign_headers = True)
if not actor_data:
click.echo(f"Failed to fetch actor: {actor}") click.echo(f"Failed to fetch actor: {actor}")
return return
@ -59,18 +60,19 @@ def cli_inbox_follow(state: State, actor: str) -> None:
actor = actor actor = actor
) )
asyncio.run(http.post(state, inbox, message, instance)) await state.client.post(inbox, message, instance)
click.echo(f"Sent follow message to actor: {actor}") click.echo(f"Sent follow message to actor: {actor}")
@cli_inbox.command("unfollow") @cli_inbox.command("unfollow")
@click.argument("actor") @click.argument("actor")
@pass_state @pass_state
def cli_inbox_unfollow(state: State, actor: str) -> None: async def cli_inbox_unfollow(state: State, actor: str) -> None:
"Unfollow an actor (Relay must be running)" "Unfollow an actor (Relay must be running)"
instance: Instance | None = None instance: Instance | None = None
async with state.client:
with state.database.session() as conn: with state.database.session() as conn:
if conn.get_domain_ban(actor): if conn.get_domain_ban(actor):
click.echo(f"Error: Refusing to follow banned actor: {actor}") click.echo(f"Error: Refusing to follow banned actor: {actor}")
@ -88,7 +90,7 @@ def cli_inbox_unfollow(state: State, actor: str) -> None:
if not actor.startswith("http"): if not actor.startswith("http"):
actor = f"https://{actor}/actor" actor = f"https://{actor}/actor"
actor_data = asyncio.run(http.get(state, actor, sign_headers = True)) actor_data = await state.client.get(actor, cls = Message, sign_headers = True)
if not actor_data: if not actor_data:
click.echo("Failed to fetch actor") click.echo("Failed to fetch actor")
@ -105,7 +107,7 @@ def cli_inbox_unfollow(state: State, actor: str) -> None:
} }
) )
asyncio.run(http.post(state, inbox, message, instance)) await state.client.post(inbox, message, instance)
click.echo(f"Sent unfollow message to: {actor}") click.echo(f"Sent unfollow message to: {actor}")
@ -115,7 +117,7 @@ def cli_inbox_unfollow(state: State, actor: str) -> None:
@click.option("--followid", "-f", help = "Url for the follow activity") @click.option("--followid", "-f", help = "Url for the follow activity")
@click.option("--software", "-s", help = "Nodeinfo software name of the instance") @click.option("--software", "-s", help = "Nodeinfo software name of the instance")
@pass_state @pass_state
def cli_inbox_add( async def cli_inbox_add(
state: State, state: State,
inbox: str, inbox: str,
actor: str | None = None, actor: str | None = None,
@ -131,7 +133,8 @@ def cli_inbox_add(
domain = urlparse(inbox).netloc domain = urlparse(inbox).netloc
if not software: if not software:
if (nodeinfo := asyncio.run(http.fetch_nodeinfo(state, domain))): async with state.client:
if (nodeinfo := await state.client.fetch_nodeinfo(domain)):
software = nodeinfo.sw_name software = nodeinfo.sw_name
if not actor and software: if not actor and software:

View file

@ -1,9 +1,7 @@
import asyncio
import click import click
from . import cli, pass_state from . import cli, pass_state
from .. import http_client as http
from ..misc import Message from ..misc import Message
from ..state import State from ..state import State
@ -29,7 +27,7 @@ def cli_request_list(state: State) -> None:
@cli_request.command("accept") @cli_request.command("accept")
@click.argument("domain") @click.argument("domain")
@pass_state @pass_state
def cli_request_accept(state: State, domain: str) -> None: async def cli_request_accept(state: State, domain: str) -> None:
"Accept a follow request" "Accept a follow request"
try: try:
@ -40,28 +38,29 @@ def cli_request_accept(state: State, domain: str) -> None:
click.echo("Request not found") click.echo("Request not found")
return return
message = Message.new_response( response = Message.new_response(
host = state.config.domain, host = state.config.domain,
actor = instance.actor, actor = instance.actor,
followid = instance.followid, followid = instance.followid,
accept = True accept = True
) )
asyncio.run(http.post(state, instance.inbox, message, instance)) async with state.client:
await state.client.post(instance.inbox, response, instance)
if instance.software != "mastodon": if instance.software != "mastodon":
message = Message.new_follow( follow = Message.new_follow(
host = state.config.domain, host = state.config.domain,
actor = instance.actor actor = instance.actor
) )
asyncio.run(http.post(state, instance.inbox, message, instance)) await state.client.post(instance.inbox, follow, instance)
@cli_request.command("deny") @cli_request.command("deny")
@click.argument("domain") @click.argument("domain")
@pass_state @pass_state
def cli_request_deny(state: State, domain: str) -> None: async def cli_request_deny(state: State, domain: str) -> None:
"Accept a follow request" "Accept a follow request"
try: try:
@ -79,4 +78,5 @@ def cli_request_deny(state: State, domain: str) -> None:
accept = False accept = False
) )
asyncio.run(http.post(state, instance.inbox, response, instance)) async with state.client:
await state.client.post(instance.inbox, response, instance)

View file

@ -1,9 +1,7 @@
import asyncio
import click import click
from . import cli, pass_state from . import cli, pass_state
from .. import http_client as http
from ..misc import RELAY_SOFTWARE from ..misc import RELAY_SOFTWARE
from ..state import State from ..state import State
@ -39,7 +37,7 @@ def cli_software_list(state: State) -> None:
help = "Treat NAME like a domain and try to fetch the software name from nodeinfo" help = "Treat NAME like a domain and try to fetch the software name from nodeinfo"
) )
@pass_state @pass_state
def cli_software_ban(state: State, async def cli_software_ban(state: State,
name: str, name: str,
reason: str, reason: str,
note: str, note: str,
@ -59,7 +57,10 @@ def cli_software_ban(state: State,
return return
if fetch_nodeinfo: if fetch_nodeinfo:
if not (nodeinfo := asyncio.run(http.fetch_nodeinfo(state, name))): async with state.client:
nodeinfo = await state.client.fetch_nodeinfo(name)
if not nodeinfo:
click.echo(f"Failed to fetch software name from domain: {name}") click.echo(f"Failed to fetch software name from domain: {name}")
return return
@ -86,7 +87,7 @@ def cli_software_ban(state: State,
help = "Treat NAME like a domain and try to fetch the software name from nodeinfo" help = "Treat NAME like a domain and try to fetch the software name from nodeinfo"
) )
@pass_state @pass_state
def cli_software_unban(state: State, name: str, fetch_nodeinfo: bool) -> None: async def cli_software_unban(state: State, name: str, fetch_nodeinfo: bool) -> None:
"Ban software. Use RELAYS for NAME to unban relays" "Ban software. Use RELAYS for NAME to unban relays"
with state.database.session() as conn: with state.database.session() as conn:
@ -99,7 +100,10 @@ def cli_software_unban(state: State, name: str, fetch_nodeinfo: bool) -> None:
return return
if fetch_nodeinfo: if fetch_nodeinfo:
if not (nodeinfo := asyncio.run(http.fetch_nodeinfo(state, name))): async with state.client:
nodeinfo = await state.client.fetch_nodeinfo(name)
if not nodeinfo:
click.echo(f"Failed to fetch software name from domain: {name}") click.echo(f"Failed to fetch software name from domain: {name}")
return return

View file

@ -167,7 +167,7 @@ class HttpClient:
if cls is not None: if cls is not None:
if data is None: if data is None:
# this shouldn"t actually get raised, but keeping just in case # this shouldn't actually get raised, but keeping just in case
raise EmptyBodyError(f"GET {url}") raise EmptyBodyError(f"GET {url}")
return cls.parse(data) return cls.parse(data)
@ -237,18 +237,3 @@ class HttpClient:
raise ValueError(f"Failed to fetch nodeinfo url for {domain}") raise ValueError(f"Failed to fetch nodeinfo url for {domain}")
return await self.get(nodeinfo_url, False, Nodeinfo, force) return await self.get(nodeinfo_url, False, Nodeinfo, force)
async def get(state: State, *args: Any, **kwargs: Any) -> Any:
async with HttpClient(state) as client:
return await client.get(*args, **kwargs)
async def post(state: State, *args: Any, **kwargs: Any) -> None:
async with HttpClient(state) as client:
return await client.post(*args, **kwargs)
async def fetch_nodeinfo(state: State, *args: Any, **kwargs: Any) -> Nodeinfo | None:
async with HttpClient(state) as client:
return await client.fetch_nodeinfo(*args, **kwargs)