Compare commits

..

5 commits

Author SHA1 Message Date
Izalia Mae 7329ef4fd9 fix queries for migrations 2024-09-14 06:06:03 -04:00
Izalia Mae c5acc5aa16 update barkshark-sql to 0.2.0 2024-09-14 05:57:49 -04:00
Izalia Mae 9a0400d84f add command to switch database backends 2024-09-14 05:56:47 -04:00
Izalia Mae c54aeabc90 fix postgres support 2024-09-14 05:56:27 -04:00
Izalia Mae 16fcea90f2 use correct key when loading postgres config 2024-09-14 05:09:48 -04:00
6 changed files with 53 additions and 42 deletions

View file

@ -19,7 +19,7 @@ dependencies = [
"aiohttp-swagger[performance] == 1.0.16", "aiohttp-swagger[performance] == 1.0.16",
"argon2-cffi == 23.1.0", "argon2-cffi == 23.1.0",
"barkshark-lib >= 0.2.2.post2, < 0.3.0", "barkshark-lib >= 0.2.2.post2, < 0.3.0",
"barkshark-sql >= 0.2.0rc2, < 0.3.0", "barkshark-sql >= 0.2.0, < 0.3.0",
"click == 8.1.2", "click == 8.1.2",
"hiredis == 2.3.2", "hiredis == 2.3.2",
"idna == 3.4", "idna == 3.4",

View file

@ -145,7 +145,7 @@ class Config:
if not config: if not config:
raise ValueError('Config is empty') raise ValueError('Config is empty')
pgcfg = config.get('postgresql', {}) pgcfg = config.get('postgres', {})
rdcfg = config.get('redis', {}) rdcfg = config.get('redis', {})
for key in type(self).KEYS(): for key in type(self).KEYS():

View file

@ -40,7 +40,7 @@ WHERE domain = :value or inbox = :value or actor = :value;
-- name: get-request -- name: get-request
SELECT * FROM inboxes WHERE accepted = 0 and domain = :domain; SELECT * FROM inboxes WHERE accepted = false and domain = :domain;
-- name: get-user -- name: get-user
@ -64,7 +64,7 @@ RETURNING *;
-- name: del-user -- name: del-user
DELETE FROM users DELETE FROM users
WHERE username = :value or handle = :value; WHERE username = :username or handle = :username;
-- name: get-app -- name: get-app
@ -91,6 +91,10 @@ DELETE FROM apps
WHERE client_id = :id and client_secret = :secret and token = :token; WHERE client_id = :id and client_secret = :secret and token = :token;
-- name: del-token-user
DELETE FROM apps WHERE "user" = :username;
-- name: get-software-ban -- name: get-software-ban
SELECT * FROM software_bans WHERE name = :name; SELECT * FROM software_bans WHERE name = :name;

View file

@ -138,7 +138,7 @@ class Connection(SqlConnection):
def get_inboxes(self) -> Iterator[schema.Instance]: def get_inboxes(self) -> Iterator[schema.Instance]:
return self.execute("SELECT * FROM inboxes WHERE accepted = 1").all(schema.Instance) return self.execute("SELECT * FROM inboxes WHERE accepted = true").all(schema.Instance)
# todo: check if software is different than stored row # todo: check if software is different than stored row
@ -196,7 +196,7 @@ class Connection(SqlConnection):
def get_requests(self) -> Iterator[schema.Instance]: def get_requests(self) -> Iterator[schema.Instance]:
return self.execute('SELECT * FROM inboxes WHERE accepted = 0').all(schema.Instance) return self.execute('SELECT * FROM inboxes WHERE accepted = false').all(schema.Instance)
def put_request_response(self, domain: str, accepted: bool) -> schema.Instance: def put_request_response(self, domain: str, accepted: bool) -> schema.Instance:
@ -275,10 +275,10 @@ class Connection(SqlConnection):
if (user := self.get_user(username)) is None: if (user := self.get_user(username)) is None:
raise KeyError(username) raise KeyError(username)
with self.run('del-user', {'value': user.username}): with self.run('del-token-user', {'username': user.username}):
pass pass
with self.run('del-token-user', {'username': user.username}): with self.run('del-user', {'username': user.username}):
pass pass
@ -315,8 +315,8 @@ class Connection(SqlConnection):
'website': website, 'website': website,
'client_id': secrets.token_hex(20), 'client_id': secrets.token_hex(20),
'client_secret': secrets.token_hex(20), 'client_secret': secrets.token_hex(20),
'created': Date.new_utc().timestamp(), 'created': Date.new_utc(),
'accessed': Date.new_utc().timestamp() 'accessed': Date.new_utc()
} }
with self.insert('apps', params) as cur: with self.insert('apps', params) as cur:
@ -336,8 +336,8 @@ class Connection(SqlConnection):
'client_secret': secrets.token_hex(20), 'client_secret': secrets.token_hex(20),
'auth_code': None, 'auth_code': None,
'token': secrets.token_hex(20), 'token': secrets.token_hex(20),
'created': Date.new_utc().timestamp(), 'created': Date.new_utc(),
'accessed': Date.new_utc().timestamp() 'accessed': Date.new_utc()
} }
with self.insert('apps', params) as cur: with self.insert('apps', params) as cur:

View file

