update barkshark-sql to 0.2.0-rc1 and create row classes

This commit is contained in:
Izalia Mae 2024-06-25 16:02:49 -04:00
parent 45b0de26c7
commit bdc7d41d7a
14 changed files with 374 additions and 329 deletions

View file

@ -20,7 +20,7 @@ dependencies = [
"aiohttp >= 3.9.5", "aiohttp >= 3.9.5",
"aiohttp-swagger[performance] == 1.0.16", "aiohttp-swagger[performance] == 1.0.16",
"argon2-cffi == 23.1.0", "argon2-cffi == 23.1.0",
"barkshark-lib >= 0.1.3-1", "barkshark-lib >= 0.2.0-rc1",
"barkshark-sql == 0.1.4-1", "barkshark-sql == 0.1.4-1",
"click >= 8.1.2", "click >= 8.1.2",
"hiredis == 2.3.2", "hiredis == 2.3.2",
@ -104,7 +104,3 @@ implicit_reexport = true
[[tool.mypy.overrides]] [[tool.mypy.overrides]]
module = "blib" module = "blib"
implicit_reexport = true implicit_reexport = true
[[tool.mypy.overrides]]
module = "bsql"
implicit_reexport = true

View file

@ -1 +1 @@
__version__ = '0.3.2' __version__ = '0.3.3'

View file

@ -4,30 +4,26 @@ import asyncio
import multiprocessing import multiprocessing
import signal import signal
import time import time
import traceback
from aiohttp import web from aiohttp import web
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
from aiohttp.web import StaticResource from aiohttp.web import StaticResource
from aiohttp_swagger import setup_swagger from aiohttp_swagger import setup_swagger
from aputils.signer import Signer from aputils.signer import Signer
from asyncio.exceptions import TimeoutError as AsyncTimeoutError from bsql import Database
from bsql import Database, Row
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta from datetime import datetime, timedelta
from mimetypes import guess_type from mimetypes import guess_type
from pathlib import Path from pathlib import Path
from queue import Empty
from threading import Event, Thread from threading import Event, Thread
from typing import Any from typing import Any
from urllib.parse import urlparse
from . import logger as logging, workers from . import logger as logging, workers
from .cache import Cache, get_cache from .cache import Cache, get_cache
from .config import Config from .config import Config
from .database import Connection, get_database from .database import Connection, get_database
from .database.schema import Instance
from .http_client import HttpClient from .http_client import HttpClient
from .misc import IS_WINDOWS, Message, Response, check_open_port, get_resource from .misc import Message, Response, check_open_port, get_resource
from .template import Template from .template import Template
from .views import VIEWS from .views import VIEWS
from .views.api import handle_api_path from .views.api import handle_api_path
@ -142,7 +138,7 @@ class Application(web.Application):
return timedelta(seconds=uptime.seconds) return timedelta(seconds=uptime.seconds)
def push_message(self, inbox: str, message: Message, instance: Row) -> None: def push_message(self, inbox: str, message: Message, instance: Instance) -> None:
self['workers'].push_message(inbox, message, instance) self['workers'].push_message(inbox, message, instance)
@ -286,67 +282,6 @@ class CacheCleanupThread(Thread):
self.running.clear() self.running.clear()
class PushWorker(multiprocessing.Process):
def __init__(self, queue: multiprocessing.Queue[tuple[str, Message, Row]]) -> None:
if Application.DEFAULT is None:
raise RuntimeError('Application not setup yet')
multiprocessing.Process.__init__(self)
self.queue = queue
self.shutdown = multiprocessing.Event()
self.path = Application.DEFAULT.config.path
def stop(self) -> None:
self.shutdown.set()
def run(self) -> None:
asyncio.run(self.handle_queue())
async def handle_queue(self) -> None:
if IS_WINDOWS:
app = Application(self.path)
client = app.client
client.open()
app.database.connect()
app.cache.setup()
else:
client = HttpClient()
client.open()
while not self.shutdown.is_set():
try:
inbox, message, instance = self.queue.get(block=True, timeout=0.1)
asyncio.create_task(client.post(inbox, message, instance))
except Empty:
await asyncio.sleep(0)
except ClientSSLError as e:
logging.error('SSL error when pushing to %s: %s', urlparse(inbox).netloc, str(e))
except (AsyncTimeoutError, ClientConnectionError) as e:
logging.error(
'Failed to connect to %s for message push: %s',
urlparse(inbox).netloc, str(e)
)
# make sure an exception doesn't bring down the worker
except Exception:
traceback.print_exc()
if IS_WINDOWS:
app.database.disconnect()
app.cache.close()
await client.close()
@web.middleware @web.middleware
async def handle_response_headers( async def handle_response_headers(
request: web.Request, request: web.Request,

View file

@ -4,7 +4,7 @@ import json
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from bsql import Database from bsql import Database, Row
from collections.abc import Callable, Iterator from collections.abc import Callable, Iterator
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
@ -172,7 +172,7 @@ class SqlCache(Cache):
with self._db.session(False) as conn: with self._db.session(False) as conn:
with conn.run('get-cache-item', params) as cur: with conn.run('get-cache-item', params) as cur:
if not (row := cur.one()): if not (row := cur.one(Row)):
raise KeyError(f'{namespace}:{key}') raise KeyError(f'{namespace}:{key}')
row.pop('id', None) row.pop('id', None)
@ -211,9 +211,11 @@ class SqlCache(Cache):
with self._db.session(True) as conn: with self._db.session(True) as conn:
with conn.run('set-cache-item', params) as cur: with conn.run('set-cache-item', params) as cur:
row = cur.one() if (row := cur.one(Row)) is None:
row.pop('id', None) # type: ignore[union-attr] raise RuntimeError("Cache item not set")
return Item.from_data(*tuple(row.values())) # type: ignore[union-attr]
row.pop('id', None)
return Item.from_data(*tuple(row.values()))
def delete(self, namespace: str, key: str) -> None: def delete(self, namespace: str, key: str) -> None:

View file

@ -16,6 +16,8 @@ def get_database(config: Config, migrate: bool = True) -> Database[Connection]:
'tables': TABLES 'tables': TABLES
} }
db: Database[Connection]
if config.db_type == 'sqlite': if config.db_type == 'sqlite':
db = Database.sqlite(config.sqlite_path, **options) db = Database.sqlite(config.sqlite_path, **options)

View file

@ -2,12 +2,13 @@ from __future__ import annotations
from argon2 import PasswordHasher from argon2 import PasswordHasher
from bsql import Connection as SqlConnection, Row, Update from bsql import Connection as SqlConnection, Row, Update
from collections.abc import Iterator, Sequence from collections.abc import Iterator
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse from urllib.parse import urlparse
from uuid import uuid4 from uuid import uuid4
from . import schema
from .config import ( from .config import (
THEMES, THEMES,
ConfigData ConfigData
@ -37,14 +38,14 @@ class Connection(SqlConnection):
return get_app() return get_app()
def distill_inboxes(self, message: Message) -> Iterator[Row]: def distill_inboxes(self, message: Message) -> Iterator[schema.Instance]:
src_domains = { src_domains = {
message.domain, message.domain,
urlparse(message.object_id).netloc urlparse(message.object_id).netloc
} }
for instance in self.get_inboxes(): for instance in self.get_inboxes():
if instance['domain'] not in src_domains: if instance.domain not in src_domains:
yield instance yield instance
@ -52,7 +53,7 @@ class Connection(SqlConnection):
key = key.replace('_', '-') key = key.replace('_', '-')
with self.run('get-config', {'key': key}) as cur: with self.run('get-config', {'key': key}) as cur:
if not (row := cur.one()): if (row := cur.one(Row)) is None:
return ConfigData.DEFAULT(key) return ConfigData.DEFAULT(key)
data = ConfigData() data = ConfigData()
@ -61,8 +62,8 @@ class Connection(SqlConnection):
def get_config_all(self) -> ConfigData: def get_config_all(self) -> ConfigData:
with self.run('get-config-all', None) as cur: rows = tuple(self.run('get-config-all', None).all(schema.Row))
return ConfigData.from_rows(tuple(cur.all())) return ConfigData.from_rows(rows)
def put_config(self, key: str, value: Any) -> Any: def put_config(self, key: str, value: Any) -> Any:
@ -99,14 +100,13 @@ class Connection(SqlConnection):
return data.get(key) return data.get(key)
def get_inbox(self, value: str) -> Row: def get_inbox(self, value: str) -> schema.Instance | None:
with self.run('get-inbox', {'value': value}) as cur: with self.run('get-inbox', {'value': value}) as cur:
return cur.one() # type: ignore return cur.one(schema.Instance)
def get_inboxes(self) -> Sequence[Row]: def get_inboxes(self) -> Iterator[schema.Instance]:
with self.execute("SELECT * FROM inboxes WHERE accepted = 1") as cur: return self.execute("SELECT * FROM inboxes WHERE accepted = 1").all(schema.Instance)
return tuple(cur.all())
def put_inbox(self, def put_inbox(self,
@ -115,7 +115,7 @@ class Connection(SqlConnection):
actor: str | None = None, actor: str | None = None,
followid: str | None = None, followid: str | None = None,
software: str | None = None, software: str | None = None,
accepted: bool = True) -> Row: accepted: bool = True) -> schema.Instance:
params: dict[str, Any] = { params: dict[str, Any] = {
'inbox': inbox, 'inbox': inbox,
@ -125,7 +125,7 @@ class Connection(SqlConnection):
'accepted': accepted 'accepted': accepted
} }
if not self.get_inbox(domain): if self.get_inbox(domain) is None:
if not inbox: if not inbox:
raise ValueError("Missing inbox") raise ValueError("Missing inbox")
@ -133,14 +133,20 @@ class Connection(SqlConnection):
params['created'] = datetime.now(tz = timezone.utc) params['created'] = datetime.now(tz = timezone.utc)
with self.run('put-inbox', params) as cur: with self.run('put-inbox', params) as cur:
return cur.one() # type: ignore if (row := cur.one(schema.Instance)) is None:
raise RuntimeError(f"Failed to insert instance: {domain}")
return row
for key, value in tuple(params.items()): for key, value in tuple(params.items()):
if value is None: if value is None:
del params[key] del params[key]
with self.update('inboxes', params, domain = domain) as cur: with self.update('inboxes', params, domain = domain) as cur:
return cur.one() # type: ignore if (row := cur.one(schema.Instance)) is None:
raise RuntimeError(f"Failed to update instance: {domain}")
return row
def del_inbox(self, value: str) -> bool: def del_inbox(self, value: str) -> bool:
@ -151,24 +157,23 @@ class Connection(SqlConnection):
return cur.row_count == 1 return cur.row_count == 1
def get_request(self, domain: str) -> Row: def get_request(self, domain: str) -> schema.Instance | None:
with self.run('get-request', {'domain': domain}) as cur: with self.run('get-request', {'domain': domain}) as cur:
if not (row := cur.one()): return cur.one(schema.Instance)
def get_requests(self) -> Iterator[schema.Instance]:
return self.execute('SELECT * FROM inboxes WHERE accepted = 0').all(schema.Instance)
def put_request_response(self, domain: str, accepted: bool) -> schema.Instance:
if (instance := self.get_request(domain)) is None:
raise KeyError(domain) raise KeyError(domain)
return row
def get_requests(self) -> Sequence[Row]:
with self.execute('SELECT * FROM inboxes WHERE accepted = 0') as cur:
return tuple(cur.all())
def put_request_response(self, domain: str, accepted: bool) -> Row:
instance = self.get_request(domain)
if not accepted: if not accepted:
self.del_inbox(domain) if not self.del_inbox(domain):
raise RuntimeError(f'Failed to delete request: {domain}')
return instance return instance
params = { params = {
@ -177,21 +182,28 @@ class Connection(SqlConnection):
} }
with self.run('put-inbox-accept', params) as cur: with self.run('put-inbox-accept', params) as cur:
return cur.one() # type: ignore if (row := cur.one(schema.Instance)) is None:
raise RuntimeError(f"Failed to insert response for domain: {domain}")
return row
def get_user(self, value: str) -> Row: def get_user(self, value: str) -> schema.User | None:
with self.run('get-user', {'value': value}) as cur: with self.run('get-user', {'value': value}) as cur:
return cur.one() # type: ignore return cur.one(schema.User)
def get_user_by_token(self, code: str) -> Row: def get_user_by_token(self, code: str) -> schema.User | None:
with self.run('get-user-by-token', {'code': code}) as cur: with self.run('get-user-by-token', {'code': code}) as cur:
return cur.one() # type: ignore return cur.one(schema.User)
def put_user(self, username: str, password: str | None, handle: str | None = None) -> Row: def get_users(self) -> Iterator[schema.User]:
if self.get_user(username): return self.execute("SELECT * FROM users").all(schema.User)
def put_user(self, username: str, password: str | None, handle: str | None = None) -> schema.User:
if self.get_user(username) is not None:
data: dict[str, str | datetime | None] = {} data: dict[str, str | datetime | None] = {}
if password: if password:
@ -204,7 +216,10 @@ class Connection(SqlConnection):
stmt.set_where("username", username) stmt.set_where("username", username)
with self.query(stmt) as cur: with self.query(stmt) as cur:
return cur.one() # type: ignore if (row := cur.one(schema.User)) is None:
raise RuntimeError(f"Failed to update user: {username}")
return row
if password is None: if password is None:
raise ValueError('Password cannot be empty') raise ValueError('Password cannot be empty')
@ -217,25 +232,36 @@ class Connection(SqlConnection):
} }
with self.run('put-user', data) as cur: with self.run('put-user', data) as cur:
return cur.one() # type: ignore if (row := cur.one(schema.User)) is None:
raise RuntimeError(f"Failed to insert user: {username}")
return row
def del_user(self, username: str) -> None: def del_user(self, username: str) -> None:
user = self.get_user(username) if (user := self.get_user(username)) is None:
raise KeyError(username)
with self.run('del-user', {'value': user['username']}): with self.run('del-user', {'value': user.username}):
pass pass
with self.run('del-token-user', {'username': user['username']}): with self.run('del-token-user', {'username': user.username}):
pass pass
def get_token(self, code: str) -> Row: def get_token(self, code: str) -> schema.Token | None:
with self.run('get-token', {'code': code}) as cur: with self.run('get-token', {'code': code}) as cur:
return cur.one() # type: ignore return cur.one(schema.Token)
def put_token(self, username: str) -> Row: def get_tokens(self, username: str | None = None) -> Iterator[schema.Token]:
if username is not None:
return self.select('tokens').all(schema.Token)
return self.select('tokens', username = username).all(schema.Token)
def put_token(self, username: str) -> schema.Token:
data = { data = {
'code': uuid4().hex, 'code': uuid4().hex,
'user': username, 'user': username,
@ -243,7 +269,10 @@ class Connection(SqlConnection):
} }
with self.run('put-token', data) as cur: with self.run('put-token', data) as cur:
return cur.one() # type: ignore if (row := cur.one(schema.Token)) is None:
raise RuntimeError(f"Failed to insert token for user: {username}")
return row
def del_token(self, code: str) -> None: def del_token(self, code: str) -> None:
@ -251,18 +280,22 @@ class Connection(SqlConnection):
pass pass
def get_domain_ban(self, domain: str) -> Row: def get_domain_ban(self, domain: str) -> schema.DomainBan | None:
if domain.startswith('http'): if domain.startswith('http'):
domain = urlparse(domain).netloc domain = urlparse(domain).netloc
with self.run('get-domain-ban', {'domain': domain}) as cur: with self.run('get-domain-ban', {'domain': domain}) as cur:
return cur.one() # type: ignore return cur.one(schema.DomainBan)
def get_domain_bans(self) -> Iterator[schema.DomainBan]:
return self.execute("SELECT * FROM domain_bans").all(schema.DomainBan)
def put_domain_ban(self, def put_domain_ban(self,
domain: str, domain: str,
reason: str | None = None, reason: str | None = None,
note: str | None = None) -> Row: note: str | None = None) -> schema.DomainBan:
params = { params = {
'domain': domain, 'domain': domain,
@ -272,13 +305,16 @@ class Connection(SqlConnection):
} }
with self.run('put-domain-ban', params) as cur: with self.run('put-domain-ban', params) as cur:
return cur.one() # type: ignore if (row := cur.one(schema.DomainBan)) is None:
raise RuntimeError(f"Failed to insert domain ban: {domain}")
return row
def update_domain_ban(self, def update_domain_ban(self,
domain: str, domain: str,
reason: str | None = None, reason: str | None = None,
note: str | None = None) -> Row: note: str | None = None) -> schema.DomainBan:
if not (reason or note): if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified') raise ValueError('"reason" and/or "note" must be specified')
@ -298,7 +334,10 @@ class Connection(SqlConnection):
if cur.row_count > 1: if cur.row_count > 1:
raise ValueError('More than one row was modified') raise ValueError('More than one row was modified')
return self.get_domain_ban(domain) if (row := cur.one(schema.DomainBan)) is None:
raise RuntimeError(f"Failed to update domain ban: {domain}")
return row
def del_domain_ban(self, domain: str) -> bool: def del_domain_ban(self, domain: str) -> bool:
@ -309,15 +348,19 @@ class Connection(SqlConnection):
return cur.row_count == 1 return cur.row_count == 1
def get_software_ban(self, name: str) -> Row: def get_software_ban(self, name: str) -> schema.SoftwareBan | None:
with self.run('get-software-ban', {'name': name}) as cur: with self.run('get-software-ban', {'name': name}) as cur:
return cur.one() # type: ignore return cur.one(schema.SoftwareBan)
def get_software_bans(self) -> Iterator[schema.SoftwareBan,]:
return self.execute('SELECT * FROM software_bans').all(schema.SoftwareBan)
def put_software_ban(self, def put_software_ban(self,
name: str, name: str,
reason: str | None = None, reason: str | None = None,
note: str | None = None) -> Row: note: str | None = None) -> schema.SoftwareBan:
params = { params = {
'name': name, 'name': name,
@ -327,13 +370,16 @@ class Connection(SqlConnection):
} }
with self.run('put-software-ban', params) as cur: with self.run('put-software-ban', params) as cur:
return cur.one() # type: ignore if (row := cur.one(schema.SoftwareBan)) is None:
raise RuntimeError(f'Failed to insert software ban: {name}')
return row
def update_software_ban(self, def update_software_ban(self,
name: str, name: str,
reason: str | None = None, reason: str | None = None,
note: str | None = None) -> Row: note: str | None = None) -> schema.SoftwareBan:
if not (reason or note): if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified') raise ValueError('"reason" and/or "note" must be specified')
@ -353,7 +399,10 @@ class Connection(SqlConnection):
if cur.row_count > 1: if cur.row_count > 1:
raise ValueError('More than one row was modified') raise ValueError('More than one row was modified')
return self.get_software_ban(name) if (row := cur.one(schema.SoftwareBan)) is None:
raise RuntimeError(f'Failed to update software ban: {name}')
return row
def del_software_ban(self, name: str) -> bool: def del_software_ban(self, name: str) -> bool:
@ -364,19 +413,26 @@ class Connection(SqlConnection):
return cur.row_count == 1 return cur.row_count == 1
def get_domain_whitelist(self, domain: str) -> Row: def get_domain_whitelist(self, domain: str) -> schema.Whitelist | None:
with self.run('get-domain-whitelist', {'domain': domain}) as cur: with self.run('get-domain-whitelist', {'domain': domain}) as cur:
return cur.one() # type: ignore return cur.one()
def put_domain_whitelist(self, domain: str) -> Row: def get_domains_whitelist(self) -> Iterator[schema.Whitelist,]:
return self.execute("SELECT * FROM whitelist").all(schema.Whitelist)
def put_domain_whitelist(self, domain: str) -> schema.Whitelist:
params = { params = {
'domain': domain, 'domain': domain,
'created': datetime.now(tz = timezone.utc) 'created': datetime.now(tz = timezone.utc)
} }
with self.run('put-domain-whitelist', params) as cur: with self.run('put-domain-whitelist', params) as cur:
return cur.one() # type: ignore if (row := cur.one(schema.Whitelist)) is None:
raise RuntimeError(f'Failed to insert whitelisted domain: {domain}')
return row
def del_domain_whitelist(self, domain: str) -> bool: def del_domain_whitelist(self, domain: str) -> bool:

View file

@ -1,61 +1,88 @@
from bsql import Column, Table, Tables from __future__ import annotations
import typing
from bsql import Column, Row, Tables
from collections.abc import Callable from collections.abc import Callable
from datetime import datetime
from .config import ConfigData from .config import ConfigData
from .connection import Connection
if typing.TYPE_CHECKING:
from .connection import Connection
VERSIONS: dict[int, Callable[[Connection], None]] = {} VERSIONS: dict[int, Callable[[Connection], None]] = {}
TABLES: Tables = Tables( TABLES = Tables()
Table(
'config',
Column('key', 'text', primary_key = True, unique = True, nullable = False), @TABLES.add_row
Column('value', 'text'), class Config(Row):
Column('type', 'text', default = 'str') key: Column[str] = Column('key', 'text', primary_key = True, unique = True, nullable = False)
), value: Column[str] = Column('value', 'text')
Table( type: Column[str] = Column('type', 'text', default = 'str')
'inboxes',
Column('domain', 'text', primary_key = True, unique = True, nullable = False),
Column('actor', 'text', unique = True), @TABLES.add_row
Column('inbox', 'text', unique = True, nullable = False), class Instance(Row):
Column('followid', 'text'), table_name: str = 'inboxes'
Column('software', 'text'),
Column('accepted', 'boolean'), domain: Column[str] = Column(
Column('created', 'timestamp', nullable = False) 'domain', 'text', primary_key = True, unique = True, nullable = False)
), actor: Column[str] = Column('actor', 'text', unique = True)
Table( inbox: Column[str] = Column('inbox', 'text', unique = True, nullable = False)
'whitelist', followid: Column[str] = Column('followid', 'text')
Column('domain', 'text', primary_key = True, unique = True, nullable = True), software: Column[str] = Column('software', 'text')
Column('created', 'timestamp') accepted: Column[datetime] = Column('accepted', 'boolean')
), created: Column[datetime] = Column('created', 'timestamp', nullable = False)
Table(
'domain_bans',
Column('domain', 'text', primary_key = True, unique = True, nullable = True), @TABLES.add_row
Column('reason', 'text'), class Whitelist(Row):
Column('note', 'text'), domain: Column[str] = Column(
Column('created', 'timestamp', nullable = False) 'domain', 'text', primary_key = True, unique = True, nullable = True)
), created: Column[datetime] = Column('created', 'timestamp')
Table(
'software_bans',
Column('name', 'text', primary_key = True, unique = True, nullable = True), @TABLES.add_row
Column('reason', 'text'), class DomainBan(Row):
Column('note', 'text'), table_name: str = 'domain_bans'
Column('created', 'timestamp', nullable = False)
), domain: Column[str] = Column(
Table( 'domain', 'text', primary_key = True, unique = True, nullable = True)
'users', reason: Column[str] = Column('reason', 'text')
Column('username', 'text', primary_key = True, unique = True, nullable = False), note: Column[str] = Column('note', 'text')
Column('hash', 'text', nullable = False), created: Column[datetime] = Column('created', 'timestamp')
Column('handle', 'text'),
Column('created', 'timestamp', nullable = False)
), @TABLES.add_row
Table( class SoftwareBan(Row):
'tokens', table_name: str = 'software_bans'
Column('code', 'text', primary_key = True, unique = True, nullable = False),
Column('user', 'text', nullable = False), name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True)
Column('created', 'timestmap', nullable = False) reason: Column[str] = Column('reason', 'text')
) note: Column[str] = Column('note', 'text')
) created: Column[datetime] = Column('created', 'timestamp')
@TABLES.add_row
class User(Row):
table_name: str = 'users'
username: Column[str] = Column(
'username', 'text', primary_key = True, unique = True, nullable = False)
hash: Column[str] = Column('hash', 'text', nullable = False)
handle: Column[str] = Column('handle', 'text')
created: Column[datetime] = Column('created', 'timestamp')
@TABLES.add_row
class Token(Row):
table_name: str = 'tokens'
code: Column[str] = Column('code', 'text', primary_key = True, unique = True, nullable = False)
user: Column[str] = Column('user', 'text', nullable = False)
created: Column[datetime] = Column('created', 'timestamp')
def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]: def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]:

