mirror of
https://git.pleroma.social/pleroma/relay.git
synced 2025-04-20 01:26:43 +00:00
use async for cli commands that make http requests
This commit is contained in:
parent
c445b54a91
commit
2bc8633d54
5 changed files with 93 additions and 96 deletions
|
@ -1,6 +1,8 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import click
|
import click
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
|
||||||
|
@ -42,6 +44,9 @@ def cli(config: File | None) -> None:
|
||||||
|
|
||||||
def pass_state(func: Callable[Concatenate[State, P], R]) -> Callable[P, R]:
|
def pass_state(func: Callable[Concatenate[State, P], R]) -> Callable[P, R]:
|
||||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
|
if inspect.iscoroutinefunction(func):
|
||||||
|
return asyncio.run(func(State.default(), *args, **kwargs)) # type: ignore[no-any-return]
|
||||||
|
|
||||||
return func(State.default(), *args, **kwargs)
|
return func(State.default(), *args, **kwargs)
|
||||||
|
|
||||||
return update_wrapper(wrapper, func)
|
return update_wrapper(wrapper, func)
|
||||||
|
|
|
@ -1,11 +1,9 @@
|
||||||
import asyncio
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from . import cli, pass_state
|
from . import cli, pass_state
|
||||||
|
|
||||||
from .. import http_client as http
|
|
||||||
from ..database.schema import Instance
|
from ..database.schema import Instance
|
||||||
from ..misc import ACTOR_FORMATS, Message
|
from ..misc import ACTOR_FORMATS, Message
|
||||||
from ..state import State
|
from ..state import State
|
||||||
|
@ -31,82 +29,86 @@ def cli_inbox_list(state: State) -> None:
|
||||||
@cli_inbox.command("follow")
|
@cli_inbox.command("follow")
|
||||||
@click.argument("actor")
|
@click.argument("actor")
|
||||||
@pass_state
|
@pass_state
|
||||||
def cli_inbox_follow(state: State, actor: str) -> None:
|
async def cli_inbox_follow(state: State, actor: str) -> None:
|
||||||
"Follow an actor (Relay must be running)"
|
"Follow an actor (Relay must be running)"
|
||||||
|
|
||||||
instance: Instance | None = None
|
instance: Instance | None = None
|
||||||
|
|
||||||
with state.database.session() as conn:
|
async with state.client:
|
||||||
if conn.get_domain_ban(actor):
|
with state.database.session() as conn:
|
||||||
click.echo(f"Error: Refusing to follow banned actor: {actor}")
|
if conn.get_domain_ban(actor):
|
||||||
return
|
click.echo(f"Error: Refusing to follow banned actor: {actor}")
|
||||||
|
|
||||||
if (instance := conn.get_inbox(actor)) is not None:
|
|
||||||
inbox = instance.inbox
|
|
||||||
|
|
||||||
else:
|
|
||||||
if not actor.startswith("http"):
|
|
||||||
actor = f"https://{actor}/actor"
|
|
||||||
|
|
||||||
if (actor_data := asyncio.run(http.get(state, actor, sign_headers = True))) is None:
|
|
||||||
click.echo(f"Failed to fetch actor: {actor}")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
inbox = actor_data.shared_inbox
|
if (instance := conn.get_inbox(actor)) is not None:
|
||||||
|
inbox = instance.inbox
|
||||||
|
|
||||||
message = Message.new_follow(
|
else:
|
||||||
host = state.config.domain,
|
if not actor.startswith("http"):
|
||||||
actor = actor
|
actor = f"https://{actor}/actor"
|
||||||
)
|
|
||||||
|
|
||||||
asyncio.run(http.post(state, inbox, message, instance))
|
actor_data = await state.client.get(actor, cls = Message, sign_headers = True)
|
||||||
click.echo(f"Sent follow message to actor: {actor}")
|
|
||||||
|
if not actor_data:
|
||||||
|
click.echo(f"Failed to fetch actor: {actor}")
|
||||||
|
return
|
||||||
|
|
||||||
|
inbox = actor_data.shared_inbox
|
||||||
|
|
||||||
|
message = Message.new_follow(
|
||||||
|
host = state.config.domain,
|
||||||
|
actor = actor
|
||||||
|
)
|
||||||
|
|
||||||
|
await state.client.post(inbox, message, instance)
|
||||||
|
click.echo(f"Sent follow message to actor: {actor}")
|
||||||
|
|
||||||
|
|
||||||
@cli_inbox.command("unfollow")
|
@cli_inbox.command("unfollow")
|
||||||
@click.argument("actor")
|
@click.argument("actor")
|
||||||
@pass_state
|
@pass_state
|
||||||
def cli_inbox_unfollow(state: State, actor: str) -> None:
|
async def cli_inbox_unfollow(state: State, actor: str) -> None:
|
||||||
"Unfollow an actor (Relay must be running)"
|
"Unfollow an actor (Relay must be running)"
|
||||||
|
|
||||||
instance: Instance | None = None
|
instance: Instance | None = None
|
||||||
|
|
||||||
with state.database.session() as conn:
|
async with state.client:
|
||||||
if conn.get_domain_ban(actor):
|
with state.database.session() as conn:
|
||||||
click.echo(f"Error: Refusing to follow banned actor: {actor}")
|
if conn.get_domain_ban(actor):
|
||||||
return
|
click.echo(f"Error: Refusing to follow banned actor: {actor}")
|
||||||
|
|
||||||
if (instance := conn.get_inbox(actor)):
|
|
||||||
inbox = instance.inbox
|
|
||||||
message = Message.new_unfollow(
|
|
||||||
host = state.config.domain,
|
|
||||||
actor = actor,
|
|
||||||
follow = instance.followid
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
if not actor.startswith("http"):
|
|
||||||
actor = f"https://{actor}/actor"
|
|
||||||
|
|
||||||
actor_data = asyncio.run(http.get(state, actor, sign_headers = True))
|
|
||||||
|
|
||||||
if not actor_data:
|
|
||||||
click.echo("Failed to fetch actor")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
inbox = actor_data.shared_inbox
|
if (instance := conn.get_inbox(actor)):
|
||||||
message = Message.new_unfollow(
|
inbox = instance.inbox
|
||||||
host = state.config.domain,
|
message = Message.new_unfollow(
|
||||||
actor = actor,
|
host = state.config.domain,
|
||||||
follow = {
|
actor = actor,
|
||||||
"type": "Follow",
|
follow = instance.followid
|
||||||
"object": actor,
|
)
|
||||||
"actor": f"https://{state.config.domain}/actor"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
asyncio.run(http.post(state, inbox, message, instance))
|
else:
|
||||||
click.echo(f"Sent unfollow message to: {actor}")
|
if not actor.startswith("http"):
|
||||||
|
actor = f"https://{actor}/actor"
|
||||||
|
|
||||||
|
actor_data = await state.client.get(actor, cls = Message, sign_headers = True)
|
||||||
|
|
||||||
|
if not actor_data:
|
||||||
|
click.echo("Failed to fetch actor")
|
||||||
|
return
|
||||||
|
|
||||||
|
inbox = actor_data.shared_inbox
|
||||||
|
message = Message.new_unfollow(
|
||||||
|
host = state.config.domain,
|
||||||
|
actor = actor,
|
||||||
|
follow = {
|
||||||
|
"type": "Follow",
|
||||||
|
"object": actor,
|
||||||
|
"actor": f"https://{state.config.domain}/actor"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
await state.client.post(inbox, message, instance)
|
||||||
|
click.echo(f"Sent unfollow message to: {actor}")
|
||||||
|
|
||||||
|
|
||||||
@cli_inbox.command("add")
|
@cli_inbox.command("add")
|
||||||
|
@ -115,7 +117,7 @@ def cli_inbox_unfollow(state: State, actor: str) -> None:
|
||||||
@click.option("--followid", "-f", help = "Url for the follow activity")
|
@click.option("--followid", "-f", help = "Url for the follow activity")
|
||||||
@click.option("--software", "-s", help = "Nodeinfo software name of the instance")
|
@click.option("--software", "-s", help = "Nodeinfo software name of the instance")
|
||||||
@pass_state
|
@pass_state
|
||||||
def cli_inbox_add(
|
async def cli_inbox_add(
|
||||||
state: State,
|
state: State,
|
||||||
inbox: str,
|
inbox: str,
|
||||||
actor: str | None = None,
|
actor: str | None = None,
|
||||||
|
@ -131,8 +133,9 @@ def cli_inbox_add(
|
||||||
domain = urlparse(inbox).netloc
|
domain = urlparse(inbox).netloc
|
||||||
|
|
||||||
if not software:
|
if not software:
|
||||||
if (nodeinfo := asyncio.run(http.fetch_nodeinfo(state, domain))):
|
async with state.client:
|
||||||
software = nodeinfo.sw_name
|
if (nodeinfo := await state.client.fetch_nodeinfo(domain)):
|
||||||
|
software = nodeinfo.sw_name
|
||||||
|
|
||||||
if not actor and software:
|
if not actor and software:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1,9 +1,7 @@
|
||||||
import asyncio
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from . import cli, pass_state
|
from . import cli, pass_state
|
||||||
|
|
||||||
from .. import http_client as http
|
|
||||||
from ..misc import Message
|
from ..misc import Message
|
||||||
from ..state import State
|
from ..state import State
|
||||||
|
|
||||||
|
@ -29,7 +27,7 @@ def cli_request_list(state: State) -> None:
|
||||||
@cli_request.command("accept")
|
@cli_request.command("accept")
|
||||||
@click.argument("domain")
|
@click.argument("domain")
|
||||||
@pass_state
|
@pass_state
|
||||||
def cli_request_accept(state: State, domain: str) -> None:
|
async def cli_request_accept(state: State, domain: str) -> None:
|
||||||
"Accept a follow request"
|
"Accept a follow request"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -40,28 +38,29 @@ def cli_request_accept(state: State, domain: str) -> None:
|
||||||
click.echo("Request not found")
|
click.echo("Request not found")
|
||||||
return
|
return
|
||||||
|
|
||||||
message = Message.new_response(
|
response = Message.new_response(
|
||||||
host = state.config.domain,
|
host = state.config.domain,
|
||||||
actor = instance.actor,
|
actor = instance.actor,
|
||||||
followid = instance.followid,
|
followid = instance.followid,
|
||||||
accept = True
|
accept = True
|
||||||
)
|
)
|
||||||
|
|
||||||
asyncio.run(http.post(state, instance.inbox, message, instance))
|
async with state.client:
|
||||||
|
await state.client.post(instance.inbox, response, instance)
|
||||||
|
|
||||||
if instance.software != "mastodon":
|
if instance.software != "mastodon":
|
||||||
message = Message.new_follow(
|
follow = Message.new_follow(
|
||||||
host = state.config.domain,
|
host = state.config.domain,
|
||||||
actor = instance.actor
|
actor = instance.actor
|
||||||
)
|
)
|
||||||
|
|
||||||
asyncio.run(http.post(state, instance.inbox, message, instance))
|
await state.client.post(instance.inbox, follow, instance)
|
||||||
|
|
||||||
|
|
||||||
@cli_request.command("deny")
|
@cli_request.command("deny")
|
||||||
@click.argument("domain")
|
@click.argument("domain")
|
||||||
@pass_state
|
@pass_state
|
||||||
def cli_request_deny(state: State, domain: str) -> None:
|
async def cli_request_deny(state: State, domain: str) -> None:
|
||||||
"Accept a follow request"
|
"Accept a follow request"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -79,4 +78,5 @@ def cli_request_deny(state: State, domain: str) -> None:
|
||||||
accept = False
|
accept = False
|
||||||
)
|
)
|
||||||
|
|
||||||
asyncio.run(http.post(state, instance.inbox, response, instance))
|
async with state.client:
|
||||||
|
await state.client.post(instance.inbox, response, instance)
|
||||||
|
|
|
@ -1,9 +1,7 @@
|
||||||
import asyncio
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from . import cli, pass_state
|
from . import cli, pass_state
|
||||||
|
|
||||||
from .. import http_client as http
|
|
||||||
from ..misc import RELAY_SOFTWARE
|
from ..misc import RELAY_SOFTWARE
|
||||||
from ..state import State
|
from ..state import State
|
||||||
|
|
||||||
|
@ -39,7 +37,7 @@ def cli_software_list(state: State) -> None:
|
||||||
help = "Treat NAME like a domain and try to fetch the software name from nodeinfo"
|
help = "Treat NAME like a domain and try to fetch the software name from nodeinfo"
|
||||||
)
|
)
|
||||||
@pass_state
|
@pass_state
|
||||||
def cli_software_ban(state: State,
|
async def cli_software_ban(state: State,
|
||||||
name: str,
|
name: str,
|
||||||
reason: str,
|
reason: str,
|
||||||
note: str,
|
note: str,
|
||||||
|
@ -59,7 +57,10 @@ def cli_software_ban(state: State,
|
||||||
return
|
return
|
||||||
|
|
||||||
if fetch_nodeinfo:
|
if fetch_nodeinfo:
|
||||||
if not (nodeinfo := asyncio.run(http.fetch_nodeinfo(state, name))):
|
async with state.client:
|
||||||
|
nodeinfo = await state.client.fetch_nodeinfo(name)
|
||||||
|
|
||||||
|
if not nodeinfo:
|
||||||
click.echo(f"Failed to fetch software name from domain: {name}")
|
click.echo(f"Failed to fetch software name from domain: {name}")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -86,7 +87,7 @@ def cli_software_ban(state: State,
|
||||||
help = "Treat NAME like a domain and try to fetch the software name from nodeinfo"
|
help = "Treat NAME like a domain and try to fetch the software name from nodeinfo"
|
||||||
)
|
)
|
||||||
@pass_state
|
@pass_state
|
||||||
def cli_software_unban(state: State, name: str, fetch_nodeinfo: bool) -> None:
|
async def cli_software_unban(state: State, name: str, fetch_nodeinfo: bool) -> None:
|
||||||
"Ban software. Use RELAYS for NAME to unban relays"
|
"Ban software. Use RELAYS for NAME to unban relays"
|
||||||
|
|
||||||
with state.database.session() as conn:
|
with state.database.session() as conn:
|
||||||
|
@ -99,7 +100,10 @@ def cli_software_unban(state: State, name: str, fetch_nodeinfo: bool) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if fetch_nodeinfo:
|
if fetch_nodeinfo:
|
||||||
if not (nodeinfo := asyncio.run(http.fetch_nodeinfo(state, name))):
|
async with state.client:
|
||||||
|
nodeinfo = await state.client.fetch_nodeinfo(name)
|
||||||
|
|
||||||
|
if not nodeinfo:
|
||||||
click.echo(f"Failed to fetch software name from domain: {name}")
|
click.echo(f"Failed to fetch software name from domain: {name}")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -167,7 +167,7 @@ class HttpClient:
|
||||||
|
|
||||||
if cls is not None:
|
if cls is not None:
|
||||||
if data is None:
|
if data is None:
|
||||||
# this shouldn"t actually get raised, but keeping just in case
|
# this shouldn't actually get raised, but keeping just in case
|
||||||
raise EmptyBodyError(f"GET {url}")
|
raise EmptyBodyError(f"GET {url}")
|
||||||
|
|
||||||
return cls.parse(data)
|
return cls.parse(data)
|
||||||
|
@ -237,18 +237,3 @@ class HttpClient:
|
||||||
raise ValueError(f"Failed to fetch nodeinfo url for {domain}")
|
raise ValueError(f"Failed to fetch nodeinfo url for {domain}")
|
||||||
|
|
||||||
return await self.get(nodeinfo_url, False, Nodeinfo, force)
|
return await self.get(nodeinfo_url, False, Nodeinfo, force)
|
||||||
|
|
||||||
|
|
||||||
async def get(state: State, *args: Any, **kwargs: Any) -> Any:
|
|
||||||
async with HttpClient(state) as client:
|
|
||||||
return await client.get(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
async def post(state: State, *args: Any, **kwargs: Any) -> None:
|
|
||||||
async with HttpClient(state) as client:
|
|
||||||
return await client.post(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
async def fetch_nodeinfo(state: State, *args: Any, **kwargs: Any) -> Nodeinfo | None:
|
|
||||||
async with HttpClient(state) as client:
|
|
||||||
return await client.fetch_nodeinfo(*args, **kwargs)
|
|
||||||
|
|
Loading…
Reference in a new issue