Compare commits

..

3 commits

Author SHA1 Message Date
Izalia Mae a2ae1bdd21 return correct types for api 2024-02-20 19:22:18 -05:00
Izalia Mae e83f7e91af Merge branch 'replace_tinysql' into 'master'
Replace tinysql

See merge request pleroma/relay!56
2024-02-21 00:00:21 +00:00
Izalia Mae 93797b639e Replace tinysql 2024-02-21 00:00:21 +00:00
16 changed files with 169 additions and 168 deletions

View file

@ -190,7 +190,7 @@ class Application(web.Application):
self['proc'] = None
self.cache.close()
self.database.close()
self.database.disconnect()
class CacheCleanupThread(Thread):
@ -202,14 +202,10 @@ class CacheCleanupThread(Thread):
def run(self) -> None:
cache = get_cache(self.app)
while self.running.is_set():
time.sleep(3600)
logging.verbose("Removing old cache items")
cache.delete_old(14)
cache.close()
self.app.cache.delete_old(14)
def start(self) -> None:
@ -244,7 +240,7 @@ async def handle_access_log(request: web.Request, response: web.Response) -> Non
async def handle_cleanup(app: Application) -> None:
await app.client.close()
app.cache.close()
app.database.close()
app.database.disconnect()
async def main_gunicorn():

View file

@ -168,8 +168,8 @@ class SqlCache(Cache):
'key': key
}
with self._db.connection() as conn:
with conn.exec_statement('get-cache-item', params) as cur:
with self._db.session(False) as conn:
with conn.run('get-cache-item', params) as cur:
if not (row := cur.one()):
raise KeyError(f'{namespace}:{key}')
@ -178,14 +178,14 @@ class SqlCache(Cache):
def get_keys(self, namespace: str) -> Iterator[str]:
with self._db.connection() as conn:
for row in conn.exec_statement('get-cache-keys', {'namespace': namespace}):
with self._db.session(False) as conn:
for row in conn.run('get-cache-keys', {'namespace': namespace}):
yield row['key']
def get_namespaces(self) -> Iterator[str]:
with self._db.connection() as conn:
for row in conn.exec_statement('get-cache-namespaces', None):
with self._db.session(False) as conn:
for row in conn.run('get-cache-namespaces', None):
yield row['namespace']
@ -198,8 +198,8 @@ class SqlCache(Cache):
'date': datetime.now(tz = timezone.utc)
}
with self._db.connection() as conn:
with conn.exec_statement('set-cache-item', params) as conn:
with self._db.session(True) as conn:
with conn.run('set-cache-item', params) as conn:
row = conn.one()
row.pop('id', None)
return Item.from_data(*tuple(row.values()))
@ -211,8 +211,8 @@ class SqlCache(Cache):
'key': key
}
with self._db.connection() as conn:
with conn.exec_statement('del-cache-item', params):
with self._db.session(True) as conn:
with conn.run('del-cache-item', params):
pass
@ -220,25 +220,27 @@ class SqlCache(Cache):
limit = datetime.now(tz = timezone.utc) - timedelta(days = days)
params = {"limit": limit.timestamp()}
with self._db.connection() as conn:
with self._db.session(True) as conn:
with conn.execute("DELETE FROM cache WHERE updated < :limit", params):
pass
def clear(self) -> None:
with self._db.connection() as conn:
with self._db.session(True) as conn:
with conn.execute("DELETE FROM cache"):
pass
def setup(self) -> None:
with self._db.connection() as conn:
with conn.exec_statement(f'create-cache-table-{self._db.type.name.lower()}', None):
self._db.connect()
with self._db.session(True) as conn:
with conn.run(f'create-cache-table-{self._db.backend_type.value}', None):
pass
def close(self) -> None:
self._db.close()
self._db.disconnect()
self._db = None

View file

