update barkshark-sql to 0.2.0-rc1 and create row classes

This commit is contained in:
Izalia Mae 2024-06-25 16:02:49 -04:00
parent 45b0de26c7
commit bdc7d41d7a
14 changed files with 374 additions and 329 deletions

View file

@ -20,7 +20,7 @@ dependencies = [
"aiohttp >= 3.9.5",
"aiohttp-swagger[performance] == 1.0.16",
"argon2-cffi == 23.1.0",
"barkshark-lib >= 0.1.3-1",
"barkshark-lib >= 0.2.0-rc1",
"barkshark-sql == 0.1.4-1",
"click >= 8.1.2",
"hiredis == 2.3.2",
@ -104,7 +104,3 @@ implicit_reexport = true
[[tool.mypy.overrides]]
module = "blib"
implicit_reexport = true
[[tool.mypy.overrides]]
module = "bsql"
implicit_reexport = true

View file

@ -1 +1 @@
__version__ = '0.3.2'
__version__ = '0.3.3'

View file

@ -4,30 +4,26 @@ import asyncio
import multiprocessing
import signal
import time
import traceback
from aiohttp import web
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
from aiohttp.web import StaticResource
from aiohttp_swagger import setup_swagger
from aputils.signer import Signer
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
from bsql import Database, Row
from bsql import Database
from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta
from mimetypes import guess_type
from pathlib import Path
from queue import Empty
from threading import Event, Thread
from typing import Any
from urllib.parse import urlparse
from . import logger as logging, workers
from .cache import Cache, get_cache
from .config import Config
from .database import Connection, get_database
from .database.schema import Instance
from .http_client import HttpClient
from .misc import IS_WINDOWS, Message, Response, check_open_port, get_resource
from .misc import Message, Response, check_open_port, get_resource
from .template import Template
from .views import VIEWS
from .views.api import handle_api_path
@ -142,7 +138,7 @@ class Application(web.Application):
return timedelta(seconds=uptime.seconds)
def push_message(self, inbox: str, message: Message, instance: Row) -> None:
def push_message(self, inbox: str, message: Message, instance: Instance) -> None:
self['workers'].push_message(inbox, message, instance)
@ -286,67 +282,6 @@ class CacheCleanupThread(Thread):
self.running.clear()
class PushWorker(multiprocessing.Process):
def __init__(self, queue: multiprocessing.Queue[tuple[str, Message, Row]]) -> None:
if Application.DEFAULT is None:
raise RuntimeError('Application not setup yet')
multiprocessing.Process.__init__(self)
self.queue = queue
self.shutdown = multiprocessing.Event()
self.path = Application.DEFAULT.config.path
def stop(self) -> None:
self.shutdown.set()
def run(self) -> None:
asyncio.run(self.handle_queue())
async def handle_queue(self) -> None:
if IS_WINDOWS:
app = Application(self.path)
client = app.client
client.open()
app.database.connect()
app.cache.setup()
else:
client = HttpClient()
client.open()
while not self.shutdown.is_set():
try:
inbox, message, instance = self.queue.get(block=True, timeout=0.1)
asyncio.create_task(client.post(inbox, message, instance))
except Empty:
await asyncio.sleep(0)
except ClientSSLError as e:
logging.error('SSL error when pushing to %s: %s', urlparse(inbox).netloc, str(e))
except (AsyncTimeoutError, ClientConnectionError) as e:
logging.error(
'Failed to connect to %s for message push: %s',
urlparse(inbox).netloc, str(e)
)
# make sure an exception doesn't bring down the worker
except Exception:
traceback.print_exc()
if IS_WINDOWS:
app.database.disconnect()
app.cache.close()
await client.close()
@web.middleware
async def handle_response_headers(
request: web.Request,

View file

@ -4,7 +4,7 @@ import json
import os
from abc import ABC, abstractmethod
from bsql import Database
from bsql import Database, Row
from collections.abc import Callable, Iterator
from dataclasses import asdict, dataclass
from datetime import datetime, timedelta, timezone
@ -172,7 +172,7 @@ class SqlCache(Cache):
with self._db.session(False) as conn:
with conn.run('get-cache-item', params) as cur:
if not (row := cur.one()):
if not (row := cur.one(Row)):
raise KeyError(f'{namespace}:{key}')
row.pop('id', None)
@ -211,9 +211,11 @@ class SqlCache(Cache):
with self._db.session(True) as conn:
with conn.run('set-cache-item', params) as cur:
row = cur.one()
row.pop('id', None) # type: ignore[union-attr]
return Item.from_data(*tuple(row.values())) # type: ignore[union-attr]
if (row := cur.one(Row)) is None:
raise RuntimeError("Cache item not set")
row.pop('id', None)
return Item.from_data(*tuple(row.values()))
def delete(self, namespace: str, key: str) -> None:

View file

@ -16,6 +16,8 @@ def get_database(config: Config, migrate: bool = True) -> Database[Connection]:
'tables': TABLES
}
db: Database[Connection]
if config.db_type == 'sqlite':
db = Database.sqlite(config.sqlite_path, **options)

View file

@ -2,12 +2,13 @@ from __future__ import annotations
from argon2 import PasswordHasher
from bsql import Connection as SqlConnection, Row, Update
from collections.abc import Iterator, Sequence
from collections.abc import Iterator
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse
from uuid import uuid4
from . import schema
from .config import (
THEMES,
ConfigData
@ -37,14 +38,14 @@ class Connection(SqlConnection):
return get_app()
def distill_inboxes(self, message: Message) -> Iterator[Row]:
def distill_inboxes(self, message: Message) -> Iterator[schema.Instance]:
src_domains = {
message.domain,
urlparse(message.object_id).netloc
}
for instance in self.get_inboxes():
if instance['domain'] not in src_domains:
if instance.domain not in src_domains:
yield instance
@ -52,7 +53,7 @@ class Connection(SqlConnection):
key = key.replace('_', '-')
with self.run('get-config', {'key': key}) as cur:
if not (row := cur.one()):
if (row := cur.one(Row)) is None:
return ConfigData.DEFAULT(key)
data = ConfigData()
@ -61,8 +62,8 @@ class Connection(SqlConnection):
def get_config_all(self) -> ConfigData:
with self.run('get-config-all', None) as cur:
return ConfigData.from_rows(tuple(cur.all()))
rows = tuple(self.run('get-config-all', None).all(schema.Row))
return ConfigData.from_rows(rows)
def put_config(self, key: str, value: Any) -> Any:
@ -99,14 +100,13 @@ class Connection(SqlConnection):
return data.get(key)
def get_inbox(self, value: str) -> Row:
def get_inbox(self, value: str) -> schema.Instance | None:
with self.run('get-inbox', {'value': value}) as cur:
return cur.one() # type: ignore
return cur.one(schema.Instance)
def get_inboxes(self) -> Sequence[Row]:
with self.execute("SELECT * FROM inboxes WHERE accepted = 1") as cur:
return tuple(cur.all())
def get_inboxes(self) -> Iterator[schema.Instance]:
return self.execute("SELECT * FROM inboxes WHERE accepted = 1").all(schema.Instance)
def put_inbox(self,
@ -115,7 +115,7 @@ class Connection(SqlConnection):
actor: str | None = None,
followid: str | None = None,
software: str | None = None,
accepted: bool = True) -> Row:
accepted: bool = True) -> schema.Instance:
params: dict[str, Any] = {
'inbox': inbox,
@ -125,7 +125,7 @@ class Connection(SqlConnection):
'accepted': accepted
}
if not self.get_inbox(domain):
if self.get_inbox(domain) is None:
if not inbox:
raise ValueError("Missing inbox")
@ -133,14 +133,20 @@ class Connection(SqlConnection):
params['created'] = datetime.now(tz = timezone.utc)
with self.run('put-inbox', params) as cur:
return cur.one() # type: ignore
if (row := cur.one(schema.Instance)) is None:
raise RuntimeError(f"Failed to insert instance: {domain}")
return row
for key, value in tuple(params.items()):
if value is None:
del params[key]
with self.update('inboxes', params, domain = domain) as cur:
return cur.one() # type: ignore
if (row := cur.one(schema.Instance)) is None:
raise RuntimeError(f"Failed to update instance: {domain}")
return row
def del_inbox(self, value: str) -> bool:
@ -151,24 +157,23 @@ class Connection(SqlConnection):
return cur.row_count == 1
def get_request(self, domain: str) -> Row:
def get_request(self, domain: str) -> schema.Instance | None:
with self.run('get-request', {'domain': domain}) as cur:
if not (row := cur.one()):
return cur.one(schema.Instance)
def get_requests(self) -> Iterator[schema.Instance]:
return self.execute('SELECT * FROM inboxes WHERE accepted = 0').all(schema.Instance)
def put_request_response(self, domain: str, accepted: bool) -> schema.Instance:
if (instance := self.get_request(domain)) is None:
raise KeyError(domain)
return row
def get_requests(self) -> Sequence[Row]:
with self.execute('SELECT * FROM inboxes WHERE accepted = 0') as cur:
return tuple(cur.all())
def put_request_response(self, domain: str, accepted: bool) -> Row:
instance = self.get_request(domain)
if not accepted:
self.del_inbox(domain)
if not self.del_inbox(domain):
raise RuntimeError(f'Failed to delete request: {domain}')
return instance
params = {
@ -177,21 +182,28 @@ class Connection(SqlConnection):
}
with self.run('put-inbox-accept', params) as cur:
return cur.one() # type: ignore
if (row := cur.one(schema.Instance)) is None:
raise RuntimeError(f"Failed to insert response for domain: {domain}")
return row
def get_user(self, value: str) -> Row:
def get_user(self, value: str) -> schema.User | None:
with self.run('get-user', {'value': value}) as cur:
return cur.one() # type: ignore
return cur.one(schema.User)
def get_user_by_token(self, code: str) -> Row:
def get_user_by_token(self, code: str) -> schema.User | None:
with self.run('get-user-by-token', {'code': code}) as cur:
return cur.one() # type: ignore
return cur.one(schema.User)
def put_user(self, username: str, password: str | None, handle: str | None = None) -> Row:
if self.get_user(username):
def get_users(self) -> Iterator[schema.User]:
return self.execute("SELECT * FROM users").all(schema.User)
def put_user(self, username: str, password: str | None, handle: str | None = None) -> schema.User:
if self.get_user(username) is not None:
data: dict[str, str | datetime | None] = {}
if password:
@ -204,7 +216,10 @@ class Connection(SqlConnection):
stmt.set_where("username", username)
with self.query(stmt) as cur:
return cur.one() # type: ignore
if (row := cur.one(schema.User)) is None:
raise RuntimeError(f"Failed to update user: {username}")
return row
if password is None:
raise ValueError('Password cannot be empty')
@ -217,25 +232,36 @@ class Connection(SqlConnection):
}
with self.run('put-user', data) as cur:
return cur.one() # type: ignore
if (row := cur.one(schema.User)) is None:
raise RuntimeError(f"Failed to insert user: {username}")
return row
def del_user(self, username: str) -> None:
user = self.get_user(username)
if (user := self.get_user(username)) is None:
raise KeyError(username)
with self.run('del-user', {'value': user['username']}):
with self.run('del-user', {'value': user.username}):
pass
with self.run('del-token-user', {'username': user['username']}):
with self.run('del-token-user', {'username': user.username}):
pass
def get_token(self, code: str) -> Row:
def get_token(self, code: str) -> schema.Token | None:
with self.run('get-token', {'code': code}) as cur:
return cur.one() # type: ignore
return cur.one(schema.Token)
def put_token(self, username: str) -> Row:
def get_tokens(self, username: str | None = None) -> Iterator[schema.Token]:
if username is not None:
return self.select('tokens').all(schema.Token)
return self.select('tokens', username = username).all(schema.Token)
def put_token(self, username: str) -> schema.Token:
data = {
'code': uuid4().hex,
'user': username,
@ -243,7 +269,10 @@ class Connection(SqlConnection):
}
with self.run('put-token', data) as cur:
return cur.one() # type: ignore
if (row := cur.one(schema.Token)) is None:
raise RuntimeError(f"Failed to insert token for user: {username}")
return row
def del_token(self, code: str) -> None:
@ -251,18 +280,22 @@ class Connection(SqlConnection):
pass
def get_domain_ban(self, domain: str) -> Row:
def get_domain_ban(self, domain: str) -> schema.DomainBan | None:
if domain.startswith('http'):
domain = urlparse(domain).netloc
with self.run('get-domain-ban', {'domain': domain}) as cur:
return cur.one() # type: ignore
return cur.one(schema.DomainBan)
def get_domain_bans(self) -> Iterator[schema.DomainBan]:
return self.execute("SELECT * FROM domain_bans").all(schema.DomainBan)
def put_domain_ban(self,
domain: str,
reason: str | None = None,
note: str | None = None) -> Row:
note: str | None = None) -> schema.DomainBan:
params = {
'domain': domain,
@ -272,13 +305,16 @@ class Connection(SqlConnection):
}
with self.run('put-domain-ban', params) as cur:
return cur.one() # type: ignore
if (row := cur.one(schema.DomainBan)) is None:
raise RuntimeError(f"Failed to insert domain ban: {domain}")
return row
def update_domain_ban(self,
domain: str,
reason: str | None = None,
note: str | None = None) -> Row:
note: str | None = None) -> schema.DomainBan:
if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified')
@ -298,7 +334,10 @@ class Connection(SqlConnection):
if cur.row_count > 1:
raise ValueError('More than one row was modified')
return self.get_domain_ban(domain)
if (row := cur.one(schema.DomainBan)) is None:
raise RuntimeError(f"Failed to update domain ban: {domain}")
return row
def del_domain_ban(self, domain: str) -> bool:
@ -309,15 +348,19 @@ class Connection(SqlConnection):
return cur.row_count == 1
def get_software_ban(self, name: str) -> Row:
def get_software_ban(self, name: str) -> schema.SoftwareBan | None:
with self.run('get-software-ban', {'name': name}) as cur:
return cur.one() # type: ignore
return cur.one(schema.SoftwareBan)
def get_software_bans(self) -> Iterator[schema.SoftwareBan,]:
return self.execute('SELECT * FROM software_bans').all(schema.SoftwareBan)
def put_software_ban(self,
name: str,
reason: str | None = None,
note: str | None = None) -> Row:
note: str | None = None) -> schema.SoftwareBan:
params = {
'name': name,
@ -327,13 +370,16 @@ class Connection(SqlConnection):
}
with self.run('put-software-ban', params) as cur:
return cur.one() # type: ignore
if (row := cur.one(schema.SoftwareBan)) is None:
raise RuntimeError(f'Failed to insert software ban: {name}')
return row
def update_software_ban(self,
name: str,
reason: str | None = None,
note: str | None = None) -> Row:
note: str | None = None) -> schema.SoftwareBan:
if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified')
@ -353,7 +399,10 @@ class Connection(SqlConnection):
if cur.row_count > 1:
raise ValueError('More than one row was modified')
return self.get_software_ban(name)
if (row := cur.one(schema.SoftwareBan)) is None:
raise RuntimeError(f'Failed to update software ban: {name}')
return row
def del_software_ban(self, name: str) -> bool:
@ -364,19 +413,26 @@ class Connection(SqlConnection):
return cur.row_count == 1
def get_domain_whitelist(self, domain: str) -> Row:
def get_domain_whitelist(self, domain: str) -> schema.Whitelist | None:
with self.run('get-domain-whitelist', {'domain': domain}) as cur:
return cur.one() # type: ignore
return cur.one()
def put_domain_whitelist(self, domain: str) -> Row:
def get_domains_whitelist(self) -> Iterator[schema.Whitelist,]:
return self.execute("SELECT * FROM whitelist").all(schema.Whitelist)
def put_domain_whitelist(self, domain: str) -> schema.Whitelist:
params = {
'domain': domain,
'created': datetime.now(tz = timezone.utc)
}
with self.run('put-domain-whitelist', params) as cur:
return cur.one() # type: ignore
if (row := cur.one(schema.Whitelist)) is None:
raise RuntimeError(f'Failed to insert whitelisted domain: {domain}')
return row
def del_domain_whitelist(self, domain: str) -> bool:

