replace tinysql with barkshark-sql

This commit is contained in:
Izalia Mae 2024-02-19 18:35:40 -05:00
parent 29732824ac
commit eff2f8b7fd
14 changed files with 141 additions and 139 deletions

View file

@ -190,7 +190,7 @@ class Application(web.Application):
self['proc'] = None self['proc'] = None
self.cache.close() self.cache.close()
self.database.close() self.database.disconnect()
class CacheCleanupThread(Thread): class CacheCleanupThread(Thread):
@ -244,7 +244,7 @@ async def handle_access_log(request: web.Request, response: web.Response) -> Non
async def handle_cleanup(app: Application) -> None: async def handle_cleanup(app: Application) -> None:
await app.client.close() await app.client.close()
app.cache.close() app.cache.close()
app.database.close() app.database.disconnect()
async def main_gunicorn(): async def main_gunicorn():

View file

@ -168,8 +168,8 @@ class SqlCache(Cache):
'key': key 'key': key
} }
with self._db.connection() as conn: with self._db.session(False) as conn:
with conn.exec_statement('get-cache-item', params) as cur: with conn.run('get-cache-item', params) as cur:
if not (row := cur.one()): if not (row := cur.one()):
raise KeyError(f'{namespace}:{key}') raise KeyError(f'{namespace}:{key}')
@ -178,14 +178,14 @@ class SqlCache(Cache):
def get_keys(self, namespace: str) -> Iterator[str]: def get_keys(self, namespace: str) -> Iterator[str]:
with self._db.connection() as conn: with self._db.session(False) as conn:
for row in conn.exec_statement('get-cache-keys', {'namespace': namespace}): for row in conn.run('get-cache-keys', {'namespace': namespace}):
yield row['key'] yield row['key']
def get_namespaces(self) -> Iterator[str]: def get_namespaces(self) -> Iterator[str]:
with self._db.connection() as conn: with self._db.session(False) as conn:
for row in conn.exec_statement('get-cache-namespaces', None): for row in conn.run('get-cache-namespaces', None):
yield row['namespace'] yield row['namespace']
@ -198,8 +198,8 @@ class SqlCache(Cache):
'date': datetime.now(tz = timezone.utc) 'date': datetime.now(tz = timezone.utc)
} }
with self._db.connection() as conn: with self._db.session(True) as conn:
with conn.exec_statement('set-cache-item', params) as conn: with conn.run('set-cache-item', params) as conn:
row = conn.one() row = conn.one()
row.pop('id', None) row.pop('id', None)
return Item.from_data(*tuple(row.values())) return Item.from_data(*tuple(row.values()))
@ -211,8 +211,8 @@ class SqlCache(Cache):
'key': key 'key': key
} }
with self._db.connection() as conn: with self._db.session(True) as conn:
with conn.exec_statement('del-cache-item', params): with conn.run('del-cache-item', params):
pass pass
@ -220,25 +220,27 @@ class SqlCache(Cache):
limit = datetime.now(tz = timezone.utc) - timedelta(days = days) limit = datetime.now(tz = timezone.utc) - timedelta(days = days)
params = {"limit": limit.timestamp()} params = {"limit": limit.timestamp()}
with self._db.connection() as conn: with self._db.session(True) as conn:
with conn.execute("DELETE FROM cache WHERE updated < :limit", params): with conn.execute("DELETE FROM cache WHERE updated < :limit", params):
pass pass
def clear(self) -> None: def clear(self) -> None:
with self._db.connection() as conn: with self._db.session(True) as conn:
with conn.execute("DELETE FROM cache"): with conn.execute("DELETE FROM cache"):
pass pass
def setup(self) -> None: def setup(self) -> None:
with self._db.connection() as conn: self._db.connect()
with conn.exec_statement(f'create-cache-table-{self._db.type.name.lower()}', None):
with self._db.session(True) as conn:
with conn.run(f'create-cache-table-{self._db.backend_type.value}', None):
pass pass
def close(self) -> None: def close(self) -> None:
self._db.close() self._db.disconnect()
self._db = None self._db = None

View file