View file

@ -5,11 +5,11 @@ import json
from aiohttp import ClientSession, ClientTimeout, TCPConnector from aiohttp import ClientSession, ClientTimeout, TCPConnector
from aputils import AlgorithmType, Nodeinfo, ObjectType, Signer, WellKnownNodeinfo from aputils import AlgorithmType, Nodeinfo, ObjectType, Signer, WellKnownNodeinfo
from blib import JsonBase from blib import JsonBase
from bsql import Row
from typing import TYPE_CHECKING, Any, TypeVar, overload from typing import TYPE_CHECKING, Any, TypeVar, overload
from . import __version__, logger as logging from . import __version__, logger as logging
from .cache import Cache from .cache import Cache
from .database.schema import Instance
from .misc import MIMETYPES, Message, get_app from .misc import MIMETYPES, Message, get_app
if TYPE_CHECKING: if TYPE_CHECKING:
@ -184,12 +184,12 @@ class HttpClient:
return None return None
async def post(self, url: str, data: Message | bytes, instance: Row | None = None) -> None: async def post(self, url: str, data: Message | bytes, instance: Instance | None = None) -> None:
if not self._session: if not self._session:
raise RuntimeError('Client not open') raise RuntimeError('Client not open')
# akkoma and pleroma do not support HS2019 and other software still needs to be tested # akkoma and pleroma do not support HS2019 and other software still needs to be tested
if instance and instance['software'] in SUPPORTS_HS2019: if instance is not None and instance.software in SUPPORTS_HS2019:
algorithm = AlgorithmType.HS2019 algorithm = AlgorithmType.HS2019
else: else:

