mirror of
https://git.pleroma.social/pleroma/relay.git
synced 2025-04-19 17:16:42 +00:00
use async for cli commands that make http requests
This commit is contained in:
parent
c445b54a91
commit
2bc8633d54
5 changed files with 93 additions and 96 deletions
|
@ -1,6 +1,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import click
|
||||
import inspect
|
||||
import json
|
||||
import multiprocessing
|
||||
|
||||
|
@ -42,6 +44,9 @@ def cli(config: File | None) -> None:
|
|||
|
||||
def pass_state(func: Callable[Concatenate[State, P], R]) -> Callable[P, R]:
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return asyncio.run(func(State.default(), *args, **kwargs)) # type: ignore[no-any-return]
|
||||
|
||||
return func(State.default(), *args, **kwargs)
|
||||
|
||||
return update_wrapper(wrapper, func)
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
import asyncio
|
||||
import click
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from . import cli, pass_state
|
||||
|
||||
from .. import http_client as http
|
||||
from ..database.schema import Instance
|
||||
from ..misc import ACTOR_FORMATS, Message
|
||||
from ..state import State
|
||||
|
@ -31,82 +29,86 @@ def cli_inbox_list(state: State) -> None:
|
|||
@cli_inbox.command("follow")
|
||||
@click.argument("actor")
|
||||
@pass_state
|
||||
def cli_inbox_follow(state: State, actor: str) -> None:
|
||||
async def cli_inbox_follow(state: State, actor: str) -> None:
|
||||
"Follow an actor (Relay must be running)"
|
||||
|
||||
instance: Instance | None = None
|
||||
|
||||
with state.database.session() as conn:
|
||||
if conn.get_domain_ban(actor):
|
||||
click.echo(f"Error: Refusing to follow banned actor: {actor}")
|
||||
return
|
||||
|
||||
if (instance := conn.get_inbox(actor)) is not None:
|
||||
inbox = instance.inbox
|
||||
|
||||
else:
|
||||
if not actor.startswith("http"):
|
||||
actor = f"https://{actor}/actor"
|
||||
|
||||
if (actor_data := asyncio.run(http.get(state, actor, sign_headers = True))) is None:
|
||||
click.echo(f"Failed to fetch actor: {actor}")
|
||||
async with state.client:
|
||||
with state.database.session() as conn:
|
||||
if conn.get_domain_ban(actor):
|
||||
click.echo(f"Error: Refusing to follow banned actor: {actor}")
|
||||
return
|
||||
|
||||
inbox = actor_data.shared_inbox
|
||||
if (instance := conn.get_inbox(actor)) is not None:
|
||||
inbox = instance.inbox
|
||||
|
||||
message = Message.new_follow(
|
||||
host = state.config.domain,
|
||||
actor = actor
|
||||
)
|
||||
else:
|
||||
if not actor.startswith("http"):
|
||||
actor = f"https://{actor}/actor"
|
||||
|
||||
asyncio.run(http.post(state, inbox, message, instance))
|
||||
click.echo(f"Sent follow message to actor: {actor}")
|
||||
actor_data = await state.client.get(actor, cls = Message, sign_headers = True)
|
||||
|
||||
if not actor_data:
|
||||
click.echo(f"Failed to fetch actor: {actor}")
|
||||
return
|
||||
|
||||
inbox = actor_data.shared_inbox
|
||||
|
||||
message = Message.new_follow(
|
||||
host = state.config.domain,
|
||||
actor = actor
|
||||
)
|
||||
|
||||
await state.client.post(inbox, message, instance)
|
||||
click.echo(f"Sent follow message to actor: {actor}")
|
||||
|
||||
|
||||
@cli_inbox.command("unfollow")
|
||||
@click.argument("actor")
|
||||
@pass_state
|
||||
def cli_inbox_unfollow(state: State, actor: str) -> None:
|
||||
async def cli_inbox_unfollow(state: State, actor: str) -> None:
|
||||
"Unfollow an actor (Relay must be running)"
|
||||
|
||||
instance: Instance | None = None
|
||||
|
||||
with state.database.session() as conn:
|
||||
if conn.get_domain_ban(actor):
|
||||
click.echo(f"Error: Refusing to follow banned actor: {actor}")
|
||||
return
|
||||
|
||||
if (instance := conn.get_inbox(actor)):
|
||||
inbox = instance.inbox
|
||||
message = Message.new_unfollow(
|
||||
host = state.config.domain,
|
||||
actor = actor,
|
||||
follow = instance.followid
|
||||
)
|
||||
|
||||
else:
|
||||
if not actor.startswith("http"):
|
||||
actor = f"https://{actor}/actor"
|
||||
|
||||
actor_data = asyncio.run(http.get(state, actor, sign_headers = True))
|
||||
|
||||
if not actor_data:
|
||||
click.echo("Failed to fetch actor")
|
||||
async with state.client:
|
||||
with state.database.session() as conn:
|
||||
if conn.get_domain_ban(actor):
|
||||
click.echo(f"Error: Refusing to follow banned actor: {actor}")
|
||||
return
|
||||
|
||||
inbox = actor_data.shared_inbox
|
||||
message = Message.new_unfollow(
|
||||
host = state.config.domain,
|
||||
actor = actor,
|
||||
follow = {
|
||||
"type": "Follow",
|
||||
"object": actor,
|
||||
"actor": f"https://{state.config.domain}/actor"
|
||||
}
|
||||
)
|
||||
if (instance := conn.get_inbox(actor)):
|
||||
inbox = instance.inbox
|
||||
message = Message.new_unfollow(
|
||||
host = state.config.domain,
|
||||
actor = actor,
|
||||
follow = instance.followid
|
||||
)
|
||||
|
||||
asyncio.run(http.post(state, inbox, message, instance))
|
||||
click.echo(f"Sent unfollow message to: {actor}")
|
||||
else:
|
||||
if not actor.startswith("http"):
|
||||
actor = f"https://{actor}/actor"
|
||||
|
||||
actor_data = await state.client.get(actor, cls = Message, sign_headers = True)
|
||||
|
||||
if not actor_data:
|
||||
click.echo("Failed to fetch actor")
|
||||
return
|
||||
|
||||
inbox = actor_data.shared_inbox
|
||||
message = Message.new_unfollow(
|
||||
host = state.config.domain,
|
||||
actor = actor,
|
||||
follow = {
|
||||
"type": "Follow",
|
||||
"object": actor,
|
||||
"actor": f"https://{state.config.domain}/actor"
|
||||
}
|
||||
)
|
||||
|
||||
await state.client.post(inbox, message, instance)
|
||||
click.echo(f"Sent unfollow message to: {actor}")
|
||||
|
||||
|
||||
@cli_inbox.command("add")
|
||||
|
@ -115,7 +117,7 @@ def cli_inbox_unfollow(state: State, actor: str) -> None:
|
|||
@click.option("--followid", "-f", help = "Url for the follow activity")
|
||||
@click.option("--software", "-s", help = "Nodeinfo software name of the instance")
|
||||
@pass_state
|
||||
def cli_inbox_add(
|
||||
async def cli_inbox_add(
|
||||
state: State,
|
||||
inbox: str,
|
||||
actor: str | None = None,
|
||||
|
@ -131,8 +133,9 @@ def cli_inbox_add(
|
|||
domain = urlparse(inbox).netloc
|
||||
|
||||
if not software:
|
||||
if (nodeinfo := asyncio.run(http.fetch_nodeinfo(state, domain))):
|
||||
software = nodeinfo.sw_name
|
||||
async with state.client:
|
||||
if (nodeinfo := await state.client.fetch_nodeinfo(domain)):
|
||||
software = nodeinfo.sw_name
|
||||
|
||||
if not actor and software:
|
||||
try:
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
import asyncio
|
||||
import click
|
||||
|
||||
from . import cli, pass_state
|
||||
|
||||
from .. import http_client as http
|
||||
from ..misc import Message
|
||||
from ..state import State
|
||||
|
||||
|
@ -29,7 +27,7 @@ def cli_request_list(state: State) -> None:
|
|||
@cli_request.command("accept")
|
||||
@click.argument("domain")
|
||||
@pass_state
|
||||
def cli_request_accept(state: State, domain: str) -> None:
|
||||
async def cli_request_accept(state: State, domain: str) -> None:
|
||||
"Accept a follow request"
|
||||
|
||||
try:
|
||||
|
@ -40,28 +38,29 @@ def cli_request_accept(state: State, domain: str) -> None:
|
|||
click.echo("Request not found")
|
||||
return
|
||||
|
||||
message = Message.new_response(
|
||||
response = Message.new_response(
|
||||
host = state.config.domain,
|
||||
actor = instance.actor,
|
||||
followid = instance.followid,
|
||||
accept = True
|
||||
)
|
||||
|
||||
asyncio.run(http.post(state, instance.inbox, message, instance))
|
||||
async with state.client:
|
||||
await state.client.post(instance.inbox, response, instance)
|
||||
|
||||
if instance.software != "mastodon":
|
||||
message = Message.new_follow(
|
||||
host = state.config.domain,
|
||||
actor = instance.actor
|
||||
)
|
||||
if instance.software != "mastodon":
|
||||
follow = Message.new_follow(
|
||||
host = state.config.domain,
|
||||
actor = instance.actor
|
||||
)
|
||||
|
||||
asyncio.run(http.post(state, instance.inbox, message, instance))
|
||||
await state.client.post(instance.inbox, follow, instance)
|
||||
|
||||
|
||||
@cli_request.command("deny")
|
||||
@click.argument("domain")
|
||||
@pass_state
|
||||
def cli_request_deny(state: State, domain: str) -> None:
|
||||
async def cli_request_deny(state: State, domain: str) -> None:
|
||||
"Accept a follow request"
|
||||
|
||||
try:
|
||||
|
@ -79,4 +78,5 @@ def cli_request_deny(state: State, domain: str) -> None:
|
|||
accept = False
|
||||
)
|
||||
|
||||
asyncio.run(http.post(state, instance.inbox, response, instance))
|
||||
async with state.client:
|
||||
await state.client.post(instance.inbox, response, instance)
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
import asyncio
|
||||
import click
|
||||
|
||||
from . import cli, pass_state
|
||||
|
||||
from .. import http_client as http
|
||||
from ..misc import RELAY_SOFTWARE
|
||||
from ..state import State
|
||||
|
||||
|
@ -39,7 +37,7 @@ def cli_software_list(state: State) -> None:
|
|||
help = "Treat NAME like a domain and try to fetch the software name from nodeinfo"
|
||||
)
|
||||
@pass_state
|
||||
def cli_software_ban(state: State,
|
||||
async def cli_software_ban(state: State,
|
||||
name: str,
|
||||
reason: str,
|
||||
note: str,
|
||||
|
@ -59,7 +57,10 @@ def cli_software_ban(state: State,
|
|||
return
|
||||
|
||||
if fetch_nodeinfo:
|
||||
if not (nodeinfo := asyncio.run(http.fetch_nodeinfo(state, name))):
|
||||
async with state.client:
|
||||
nodeinfo = await state.client.fetch_nodeinfo(name)
|
||||
|
||||
if not nodeinfo:
|
||||
click.echo(f"Failed to fetch software name from domain: {name}")
|
||||
return
|
||||
|
||||
|
@ -86,7 +87,7 @@ def cli_software_ban(state: State,
|
|||
help = "Treat NAME like a domain and try to fetch the software name from nodeinfo"
|
||||
)
|
||||
@pass_state
|
||||
def cli_software_unban(state: State, name: str, fetch_nodeinfo: bool) -> None:
|
||||
async def cli_software_unban(state: State, name: str, fetch_nodeinfo: bool) -> None:
|
||||
"Ban software. Use RELAYS for NAME to unban relays"
|
||||
|
||||
with state.database.session() as conn:
|
||||
|
@ -99,7 +100,10 @@ def cli_software_unban(state: State, name: str, fetch_nodeinfo: bool) -> None:
|
|||
return
|
||||
|
||||
if fetch_nodeinfo:
|
||||
if not (nodeinfo := asyncio.run(http.fetch_nodeinfo(state, name))):
|
||||
async with state.client:
|
||||
nodeinfo = await state.client.fetch_nodeinfo(name)
|
||||
|
||||
if not nodeinfo:
|
||||
click.echo(f"Failed to fetch software name from domain: {name}")
|
||||
return
|
||||
|
||||
|
|
|
@ -167,7 +167,7 @@ class HttpClient:
|
|||
|
||||
if cls is not None:
|
||||
if data is None:
|
||||
# this shouldn"t actually get raised, but keeping just in case
|
||||
# this shouldn't actually get raised, but keeping just in case
|
||||
raise EmptyBodyError(f"GET {url}")
|
||||
|
||||
return cls.parse(data)
|
||||
|
@ -237,18 +237,3 @@ class HttpClient:
|
|||
raise ValueError(f"Failed to fetch nodeinfo url for {domain}")
|
||||
|
||||
return await self.get(nodeinfo_url, False, Nodeinfo, force)
|
||||
|
||||
|
||||
async def get(state: State, *args: Any, **kwargs: Any) -> Any:
|
||||
async with HttpClient(state) as client:
|
||||
return await client.get(*args, **kwargs)
|
||||
|
||||
|
||||
async def post(state: State, *args: Any, **kwargs: Any) -> None:
|
||||
async with HttpClient(state) as client:
|
||||
return await client.post(*args, **kwargs)
|
||||
|
||||
|
||||
async def fetch_nodeinfo(state: State, *args: Any, **kwargs: Any) -> Nodeinfo | None:
|
||||
async with HttpClient(state) as client:
|
||||
return await client.fetch_nodeinfo(*args, **kwargs)
|
||||
|
|
Loading…
Reference in a new issue