@ -137,7 +137,7 @@ CREATE TABLE IF NOT EXISTS cache (
UNIQUE(namespace, key) UNIQUE(namespace, key)
); );
-- name: create-cache-table-postgres -- name: create-cache-table-postgresql
CREATE TABLE IF NOT EXISTS cache ( CREATE TABLE IF NOT EXISTS cache (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
namespace TEXT NOT NULL, namespace TEXT NOT NULL,

View file

@ -1,49 +1,51 @@
from __future__ import annotations from __future__ import annotations
import tinysql import bsql
import typing import typing
from .config import get_default_value from .config import get_default_value
from .connection import Connection from .connection import RELAY_SOFTWARE, Connection
from .schema import VERSIONS, migrate_0 from .schema import TABLES, VERSIONS, migrate_0
from .. import logger as logging from .. import logger as logging
try: try:
from importlib.resources import files as pkgfiles from importlib.resources import files as pkgfiles
except ImportError: except ImportError: # pylint: disable=duplicate-code
from importlib_resources import files as pkgfiles from importlib_resources import files as pkgfiles
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from .config import Config from .config import Config
def get_database(config: Config, migrate: bool = True) -> tinysql.Database: def get_database(config: Config, migrate: bool = True) -> bsql.Database:
options = {
"connection_class": Connection,
"pool_size": 5,
"tables": TABLES
}
if config.db_type == "sqlite": if config.db_type == "sqlite":
db = tinysql.Database.sqlite( db = bsql.Database.sqlite(config.sqlite_path, **options)
config.sqlite_path,
connection_class = Connection,
min_connections = 2,
max_connections = 10
)
elif config.db_type == "postgres": elif config.db_type == "postgres":
db = tinysql.Database.postgres( db = bsql.Database.postgresql(
config.pg_name, config.pg_name,
config.pg_host, config.pg_host,
config.pg_port, config.pg_port,
config.pg_user, config.pg_user,
config.pg_pass, config.pg_pass,
connection_class = Connection **options
) )
db.load_prepared_statements(pkgfiles("relay").joinpath("data", "statements.sql")) db.load_prepared_statements(pkgfiles("relay").joinpath("data", "statements.sql"))
db.connect()
if not migrate: if not migrate:
return db return db
with db.connection() as conn: with db.session(True) as conn:
if 'config' not in conn.get_tables(): if 'config' not in conn.get_tables():
logging.info("Creating database tables") logging.info("Creating database tables")
migrate_0(conn) migrate_0(conn)

View file

@ -1,9 +1,9 @@
from __future__ import annotations from __future__ import annotations
import tinysql
import typing import typing
from argon2 import PasswordHasher from argon2 import PasswordHasher
from bsql import Connection as SqlConnection, Update
from datetime import datetime, timezone from datetime import datetime, timezone
from urllib.parse import urlparse from urllib.parse import urlparse
from uuid import uuid4 from uuid import uuid4
@ -15,7 +15,7 @@ from ..misc import get_app
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from collections.abc import Iterator from collections.abc import Iterator
from tinysql import Cursor, Row from bsql import Row
from typing import Any from typing import Any
from .application import Application from .application import Application
from ..misc import Message from ..misc import Message
@ -29,7 +29,7 @@ RELAY_SOFTWARE = [
] ]
class Connection(tinysql.Connection): class Connection(SqlConnection):
hasher = PasswordHasher( hasher = PasswordHasher(
encoding = 'utf-8' encoding = 'utf-8'
) )
@ -50,15 +50,11 @@ class Connection(tinysql.Connection):
yield inbox['inbox'] yield inbox['inbox']
def exec_statement(self, name: str, params: dict[str, Any] | None = None) -> Cursor:
return self.execute(self.database.prepared_statements[name], params)
def get_config(self, key: str) -> Any: def get_config(self, key: str) -> Any:
if key not in CONFIG_DEFAULTS: if key not in CONFIG_DEFAULTS:
raise KeyError(key) raise KeyError(key)
with self.exec_statement('get-config', {'key': key}) as cur: with self.run('get-config', {'key': key}) as cur:
if not (row := cur.one()): if not (row := cur.one()):
return get_default_value(key) return get_default_value(key)
@ -69,7 +65,7 @@ class Connection(tinysql.Connection):
def get_config_all(self) -> dict[str, Any]: def get_config_all(self) -> dict[str, Any]:
with self.exec_statement('get-config-all') as cur: with self.run('get-config-all', None) as cur:
db_config = {row['key']: row['value'] for row in cur} db_config = {row['key']: row['value'] for row in cur}
config = {} config = {}
@ -105,12 +101,12 @@ class Connection(tinysql.Connection):
'type': get_default_type(key) 'type': get_default_type(key)
} }
with self.exec_statement('put-config', params): with self.run('put-config', params):
return value return value
def get_inbox(self, value: str) -> Row: def get_inbox(self, value: str) -> Row:
with self.exec_statement('get-inbox', {'value': value}) as cur: with self.run('get-inbox', {'value': value}) as cur:
return cur.one() return cur.one()
@ -130,7 +126,7 @@ class Connection(tinysql.Connection):
'created': datetime.now(tz = timezone.utc) 'created': datetime.now(tz = timezone.utc)
} }
with self.exec_statement('put-inbox', params) as cur: with self.run('put-inbox', params) as cur:
return cur.one() return cur.one()
@ -154,27 +150,28 @@ class Connection(tinysql.Connection):
if software: if software:
data['software'] = software data['software'] = software
statement = tinysql.Update('inboxes', data, inbox = inbox) statement = Update('inboxes', data)
statement.set_where("inbox", inbox)
with self.query(statement): with self.query(statement):
return self.get_inbox(inbox) return self.get_inbox(inbox)
def del_inbox(self, value: str) -> bool: def del_inbox(self, value: str) -> bool:
with self.exec_statement('del-inbox', {'value': value}) as cur: with self.run('del-inbox', {'value': value}) as cur:
if cur.modified_row_count > 1: if cur.row_count > 1:
raise ValueError('More than one row was modified') raise ValueError('More than one row was modified')
return cur.modified_row_count == 1 return cur.row_count == 1
def get_user(self, value: str) -> Row: def get_user(self, value: str) -> Row:
with self.exec_statement('get-user', {'value': value}) as cur: with self.run('get-user', {'value': value}) as cur:
return cur.one() return cur.one()
def get_user_by_token(self, code: str) -> Row: def get_user_by_token(self, code: str) -> Row:
with self.exec_statement('get-user-by-token', {'code': code}) as cur: with self.run('get-user-by-token', {'code': code}) as cur:
return cur.one() return cur.one()
@ -186,22 +183,22 @@ class Connection(tinysql.Connection):
'created': datetime.now(tz = timezone.utc) 'created': datetime.now(tz = timezone.utc)
} }
with self.exec_statement('put-user', data) as cur: with self.run('put-user', data) as cur:
return cur.one() return cur.one()
def del_user(self, username: str) -> None: def del_user(self, username: str) -> None:
user = self.get_user(username) user = self.get_user(username)
with self.exec_statement('del-user', {'value': user['username']}): with self.run('del-user', {'value': user['username']}):
pass pass
with self.exec_statement('del-token-user', {'username': user['username']}): with self.run('del-token-user', {'username': user['username']}):
pass pass
def get_token(self, code: str) -> Row: def get_token(self, code: str) -> Row:
with self.exec_statement('get-token', {'code': code}) as cur: with self.run('get-token', {'code': code}) as cur:
return cur.one() return cur.one()
@ -212,12 +209,12 @@ class Connection(tinysql.Connection):
'created': datetime.now(tz = timezone.utc) 'created': datetime.now(tz = timezone.utc)
} }
with self.exec_statement('put-token', data) as cur: with self.run('put-token', data) as cur:
return cur.one() return cur.one()
def del_token(self, code: str) -> None: def del_token(self, code: str) -> None:
with self.exec_statement('del-token', {'code': code}): with self.run('del-token', {'code': code}):
pass pass
@ -225,7 +222,7 @@ class Connection(tinysql.Connection):
if domain.startswith('http'): if domain.startswith('http'):
domain = urlparse(domain).netloc domain = urlparse(domain).netloc
with self.exec_statement('get-domain-ban', {'domain': domain}) as cur: with self.run('get-domain-ban', {'domain': domain}) as cur:
return cur.one() return cur.one()
@ -241,14 +238,14 @@ class Connection(tinysql.Connection):
'created': datetime.now(tz = timezone.utc) 'created': datetime.now(tz = timezone.utc)
} }
with self.exec_statement('put-domain-ban', params) as cur: with self.run('put-domain-ban', params) as cur:
return cur.one() return cur.one()
def update_domain_ban(self, def update_domain_ban(self,
domain: str, domain: str,
reason: str | None = None, reason: str | None = None,
note: str | None = None) -> tinysql.Row: note: str | None = None) -> Row:
if not (reason or note): if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified') raise ValueError('"reason" and/or "note" must be specified')
@ -261,25 +258,26 @@ class Connection(tinysql.Connection):
if note: if note:
params['note'] = note params['note'] = note
statement = tinysql.Update('domain_bans', params, domain = domain) statement = Update('domain_bans', params)
statement.set_where("domain", domain)
with self.query(statement) as cur: with self.query(statement) as cur:
if cur.modified_row_count > 1: if cur.row_count > 1:
raise ValueError('More than one row was modified') raise ValueError('More than one row was modified')
return self.get_domain_ban(domain) return self.get_domain_ban(domain)
def del_domain_ban(self, domain: str) -> bool: def del_domain_ban(self, domain: str) -> bool:
with self.exec_statement('del-domain-ban', {'domain': domain}) as cur: with self.run('del-domain-ban', {'domain': domain}) as cur:
if cur.modified_row_count > 1: if cur.row_count > 1:
raise ValueError('More than one row was modified') raise ValueError('More than one row was modified')
return cur.modified_row_count == 1 return cur.row_count == 1
def get_software_ban(self, name: str) -> Row: def get_software_ban(self, name: str) -> Row:
with self.exec_statement('get-software-ban', {'name': name}) as cur: with self.run('get-software-ban', {'name': name}) as cur:
return cur.one() return cur.one()
@ -295,14 +293,14 @@ class Connection(tinysql.Connection):
'created': datetime.now(tz = timezone.utc) 'created': datetime.now(tz = timezone.utc)
} }
with self.exec_statement('put-software-ban', params) as cur: with self.run('put-software-ban', params) as cur:
return cur.one() return cur.one()
def update_software_ban(self, def update_software_ban(self,
name: str, name: str,
reason: str | None = None, reason: str | None = None,
note: str | None = None) -> tinysql.Row: note: str | None = None) -> Row:
if not (reason or note): if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified') raise ValueError('"reason" and/or "note" must be specified')
@ -315,25 +313,26 @@ class Connection(tinysql.Connection):
if note: if note:
params['note'] = note params['note'] = note
statement = tinysql.Update('software_bans', params, name = name) statement = Update('software_bans', params)
statement.set_where("name", name)
with self.query(statement) as cur: with self.query(statement) as cur:
if cur.modified_row_count > 1: if cur.row_count > 1:
raise ValueError('More than one row was modified') raise ValueError('More than one row was modified')
return self.get_software_ban(name) return self.get_software_ban(name)
def del_software_ban(self, name: str) -> bool: def del_software_ban(self, name: str) -> bool:
with self.exec_statement('del-software-ban', {'name': name}) as cur: with self.run('del-software-ban', {'name': name}) as cur:
if cur.modified_row_count > 1: if cur.row_count > 1:
raise ValueError('More than one row was modified') raise ValueError('More than one row was modified')
return cur.modified_row_count == 1 return cur.row_count == 1
def get_domain_whitelist(self, domain: str) -> Row: def get_domain_whitelist(self, domain: str) -> Row:
with self.exec_statement('get-domain-whitelist', {'domain': domain}) as cur: with self.run('get-domain-whitelist', {'domain': domain}) as cur:
return cur.one() return cur.one()
@ -343,13 +342,13 @@ class Connection(tinysql.Connection):
'created': datetime.now(tz = timezone.utc) 'created': datetime.now(tz = timezone.utc)
} }
with self.exec_statement('put-domain-whitelist', params) as cur: with self.run('put-domain-whitelist', params) as cur:
return cur.one() return cur.one()
def del_domain_whitelist(self, domain: str) -> bool: def del_domain_whitelist(self, domain: str) -> bool:
with self.exec_statement('del-domain-whitelist', {'domain': domain}) as cur: with self.run('del-domain-whitelist', {'domain': domain}) as cur:
if cur.modified_row_count > 1: if cur.row_count > 1:
raise ValueError('More than one row was modified') raise ValueError('More than one row was modified')
return cur.modified_row_count == 1 return cur.row_count == 1