View file

@ -6,7 +6,6 @@ import click
import json import json
import os import os
from bsql import Row
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any from typing import Any
@ -17,7 +16,7 @@ from . import http_client as http
from . import logger as logging from . import logger as logging
from .application import Application from .application import Application
from .compat import RelayConfig, RelayDatabase from .compat import RelayConfig, RelayDatabase
from .database import RELAY_SOFTWARE, get_database from .database import RELAY_SOFTWARE, get_database, schema
from .misc import ACTOR_FORMATS, SOFTWARE, IS_DOCKER, Message from .misc import ACTOR_FORMATS, SOFTWARE, IS_DOCKER, Message
@ -367,8 +366,8 @@ def cli_user_list(ctx: click.Context) -> None:
click.echo('Users:') click.echo('Users:')
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
for user in conn.execute('SELECT * FROM users'): for row in conn.get_users():
click.echo(f'- {user["username"]}') click.echo(f'- {row.username}')
@cli_user.command('create') @cli_user.command('create')
@ -379,7 +378,7 @@ def cli_user_create(ctx: click.Context, username: str, handle: str) -> None:
'Create a new local user' 'Create a new local user'
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
if conn.get_user(username): if conn.get_user(username) is not None:
click.echo(f'User already exists: {username}') click.echo(f'User already exists: {username}')
return return
@ -406,7 +405,7 @@ def cli_user_delete(ctx: click.Context, username: str) -> None:
'Delete a local user' 'Delete a local user'
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
if not conn.get_user(username): if conn.get_user(username) is None:
click.echo(f'User does not exist: {username}') click.echo(f'User does not exist: {username}')
return return
@ -424,8 +423,8 @@ def cli_user_list_tokens(ctx: click.Context, username: str) -> None:
click.echo(f'Tokens for "{username}":') click.echo(f'Tokens for "{username}":')
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
for token in conn.execute('SELECT * FROM tokens WHERE user = :user', {'user': username}): for row in conn.get_tokens(username):
click.echo(f'- {token["code"]}') click.echo(f'- {row.code}')
@cli_user.command('create-token') @cli_user.command('create-token')
@ -435,13 +434,13 @@ def cli_user_create_token(ctx: click.Context, username: str) -> None:
'Create a new API token for a user' 'Create a new API token for a user'
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
if not (user := conn.get_user(username)): if (user := conn.get_user(username)) is None:
click.echo(f'User does not exist: {username}') click.echo(f'User does not exist: {username}')
return return
token = conn.put_token(user['username']) token = conn.put_token(user.username)
click.echo(f'New token for "{username}": {token["code"]}') click.echo(f'New token for "{username}": {token.code}')
@cli_user.command('delete-token') @cli_user.command('delete-token')
@ -451,7 +450,7 @@ def cli_user_delete_token(ctx: click.Context, code: str) -> None:
'Delete an API token' 'Delete an API token'
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
if not conn.get_token(code): if conn.get_token(code) is None:
click.echo('Token does not exist') click.echo('Token does not exist')
return return
@ -473,8 +472,8 @@ def cli_inbox_list(ctx: click.Context) -> None:
click.echo('Connected to the following instances or relays:') click.echo('Connected to the following instances or relays:')
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
for inbox in conn.get_inboxes(): for row in conn.get_inboxes():
click.echo(f'- {inbox["inbox"]}') click.echo(f'- {row.inbox}')
@cli_inbox.command('follow') @cli_inbox.command('follow')
@ -483,19 +482,21 @@ def cli_inbox_list(ctx: click.Context) -> None:
def cli_inbox_follow(ctx: click.Context, actor: str) -> None: def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
'Follow an actor (Relay must be running)' 'Follow an actor (Relay must be running)'
instance: schema.Instance | None = None
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
if conn.get_domain_ban(actor): if conn.get_domain_ban(actor):
click.echo(f'Error: Refusing to follow banned actor: {actor}') click.echo(f'Error: Refusing to follow banned actor: {actor}')
return return
if (inbox_data := conn.get_inbox(actor)): if (instance := conn.get_inbox(actor)) is not None:
inbox = inbox_data['inbox'] inbox = instance.inbox
else: else:
if not actor.startswith('http'): if not actor.startswith('http'):
actor = f'https://{actor}/actor' actor = f'https://{actor}/actor'
if not (actor_data := asyncio.run(http.get(actor, sign_headers = True))): if (actor_data := asyncio.run(http.get(actor, sign_headers = True))) is None:
click.echo(f'Failed to fetch actor: {actor}') click.echo(f'Failed to fetch actor: {actor}')
return return
@ -506,7 +507,7 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
actor = actor actor = actor
) )
asyncio.run(http.post(inbox, message, inbox_data)) asyncio.run(http.post(inbox, message, instance))
click.echo(f'Sent follow message to actor: {actor}') click.echo(f'Sent follow message to actor: {actor}')
@ -516,19 +517,19 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None: def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
'Unfollow an actor (Relay must be running)' 'Unfollow an actor (Relay must be running)'
inbox_data: Row | None = None instance: schema.Instance | None = None
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
if conn.get_domain_ban(actor): if conn.get_domain_ban(actor):
click.echo(f'Error: Refusing to follow banned actor: {actor}') click.echo(f'Error: Refusing to follow banned actor: {actor}')
return return
if (inbox_data := conn.get_inbox(actor)): if (instance := conn.get_inbox(actor)):
inbox = inbox_data['inbox'] inbox = instance.inbox
message = Message.new_unfollow( message = Message.new_unfollow(
host = ctx.obj.config.domain, host = ctx.obj.config.domain,
actor = actor, actor = actor,
follow = inbox_data['followid'] follow = instance.followid
) )
else: else:
@ -552,7 +553,7 @@ def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
} }
) )
asyncio.run(http.post(inbox, message, inbox_data)) asyncio.run(http.post(inbox, message, instance))
click.echo(f'Sent unfollow message to: {actor}') click.echo(f'Sent unfollow message to: {actor}')
@ -632,9 +633,9 @@ def cli_request_list(ctx: click.Context) -> None:
click.echo('Follow requests:') click.echo('Follow requests:')
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
for instance in conn.get_requests(): for row in conn.get_requests():
date = instance['created'].strftime('%Y-%m-%d') date = row.created.strftime('%Y-%m-%d')
click.echo(f'- [{date}] {instance["domain"]}') click.echo(f'- [{date}] {row.domain}')
@cli_request.command('accept') @cli_request.command('accept')
@ -653,20 +654,20 @@ def cli_request_accept(ctx: click.Context, domain: str) -> None:
message = Message.new_response( message = Message.new_response(
host = ctx.obj.config.domain, host = ctx.obj.config.domain,
actor = instance['actor'], actor = instance.actor,
followid = instance['followid'], followid = instance.followid,
accept = True accept = True
) )
asyncio.run(http.post(instance['inbox'], message, instance)) asyncio.run(http.post(instance.inbox, message, instance))
if instance['software'] != 'mastodon': if instance.software != 'mastodon':
message = Message.new_follow( message = Message.new_follow(
host = ctx.obj.config.domain, host = ctx.obj.config.domain,
actor = instance['actor'] actor = instance.actor
) )
asyncio.run(http.post(instance['inbox'], message, instance)) asyncio.run(http.post(instance.inbox, message, instance))
@cli_request.command('deny') @cli_request.command('deny')
@ -685,12 +686,12 @@ def cli_request_deny(ctx: click.Context, domain: str) -> None:
response = Message.new_response( response = Message.new_response(
host = ctx.obj.config.domain, host = ctx.obj.config.domain,
actor = instance['actor'], actor = instance.actor,
followid = instance['followid'], followid = instance.followid,
accept = False accept = False
) )
asyncio.run(http.post(instance['inbox'], response, instance)) asyncio.run(http.post(instance.inbox, response, instance))
@cli.group('instance') @cli.group('instance')
@ -706,12 +707,12 @@ def cli_instance_list(ctx: click.Context) -> None:
click.echo('Banned domains:') click.echo('Banned domains:')
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
for instance in conn.execute('SELECT * FROM domain_bans'): for row in conn.get_domain_bans():
if instance['reason']: if row.reason is not None:
click.echo(f'- {instance["domain"]} ({instance["reason"]})') click.echo(f'- {row.domain} ({row.reason})')
else: else:
click.echo(f'- {instance["domain"]}') click.echo(f'- {row.domain}')
@cli_instance.command('ban') @cli_instance.command('ban')
@ -723,7 +724,7 @@ def cli_instance_ban(ctx: click.Context, domain: str, reason: str, note: str) ->
'Ban an instance and remove the associated inbox if it exists' 'Ban an instance and remove the associated inbox if it exists'
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
if conn.get_domain_ban(domain): if conn.get_domain_ban(domain) is not None:
click.echo(f'Domain already banned: {domain}') click.echo(f'Domain already banned: {domain}')
return return
@ -739,7 +740,7 @@ def cli_instance_unban(ctx: click.Context, domain: str) -> None:
'Unban an instance' 'Unban an instance'
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
if not conn.del_domain_ban(domain): if conn.del_domain_ban(domain) is None:
click.echo(f'Instance wasn\'t banned: {domain}') click.echo(f'Instance wasn\'t banned: {domain}')
return return
@ -764,11 +765,11 @@ def cli_instance_update(ctx: click.Context, domain: str, reason: str, note: str)
click.echo(f'Updated domain ban: {domain}') click.echo(f'Updated domain ban: {domain}')
if row['reason']: if row.reason:
click.echo(f'- {row["domain"]} ({row["reason"]})') click.echo(f'- {row.domain} ({row.reason})')
else: else:
click.echo(f'- {row["domain"]}') click.echo(f'- {row.domain}')
@cli.group('software') @cli.group('software')
@ -784,12 +785,12 @@ def cli_software_list(ctx: click.Context) -> None:
click.echo('Banned software:') click.echo('Banned software:')
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
for software in conn.execute('SELECT * FROM software_bans'): for row in conn.get_software_bans():
if software['reason']: if row.reason:
click.echo(f'- {software["name"]} ({software["reason"]})') click.echo(f'- {row.name} ({row.reason})')
else: else:
click.echo(f'- {software["name"]}') click.echo(f'- {row.name}')
@cli_software.command('ban') @cli_software.command('ban')
@ -811,12 +812,12 @@ def cli_software_ban(ctx: click.Context,
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
if name == 'RELAYS': if name == 'RELAYS':
for software in RELAY_SOFTWARE: for item in RELAY_SOFTWARE:
if conn.get_software_ban(software): if conn.get_software_ban(item):
click.echo(f'Relay already banned: {software}') click.echo(f'Relay already banned: {item}')
continue continue
conn.put_software_ban(software, reason or 'relay', note) conn.put_software_ban(item, reason or 'relay', note)
click.echo('Banned all relay software') click.echo('Banned all relay software')
return return
@ -893,11 +894,11 @@ def cli_software_update(ctx: click.Context, name: str, reason: str, note: str) -
click.echo(f'Updated software ban: {name}') click.echo(f'Updated software ban: {name}')
if row['reason']: if row.reason:
click.echo(f'- {row["name"]} ({row["reason"]})') click.echo(f'- {row.name} ({row.reason})')
else: else:
click.echo(f'- {row["name"]}') click.echo(f'- {row.name}')
@cli.group('whitelist') @cli.group('whitelist')
@ -913,8 +914,8 @@ def cli_whitelist_list(ctx: click.Context) -> None:
click.echo('Current whitelisted domains:') click.echo('Current whitelisted domains:')
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
for domain in conn.execute('SELECT * FROM whitelist'): for row in conn.get_domain_whitelist():
click.echo(f'- {domain["domain"]}') click.echo(f'- {row.domain}')
@cli_whitelist.command('add') @cli_whitelist.command('add')
@ -953,23 +954,19 @@ def cli_whitelist_remove(ctx: click.Context, domain: str) -> None:
@cli_whitelist.command('import') @cli_whitelist.command('import')
@click.pass_context @click.pass_context
def cli_whitelist_import(ctx: click.Context) -> None: def cli_whitelist_import(ctx: click.Context) -> None:
'Add all current inboxes to the whitelist' 'Add all current instances to the whitelist'
with ctx.obj.database.session() as conn: with ctx.obj.database.session() as conn:
for inbox in conn.execute('SELECT * FROM inboxes').all(): for row in conn.get_inboxes():
if conn.get_domain_whitelist(inbox['domain']): if conn.get_domain_whitelist(row.domain) is not None:
click.echo(f'Domain already in whitelist: {inbox["domain"]}') click.echo(f'Domain already in whitelist: {row.domain}')
continue continue
conn.put_domain_whitelist(inbox['domain']) conn.put_domain_whitelist(row.domain)
click.echo('Imported whitelist from inboxes') click.echo('Imported whitelist from inboxes')
def main() -> None: def main() -> None:
cli(prog_name='relay') cli(prog_name='activityrelay')
if __name__ == '__main__':
click.echo('Running relay.manage is depreciated. Run `activityrelay [command]` instead.')

View file

@ -34,7 +34,7 @@ async def handle_relay(view: ActorView, conn: Connection) -> None:
logging.debug('>> relay: %s', message) logging.debug('>> relay: %s', message)
for instance in conn.distill_inboxes(view.message): for instance in conn.distill_inboxes(view.message):
view.app.push_message(instance["inbox"], message, instance) view.app.push_message(instance.inbox, message, instance)
view.cache.set('handle-relay', view.message.object_id, message.id, 'str') view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
@ -52,7 +52,7 @@ async def handle_forward(view: ActorView, conn: Connection) -> None:
logging.debug('>> forward: %s', message) logging.debug('>> forward: %s', message)
for instance in conn.distill_inboxes(view.message): for instance in conn.distill_inboxes(view.message):
view.app.push_message(instance["inbox"], view.message, instance) view.app.push_message(instance.inbox, view.message, instance)
view.cache.set('handle-relay', view.message.id, message.id, 'str') view.cache.set('handle-relay', view.message.id, message.id, 'str')
@ -177,7 +177,7 @@ async def handle_undo(view: ActorView, conn: Connection) -> None:
return return
# prevent past unfollows from removing an instance # prevent past unfollows from removing an instance
if view.instance['followid'] and view.instance['followid'] != view.message.object_id: if view.instance.followid and view.instance.followid != view.message.object_id:
return return
with conn.transaction(): with conn.transaction():
@ -221,18 +221,18 @@ async def run_processor(view: ActorView) -> None:
with view.database.session() as conn: with view.database.session() as conn:
if view.instance: if view.instance:
if not view.instance['software']: if not view.instance.software:
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])): if (nodeinfo := await view.client.fetch_nodeinfo(view.instance.domain)):
with conn.transaction(): with conn.transaction():
view.instance = conn.put_inbox( view.instance = conn.put_inbox(
domain = view.instance['domain'], domain = view.instance.domain,
software = nodeinfo.sw_name software = nodeinfo.sw_name
) )
if not view.instance['actor']: if not view.instance.actor:
with conn.transaction(): with conn.transaction():
view.instance = conn.put_inbox( view.instance = conn.put_inbox(
domain = view.instance['domain'], domain = view.instance.domain,
actor = view.actor.id actor = view.actor.id
) )