@ -137,7 +137,7 @@ CREATE TABLE IF NOT EXISTS cache (
UNIQUE(namespace, key)
);
-- name: create-cache-table-postgres
-- name: create-cache-table-postgresql
CREATE TABLE IF NOT EXISTS cache (
id SERIAL PRIMARY KEY,
namespace TEXT NOT NULL,

View file

@ -1,49 +1,51 @@
from __future__ import annotations
import tinysql
import bsql
import typing
from .config import get_default_value
from .connection import Connection
from .schema import VERSIONS, migrate_0
from .connection import RELAY_SOFTWARE, Connection
from .schema import TABLES, VERSIONS, migrate_0
from .. import logger as logging
try:
from importlib.resources import files as pkgfiles
except ImportError:
except ImportError: # pylint: disable=duplicate-code
from importlib_resources import files as pkgfiles
if typing.TYPE_CHECKING:
from .config import Config
def get_database(config: Config, migrate: bool = True) -> tinysql.Database:
def get_database(config: Config, migrate: bool = True) -> bsql.Database:
options = {
"connection_class": Connection,
"pool_size": 5,
"tables": TABLES
}
if config.db_type == "sqlite":
db = tinysql.Database.sqlite(
config.sqlite_path,
connection_class = Connection,
min_connections = 2,
max_connections = 10
)
db = bsql.Database.sqlite(config.sqlite_path, **options)
elif config.db_type == "postgres":
db = tinysql.Database.postgres(
db = bsql.Database.postgresql(
config.pg_name,
config.pg_host,
config.pg_port,
config.pg_user,
config.pg_pass,
connection_class = Connection
**options
)
db.load_prepared_statements(pkgfiles("relay").joinpath("data", "statements.sql"))
db.connect()
if not migrate:
return db
with db.connection() as conn:
with db.session(True) as conn:
if 'config' not in conn.get_tables():
logging.info("Creating database tables")
migrate_0(conn)

View file

@ -1,9 +1,9 @@
from __future__ import annotations
import tinysql
import typing
from argon2 import PasswordHasher
from bsql import Connection as SqlConnection, Update
from datetime import datetime, timezone
from urllib.parse import urlparse
from uuid import uuid4
@ -15,7 +15,7 @@ from ..misc import get_app
if typing.TYPE_CHECKING:
from collections.abc import Iterator
from tinysql import Cursor, Row
from bsql import Row
from typing import Any
from .application import Application
from ..misc import Message
@ -29,7 +29,7 @@ RELAY_SOFTWARE = [
]
class Connection(tinysql.Connection):
class Connection(SqlConnection):
hasher = PasswordHasher(
encoding = 'utf-8'
)
@ -50,15 +50,11 @@ class Connection(tinysql.Connection):
yield inbox['inbox']
def exec_statement(self, name: str, params: dict[str, Any] | None = None) -> Cursor:
return self.execute(self.database.prepared_statements[name], params)
def get_config(self, key: str) -> Any:
if key not in CONFIG_DEFAULTS:
raise KeyError(key)
with self.exec_statement('get-config', {'key': key}) as cur:
with self.run('get-config', {'key': key}) as cur:
if not (row := cur.one()):
return get_default_value(key)
@ -69,7 +65,7 @@ class Connection(tinysql.Connection):
def get_config_all(self) -> dict[str, Any]:
with self.exec_statement('get-config-all') as cur:
with self.run('get-config-all', None) as cur:
db_config = {row['key']: row['value'] for row in cur}
config = {}
@ -105,12 +101,12 @@ class Connection(tinysql.Connection):
'type': get_default_type(key)
}
with self.exec_statement('put-config', params):
with self.run('put-config', params):
return value
def get_inbox(self, value: str) -> Row:
with self.exec_statement('get-inbox', {'value': value}) as cur:
with self.run('get-inbox', {'value': value}) as cur:
return cur.one()
@ -130,7 +126,7 @@ class Connection(tinysql.Connection):
'created': datetime.now(tz = timezone.utc)
}
with self.exec_statement('put-inbox', params) as cur:
with self.run('put-inbox', params) as cur:
return cur.one()
@ -154,27 +150,28 @@ class Connection(tinysql.Connection):
if software:
data['software'] = software
statement = tinysql.Update('inboxes', data, inbox = inbox)
statement = Update('inboxes', data)
statement.set_where("inbox", inbox)
with self.query(statement):
return self.get_inbox(inbox)
def del_inbox(self, value: str) -> bool:
with self.exec_statement('del-inbox', {'value': value}) as cur:
if cur.modified_row_count > 1:
with self.run('del-inbox', {'value': value}) as cur:
if cur.row_count > 1:
raise ValueError('More than one row was modified')
return cur.modified_row_count == 1
return cur.row_count == 1
def get_user(self, value: str) -> Row:
with self.exec_statement('get-user', {'value': value}) as cur:
with self.run('get-user', {'value': value}) as cur:
return cur.one()
def get_user_by_token(self, code: str) -> Row:
with self.exec_statement('get-user-by-token', {'code': code}) as cur:
with self.run('get-user-by-token', {'code': code}) as cur:
return cur.one()
@ -186,22 +183,22 @@ class Connection(tinysql.Connection):
'created': datetime.now(tz = timezone.utc)
}
with self.exec_statement('put-user', data) as cur:
with self.run('put-user', data) as cur:
return cur.one()
def del_user(self, username: str) -> None:
user = self.get_user(username)
with self.exec_statement('del-user', {'value': user['username']}):
with self.run('del-user', {'value': user['username']}):
pass
with self.exec_statement('del-token-user', {'username': user['username']}):
with self.run('del-token-user', {'username': user['username']}):
pass
def get_token(self, code: str) -> Row:
with self.exec_statement('get-token', {'code': code}) as cur:
with self.run('get-token', {'code': code}) as cur:
return cur.one()
@ -212,12 +209,12 @@ class Connection(tinysql.Connection):
'created': datetime.now(tz = timezone.utc)
}
with self.exec_statement('put-token', data) as cur:
with self.run('put-token', data) as cur:
return cur.one()
def del_token(self, code: str) -> None:
with self.exec_statement('del-token', {'code': code}):
with self.run('del-token', {'code': code}):
pass
@ -225,7 +222,7 @@ class Connection(tinysql.Connection):
if domain.startswith('http'):
domain = urlparse(domain).netloc
with self.exec_statement('get-domain-ban', {'domain': domain}) as cur:
with self.run('get-domain-ban', {'domain': domain}) as cur:
return cur.one()
@ -241,14 +238,14 @@ class Connection(tinysql.Connection):
'created': datetime.now(tz = timezone.utc)
}
with self.exec_statement('put-domain-ban', params) as cur:
with self.run('put-domain-ban', params) as cur:
return cur.one()
def update_domain_ban(self,
domain: str,
reason: str | None = None,
note: str | None = None) -> tinysql.Row:
note: str | None = None) -> Row:
if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified')
@ -261,25 +258,26 @@ class Connection(tinysql.Connection):
if note:
params['note'] = note
statement = tinysql.Update('domain_bans', params, domain = domain)
statement = Update('domain_bans', params)
statement.set_where("domain", domain)
with self.query(statement) as cur:
if cur.modified_row_count > 1:
if cur.row_count > 1:
raise ValueError('More than one row was modified')
return self.get_domain_ban(domain)
def del_domain_ban(self, domain: str) -> bool:
with self.exec_statement('del-domain-ban', {'domain': domain}) as cur:
if cur.modified_row_count > 1:
with self.run('del-domain-ban', {'domain': domain}) as cur:
if cur.row_count > 1:
raise ValueError('More than one row was modified')
return cur.modified_row_count == 1
return cur.row_count == 1
def get_software_ban(self, name: str) -> Row:
with self.exec_statement('get-software-ban', {'name': name}) as cur:
with self.run('get-software-ban', {'name': name}) as cur:
return cur.one()
@ -295,14 +293,14 @@ class Connection(tinysql.Connection):
'created': datetime.now(tz = timezone.utc)
}
with self.exec_statement('put-software-ban', params) as cur:
with self.run('put-software-ban', params) as cur:
return cur.one()
def update_software_ban(self,
name: str,
reason: str | None = None,
note: str | None = None) -> tinysql.Row:
note: str | None = None) -> Row:
if not (reason or note):
raise ValueError('"reason" and/or "note" must be specified')
@ -315,25 +313,26 @@ class Connection(tinysql.Connection):
if note:
params['note'] = note
statement = tinysql.Update('software_bans', params, name = name)
statement = Update('software_bans', params)
statement.set_where("name", name)
with self.query(statement) as cur:
if cur.modified_row_count > 1:
if cur.row_count > 1:
raise ValueError('More than one row was modified')
return self.get_software_ban(name)
def del_software_ban(self, name: str) -> bool:
with self.exec_statement('del-software-ban', {'name': name}) as cur:
if cur.modified_row_count > 1:
with self.run('del-software-ban', {'name': name}) as cur:
if cur.row_count > 1:
raise ValueError('More than one row was modified')
return cur.modified_row_count == 1
return cur.row_count == 1
def get_domain_whitelist(self, domain: str) -> Row:
with self.exec_statement('get-domain-whitelist', {'domain': domain}) as cur:
with self.run('get-domain-whitelist', {'domain': domain}) as cur:
return cur.one()
@ -343,13 +342,13 @@ class Connection(tinysql.Connection):
'created': datetime.now(tz = timezone.utc)
}
with self.exec_statement('put-domain-whitelist', params) as cur:
with self.run('put-domain-whitelist', params) as cur:
return cur.one()
def del_domain_whitelist(self, domain: str) -> bool:
with self.exec_statement('del-domain-whitelist', {'domain': domain}) as cur:
if cur.modified_row_count > 1:
with self.run('del-domain-whitelist', {'domain': domain}) as cur:
if cur.row_count > 1:
raise ValueError('More than one row was modified')
return cur.modified_row_count == 1
return cur.row_count == 1