View file

@ -1,61 +1,88 @@
from bsql import Column, Table, Tables
from __future__ import annotations
import typing
from bsql import Column, Row, Tables
from collections.abc import Callable
from datetime import datetime
from .config import ConfigData
from .connection import Connection
if typing.TYPE_CHECKING:
from .connection import Connection
VERSIONS: dict[int, Callable[[Connection], None]] = {}
TABLES: Tables = Tables(
Table(
'config',
Column('key', 'text', primary_key = True, unique = True, nullable = False),
Column('value', 'text'),
Column('type', 'text', default = 'str')
),
Table(
'inboxes',
Column('domain', 'text', primary_key = True, unique = True, nullable = False),
Column('actor', 'text', unique = True),
Column('inbox', 'text', unique = True, nullable = False),
Column('followid', 'text'),
Column('software', 'text'),
Column('accepted', 'boolean'),
Column('created', 'timestamp', nullable = False)
),
Table(
'whitelist',
Column('domain', 'text', primary_key = True, unique = True, nullable = True),
Column('created', 'timestamp')
),
Table(
'domain_bans',
Column('domain', 'text', primary_key = True, unique = True, nullable = True),
Column('reason', 'text'),
Column('note', 'text'),
Column('created', 'timestamp', nullable = False)
),
Table(
'software_bans',
Column('name', 'text', primary_key = True, unique = True, nullable = True),
Column('reason', 'text'),
Column('note', 'text'),
Column('created', 'timestamp', nullable = False)
),
Table(
'users',
Column('username', 'text', primary_key = True, unique = True, nullable = False),
Column('hash', 'text', nullable = False),
Column('handle', 'text'),
Column('created', 'timestamp', nullable = False)
),
Table(
'tokens',
Column('code', 'text', primary_key = True, unique = True, nullable = False),
Column('user', 'text', nullable = False),
Column('created', 'timestmap', nullable = False)
)
)
TABLES = Tables()
@TABLES.add_row
class Config(Row):
key: Column[str] = Column('key', 'text', primary_key = True, unique = True, nullable = False)
value: Column[str] = Column('value', 'text')
type: Column[str] = Column('type', 'text', default = 'str')
@TABLES.add_row
class Instance(Row):
table_name: str = 'inboxes'
domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = False)
actor: Column[str] = Column('actor', 'text', unique = True)
inbox: Column[str] = Column('inbox', 'text', unique = True, nullable = False)
followid: Column[str] = Column('followid', 'text')
software: Column[str] = Column('software', 'text')
accepted: Column[datetime] = Column('accepted', 'boolean')
created: Column[datetime] = Column('created', 'timestamp', nullable = False)
@TABLES.add_row
class Whitelist(Row):
domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = True)
created: Column[datetime] = Column('created', 'timestamp')
@TABLES.add_row
class DomainBan(Row):
table_name: str = 'domain_bans'
domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text')
created: Column[datetime] = Column('created', 'timestamp')
@TABLES.add_row
class SoftwareBan(Row):
table_name: str = 'software_bans'
name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text')
created: Column[datetime] = Column('created', 'timestamp')
@TABLES.add_row
class User(Row):
table_name: str = 'users'
username: Column[str] = Column(
'username', 'text', primary_key = True, unique = True, nullable = False)
hash: Column[str] = Column('hash', 'text', nullable = False)
handle: Column[str] = Column('handle', 'text')
created: Column[datetime] = Column('created', 'timestamp')
@TABLES.add_row
class Token(Row):
table_name: str = 'tokens'
code: Column[str] = Column('code', 'text', primary_key = True, unique = True, nullable = False)
user: Column[str] = Column('user', 'text', nullable = False)
created: Column[datetime] = Column('created', 'timestamp')
def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]:

View file

@ -5,11 +5,11 @@ import json
from aiohttp import ClientSession, ClientTimeout, TCPConnector
from aputils import AlgorithmType, Nodeinfo, ObjectType, Signer, WellKnownNodeinfo
from blib import JsonBase
from bsql import Row
from typing import TYPE_CHECKING, Any, TypeVar, overload
from . import __version__, logger as logging
from .cache import Cache
from .database.schema import Instance
from .misc import MIMETYPES, Message, get_app
if TYPE_CHECKING:
@ -184,12 +184,12 @@ class HttpClient:
return None
async def post(self, url: str, data: Message | bytes, instance: Row | None = None) -> None:
async def post(self, url: str, data: Message | bytes, instance: Instance | None = None) -> None:
if not self._session:
raise RuntimeError('Client not open')
# akkoma and pleroma do not support HS2019 and other software still needs to be tested
if instance and instance['software'] in SUPPORTS_HS2019:
if instance is not None and instance.software in SUPPORTS_HS2019:
algorithm = AlgorithmType.HS2019
else:

View file

@ -6,7 +6,6 @@ import click
import json
import os
from bsql import Row
from pathlib import Path
from shutil import copyfile
from typing import Any
@ -17,7 +16,7 @@ from . import http_client as http
from . import logger as logging
from .application import Application
from .compat import RelayConfig, RelayDatabase
from .database import RELAY_SOFTWARE, get_database
from .database import RELAY_SOFTWARE, get_database, schema
from .misc import ACTOR_FORMATS, SOFTWARE, IS_DOCKER, Message
@ -367,8 +366,8 @@ def cli_user_list(ctx: click.Context) -> None:
click.echo('Users:')
with ctx.obj.database.session() as conn:
for user in conn.execute('SELECT * FROM users'):
click.echo(f'- {user["username"]}')
for row in conn.get_users():
click.echo(f'- {row.username}')
@cli_user.command('create')
@ -379,7 +378,7 @@ def cli_user_create(ctx: click.Context, username: str, handle: str) -> None:
'Create a new local user'
with ctx.obj.database.session() as conn:
if conn.get_user(username):
if conn.get_user(username) is not None:
click.echo(f'User already exists: {username}')
return
@ -406,7 +405,7 @@ def cli_user_delete(ctx: click.Context, username: str) -> None:
'Delete a local user'
with ctx.obj.database.session() as conn:
if not conn.get_user(username):
if conn.get_user(username) is None:
click.echo(f'User does not exist: {username}')
return
@ -424,8 +423,8 @@ def cli_user_list_tokens(ctx: click.Context, username: str) -> None:
click.echo(f'Tokens for "{username}":')
with ctx.obj.database.session() as conn:
for token in conn.execute('SELECT * FROM tokens WHERE user = :user', {'user': username}):
click.echo(f'- {token["code"]}')
for row in conn.get_tokens(username):
click.echo(f'- {row.code}')
@cli_user.command('create-token')
@ -435,13 +434,13 @@ def cli_user_create_token(ctx: click.Context, username: str) -> None:
'Create a new API token for a user'
with ctx.obj.database.session() as conn:
if not (user := conn.get_user(username)):
if (user := conn.get_user(username)) is None:
click.echo(f'User does not exist: {username}')
return
token = conn.put_token(user['username'])
token = conn.put_token(user.username)
click.echo(f'New token for "{username}": {token["code"]}')
click.echo(f'New token for "{username}": {token.code}')
@cli_user.command('delete-token')
@ -451,7 +450,7 @@ def cli_user_delete_token(ctx: click.Context, code: str) -> None:
'Delete an API token'
with ctx.obj.database.session() as conn:
if not conn.get_token(code):
if conn.get_token(code) is None:
click.echo('Token does not exist')
return
@ -473,8 +472,8 @@ def cli_inbox_list(ctx: click.Context) -> None:
click.echo('Connected to the following instances or relays:')
with ctx.obj.database.session() as conn:
for inbox in conn.get_inboxes():
click.echo(f'- {inbox["inbox"]}')
for row in conn.get_inboxes():
click.echo(f'- {row.inbox}')
@cli_inbox.command('follow')
@ -483,19 +482,21 @@ def cli_inbox_list(ctx: click.Context) -> None:
def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
'Follow an actor (Relay must be running)'
instance: schema.Instance | None = None
with ctx.obj.database.session() as conn:
if conn.get_domain_ban(actor):
click.echo(f'Error: Refusing to follow banned actor: {actor}')
return
if (inbox_data := conn.get_inbox(actor)):
inbox = inbox_data['inbox']
if (instance := conn.get_inbox(actor)) is not None:
inbox = instance.inbox
else:
if not actor.startswith('http'):
actor = f'https://{actor}/actor'
if not (actor_data := asyncio.run(http.get(actor, sign_headers = True))):
if (actor_data := asyncio.run(http.get(actor, sign_headers = True))) is None:
click.echo(f'Failed to fetch actor: {actor}')
return
@ -506,7 +507,7 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
actor = actor
)
asyncio.run(http.post(inbox, message, inbox_data))
asyncio.run(http.post(inbox, message, instance))
click.echo(f'Sent follow message to actor: {actor}')
@ -516,19 +517,19 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
'Unfollow an actor (Relay must be running)'
inbox_data: Row | None = None
instance: schema.Instance | None = None
with ctx.obj.database.session() as conn:
if conn.get_domain_ban(actor):
click.echo(f'Error: Refusing to follow banned actor: {actor}')
return
if (inbox_data := conn.get_inbox(actor)):
inbox = inbox_data['inbox']
if (instance := conn.get_inbox(actor)):
inbox = instance.inbox
message = Message.new_unfollow(
host = ctx.obj.config.domain,
actor = actor,
follow = inbox_data['followid']
follow = instance.followid
)
else:
@ -552,7 +553,7 @@ def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
}
)
asyncio.run(http.post(inbox, message, inbox_data))
asyncio.run(http.post(inbox, message, instance))
click.echo(f'Sent unfollow message to: {actor}')
@ -632,9 +633,9 @@ def cli_request_list(ctx: click.Context) -> None:
click.echo('Follow requests:')
with ctx.obj.database.session() as conn:
for instance in conn.get_requests():
date = instance['created'].strftime('%Y-%m-%d')
click.echo(f'- [{date}] {instance["domain"]}')
for row in conn.get_requests():
date = row.created.strftime('%Y-%m-%d')
click.echo(f'- [{date}] {row.domain}')
@cli_request.command('accept')
@ -653,20 +654,20 @@ def cli_request_accept(ctx: click.Context, domain: str) -> None:
message = Message.new_response(
host = ctx.obj.config.domain,
actor = instance['actor'],
followid = instance['followid'],
actor = instance.actor,
followid = instance.followid,
accept = True
)
asyncio.run(http.post(instance['inbox'], message, instance))
asyncio.run(http.post(instance.inbox, message, instance))
if instance['software'] != 'mastodon':
if instance.software != 'mastodon':
message = Message.new_follow(
host = ctx.obj.config.domain,
actor = instance['actor']
actor = instance.actor
)
asyncio.run(http.post(instance['inbox'], message, instance))
asyncio.run(http.post(instance.inbox, message, instance))
@cli_request.command('deny')
@ -685,12 +686,12 @@ def cli_request_deny(ctx: click.Context, domain: str) -> None:
response = Message.new_response(
host = ctx.obj.config.domain,
actor = instance['actor'],
followid = instance['followid'],
actor = instance.actor,
followid = instance.followid,
accept = False
)
asyncio.run(http.post(instance['inbox'], response, instance))
asyncio.run(http.post(instance.inbox, response, instance))
@cli.group('instance')
@ -706,12 +707,12 @@ def cli_instance_list(ctx: click.Context) -> None:
click.echo('Banned domains:')
with ctx.obj.database.session() as conn:
for instance in conn.execute('SELECT * FROM domain_bans'):
if instance['reason']:
click.echo(f'- {instance["domain"]} ({instance["reason"]})')
for row in conn.get_domain_bans():
if row.reason is not None:
click.echo(f'- {row.domain} ({row.reason})')
else:
click.echo(f'- {instance["domain"]}')
click.echo(f'- {row.domain}')
@cli_instance.command('ban')
@ -723,7 +724,7 @@ def cli_instance_ban(ctx: click.Context, domain: str, reason: str, note: str) ->
'Ban an instance and remove the associated inbox if it exists'
with ctx.obj.database.session() as conn:
if conn.get_domain_ban(domain):
if conn.get_domain_ban(domain) is not None:
click.echo(f'Domain already banned: {domain}')
return
@ -739,7 +740,7 @@ def cli_instance_unban(ctx: click.Context, domain: str) -> None:
'Unban an instance'
with ctx.obj.database.session() as conn:
if not conn.del_domain_ban(domain):
if conn.del_domain_ban(domain) is None:
click.echo(f'Instance wasn\'t banned: {domain}')
return
@ -764,11 +765,11 @@ def cli_instance_update(ctx: click.Context, domain: str, reason: str, note: str)
click.echo(f'Updated domain ban: {domain}')
if row['reason']:
click.echo(f'- {row["domain"]} ({row["reason"]})')
if row.reason:
click.echo(f'- {row.domain} ({row.reason})')
else:
click.echo(f'- {row["domain"]}')
click.echo(f'- {row.domain}')
@cli.group('software')
@ -784,12 +785,12 @@ def cli_software_list(ctx: click.Context) -> None:
click.echo('Banned software:')
with ctx.obj.database.session() as conn:
for software in conn.execute('SELECT * FROM software_bans'):
if software['reason']:
click.echo(f'- {software["name"]} ({software["reason"]})')
for row in conn.get_software_bans():
if row.reason:
click.echo(f'- {row.name} ({row.reason})')
else:
click.echo(f'- {software["name"]}')
click.echo(f'- {row.name}')
@cli_software.command('ban')
@ -811,12 +812,12 @@ def cli_software_ban(ctx: click.Context,
with ctx.obj.database.session() as conn:
if name == 'RELAYS':
for software in RELAY_SOFTWARE:
if conn.get_software_ban(software):
click.echo(f'Relay already banned: {software}')
for item in RELAY_SOFTWARE:
if conn.get_software_ban(item):
click.echo(f'Relay already banned: {item}')
continue
conn.put_software_ban(software, reason or 'relay', note)
conn.put_software_ban(item, reason or 'relay', note)
click.echo('Banned all relay software')
return
@ -893,11 +894,11 @@ def cli_software_update(ctx: click.Context, name: str, reason: str, note: str) -
click.echo(f'Updated software ban: {name}')
if row['reason']:
click.echo(f'- {row["name"]} ({row["reason"]})')
if row.reason:
click.echo(f'- {row.name} ({row.reason})')
else:
click.echo(f'- {row["name"]}')
click.echo(f'- {row.name}')
@cli.group('whitelist')
@ -913,8 +914,8 @@ def cli_whitelist_list(ctx: click.Context) -> None:
click.echo('Current whitelisted domains:')
with ctx.obj.database.session() as conn:
for domain in conn.execute('SELECT * FROM whitelist'):
click.echo(f'- {domain["domain"]}')
for row in conn.get_domain_whitelist():
click.echo(f'- {row.domain}')
@cli_whitelist.command('add')
@ -953,23 +954,19 @@ def cli_whitelist_remove(ctx: click.Context, domain: str) -> None:
@cli_whitelist.command('import')
@click.pass_context
def cli_whitelist_import(ctx: click.Context) -> None:
'Add all current inboxes to the whitelist'
'Add all current instances to the whitelist'
with ctx.obj.database.session() as conn:
for inbox in conn.execute('SELECT * FROM inboxes').all():
if conn.get_domain_whitelist(inbox['domain']):
click.echo(f'Domain already in whitelist: {inbox["domain"]}')
for row in conn.get_inboxes():
if conn.get_domain_whitelist(row.domain) is not None:
click.echo(f'Domain already in whitelist: {row.domain}')
continue
conn.put_domain_whitelist(inbox['domain'])
conn.put_domain_whitelist(row.domain)
click.echo('Imported whitelist from inboxes')
def main() -> None:
cli(prog_name='relay')
if __name__ == '__main__':
click.echo('Running relay.manage is depreciated. Run `activityrelay [command]` instead.')
cli(prog_name='activityrelay')

View file

@ -34,7 +34,7 @@ async def handle_relay(view: ActorView, conn: Connection) -> None:
logging.debug('>> relay: %s', message)
for instance in conn.distill_inboxes(view.message):
view.app.push_message(instance["inbox"], message, instance)
view.app.push_message(instance.inbox, message, instance)
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
@ -52,7 +52,7 @@ async def handle_forward(view: ActorView, conn: Connection) -> None:
logging.debug('>> forward: %s', message)
for instance in conn.distill_inboxes(view.message):
view.app.push_message(instance["inbox"], view.message, instance)
view.app.push_message(instance.inbox, view.message, instance)
view.cache.set('handle-relay', view.message.id, message.id, 'str')
@ -177,7 +177,7 @@ async def handle_undo(view: ActorView, conn: Connection) -> None:
return
# prevent past unfollows from removing an instance
if view.instance['followid'] and view.instance['followid'] != view.message.object_id:
if view.instance.followid and view.instance.followid != view.message.object_id:
return
with conn.transaction():
@ -221,18 +221,18 @@ async def run_processor(view: ActorView) -> None:
with view.database.session() as conn:
if view.instance:
if not view.instance['software']:
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])):
if not view.instance.software:
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance.domain)):
with conn.transaction():
view.instance = conn.put_inbox(
domain = view.instance['domain'],
domain = view.instance.domain,
software = nodeinfo.sw_name
)
if not view.instance['actor']:
if not view.instance.actor:
with conn.transaction():
view.instance = conn.put_inbox(
domain = view.instance['domain'],
domain = view.instance.domain,
actor = view.actor.id
)