View file

@ -1,26 +1,22 @@
from __future__ import annotations
import aputils import aputils
import traceback import traceback
import typing
from aiohttp.web import Request
from .base import View, register_route from .base import View, register_route
from .. import logger as logging from .. import logger as logging
from ..database import schema
from ..misc import Message, Response from ..misc import Message, Response
from ..processors import run_processor from ..processors import run_processor
if typing.TYPE_CHECKING:
from aiohttp.web import Request
from bsql import Row
@register_route('/actor', '/inbox') @register_route('/actor', '/inbox')
class ActorView(View): class ActorView(View):
signature: aputils.Signature signature: aputils.Signature
message: Message message: Message
actor: Message actor: Message
instancce: Row instance: schema.Instance
signer: aputils.Signer signer: aputils.Signer
@ -47,7 +43,7 @@ class ActorView(View):
return response return response
with self.database.session() as conn: with self.database.session() as conn:
self.instance = conn.get_inbox(self.actor.shared_inbox) self.instance = conn.get_inbox(self.actor.shared_inbox) # type: ignore[assignment]
# reject if actor is banned # reject if actor is banned
if conn.get_domain_ban(self.actor.domain): if conn.get_domain_ban(self.actor.domain):

View file

@ -90,10 +90,10 @@ class Login(View):
token = conn.put_token(data['username']) token = conn.put_token(data['username'])
resp = Response.new({'token': token['code']}, ctype = 'json') resp = Response.new({'token': token.code}, ctype = 'json')
resp.set_cookie( resp.set_cookie(
'user-token', 'user-token',
token['code'], token.code,
max_age = 60 * 60 * 24 * 365, max_age = 60 * 60 * 24 * 365,
domain = self.config.domain, domain = self.config.domain,
path = '/', path = '/',
@ -117,7 +117,7 @@ class RelayInfo(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
config = conn.get_config_all() config = conn.get_config_all()
inboxes = [row['domain'] for row in conn.get_inboxes()] inboxes = [row.domain for row in conn.get_inboxes()]
data = { data = {
'domain': self.config.domain, 'domain': self.config.domain,
@ -188,7 +188,7 @@ class Config(View):
class Inbox(View): class Inbox(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
data = conn.get_inboxes() data = tuple(conn.get_inboxes())
return Response.new(data, ctype = 'json') return Response.new(data, ctype = 'json')
@ -202,7 +202,7 @@ class Inbox(View):
data['domain'] = urlparse(data["actor"]).netloc data['domain'] = urlparse(data["actor"]).netloc
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_inbox(data['domain']): if conn.get_inbox(data['domain']) is not None:
return Response.new_error(404, 'Instance already in database', 'json') return Response.new_error(404, 'Instance already in database', 'json')
data['domain'] = data['domain'].encode('idna').decode() data['domain'] = data['domain'].encode('idna').decode()
@ -225,7 +225,12 @@ class Inbox(View):
except Exception: except Exception:
pass pass
row = conn.put_inbox(**data) # type: ignore[arg-type] row = conn.put_inbox(
data['domain'],
actor = data.get('actor'),
software = data.get('software'),
followid = data.get('followid')
)
return Response.new(row, ctype = 'json') return Response.new(row, ctype = 'json')
@ -239,10 +244,15 @@ class Inbox(View):
data['domain'] = data['domain'].encode('idna').decode() data['domain'] = data['domain'].encode('idna').decode()
if not (instance := conn.get_inbox(data['domain'])): if (instance := conn.get_inbox(data['domain'])) is None:
return Response.new_error(404, 'Instance with domain not found', 'json') return Response.new_error(404, 'Instance with domain not found', 'json')
instance = conn.put_inbox(instance['domain'], **data) # type: ignore[arg-type] instance = conn.put_inbox(
instance.domain,
actor = data.get('actor'),
software = data.get('software'),
followid = data.get('followid')
)
return Response.new(instance, ctype = 'json') return Response.new(instance, ctype = 'json')
@ -268,7 +278,7 @@ class Inbox(View):
class RequestView(View): class RequestView(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
instances = conn.get_requests() instances = tuple(conn.get_requests())
return Response.new(instances, ctype = 'json') return Response.new(instances, ctype = 'json')
@ -291,20 +301,20 @@ class RequestView(View):
message = Message.new_response( message = Message.new_response(
host = self.config.domain, host = self.config.domain,
actor = instance['actor'], actor = instance.actor,
followid = instance['followid'], followid = instance.followid,
accept = data['accept'] accept = data['accept']
) )
self.app.push_message(instance['inbox'], message, instance) self.app.push_message(instance.inbox, message, instance)
if data['accept'] and instance['software'] != 'mastodon': if data['accept'] and instance.software != 'mastodon':
message = Message.new_follow( message = Message.new_follow(
host = self.config.domain, host = self.config.domain,
actor = instance['actor'] actor = instance.actor
) )
self.app.push_message(instance['inbox'], message, instance) self.app.push_message(instance.inbox, message, instance)
resp_message = {'message': 'Request accepted' if data['accept'] else 'Request denied'} resp_message = {'message': 'Request accepted' if data['accept'] else 'Request denied'}
return Response.new(resp_message, ctype = 'json') return Response.new(resp_message, ctype = 'json')
@ -314,7 +324,7 @@ class RequestView(View):
class DomainBan(View): class DomainBan(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
bans = tuple(conn.execute('SELECT * FROM domain_bans').all()) bans = tuple(conn.get_domain_bans())
return Response.new(bans, ctype = 'json') return Response.new(bans, ctype = 'json')
@ -328,10 +338,14 @@ class DomainBan(View):
data['domain'] = data['domain'].encode('idna').decode() data['domain'] = data['domain'].encode('idna').decode()
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_domain_ban(data['domain']): if conn.get_domain_ban(data['domain']) is not None:
return Response.new_error(400, 'Domain already banned', 'json') return Response.new_error(400, 'Domain already banned', 'json')
ban = conn.put_domain_ban(**data) ban = conn.put_domain_ban(
domain = data['domain'],
reason = data.get('reason'),
note = data.get('note')
)
return Response.new(ban, ctype = 'json') return Response.new(ban, ctype = 'json')
@ -343,15 +357,19 @@ class DomainBan(View):
if isinstance(data, Response): if isinstance(data, Response):
return data return data
data['domain'] = data['domain'].encode('idna').decode()
if not conn.get_domain_ban(data['domain']):
return Response.new_error(404, 'Domain not banned', 'json')
if not any([data.get('note'), data.get('reason')]): if not any([data.get('note'), data.get('reason')]):
return Response.new_error(400, 'Must include note and/or reason parameters', 'json') return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
ban = conn.update_domain_ban(**data) data['domain'] = data['domain'].encode('idna').decode()
if conn.get_domain_ban(data['domain']) is None:
return Response.new_error(404, 'Domain not banned', 'json')
ban = conn.update_domain_ban(
domain = data['domain'],
reason = data.get('reason'),
note = data.get('note')
)
return Response.new(ban, ctype = 'json') return Response.new(ban, ctype = 'json')
@ -365,7 +383,7 @@ class DomainBan(View):
data['domain'] = data['domain'].encode('idna').decode() data['domain'] = data['domain'].encode('idna').decode()
if not conn.get_domain_ban(data['domain']): if conn.get_domain_ban(data['domain']) is None:
return Response.new_error(404, 'Domain not banned', 'json') return Response.new_error(404, 'Domain not banned', 'json')
conn.del_domain_ban(data['domain']) conn.del_domain_ban(data['domain'])
@ -377,7 +395,7 @@ class DomainBan(View):
class SoftwareBan(View): class SoftwareBan(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
bans = tuple(conn.execute('SELECT * FROM software_bans').all()) bans = tuple(conn.get_software_bans())
return Response.new(bans, ctype = 'json') return Response.new(bans, ctype = 'json')
@ -389,10 +407,14 @@ class SoftwareBan(View):
return data return data
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_software_ban(data['name']): if conn.get_software_ban(data['name']) is not None:
return Response.new_error(400, 'Domain already banned', 'json') return Response.new_error(400, 'Domain already banned', 'json')
ban = conn.put_software_ban(**data) ban = conn.put_software_ban(
name = data['name'],
reason = data.get('reason'),
note = data.get('note')
)
return Response.new(ban, ctype = 'json') return Response.new(ban, ctype = 'json')
@ -403,14 +425,18 @@ class SoftwareBan(View):
if isinstance(data, Response): if isinstance(data, Response):
return data return data
with self.database.session() as conn:
if not conn.get_software_ban(data['name']):
return Response.new_error(404, 'Software not banned', 'json')
if not any([data.get('note'), data.get('reason')]): if not any([data.get('note'), data.get('reason')]):
return Response.new_error(400, 'Must include note and/or reason parameters', 'json') return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
ban = conn.update_software_ban(**data) with self.database.session() as conn:
if conn.get_software_ban(data['name']) is None:
return Response.new_error(404, 'Software not banned', 'json')
ban = conn.update_software_ban(
name = data['name'],
reason = data.get('reason'),
note = data.get('note')
)
return Response.new(ban, ctype = 'json') return Response.new(ban, ctype = 'json')
@ -422,7 +448,7 @@ class SoftwareBan(View):
return data return data
with self.database.session() as conn: with self.database.session() as conn:
if not conn.get_software_ban(data['name']): if conn.get_software_ban(data['name']) is None:
return Response.new_error(404, 'Software not banned', 'json') return Response.new_error(404, 'Software not banned', 'json')
conn.del_software_ban(data['name']) conn.del_software_ban(data['name'])
@ -436,7 +462,7 @@ class User(View):
with self.database.session() as conn: with self.database.session() as conn:
items = [] items = []
for row in conn.execute('SELECT * FROM users'): for row in conn.get_users():
del row['hash'] del row['hash']
items.append(row) items.append(row)
@ -450,12 +476,16 @@ class User(View):
return data return data
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_user(data['username']): if conn.get_user(data['username']) is not None:
return Response.new_error(404, 'User already exists', 'json') return Response.new_error(404, 'User already exists', 'json')
user = conn.put_user(**data) user = conn.put_user(
del user['hash'] username = data['username'],
password = data['password'],
handle = data.get('handle')
)
del user['hash']
return Response.new(user, ctype = 'json') return Response.new(user, ctype = 'json')
@ -466,9 +496,13 @@ class User(View):
return data return data
with self.database.session(True) as conn: with self.database.session(True) as conn:
user = conn.put_user(**data) user = conn.put_user(
del user['hash'] username = data['username'],
password = data['password'],
handle = data.get('handle')
)
del user['hash']
return Response.new(user, ctype = 'json') return Response.new(user, ctype = 'json')
@ -479,7 +513,7 @@ class User(View):
return data return data
with self.database.session(True) as conn: with self.database.session(True) as conn:
if not conn.get_user(data['username']): if conn.get_user(data['username']) is None:
return Response.new_error(404, 'User does not exist', 'json') return Response.new_error(404, 'User does not exist', 'json')
conn.del_user(data['username']) conn.del_user(data['username'])
@ -491,7 +525,7 @@ class User(View):
class Whitelist(View): class Whitelist(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.session() as conn: with self.database.session() as conn:
items = tuple(conn.execute('SELECT * FROM whitelist').all()) items = tuple(conn.get_domains_whitelist())
return Response.new(items, ctype = 'json') return Response.new(items, ctype = 'json')
@ -502,13 +536,13 @@ class Whitelist(View):
if isinstance(data, Response): if isinstance(data, Response):
return data return data
data['domain'] = data['domain'].encode('idna').decode() domain = data['domain'].encode('idna').decode()
with self.database.session() as conn: with self.database.session() as conn:
if conn.get_domain_whitelist(data['domain']): if conn.get_domain_whitelist(domain) is not None:
return Response.new_error(400, 'Domain already added to whitelist', 'json') return Response.new_error(400, 'Domain already added to whitelist', 'json')
item = conn.put_domain_whitelist(**data) item = conn.put_domain_whitelist(domain)
return Response.new(item, ctype = 'json') return Response.new(item, ctype = 'json')
@ -519,12 +553,12 @@ class Whitelist(View):
if isinstance(data, Response): if isinstance(data, Response):
return data return data
data['domain'] = data['domain'].encode('idna').decode() domain = data['domain'].encode('idna').decode()
with self.database.session() as conn: with self.database.session() as conn:
if not conn.get_domain_whitelist(data['domain']): if conn.get_domain_whitelist(domain) is None:
return Response.new_error(404, 'Domain not in whitelist', 'json') return Response.new_error(404, 'Domain not in whitelist', 'json')
conn.del_domain_whitelist(data['domain']) conn.del_domain_whitelist(domain)
return Response.new({'message': 'Removed domain from whitelist'}, ctype = 'json') return Response.new({'message': 'Removed domain from whitelist'}, ctype = 'json')

View file

@ -202,7 +202,7 @@ class AdminConfig(View):
'message': message, 'message': message,
'desc': { 'desc': {
"name": "Name of the relay to be displayed in the header of the pages and in " + "name": "Name of the relay to be displayed in the header of the pages and in " +
"the actor endpoint.", "the actor endpoint.", # noqa: E131
"note": "Description of the relay to be displayed on the front page and as the " + "note": "Description of the relay to be displayed on the front page and as the " +
"bio in the actor endpoint.", "bio in the actor endpoint.",
"theme": "Color theme to use on the web pages.", "theme": "Color theme to use on the web pages.",

View file

@ -4,7 +4,6 @@ import typing
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
from asyncio.exceptions import TimeoutError as AsyncTimeoutError from asyncio.exceptions import TimeoutError as AsyncTimeoutError
from bsql import Row
from dataclasses import dataclass from dataclasses import dataclass
from multiprocessing import Event, Process, Queue, Value from multiprocessing import Event, Process, Queue, Value
from multiprocessing.synchronize import Event as EventType from multiprocessing.synchronize import Event as EventType
@ -13,6 +12,7 @@ from queue import Empty, Queue as QueueType
from urllib.parse import urlparse from urllib.parse import urlparse
from . import application, logger as logging from . import application, logger as logging
from .database.schema import Instance
from .http_client import HttpClient from .http_client import HttpClient
from .misc import IS_WINDOWS, Message, get_app from .misc import IS_WINDOWS, Message, get_app
@ -29,7 +29,7 @@ class QueueItem:
class PostItem(QueueItem): class PostItem(QueueItem):
inbox: str inbox: str
message: Message message: Message
instance: Row | None instance: Instance | None
@property @property
def domain(self) -> str: def domain(self) -> str:
@ -122,7 +122,7 @@ class PushWorkers(list[PushWorker]):
self.queue.put(item) self.queue.put(item)
def push_message(self, inbox: str, message: Message, instance: Row) -> None: def push_message(self, inbox: str, message: Message, instance: Instance) -> None:
self.queue.put(PostItem(inbox, message, instance)) self.queue.put(PostItem(inbox, message, instance))