View file

@ -2,7 +2,7 @@ from __future__ import annotations
import typing
from tinysql import Column, Connection, Table
from bsql import Column, Connection, Table, Tables
from .config import get_default_value
@ -11,7 +11,7 @@ if typing.TYPE_CHECKING:
VERSIONS: dict[int, Callable] = {}
TABLES: list[Table] = [
TABLES: Tables = Tables(
Table(
'config',
Column('key', 'text', primary_key = True, unique = True, nullable = False),
@ -59,7 +59,7 @@ TABLES: list[Table] = [
Column('user', 'text', nullable = False),
Column('created', 'timestmap', nullable = False)
)
]
)
def migration(func: Callable) -> Callable:
@ -69,10 +69,10 @@ def migration(func: Callable) -> Callable:
def migrate_0(conn: Connection) -> None:
conn.create_tables(TABLES)
conn.create_tables()
conn.put_config('schema-version', get_default_value('schema-version'))
@migration
def migrate_20240206(conn: Connection) -> None:
conn.create_tables(TABLES)
conn.create_tables()

View file

@ -111,7 +111,7 @@ class HttpClient:
return loads(item.value)
except KeyError:
logging.verbose('Failed to fetch cached data for url: %s', url)
logging.verbose('No cached data for url: %s', url)
headers = {}

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import Crypto
import asyncio
import click
import os
import platform
import subprocess
import sys
@ -19,8 +20,7 @@ from . import http_client as http
from . import logger as logging
from .application import Application
from .compat import RelayConfig, RelayDatabase
from .database import get_database
from .database.connection import RELAY_SOFTWARE
from .database import RELAY_SOFTWARE, get_database
from .misc import IS_DOCKER, Message
if typing.TYPE_CHECKING:
@ -199,7 +199,7 @@ def cli_setup(ctx: click.Context) -> None:
'private-key': Signer.new('n/a').export()
}
with ctx.obj.database.connection() as conn:
with ctx.obj.database.session() as conn:
for key, value in config.items():
conn.put_config(key, value)
@ -239,9 +239,12 @@ def cli_run(ctx: click.Context, dev: bool = False) -> None:
if getattr(sys, 'frozen', False):
subprocess.run([sys.executable, 'run-gunicorn'], check = False)
return
ctx.obj.run(dev)
else:
ctx.obj.run(dev)
# todo: figure out why the relay doesn't quit properly without this
os._exit(0)
@cli.command('run-gunicorn')
@ -279,7 +282,7 @@ def cli_convert(ctx: click.Context, old_config: str) -> None:
ctx.obj.config.save()
with get_database(ctx.obj.config) as db:
with db.connection() as conn:
with db.session(True) as conn:
conn.put_config('private-key', database['private-key'])
conn.put_config('note', config['note'])
conn.put_config('whitelist-enabled', config['whitelist_enabled'])
@ -365,7 +368,7 @@ def cli_config_list(ctx: click.Context) -> None:
click.echo('Relay Config:')
with ctx.obj.database.connection() as conn:
with ctx.obj.database.session() as conn:
for key, value in conn.get_config_all().items():
if key not in CONFIG_IGNORE:
key = f'{key}:'.ljust(20)
@ -379,7 +382,7 @@ def cli_config_list(ctx: click.Context) -> None:
def cli_config_set(ctx: click.Context, key: str, value: Any) -> None:
'Set a config value'
with ctx.obj.database.connection() as conn:
with ctx.obj.database.session() as conn:
new_value = conn.put_config(key, value)
print(f'{key}: {repr(new_value)}')
@ -397,7 +400,7 @@ def cli_user_list(ctx: click.Context) -> None:
click.echo('Users:')
with ctx.obj.database.connection() as conn:
with ctx.obj.database.session() as conn:
for user in conn.execute('SELECT * FROM users'):
click.echo(f'- {user["username"]}')
@ -409,7 +412,7 @@ def cli_user_list(ctx: click.Context) -> None:
def cli_user_create(ctx: click.Context, username: str, handle: str) -> None:
'Create a new local user'
with ctx.obj.database.connection() as conn:
with ctx.obj.database.session() as conn:
if conn.get_user(username):
click.echo(f'User already exists: {username}')
return
@ -436,7 +439,7 @@ def cli_user_create(ctx: click.Context, username: str, handle: str) -> None:
def cli_user_delete(ctx: click.Context, username: str) -> None:
'Delete a local user'
with ctx.obj.database.connection() as conn:
with ctx.obj.database.session() as conn:
if not conn.get_user(username):
click.echo(f'User does not exist: {username}')
return
@ -454,7 +457,7 @@ def cli_user_list_tokens(ctx: click.Context, username: str) -> None:
click.echo(f'Tokens for "{username}":')
with ctx.obj.database.connection() as conn:
with ctx.obj.database.session() as conn:
for token in conn.execute('SELECT * FROM tokens WHERE user = :user', {'user': username}):
click.echo(f'- {token["code"]}')
@ -465,7 +468,7 @@ def cli_user_list_tokens(ctx: click.Context, username: str) -> None:
def cli_user_create_token(ctx: click.Context, username: str) -> None:
'Create a new API token for a user'
with ctx.obj.database.connection() as conn:
with ctx.obj.database.session() as conn:
if not (user := conn.get_user(username)):
click.echo(f'User does not exist: {username}')
return
@ -481,7 +484,7 @@ def cli_user_create_token(ctx: click.Context, username: str) -> None:
def cli_user_delete_token(ctx: click.Context, code: str) -> None:
'Delete an API token'
with ctx.obj.database.connection() as conn:
with ctx.obj.database.session() as conn:
if not conn.get_token(code):
click.echo('Token does not exist')
return
@ -503,7 +506,7 @@ def cli_inbox_list(ctx: click.Context) -> None:
click.echo('Connected to the following instances or relays:')
with ctx.obj.database.connection() as conn:
with ctx.obj.database.session() as conn:
for inbox in conn.execute('SELECT * FROM inboxes'):
click.echo(f'- {inbox["inbox"]}')
@ -514,7 +517,7 @@ def cli_inbox_list(ctx: click.Context) -> None:
def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
'Follow an actor (Relay must be running)'
with ctx.obj.database.connection() as conn:
with ctx.obj.database.session() as conn:
if conn.get_domain_ban(actor):
click.echo(f'Error: Refusing to follow banned actor: {actor}')
return
@ -549,7 +552,7 @@ def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
inbox_data: Row = None
with ctx.obj.database.connection() as conn:
with ctx.obj.database.session() as conn:
if conn.get_domain_ban(actor):
click.echo(f'Error: Refusing to follow banned actor: {actor}')
return
@ -617,7 +620,7 @@ def cli_inbox_add(
except KeyError:
pass
with ctx.obj.database.connection() as conn:
with ctx.obj.database.session() as conn:
if conn.get_domain_ban(domain):
click.echo(f'Refusing to add banned inbox: {inbox}')
return
@ -637,7 +640,7 @@ def cli_inbox_add(
def cli_inbox_remove(ctx: click.Context, inbox: str) -> None:
'Remove an inbox from the database'
with ctx.obj.database.connection() as conn:
with ctx.obj.database.session() as conn:
if not conn.del_inbox(inbox):
click.echo(f'Inbox not in database: {inbox}')
return
@ -657,7 +660,7 @@ def cli_instance_list(ctx: click.Context) -> None:
click.echo('Banned domains:')
with ctx.obj.database.connection() as conn:
with ctx.obj.database.session() as conn:
for instance in conn.execute('SELECT * FROM domain_bans'):
if instance['reason']:
click.echo(f'- {instance["domain"]} ({instance["reason"]})')
@ -674,7 +677,7 @@ def cli_instance_list(ctx: click.Context) -> None:
def cli_instance_ban(ctx: click.Context, domain: str, reason: str, note: str) -> None:
'Ban an instance and remove the associated inbox if it exists'
with ctx.obj.database.connection() as conn:
with ctx.obj.database.session() as conn:
if conn.get_domain_ban(domain):
click.echo(f'Domain already banned: {domain}')
return
@ -690,7 +693,7 @@ def cli_instance_ban(ctx: click.Context, domain: str, reason: str, note: str) ->
def cli_instance_unban(ctx: click.Context, domain: str) -> None:
'Unban an instance'
with ctx.obj.database.connection() as conn:
with ctx.obj.database.session() as conn:
if not conn.del_domain_ban(domain):
click.echo(f'Instance wasn\'t banned: {domain}')
return
@ -709,7 +712,7 @@ def cli_instance_update(ctx: click.Context, domain: str, reason: str, note: str)
if not (reason or note):
ctx.fail('Must pass --reason or --note')
with ctx.obj.database.connection() as conn:
with ctx.obj.database.session() as conn:
if not (row := conn.update_domain_ban(domain, reason, note)):
click.echo(f'Failed to update domain ban: {domain}')
return
@ -735,7 +738,7 @@ def cli_software_list(ctx: click.Context) -> None:
click.echo('Banned software:')
with ctx.obj.database.connection() as conn:
with ctx.obj.database.session() as conn:
for software in conn.execute('SELECT * FROM software_bans'):
if software['reason']:
click.echo(f'- {software["name"]} ({software["reason"]})')
@ -761,7 +764,7 @@ def cli_software_ban(ctx: click.Context,
fetch_nodeinfo: bool) -> None:
'Ban software. Use RELAYS for NAME to ban relays'
with ctx.obj.database.connection() as conn:
with ctx.obj.database.session() as conn:
if name == 'RELAYS':
for software in RELAY_SOFTWARE:
if conn.get_software_ban(software):
@ -804,7 +807,7 @@ def cli_software_ban(ctx: click.Context,
def cli_software_unban(ctx: click.Context, name: str, fetch_nodeinfo: bool) -> None:
'Ban software. Use RELAYS for NAME to unban relays'
with ctx.obj.database.connection() as conn:
with ctx.obj.database.session() as conn:
if name == 'RELAYS':
for software in RELAY_SOFTWARE:
if not conn.del_software_ban(software):
@ -838,7 +841,7 @@ def cli_software_update(ctx: click.Context, name: str, reason: str, note: str) -
if not (reason or note):
ctx.fail('Must pass --reason or --note')
with ctx.obj.database.connection() as conn:
with ctx.obj.database.session() as conn:
if not (row := conn.update_software_ban(name, reason, note)):
click.echo(f'Failed to update software ban: {name}')
return
@ -864,7 +867,7 @@ def cli_whitelist_list(ctx: click.Context) -> None:
click.echo('Current whitelisted domains:')
with ctx.obj.database.connection() as conn:
with ctx.obj.database.session() as conn:
for domain in conn.execute('SELECT * FROM whitelist'):
click.echo(f'- {domain["domain"]}')
@ -875,7 +878,7 @@ def cli_whitelist_list(ctx: click.Context) -> None:
def cli_whitelist_add(ctx: click.Context, domain: str) -> None:
'Add a domain to the whitelist'
with ctx.obj.database.connection() as conn:
with ctx.obj.database.session() as conn:
if conn.get_domain_whitelist(domain):
click.echo(f'Instance already in the whitelist: {domain}')
return
@ -890,7 +893,7 @@ def cli_whitelist_add(ctx: click.Context, domain: str) -> None:
def cli_whitelist_remove(ctx: click.Context, domain: str) -> None:
'Remove an instance from the whitelist'
with ctx.obj.database.connection() as conn:
with ctx.obj.database.session() as conn:
if not conn.del_domain_whitelist(domain):
click.echo(f'Domain not in the whitelist: {domain}')
return
@ -907,7 +910,7 @@ def cli_whitelist_remove(ctx: click.Context, domain: str) -> None:
def cli_whitelist_import(ctx: click.Context) -> None:
'Add all current inboxes to the whitelist'
with ctx.obj.database.connection() as conn:
with ctx.obj.database.session() as conn:
for inbox in conn.execute('SELECT * FROM inboxes').all():
if conn.get_domain_whitelist(inbox['domain']):
click.echo(f'Domain already in whitelist: {inbox["domain"]}')

View file

@ -7,6 +7,7 @@ import typing
from aiohttp.web import Response as AiohttpResponse
from aputils.message import Message as ApMessage
from datetime import datetime
from uuid import uuid4
if typing.TYPE_CHECKING:
@ -74,6 +75,14 @@ def get_app() -> Application:
return Application.DEFAULT
class JsonEncoder(json.JSONEncoder):
def default(self, obj: Any) -> str:
if isinstance(obj, datetime):
return obj.isoformat()
return JSONEncoder.default(self, obj)
class Message(ApMessage):
@classmethod
def new_actor(cls: type[Message], # pylint: disable=arguments-differ
@ -193,8 +202,8 @@ class Response(AiohttpResponse):
if isinstance(body, bytes):
kwargs['body'] = body
elif isinstance(body, (dict, list, tuple, set)) and ctype in {'json', 'activity'}:
kwargs['text'] = json.dumps(body)
elif isinstance(body, (dict, list, tuple, set)) or ctype in {'json', 'activity'}:
kwargs['text'] = json.dumps(body, cls = JsonEncoder)
else:
kwargs['text'] = body
@ -209,7 +218,7 @@ class Response(AiohttpResponse):
ctype: str = 'text') -> Response:
if ctype == 'json':
body = json.dumps({'error': body})
body = {'error': body}
return cls.new(body=body, status=status, ctype=ctype)

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import typing
from . import logger as logging
from .database.connection import Connection
from .database import Connection
from .misc import Message
if typing.TYPE_CHECKING:
@ -43,8 +43,8 @@ async def handle_relay(view: ActorView, conn: Connection) -> None:
async def handle_forward(view: ActorView, conn: Connection) -> None:
try:
view.cache.get('handle-relay', view.message.object_id)
logging.verbose('already forwarded %s', view.message.object_id)
view.cache.get('handle-relay', view.message.id)
logging.verbose('already forwarded %s', view.message.id)
return
except KeyError:
@ -56,7 +56,7 @@ async def handle_forward(view: ActorView, conn: Connection) -> None:
for inbox in conn.distill_inboxes(view.message):
view.app.push_message(inbox, message, view.instance)
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
view.cache.set('handle-relay', view.message.id, message.id, 'str')
async def handle_follow(view: ActorView, conn: Connection) -> None:
@ -184,7 +184,7 @@ async def run_processor(view: ActorView) -> None:
return
with view.database.connection(False) as conn:
with view.database.session() as conn:
if view.instance:
if not view.instance['software']:
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])):