@ -45,20 +45,14 @@ class Instance(Row):
followid: Column[str] = Column('followid', 'text') followid: Column[str] = Column('followid', 'text')
software: Column[str] = Column('software', 'text') software: Column[str] = Column('software', 'text')
accepted: Column[Date] = Column('accepted', 'boolean') accepted: Column[Date] = Column('accepted', 'boolean')
created: Column[Date] = Column( created: Column[Date] = Column('created', 'timestamp', nullable = False)
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
@TABLES.add_row @TABLES.add_row
class Whitelist(Row): class Whitelist(Row):
domain: Column[str] = Column( domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = True) 'domain', 'text', primary_key = True, unique = True, nullable = True)
created: Column[Date] = Column( created: Column[Date] = Column('created', 'timestamp', nullable = False)
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
@TABLES.add_row @TABLES.add_row
@ -70,10 +64,7 @@ class DomainBan(Row):
'domain', 'text', primary_key = True, unique = True, nullable = True) 'domain', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text') reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text') note: Column[str] = Column('note', 'text')
created: Column[Date] = Column( created: Column[Date] = Column('created', 'timestamp', nullable = False)
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
@TABLES.add_row @TABLES.add_row
@ -84,10 +75,7 @@ class SoftwareBan(Row):
name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True) name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text') reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text') note: Column[str] = Column('note', 'text')
created: Column[Date] = Column( created: Column[Date] = Column('created', 'timestamp', nullable = False)
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
@TABLES.add_row @TABLES.add_row
@ -99,10 +87,7 @@ class User(Row):
'username', 'text', primary_key = True, unique = True, nullable = False) 'username', 'text', primary_key = True, unique = True, nullable = False)
hash: Column[str] = Column('hash', 'text', nullable = False) hash: Column[str] = Column('hash', 'text', nullable = False)
handle: Column[str] = Column('handle', 'text') handle: Column[str] = Column('handle', 'text')
created: Column[Date] = Column( created: Column[Date] = Column('created', 'timestamp', nullable = False)
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
@TABLES.add_row @TABLES.add_row
@ -119,14 +104,8 @@ class App(Row):
token: Column[str | None] = Column('token', 'text') token: Column[str | None] = Column('token', 'text')
auth_code: Column[str | None] = Column('auth_code', 'text') auth_code: Column[str | None] = Column('auth_code', 'text')
user: Column[str | None] = Column('user', 'text') user: Column[str | None] = Column('user', 'text')
created: Column[Date] = Column( created: Column[Date] = Column('created', 'timestamp', nullable = False)
'created', 'timestamp', nullable = False, accessed: Column[Date] = Column('accessed', 'timestamp', nullable = False)
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
accessed: Column[Date] = Column(
'accessed', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
def get_api_data(self, include_token: bool = False) -> dict[str, Any]: def get_api_data(self, include_token: bool = False) -> dict[str, Any]:
@ -159,10 +138,10 @@ def migrate_20240206(conn: Connection) -> None:
@migration @migration
def migrate_20240310(conn: Connection) -> None: def migrate_20240310(conn: Connection) -> None:
conn.execute('ALTER TABLE "inboxes" ADD COLUMN "accepted" BOOLEAN').close() conn.execute('ALTER TABLE "inboxes" ADD COLUMN "accepted" BOOLEAN').close()
conn.execute('UPDATE "inboxes" SET accepted = 1').close() conn.execute('UPDATE "inboxes" SET "accepted" = true').close()
@migration @migration
def migrate_20240625(conn: Connection) -> None: def migrate_20240625(conn: Connection) -> None:
conn.create_tables() conn.create_tables()
conn.execute('DROP TABLE tokens').close() conn.execute('DROP TABLE "tokens"').close()

View file

@ -16,6 +16,7 @@ from . import http_client as http
from . import logger as logging from . import logger as logging
from .application import Application from .application import Application
from .compat import RelayConfig, RelayDatabase from .compat import RelayConfig, RelayDatabase
from .config import Config
from .database import RELAY_SOFTWARE, get_database, schema from .database import RELAY_SOFTWARE, get_database, schema
from .misc import ACTOR_FORMATS, SOFTWARE, IS_DOCKER, Message from .misc import ACTOR_FORMATS, SOFTWARE, IS_DOCKER, Message
@ -329,6 +330,33 @@ def cli_editconfig(ctx: click.Context, editor: str) -> None:
) )
@cli.command('switch-backend')
@click.pass_context
def cli_switchbackend(ctx: click.Context) -> None:
"""
Copy the database from one backend to the other
Be sure to set the database type to the backend you want to convert from. For instance, set
the database type to `sqlite`, fill out the connection details for postgresql, and the
data from the sqlite database will be copied to the postgresql database. This only works if
the database in postgresql already exists.
"""
config = Config(ctx.obj.config.path, load = True)
config.db_type = "sqlite" if config.db_type == "postgres" else "postgres"
database = get_database(config, migrate = False)
with database.session(True) as new, ctx.obj.database.session(False) as old:
new.create_tables()
for table in schema.TABLES.keys():
for row in old.execute(f"SELECT * FROM {table}"):
new.insert(table, row).close()
config.save()
click.echo(f"Converted database to {repr(config.db_type)}")
@cli.group('config') @cli.group('config')
def cli_config() -> None: def cli_config() -> None:
'Manage the relay settings stored in the database' 'Manage the relay settings stored in the database'