View file

@ -1,26 +1,22 @@
from __future__ import annotations
import aputils
import traceback
import typing
from aiohttp.web import Request
from .base import View, register_route
from .. import logger as logging
from ..database import schema
from ..misc import Message, Response
from ..processors import run_processor
if typing.TYPE_CHECKING:
from aiohttp.web import Request
from bsql import Row
@register_route('/actor', '/inbox')
class ActorView(View):
signature: aputils.Signature
message: Message
actor: Message
instancce: Row
instance: schema.Instance
signer: aputils.Signer
@ -47,7 +43,7 @@ class ActorView(View):
return response
with self.database.session() as conn:
self.instance = conn.get_inbox(self.actor.shared_inbox)
self.instance = conn.get_inbox(self.actor.shared_inbox) # type: ignore[assignment]
# reject if actor is banned
if conn.get_domain_ban(self.actor.domain):

View file

@ -90,10 +90,10 @@ class Login(View):
token = conn.put_token(data['username'])
resp = Response.new({'token': token['code']}, ctype = 'json')
resp = Response.new({'token': token.code}, ctype = 'json')
resp.set_cookie(
'user-token',
token['code'],
token.code,
max_age = 60 * 60 * 24 * 365,
domain = self.config.domain,
path = '/',
@ -117,7 +117,7 @@ class RelayInfo(View):
async def get(self, request: Request) -> Response:
with self.database.session() as conn:
config = conn.get_config_all()
inboxes = [row['domain'] for row in conn.get_inboxes()]
inboxes = [row.domain for row in conn.get_inboxes()]
data = {
'domain': self.config.domain,
@ -188,7 +188,7 @@ class Config(View):
class Inbox(View):
async def get(self, request: Request) -> Response:
with self.database.session() as conn:
data = conn.get_inboxes()
data = tuple(conn.get_inboxes())
return Response.new(data, ctype = 'json')
@ -202,7 +202,7 @@ class Inbox(View):
data['domain'] = urlparse(data["actor"]).netloc
with self.database.session() as conn:
if conn.get_inbox(data['domain']):
if conn.get_inbox(data['domain']) is not None:
return Response.new_error(404, 'Instance already in database', 'json')
data['domain'] = data['domain'].encode('idna').decode()
@ -225,7 +225,12 @@ class Inbox(View):
except Exception:
pass
row = conn.put_inbox(**data) # type: ignore[arg-type]
row = conn.put_inbox(
data['domain'],
actor = data.get('actor'),
software = data.get('software'),
followid = data.get('followid')
)
return Response.new(row, ctype = 'json')
@ -239,10 +244,15 @@ class Inbox(View):
data['domain'] = data['domain'].encode('idna').decode()
if not (instance := conn.get_inbox(data['domain'])):
if (instance := conn.get_inbox(data['domain'])) is None:
return Response.new_error(404, 'Instance with domain not found', 'json')
instance = conn.put_inbox(instance['domain'], **data) # type: ignore[arg-type]
instance = conn.put_inbox(
instance.domain,
actor = data.get('actor'),
software = data.get('software'),
followid = data.get('followid')
)
return Response.new(instance, ctype = 'json')
@ -268,7 +278,7 @@ class Inbox(View):
class RequestView(View):
async def get(self, request: Request) -> Response:
with self.database.session() as conn:
instances = conn.get_requests()
instances = tuple(conn.get_requests())
return Response.new(instances, ctype = 'json')
@ -291,20 +301,20 @@ class RequestView(View):
message = Message.new_response(
host = self.config.domain,
actor = instance['actor'],
followid = instance['followid'],
actor = instance.actor,
followid = instance.followid,
accept = data['accept']
)
self.app.push_message(instance['inbox'], message, instance)
self.app.push_message(instance.inbox, message, instance)
if data['accept'] and instance['software'] != 'mastodon':
if data['accept'] and instance.software != 'mastodon':
message = Message.new_follow(
host = self.config.domain,
actor = instance['actor']
actor = instance.actor
)
self.app.push_message(instance['inbox'], message, instance)
self.app.push_message(instance.inbox, message, instance)
resp_message = {'message': 'Request accepted' if data['accept'] else 'Request denied'}
return Response.new(resp_message, ctype = 'json')
@ -314,7 +324,7 @@ class RequestView(View):
class DomainBan(View):
async def get(self, request: Request) -> Response:
with self.database.session() as conn:
bans = tuple(conn.execute('SELECT * FROM domain_bans').all())
bans = tuple(conn.get_domain_bans())
return Response.new(bans, ctype = 'json')
@ -328,10 +338,14 @@ class DomainBan(View):
data['domain'] = data['domain'].encode('idna').decode()
with self.database.session() as conn:
if conn.get_domain_ban(data['domain']):
if conn.get_domain_ban(data['domain']) is not None:
return Response.new_error(400, 'Domain already banned', 'json')
ban = conn.put_domain_ban(**data)
ban = conn.put_domain_ban(
domain = data['domain'],
reason = data.get('reason'),
note = data.get('note')
)
return Response.new(ban, ctype = 'json')
@ -343,15 +357,19 @@ class DomainBan(View):
if isinstance(data, Response):
return data
data['domain'] = data['domain'].encode('idna').decode()
if not conn.get_domain_ban(data['domain']):
return Response.new_error(404, 'Domain not banned', 'json')
if not any([data.get('note'), data.get('reason')]):
return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
ban = conn.update_domain_ban(**data)
data['domain'] = data['domain'].encode('idna').decode()
if conn.get_domain_ban(data['domain']) is None:
return Response.new_error(404, 'Domain not banned', 'json')
ban = conn.update_domain_ban(
domain = data['domain'],
reason = data.get('reason'),
note = data.get('note')
)
return Response.new(ban, ctype = 'json')
@ -365,7 +383,7 @@ class DomainBan(View):
data['domain'] = data['domain'].encode('idna').decode()
if not conn.get_domain_ban(data['domain']):
if conn.get_domain_ban(data['domain']) is None:
return Response.new_error(404, 'Domain not banned', 'json')
conn.del_domain_ban(data['domain'])
@ -377,7 +395,7 @@ class DomainBan(View):
class SoftwareBan(View):
async def get(self, request: Request) -> Response:
with self.database.session() as conn:
bans = tuple(conn.execute('SELECT * FROM software_bans').all())
bans = tuple(conn.get_software_bans())
return Response.new(bans, ctype = 'json')
@ -389,10 +407,14 @@ class SoftwareBan(View):
return data
with self.database.session() as conn:
if conn.get_software_ban(data['name']):
if conn.get_software_ban(data['name']) is not None:
return Response.new_error(400, 'Domain already banned', 'json')
ban = conn.put_software_ban(**data)
ban = conn.put_software_ban(
name = data['name'],
reason = data.get('reason'),
note = data.get('note')
)
return Response.new(ban, ctype = 'json')
@ -403,14 +425,18 @@ class SoftwareBan(View):
if isinstance(data, Response):
return data
with self.database.session() as conn:
if not conn.get_software_ban(data['name']):
return Response.new_error(404, 'Software not banned', 'json')
if not any([data.get('note'), data.get('reason')]):
return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
ban = conn.update_software_ban(**data)
with self.database.session() as conn:
if conn.get_software_ban(data['name']) is None:
return Response.new_error(404, 'Software not banned', 'json')
ban = conn.update_software_ban(
name = data['name'],
reason = data.get('reason'),
note = data.get('note')
)
return Response.new(ban, ctype = 'json')
@ -422,7 +448,7 @@ class SoftwareBan(View):
return data
with self.database.session() as conn:
if not conn.get_software_ban(data['name']):
if conn.get_software_ban(data['name']) is None:
return Response.new_error(404, 'Software not banned', 'json')
conn.del_software_ban(data['name'])
@ -436,7 +462,7 @@ class User(View):
with self.database.session() as conn:
items = []
for row in conn.execute('SELECT * FROM users'):
for row in conn.get_users():
del row['hash']
items.append(row)
@ -450,12 +476,16 @@ class User(View):
return data
with self.database.session() as conn:
if conn.get_user(data['username']):
if conn.get_user(data['username']) is not None:
return Response.new_error(404, 'User already exists', 'json')
user = conn.put_user(**data)
del user['hash']
user = conn.put_user(
username = data['username'],
password = data['password'],
handle = data.get('handle')
)
del user['hash']
return Response.new(user, ctype = 'json')
@ -466,9 +496,13 @@ class User(View):
return data
with self.database.session(True) as conn:
user = conn.put_user(**data)
del user['hash']
user = conn.put_user(
username = data['username'],
password = data['password'],
handle = data.get('handle')
)
del user['hash']
return Response.new(user, ctype = 'json')
@ -479,7 +513,7 @@ class User(View):
return data
with self.database.session(True) as conn:
if not conn.get_user(data['username']):
if conn.get_user(data['username']) is None:
return Response.new_error(404, 'User does not exist', 'json')
conn.del_user(data['username'])
@ -491,7 +525,7 @@ class User(View):
class Whitelist(View):
async def get(self, request: Request) -> Response:
with self.database.session() as conn:
items = tuple(conn.execute('SELECT * FROM whitelist').all())
items = tuple(conn.get_domains_whitelist())
return Response.new(items, ctype = 'json')
@ -502,13 +536,13 @@ class Whitelist(View):
if isinstance(data, Response):
return data
data['domain'] = data['domain'].encode('idna').decode()
domain = data['domain'].encode('idna').decode()
with self.database.session() as conn:
if conn.get_domain_whitelist(data['domain']):
if conn.get_domain_whitelist(domain) is not None:
return Response.new_error(400, 'Domain already added to whitelist', 'json')
item = conn.put_domain_whitelist(**data)
item = conn.put_domain_whitelist(domain)
return Response.new(item, ctype = 'json')
@ -519,12 +553,12 @@ class Whitelist(View):
if isinstance(data, Response):
return data
data['domain'] = data['domain'].encode('idna').decode()
domain = data['domain'].encode('idna').decode()
with self.database.session() as conn:
if not conn.get_domain_whitelist(data['domain']):
if conn.get_domain_whitelist(domain) is None:
return Response.new_error(404, 'Domain not in whitelist', 'json')
conn.del_domain_whitelist(data['domain'])
conn.del_domain_whitelist(domain)
return Response.new({'message': 'Removed domain from whitelist'}, ctype = 'json')

View file

@ -202,7 +202,7 @@ class AdminConfig(View):
'message': message,
'desc': {
"name": "Name of the relay to be displayed in the header of the pages and in " +
"the actor endpoint.",
"the actor endpoint.", # noqa: E131
"note": "Description of the relay to be displayed on the front page and as the " +
"bio in the actor endpoint.",
"theme": "Color theme to use on the web pages.",

View file

@ -4,7 +4,6 @@ import typing
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
from bsql import Row
from dataclasses import dataclass
from multiprocessing import Event, Process, Queue, Value
from multiprocessing.synchronize import Event as EventType
@ -13,6 +12,7 @@ from queue import Empty, Queue as QueueType
from urllib.parse import urlparse
from . import application, logger as logging
from .database.schema import Instance
from .http_client import HttpClient
from .misc import IS_WINDOWS, Message, get_app
@ -29,7 +29,7 @@ class QueueItem:
class PostItem(QueueItem):
inbox: str
message: Message
instance: Row | None
instance: Instance | None
@property
def domain(self) -> str:
@ -122,7 +122,7 @@ class PushWorkers(list[PushWorker]):
self.queue.put(item)
def push_message(self, inbox: str, message: Message, instance: Row) -> None:
def push_message(self, inbox: str, message: Message, instance: Instance) -> None:
self.queue.put(PostItem(inbox, message, instance))