View file

@ -46,7 +46,7 @@ class ActorView(View):
if response := await self.get_post_data():
return response
with self.database.connection(False) as conn:
with self.database.session() as conn:
self.instance = conn.get_inbox(self.actor.shared_inbox)
config = conn.get_config_all()

View file

@ -45,7 +45,7 @@ async def handle_api_path(request: web.Request, handler: Coroutine) -> web.Respo
try:
request['token'] = request.headers['Authorization'].replace('Bearer', '').strip()
with request.app.database.connection() as conn:
with request.app.database.session() as conn:
request['user'] = conn.get_user_by_token(request['token'])
except (KeyError, ValueError):
@ -92,7 +92,7 @@ class Login(View):
async def delete(self, request: Request) -> Response:
with self.database.connection(True) as conn:
with self.database.session() as conn:
conn.del_token(request['token'])
return Response.new({'message': 'Token revoked'}, ctype = 'json')
@ -101,7 +101,7 @@ class Login(View):
@register_route('/api/v1/relay')
class RelayInfo(View):
async def get(self, request: Request) -> Response:
with self.database.connection(False) as conn:
with self.database.session() as conn:
config = conn.get_config_all()
inboxes = [row['domain'] for row in conn.execute('SELECT * FROM inboxes')]
@ -123,7 +123,7 @@ class RelayInfo(View):
@register_route('/api/v1/config')
class Config(View):
async def get(self, request: Request) -> Response:
with self.database.connection(False) as conn:
with self.database.session() as conn:
data = conn.get_config_all()
data['log-level'] = data['log-level'].name
@ -142,7 +142,7 @@ class Config(View):
if data['key'] not in CONFIG_VALID:
return Response.new_error(400, 'Invalid key', 'json')
with self.database.connection(True) as conn:
with self.database.session() as conn:
conn.put_config(data['key'], data['value'])
return Response.new({'message': 'Updated config'}, ctype = 'json')
@ -157,7 +157,7 @@ class Config(View):
if data['key'] not in CONFIG_VALID:
return Response.new_error(400, 'Invalid key', 'json')
with self.database.connection(True) as conn:
with self.database.session() as conn:
conn.put_config(data['key'], CONFIG_DEFAULTS[data['key']][1])
return Response.new({'message': 'Updated config'}, ctype = 'json')
@ -166,18 +166,8 @@ class Config(View):
@register_route('/api/v1/instance')
class Inbox(View):
async def get(self, request: Request) -> Response:
data = []
with self.database.connection(False) as conn:
for inbox in conn.execute('SELECT * FROM inboxes'):
try:
created = datetime.fromtimestamp(inbox['created'], tz = timezone.utc)
except TypeError:
created = datetime.fromisoformat(inbox['created'])
inbox['created'] = created.isoformat()
data.append(inbox)
with self.database.session() as conn:
data = tuple(conn.execute('SELECT * FROM inboxes').all())
return Response.new(data, ctype = 'json')
@ -190,7 +180,7 @@ class Inbox(View):
data['domain'] = urlparse(data["actor"]).netloc
with self.database.connection(True) as conn:
with self.database.session() as conn:
if conn.get_inbox(data['domain']):
return Response.new_error(404, 'Instance already in database', 'json')
@ -214,7 +204,7 @@ class Inbox(View):
async def patch(self, request: Request) -> Response:
with self.database.connection(True) as conn:
with self.database.session() as conn:
data = await self.get_api_data(['domain'], ['actor', 'software', 'followid'])
if isinstance(data, Response):
@ -229,7 +219,7 @@ class Inbox(View):
async def delete(self, request: Request, domain: str) -> Response:
with self.database.connection(True) as conn:
with self.database.session() as conn:
data = await self.get_api_data(['domain'], [])
if isinstance(data, Response):
@ -246,8 +236,8 @@ class Inbox(View):
@register_route('/api/v1/domain_ban')
class DomainBan(View):
async def get(self, request: Request) -> Response:
with self.database.connection(False) as conn:
bans = conn.execute('SELECT * FROM domain_bans').all()
with self.database.session() as conn:
bans = tuple(conn.execute('SELECT * FROM domain_bans').all())
return Response.new(bans, ctype = 'json')
@ -258,7 +248,7 @@ class DomainBan(View):
if isinstance(data, Response):
return data
with self.database.connection(True) as conn:
with self.database.session() as conn:
if conn.get_domain_ban(data['domain']):
return Response.new_error(400, 'Domain already banned', 'json')
@ -268,7 +258,7 @@ class DomainBan(View):
async def patch(self, request: Request) -> Response:
with self.database.connection(True) as conn:
with self.database.session() as conn:
data = await self.get_api_data(['domain'], ['note', 'reason'])
if isinstance(data, Response):
@ -286,7 +276,7 @@ class DomainBan(View):
async def delete(self, request: Request) -> Response:
with self.database.connection(True) as conn:
with self.database.session() as conn:
data = await self.get_api_data(['domain'], [])
if isinstance(data, Response):
@ -303,8 +293,8 @@ class DomainBan(View):
@register_route('/api/v1/software_ban')
class SoftwareBan(View):
async def get(self, request: Request) -> Response:
with self.database.connection(False) as conn:
bans = conn.execute('SELECT * FROM software_bans').all()
with self.database.session() as conn:
bans = tuple(conn.execute('SELECT * FROM software_bans').all())
return Response.new(bans, ctype = 'json')
@ -315,7 +305,7 @@ class SoftwareBan(View):
if isinstance(data, Response):
return data
with self.database.connection(True) as conn:
with self.database.session() as conn:
if conn.get_software_ban(data['name']):
return Response.new_error(400, 'Domain already banned', 'json')
@ -330,7 +320,7 @@ class SoftwareBan(View):
if isinstance(data, Response):
return data
with self.database.connection(True) as conn:
with self.database.session() as conn:
if not conn.get_software_ban(data['name']):
return Response.new_error(404, 'Software not banned', 'json')
@ -348,7 +338,7 @@ class SoftwareBan(View):
if isinstance(data, Response):
return data
with self.database.connection(True) as conn:
with self.database.session() as conn:
if not conn.get_software_ban(data['name']):
return Response.new_error(404, 'Software not banned', 'json')
@ -360,8 +350,8 @@ class SoftwareBan(View):
@register_route('/api/v1/whitelist')
class Whitelist(View):
async def get(self, request: Request) -> Response:
with self.database.connection(False) as conn:
items = conn.execute('SELECT * FROM whitelist').all()
with self.database.session() as conn:
items = tuple(conn.execute('SELECT * FROM whitelist').all())
return Response.new(items, ctype = 'json')
@ -372,7 +362,7 @@ class Whitelist(View):
if isinstance(data, Response):
return data
with self.database.connection(True) as conn:
with self.database.session() as conn:
if conn.get_domain_whitelist(data['domain']):
return Response.new_error(400, 'Domain already added to whitelist', 'json')
@ -387,7 +377,7 @@ class Whitelist(View):
if isinstance(data, Response):
return data
with self.database.connection(False) as conn:
with self.database.session() as conn:
if not conn.get_domain_whitelist(data['domain']):
return Response.new_error(404, 'Domain not in whitelist', 'json')

