improve switch-backend command

This commit is contained in:
Izalia Mae 2024-09-28 09:41:25 -04:00
parent 85f062c8f3
commit f192f1c35c
3 changed files with 70 additions and 31 deletions

View file

@ -16,7 +16,8 @@ Run the relay.
## Setup
Run the setup wizard to configure your relay.
Run the setup wizard to configure your relay. For the PostgreSQL backend, the database has to be
created first.
activityrelay setup
@ -29,6 +30,16 @@ not specified, the config will get backed up as `relay.backup.yaml` before conve
activityrelay convert --old-config relaycfg.yaml
## Switch Backend
Change the database backend from the current one to the other. The config will be updated after
running the command.
Note: If switching to PostgreSQL, make sure the database exists first.
activityrelay switch-backend
## Edit Config
Open the config file in a text editor. If an editor is not specified with `--editor`, the default

View file

@ -4,7 +4,7 @@ import secrets
from argon2 import PasswordHasher
from blib import Date, convert_to_boolean
from bsql import Connection as SqlConnection, Row, Update
from bsql import BackendType, Connection as SqlConnection, Row, Update
from collections.abc import Iterator
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any
@ -51,6 +51,17 @@ class Connection(SqlConnection):
yield instance
def drop_tables(self) -> None:
with self.cursor() as cur:
for table in self.get_tables():
query = f"DROP TABLE IF EXISTS {table}"
if self.database.backend.backend_type == BackendType.POSTGRESQL:
query += " CASCADE"
cur.execute(query)
def fix_timestamps(self) -> None:
for app in self.select('apps').all(schema.App):
data = {'created': app.created.timestamp(), 'accessed': app.accessed.timestamp()}

View file

@ -102,34 +102,7 @@ def cli_setup(ctx: click.Context, skip_questions: bool) -> None:
)
elif ctx.obj.config.db_type == 'postgres':
ctx.obj.config.pg_name = click.prompt(
'What is the name of the database?',
default = ctx.obj.config.pg_name
)
ctx.obj.config.pg_host = click.prompt(
'What IP address, hostname, or unix socket does the server listen on?',
default = ctx.obj.config.pg_host,
type = int
)
ctx.obj.config.pg_port = click.prompt(
'What port does the server listen on?',
default = ctx.obj.config.pg_port,
type = int
)
ctx.obj.config.pg_user = click.prompt(
'Which user will authenticate with the server?',
default = ctx.obj.config.pg_user
)
ctx.obj.config.pg_pass = click.prompt(
'User password',
hide_input = True,
show_default = False,
default = ctx.obj.config.pg_pass or ""
) or None
config_postgresql(ctx.obj.config)
ctx.obj.config.ca_type = click.prompt(
'Which caching backend?',
@ -344,9 +317,24 @@ def cli_switchbackend(ctx: click.Context) -> None:
config = Config(ctx.obj.config.path, load = True)
config.db_type = "sqlite" if config.db_type == "postgres" else "postgres"
if config.db_type == "postgres":
if click.confirm("Setup PostgreSQL configuration?"):
config_postgresql(config)
order = ("SQLite", "PostgreSQL")
click.pause("Make sure the database and user already exist before continuing")
else:
order = ("PostgreSQL", "SQLite")
click.echo(f"About to convert from {order[0]} to {order[1]}...")
database = get_database(config, migrate = False)
with database.session(True) as new, ctx.obj.database.session(False) as old:
if click.confirm("All tables in the destination database will be dropped. Continue?"):
new.drop_tables()
new.create_tables()
for table in schema.TABLES.keys():
@ -354,7 +342,7 @@ def cli_switchbackend(ctx: click.Context) -> None:
new.insert(table, row).close()
config.save()
click.echo(f"Converted database to {repr(config.db_type)}")
click.echo("Done!")
@cli.group('config')
@ -1015,6 +1003,35 @@ def cli_whitelist_import(ctx: click.Context) -> None:
click.echo('Imported whitelist from inboxes')
def config_postgresql(config: Config) -> None:
config.pg_name = click.prompt(
'What is the name of the database?',
default = config.pg_name
)
config.pg_host = click.prompt(
'What IP address, hostname, or unix socket does the server listen on?',
default = config.pg_host,
)
config.pg_port = click.prompt(
'What port does the server listen on?',
default = config.pg_port,
type = int
)
config.pg_user = click.prompt(
'Which user will authenticate with the server?',
default = config.pg_user
)
config.pg_pass = click.prompt(
'User password',
hide_input = True,
show_default = False,
default = config.pg_pass or ""
) or None
def main() -> None:
cli(prog_name='activityrelay')