View file

@ -2,7 +2,7 @@ from __future__ import annotations
import typing import typing
from tinysql import Column, Connection, Table from bsql import Column, Connection, Table, Tables
from .config import get_default_value from .config import get_default_value
@ -11,7 +11,7 @@ if typing.TYPE_CHECKING:
VERSIONS: dict[int, Callable] = {} VERSIONS: dict[int, Callable] = {}
TABLES: list[Table] = [ TABLES: Tables = Tables(
Table( Table(
'config', 'config',
Column('key', 'text', primary_key = True, unique = True, nullable = False), Column('key', 'text', primary_key = True, unique = True, nullable = False),
@ -59,7 +59,7 @@ TABLES: list[Table] = [
Column('user', 'text', nullable = False), Column('user', 'text', nullable = False),
Column('created', 'timestmap', nullable = False) Column('created', 'timestmap', nullable = False)
) )
] )
def migration(func: Callable) -> Callable: def migration(func: Callable) -> Callable:
@ -69,10 +69,10 @@ def migration(func: Callable) -> Callable:
def migrate_0(conn: Connection) -> None: def migrate_0(conn: Connection) -> None:
conn.create_tables(TABLES) conn.create_tables()
conn.put_config('schema-version', get_default_value('schema-version')) conn.put_config('schema-version', get_default_value('schema-version'))
@migration @migration
def migrate_20240206(conn: Connection) -> None: def migrate_20240206(conn: Connection) -> None:
conn.create_tables(TABLES) conn.create_tables()

View file

@ -19,8 +19,7 @@ from . import http_client as http
from . import logger as logging from . import logger as logging
from .application import Application from .application import Application
from .compat import RelayConfig, RelayDatabase from .compat import RelayConfig, RelayDatabase
from .database import get_database from .database import RELAY_SOFTWARE, get_database
from .database.connection import RELAY_SOFTWARE
from .misc import IS_DOCKER, Message from .misc import IS_DOCKER, Message
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
@ -199,7 +198,7 @@ def cli_setup(ctx: click.Context) -> None:
'private-key': Signer.new('n/a').export() 'private-key': Signer.new('n/a').export()
} }
with ctx.obj.database.connection() as conn: with ctx.obj.database.session() as conn:
for key, value in config.items(): for key, value in config.items():
conn.put_config(key, value) conn.put_config(key, value)
@ -279,7 +278,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
ctx.obj.config.save() ctx.obj.config.save()
with get_database(ctx.obj.config) as db: with get_database(ctx.obj.config) as db:
with db.connection() as conn: with db.session(True) as conn:
conn.put_config('private-key', database['private-key']) conn.put_config('private-key', database['private-key'])
conn.put_config('note', config['note']) conn.put_config('note', config['note'])
conn.put_config('whitelist-enabled', config['whitelist_enabled']) conn.put_config('whitelist-enabled', config['whitelist_enabled'])
@ -365,7 +364,7 @@ def cli_config_list(ctx: click.Context) -> None:
click.echo('Relay Config:') click.echo('Relay Config:')
with ctx.obj.database.connection() as conn: with ctx.obj.database.session() as conn:
for key, value in conn.get_config_all().items(): for key, value in conn.get_config_all().items():
if key not in CONFIG_IGNORE: if key not in CONFIG_IGNORE:
key = f'{key}:'.ljust(20) key = f'{key}:'.ljust(20)
@ -379,7 +378,7 @@ def cli_config_list(ctx: click.Context) -> None:
def cli_config_set(ctx: click.Context, key: str, value: Any) -> None: def cli_config_set(ctx: click.Context, key: str, value: Any) -> None:
'Set a config value' 'Set a config value'
with ctx.obj.database.connection() as conn: with ctx.obj.database.session() as conn:
new_value = conn.put_config(key, value) new_value = conn.put_config(key, value)
print(f'{key}: {repr(new_value)}') print(f'{key}: {repr(new_value)}')
@ -397,7 +396,7 @@ def cli_user_list(ctx: click.Context) -> None:
click.echo('Users:') click.echo('Users:')
with ctx.obj.database.connection() as conn: with ctx.obj.database.session() as conn:
for user in conn.execute('SELECT * FROM users'): for user in conn.execute('SELECT * FROM users'):
click.echo(f'- {user["username"]}') click.echo(f'- {user["username"]}')
@ -409,7 +408,7 @@ def cli_user_list(ctx: click.Context) -> None:
def cli_user_create(ctx: click.Context, username: str, handle: str) -> None: def cli_user_create(ctx: click.Context, username: str, handle: str) -> None:
'Create a new local user' 'Create a new local user'
with ctx.obj.database.connection() as conn: with ctx.obj.database.session() as conn:
if conn.get_user(username): if conn.get_user(username):
click.echo(f'User already exists: {username}') click.echo(f'User already exists: {username}')
return return
@ -436,7 +435,7 @@ def cli_user_create(ctx: click.Context, username: str, handle: str) -> None:
def cli_user_delete(ctx: click.Context, username: str) -> None: def cli_user_delete(ctx: click.Context, username: str) -> None:
'Delete a local user' 'Delete a local user'
with ctx.obj.database.connection() as conn: with ctx.obj.database.session() as conn:
if not conn.get_user(username): if not conn.get_user(username):
click.echo(f'User does not exist: {username}') click.echo(f'User does not exist: {username}')
return return
@ -454,7 +453,7 @@ def cli_user_list_tokens(ctx: click.Context, username: str) -> None:
click.echo(f'Tokens for "{username}":') click.echo(f'Tokens for "{username}":')
with ctx.obj.database.connection() as conn: with ctx.obj.database.session() as conn:
for token in conn.execute('SELECT * FROM tokens WHERE user = :user', {'user': username}): for token in conn.execute('SELECT * FROM tokens WHERE user = :user', {'user': username}):
click.echo(f'- {token["code"]}') click.echo(f'- {token["code"]}')
@ -465,7 +464,7 @@ def cli_user_list_tokens(ctx: click.Context, username: str) -> None:
def cli_user_create_token(ctx: click.Context, username: str) -> None: def cli_user_create_token(ctx: click.Context, username: str) -> None:
'Create a new API token for a user' 'Create a new API token for a user'
with ctx.obj.database.connection() as conn: with ctx.obj.database.session() as conn:
if not (user := conn.get_user(username)): if not (user := conn.get_user(username)):
click.echo(f'User does not exist: {username}') click.echo(f'User does not exist: {username}')
return return
@ -481,7 +480,7 @@ def cli_user_create_token(ctx: click.Context, username: str) -> None:
def cli_user_delete_token(ctx: click.Context, code: str) -> None: def cli_user_delete_token(ctx: click.Context, code: str) -> None:
'Delete an API token' 'Delete an API token'
with ctx.obj.database.connection() as conn: with ctx.obj.database.session() as conn:
if not conn.get_token(code): if not conn.get_token(code):
click.echo('Token does not exist') click.echo('Token does not exist')
return return
@ -503,7 +502,7 @@ def cli_inbox_list(ctx: click.Context) -> None:
click.echo('Connected to the following instances or relays:') click.echo('Connected to the following instances or relays:')
with ctx.obj.database.connection() as conn: with ctx.obj.database.session() as conn:
for inbox in conn.execute('SELECT * FROM inboxes'): for inbox in conn.execute('SELECT * FROM inboxes'):
click.echo(f'- {inbox["inbox"]}') click.echo(f'- {inbox["inbox"]}')
@ -514,7 +513,7 @@ def cli_inbox_list(ctx: click.Context) -> None:
def cli_inbox_follow(ctx: click.Context, actor: str) -> None: def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
'Follow an actor (Relay must be running)' 'Follow an actor (Relay must be running)'
with ctx.obj.database.connection() as conn: with ctx.obj.database.session() as conn:
if conn.get_domain_ban(actor): if conn.get_domain_ban(actor):
click.echo(f'Error: Refusing to follow banned actor: {actor}') click.echo(f'Error: Refusing to follow banned actor: {actor}')
return return
@ -549,7 +548,7 @@ def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
inbox_data: Row = None inbox_data: Row = None
with ctx.obj.database.connection() as conn: with ctx.obj.database.session() as conn:
if conn.get_domain_ban(actor): if conn.get_domain_ban(actor):
click.echo(f'Error: Refusing to follow banned actor: {actor}') click.echo(f'Error: Refusing to follow banned actor: {actor}')
return return
@ -617,7 +616,7 @@ def cli_inbox_add(
except KeyError: except KeyError:
pass pass
with ctx.obj.database.connection() as conn: with ctx.obj.database.session() as conn:
if conn.get_domain_ban(domain): if conn.get_domain_ban(domain):
click.echo(f'Refusing to add banned inbox: {inbox}') click.echo(f'Refusing to add banned inbox: {inbox}')
return return
@ -637,7 +636,7 @@ def cli_inbox_add(
def cli_inbox_remove(ctx: click.Context, inbox: str) -> None: def cli_inbox_remove(ctx: click.Context, inbox: str) -> None:
'Remove an inbox from the database' 'Remove an inbox from the database'
with ctx.obj.database.connection() as conn: with ctx.obj.database.session() as conn:
if not conn.del_inbox(inbox): if not conn.del_inbox(inbox):
click.echo(f'Inbox not in database: {inbox}') click.echo(f'Inbox not in database: {inbox}')
return return
@ -657,7 +656,7 @@ def cli_instance_list(ctx: click.Context) -> None:
click.echo('Banned domains:') click.echo('Banned domains:')
with ctx.obj.database.connection() as conn: with ctx.obj.database.session() as conn:
for instance in conn.execute('SELECT * FROM domain_bans'): for instance in conn.execute('SELECT * FROM domain_bans'):
if instance['reason']: if instance['reason']:
click.echo(f'- {instance["domain"]} ({instance["reason"]})') click.echo(f'- {instance["domain"]} ({instance["reason"]})')
@ -674,7 +673,7 @@ def cli_instance_list(ctx: click.Context) -> None:
def cli_instance_ban(ctx: click.Context, domain: str, reason: str, note: str) -> None: def cli_instance_ban(ctx: click.Context, domain: str, reason: str, note: str) -> None:
'Ban an instance and remove the associated inbox if it exists' 'Ban an instance and remove the associated inbox if it exists'
with ctx.obj.database.connection() as conn: with ctx.obj.database.session() as conn:
if conn.get_domain_ban(domain): if conn.get_domain_ban(domain):
click.echo(f'Domain already banned: {domain}') click.echo(f'Domain already banned: {domain}')
return return
@ -690,7 +689,7 @@ def cli_instance_ban(ctx: click.Context, domain: str, reason: str, note: str) ->
def cli_instance_unban(ctx: click.Context, domain: str) -> None: def cli_instance_unban(ctx: click.Context, domain: str) -> None:
'Unban an instance' 'Unban an instance'
with ctx.obj.database.connection() as conn: with ctx.obj.database.session() as conn:
if not conn.del_domain_ban(domain): if not conn.del_domain_ban(domain):
click.echo(f'Instance wasn\'t banned: {domain}') click.echo(f'Instance wasn\'t banned: {domain}')
return return
@ -709,7 +708,7 @@ def cli_instance_update(ctx: click.Context, domain: str, reason: str, note: str)
if not (reason or note): if not (reason or note):
ctx.fail('Must pass --reason or --note') ctx.fail('Must pass --reason or --note')
with ctx.obj.database.connection() as conn: with ctx.obj.database.session() as conn:
if not (row := conn.update_domain_ban(domain, reason, note)): if not (row := conn.update_domain_ban(domain, reason, note)):
click.echo(f'Failed to update domain ban: {domain}') click.echo(f'Failed to update domain ban: {domain}')
return return
@ -735,7 +734,7 @@ def cli_software_list(ctx: click.Context) -> None:
click.echo('Banned software:') click.echo('Banned software:')
with ctx.obj.database.connection() as conn: with ctx.obj.database.session() as conn:
for software in conn.execute('SELECT * FROM software_bans'): for software in conn.execute('SELECT * FROM software_bans'):
if software['reason']: if software['reason']:
click.echo(f'- {software["name"]} ({software["reason"]})') click.echo(f'- {software["name"]} ({software["reason"]})')
@ -761,7 +760,7 @@ def cli_software_ban(ctx: click.Context,
fetch_nodeinfo: bool) -> None: fetch_nodeinfo: bool) -> None:
'Ban software. Use RELAYS for NAME to ban relays' 'Ban software. Use RELAYS for NAME to ban relays'
with ctx.obj.database.connection() as conn: with ctx.obj.database.session() as conn:
if name == 'RELAYS': if name == 'RELAYS':
for software in RELAY_SOFTWARE: for software in RELAY_SOFTWARE:
if conn.get_software_ban(software): if conn.get_software_ban(software):
@ -804,7 +803,7 @@ def cli_software_ban(ctx: click.Context,
def cli_software_unban(ctx: click.Context, name: str, fetch_nodeinfo: bool) -> None: def cli_software_unban(ctx: click.Context, name: str, fetch_nodeinfo: bool) -> None:
'Ban software. Use RELAYS for NAME to unban relays' 'Ban software. Use RELAYS for NAME to unban relays'
with ctx.obj.database.connection() as conn: with ctx.obj.database.session() as conn:
if name == 'RELAYS': if name == 'RELAYS':
for software in RELAY_SOFTWARE: for software in RELAY_SOFTWARE:
if not conn.del_software_ban(software): if not conn.del_software_ban(software):
@ -838,7 +837,7 @@ def cli_software_update(ctx: click.Context, name: str, reason: str, note: str) -
if not (reason or note): if not (reason or note):
ctx.fail('Must pass --reason or --note') ctx.fail('Must pass --reason or --note')
with ctx.obj.database.connection() as conn: with ctx.obj.database.session() as conn:
if not (row := conn.update_software_ban(name, reason, note)): if not (row := conn.update_software_ban(name, reason, note)):
click.echo(f'Failed to update software ban: {name}') click.echo(f'Failed to update software ban: {name}')
return return
@ -864,7 +863,7 @@ def cli_whitelist_list(ctx: click.Context) -> None:
click.echo('Current whitelisted domains:') click.echo('Current whitelisted domains:')
with ctx.obj.database.connection() as conn: with ctx.obj.database.session() as conn:
for domain in conn.execute('SELECT * FROM whitelist'): for domain in conn.execute('SELECT * FROM whitelist'):
click.echo(f'- {domain["domain"]}') click.echo(f'- {domain["domain"]}')
@ -875,7 +874,7 @@ def cli_whitelist_list(ctx: click.Context) -> None:
def cli_whitelist_add(ctx: click.Context, domain: str) -> None: def cli_whitelist_add(ctx: click.Context, domain: str) -> None:
'Add a domain to the whitelist' 'Add a domain to the whitelist'
with ctx.obj.database.connection() as conn: with ctx.obj.database.session() as conn:
if conn.get_domain_whitelist(domain): if conn.get_domain_whitelist(domain):
click.echo(f'Instance already in the whitelist: {domain}') click.echo(f'Instance already in the whitelist: {domain}')
return return
@ -890,7 +889,7 @@ def cli_whitelist_add(ctx: click.Context, domain: str) -> None:
def cli_whitelist_remove(ctx: click.Context, domain: str) -> None: def cli_whitelist_remove(ctx: click.Context, domain: str) -> None:
'Remove an instance from the whitelist' 'Remove an instance from the whitelist'
with ctx.obj.database.connection() as conn: with ctx.obj.database.session() as conn:
if not conn.del_domain_whitelist(domain): if not conn.del_domain_whitelist(domain):
click.echo(f'Domain not in the whitelist: {domain}') click.echo(f'Domain not in the whitelist: {domain}')
return return
@ -907,7 +906,7 @@ def cli_whitelist_remove(ctx: click.Context, domain: str) -> None:
def cli_whitelist_import(ctx: click.Context) -> None: def cli_whitelist_import(ctx: click.Context) -> None:
'Add all current inboxes to the whitelist' 'Add all current inboxes to the whitelist'
with ctx.obj.database.connection() as conn: with ctx.obj.database.session() as conn:
for inbox in conn.execute('SELECT * FROM inboxes').all(): for inbox in conn.execute('SELECT * FROM inboxes').all():
if conn.get_domain_whitelist(inbox['domain']): if conn.get_domain_whitelist(inbox['domain']):
click.echo(f'Domain already in whitelist: {inbox["domain"]}') click.echo(f'Domain already in whitelist: {inbox["domain"]}')

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import typing import typing
from . import logger as logging from . import logger as logging
from .database.connection import Connection from .database import Connection
from .misc import Message from .misc import Message
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
@ -184,7 +184,7 @@ async def run_processor(view: ActorView) -> None:
return return
with view.database.connection(False) as conn: with view.database.session() as conn:
if view.instance: if view.instance:
if not view.instance['software']: if not view.instance['software']:
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])): if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])):