View file

@ -44,9 +44,9 @@ HOME_TEMPLATE = """
@register_route('/')
class HomeView(View):
async def get(self, request: Request) -> Response:
with self.database.connection(False) as conn:
with self.database.session() as conn:
config = conn.get_config_all()
inboxes = conn.execute('SELECT * FROM inboxes').all()
inboxes = tuple(conn.execute('SELECT * FROM inboxes').all())
text = HOME_TEMPLATE.format(
host = self.config.domain,

View file

@ -33,7 +33,7 @@ if Path(__file__).parent.parent.joinpath('.git').exists():
class NodeinfoView(View):
# pylint: disable=no-self-use
async def get(self, request: Request, niversion: str) -> Response:
with self.database.connection(False) as conn:
with self.database.session() as conn:
inboxes = conn.execute('SELECT * FROM inboxes').all()
data = {

View file

@ -7,6 +7,6 @@ gunicorn==21.1.0
hiredis==2.3.2
pyyaml>=6.0
redis==5.0.1
tinysql[postgres]@https://git.barkshark.xyz/barkshark/tinysql/archive/main.tar.gz
barkshark-sql@https://git.barkshark.xyz/barkshark/bsql/archive/0.1.1.tar.gz
importlib_resources==6.1.1;python_version<'3.9'

View file

@ -45,4 +45,4 @@ console_scripts =
[flake8]
select = F401
per-file-ignores =
relay/views/__init__.py: F401
__init__.py: F401