View file

@ -46,7 +46,7 @@ class ActorView(View):
if response := await self.get_post_data(): if response := await self.get_post_data():
return response return response
with self.database.connection(False) as conn: with self.database.session() as conn:
self.instance = conn.get_inbox(self.actor.shared_inbox) self.instance = conn.get_inbox(self.actor.shared_inbox)
config = conn.get_config_all() config = conn.get_config_all()

View file

@ -45,7 +45,7 @@ async def handle_api_path(request: web.Request, handler: Coroutine) -> web.Respo
try: try:
request['token'] = request.headers['Authorization'].replace('Bearer', '').strip() request['token'] = request.headers['Authorization'].replace('Bearer', '').strip()
with request.app.database.connection() as conn: with request.app.database.session() as conn:
request['user'] = conn.get_user_by_token(request['token']) request['user'] = conn.get_user_by_token(request['token'])
except (KeyError, ValueError): except (KeyError, ValueError):
@ -92,7 +92,7 @@ class Login(View):
async def delete(self, request: Request) -> Response: async def delete(self, request: Request) -> Response:
with self.database.connection(True) as conn: with self.database.session() as conn:
conn.del_token(request['token']) conn.del_token(request['token'])
return Response.new({'message': 'Token revoked'}, ctype = 'json') return Response.new({'message': 'Token revoked'}, ctype = 'json')
@ -101,7 +101,7 @@ class Login(View):
@register_route('/api/v1/relay') @register_route('/api/v1/relay')
class RelayInfo(View): class RelayInfo(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.connection(False) as conn: with self.database.session() as conn:
config = conn.get_config_all() config = conn.get_config_all()
inboxes = [row['domain'] for row in conn.execute('SELECT * FROM inboxes')] inboxes = [row['domain'] for row in conn.execute('SELECT * FROM inboxes')]
@ -123,7 +123,7 @@ class RelayInfo(View):
@register_route('/api/v1/config') @register_route('/api/v1/config')
class Config(View): class Config(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.connection(False) as conn: with self.database.session() as conn:
data = conn.get_config_all() data = conn.get_config_all()
data['log-level'] = data['log-level'].name data['log-level'] = data['log-level'].name
@ -142,7 +142,7 @@ class Config(View):
if data['key'] not in CONFIG_VALID: if data['key'] not in CONFIG_VALID:
return Response.new_error(400, 'Invalid key', 'json') return Response.new_error(400, 'Invalid key', 'json')
with self.database.connection(True) as conn: with self.database.session() as conn:
conn.put_config(data['key'], data['value']) conn.put_config(data['key'], data['value'])
return Response.new({'message': 'Updated config'}, ctype = 'json') return Response.new({'message': 'Updated config'}, ctype = 'json')
@ -157,7 +157,7 @@ class Config(View):
if data['key'] not in CONFIG_VALID: if data['key'] not in CONFIG_VALID:
return Response.new_error(400, 'Invalid key', 'json') return Response.new_error(400, 'Invalid key', 'json')
with self.database.connection(True) as conn: with self.database.session() as conn:
conn.put_config(data['key'], CONFIG_DEFAULTS[data['key']][1]) conn.put_config(data['key'], CONFIG_DEFAULTS[data['key']][1])
return Response.new({'message': 'Updated config'}, ctype = 'json') return Response.new({'message': 'Updated config'}, ctype = 'json')
@ -168,7 +168,7 @@ class Inbox(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
data = [] data = []
with self.database.connection(False) as conn: with self.database.session() as conn:
for inbox in conn.execute('SELECT * FROM inboxes'): for inbox in conn.execute('SELECT * FROM inboxes'):
try: try:
created = datetime.fromtimestamp(inbox['created'], tz = timezone.utc) created = datetime.fromtimestamp(inbox['created'], tz = timezone.utc)
@ -190,7 +190,7 @@ class Inbox(View):
data['domain'] = urlparse(data["actor"]).netloc data['domain'] = urlparse(data["actor"]).netloc
with self.database.connection(True) as conn: with self.database.session() as conn:
if conn.get_inbox(data['domain']): if conn.get_inbox(data['domain']):
return Response.new_error(404, 'Instance already in database', 'json') return Response.new_error(404, 'Instance already in database', 'json')
@ -214,7 +214,7 @@ class Inbox(View):
async def patch(self, request: Request) -> Response: async def patch(self, request: Request) -> Response:
with self.database.connection(True) as conn: with self.database.session() as conn:
data = await self.get_api_data(['domain'], ['actor', 'software', 'followid']) data = await self.get_api_data(['domain'], ['actor', 'software', 'followid'])
if isinstance(data, Response): if isinstance(data, Response):
@ -229,7 +229,7 @@ class Inbox(View):
async def delete(self, request: Request, domain: str) -> Response: async def delete(self, request: Request, domain: str) -> Response:
with self.database.connection(True) as conn: with self.database.session() as conn:
data = await self.get_api_data(['domain'], []) data = await self.get_api_data(['domain'], [])
if isinstance(data, Response): if isinstance(data, Response):
@ -246,7 +246,7 @@ class Inbox(View):
@register_route('/api/v1/domain_ban') @register_route('/api/v1/domain_ban')
class DomainBan(View): class DomainBan(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.connection(False) as conn: with self.database.session() as conn:
bans = conn.execute('SELECT * FROM domain_bans').all() bans = conn.execute('SELECT * FROM domain_bans').all()
return Response.new(bans, ctype = 'json') return Response.new(bans, ctype = 'json')
@ -258,7 +258,7 @@ class DomainBan(View):
if isinstance(data, Response): if isinstance(data, Response):
return data return data
with self.database.connection(True) as conn: with self.database.session() as conn:
if conn.get_domain_ban(data['domain']): if conn.get_domain_ban(data['domain']):
return Response.new_error(400, 'Domain already banned', 'json') return Response.new_error(400, 'Domain already banned', 'json')
@ -268,7 +268,7 @@ class DomainBan(View):
async def patch(self, request: Request) -> Response: async def patch(self, request: Request) -> Response:
with self.database.connection(True) as conn: with self.database.session() as conn:
data = await self.get_api_data(['domain'], ['note', 'reason']) data = await self.get_api_data(['domain'], ['note', 'reason'])
if isinstance(data, Response): if isinstance(data, Response):
@ -286,7 +286,7 @@ class DomainBan(View):
async def delete(self, request: Request) -> Response: async def delete(self, request: Request) -> Response:
with self.database.connection(True) as conn: with self.database.session() as conn:
data = await self.get_api_data(['domain'], []) data = await self.get_api_data(['domain'], [])
if isinstance(data, Response): if isinstance(data, Response):
@ -303,7 +303,7 @@ class DomainBan(View):
@register_route('/api/v1/software_ban') @register_route('/api/v1/software_ban')
class SoftwareBan(View): class SoftwareBan(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.connection(False) as conn: with self.database.session() as conn:
bans = conn.execute('SELECT * FROM software_bans').all() bans = conn.execute('SELECT * FROM software_bans').all()
return Response.new(bans, ctype = 'json') return Response.new(bans, ctype = 'json')
@ -315,7 +315,7 @@ class SoftwareBan(View):
if isinstance(data, Response): if isinstance(data, Response):
return data return data
with self.database.connection(True) as conn: with self.database.session() as conn:
if conn.get_software_ban(data['name']): if conn.get_software_ban(data['name']):
return Response.new_error(400, 'Domain already banned', 'json') return Response.new_error(400, 'Domain already banned', 'json')
@ -330,7 +330,7 @@ class SoftwareBan(View):
if isinstance(data, Response): if isinstance(data, Response):
return data return data
with self.database.connection(True) as conn: with self.database.session() as conn:
if not conn.get_software_ban(data['name']): if not conn.get_software_ban(data['name']):
return Response.new_error(404, 'Software not banned', 'json') return Response.new_error(404, 'Software not banned', 'json')
@ -348,7 +348,7 @@ class SoftwareBan(View):
if isinstance(data, Response): if isinstance(data, Response):
return data return data
with self.database.connection(True) as conn: with self.database.session() as conn:
if not conn.get_software_ban(data['name']): if not conn.get_software_ban(data['name']):
return Response.new_error(404, 'Software not banned', 'json') return Response.new_error(404, 'Software not banned', 'json')
@ -360,7 +360,7 @@ class SoftwareBan(View):
@register_route('/api/v1/whitelist') @register_route('/api/v1/whitelist')
class Whitelist(View): class Whitelist(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.connection(False) as conn: with self.database.session() as conn:
items = conn.execute('SELECT * FROM whitelist').all() items = conn.execute('SELECT * FROM whitelist').all()
return Response.new(items, ctype = 'json') return Response.new(items, ctype = 'json')
@ -372,7 +372,7 @@ class Whitelist(View):
if isinstance(data, Response): if isinstance(data, Response):
return data return data
with self.database.connection(True) as conn: with self.database.session() as conn:
if conn.get_domain_whitelist(data['domain']): if conn.get_domain_whitelist(data['domain']):
return Response.new_error(400, 'Domain already added to whitelist', 'json') return Response.new_error(400, 'Domain already added to whitelist', 'json')
@ -387,7 +387,7 @@ class Whitelist(View):
if isinstance(data, Response): if isinstance(data, Response):
return data return data
with self.database.connection(False) as conn: with self.database.session() as conn:
if not conn.get_domain_whitelist(data['domain']): if not conn.get_domain_whitelist(data['domain']):
return Response.new_error(404, 'Domain not in whitelist', 'json') return Response.new_error(404, 'Domain not in whitelist', 'json')

View file

@ -44,7 +44,7 @@ HOME_TEMPLATE = """
@register_route('/') @register_route('/')
class HomeView(View): class HomeView(View):
async def get(self, request: Request) -> Response: async def get(self, request: Request) -> Response:
with self.database.connection(False) as conn: with self.database.session() as conn:
config = conn.get_config_all() config = conn.get_config_all()
inboxes = conn.execute('SELECT * FROM inboxes').all() inboxes = conn.execute('SELECT * FROM inboxes').all()

View file

@ -33,7 +33,7 @@ if Path(__file__).parent.parent.joinpath('.git').exists():
class NodeinfoView(View): class NodeinfoView(View):
# pylint: disable=no-self-use # pylint: disable=no-self-use
async def get(self, request: Request, niversion: str) -> Response: async def get(self, request: Request, niversion: str) -> Response:
with self.database.connection(False) as conn: with self.database.session() as conn:
inboxes = conn.execute('SELECT * FROM inboxes').all() inboxes = conn.execute('SELECT * FROM inboxes').all()
data = { data = {

View file

@ -7,6 +7,6 @@ gunicorn==21.1.0
hiredis==2.3.2 hiredis==2.3.2
pyyaml>=6.0 pyyaml>=6.0
redis==5.0.1 redis==5.0.1
tinysql[postgres]@https://git.barkshark.xyz/barkshark/tinysql/archive/main.tar.gz barkshark-sql@https://git.barkshark.xyz/barkshark/blib/archive/main.tar.gz
importlib_resources==6.1.1;python_version<'3.9' importlib_resources==6.1.1;python_version<'3.9'

View file

@ -45,4 +45,4 @@ console_scripts =
[flake8] [flake8]
select = F401 select = F401
per-file-ignores = per-file-ignores =
relay/views/__init__.py: F401 __init__.py: F401