diff --git a/docs/commands.md b/docs/commands.md index e28acbe..3749b83 100644 --- a/docs/commands.md +++ b/docs/commands.md @@ -16,7 +16,8 @@ Run the relay. ## Setup -Run the setup wizard to configure your relay. +Run the setup wizard to configure your relay. For the PostgreSQL backend, the database has to be +created first. activityrelay setup @@ -29,6 +30,16 @@ not specified, the config will get backed up as `relay.backup.yaml` before conve activityrelay convert --old-config relaycfg.yaml +## Switch Backend + +Change the database backend from the current one to the other. The config will be updated after +running the command. + +Note: If switching to PostgreSQL, make sure the database exists first. + + activityrelay switch-backend + + ## Edit Config Open the config file in a text editor. If an editor is not specified with `--editor`, the default diff --git a/relay/database/connection.py b/relay/database/connection.py index d217c32..2d5de61 100644 --- a/relay/database/connection.py +++ b/relay/database/connection.py @@ -4,7 +4,7 @@ import secrets from argon2 import PasswordHasher from blib import Date, convert_to_boolean -from bsql import Connection as SqlConnection, Row, Update +from bsql import BackendType, Connection as SqlConnection, Row, Update from collections.abc import Iterator from datetime import datetime, timezone from typing import TYPE_CHECKING, Any @@ -51,6 +51,17 @@ class Connection(SqlConnection): yield instance + def drop_tables(self) -> None: + with self.cursor() as cur: + for table in self.get_tables(): + query = f"DROP TABLE IF EXISTS {table}" + + if self.database.backend.backend_type == BackendType.POSTGRESQL: + query += " CASCADE" + + cur.execute(query) + + def fix_timestamps(self) -> None: for app in self.select('apps').all(schema.App): data = {'created': app.created.timestamp(), 'accessed': app.accessed.timestamp()} diff --git a/relay/manage.py b/relay/manage.py index e1dfce5..eb8ba44 100644 --- a/relay/manage.py +++ b/relay/manage.py @@ -102,34 +102,7 @@ def cli_setup(ctx: click.Context, skip_questions: bool) -> None: ) elif ctx.obj.config.db_type == 'postgres': - ctx.obj.config.pg_name = click.prompt( - 'What is the name of the database?', - default = ctx.obj.config.pg_name - ) - - ctx.obj.config.pg_host = click.prompt( - 'What IP address, hostname, or unix socket does the server listen on?', - default = ctx.obj.config.pg_host, - type = int - ) - - ctx.obj.config.pg_port = click.prompt( - 'What port does the server listen on?', - default = ctx.obj.config.pg_port, - type = int - ) - - ctx.obj.config.pg_user = click.prompt( - 'Which user will authenticate with the server?', - default = ctx.obj.config.pg_user - ) - - ctx.obj.config.pg_pass = click.prompt( - 'User password', - hide_input = True, - show_default = False, - default = ctx.obj.config.pg_pass or "" - ) or None + config_postgresql(ctx.obj.config) ctx.obj.config.ca_type = click.prompt( 'Which caching backend?', @@ -344,9 +317,24 @@ def cli_switchbackend(ctx: click.Context) -> None: config = Config(ctx.obj.config.path, load = True) config.db_type = "sqlite" if config.db_type == "postgres" else "postgres" + + if config.db_type == "postgres": + if click.confirm("Setup PostgreSQL configuration?"): + config_postgresql(config) + + order = ("SQLite", "PostgreSQL") + click.pause("Make sure the database and user already exist before continuing") + + else: + order = ("PostgreSQL", "SQLite") + + click.echo(f"About to convert from {order[0]} to {order[1]}...") database = get_database(config, migrate = False) with database.session(True) as new, ctx.obj.database.session(False) as old: + if click.confirm("All tables in the destination database will be dropped. Continue?"): + new.drop_tables() + new.create_tables() for table in schema.TABLES.keys(): @@ -354,7 +342,7 @@ def cli_switchbackend(ctx: click.Context) -> None: new.insert(table, row).close() config.save() - click.echo(f"Converted database to {repr(config.db_type)}") + click.echo("Done!") @cli.group('config') @@ -1015,6 +1003,35 @@ def cli_whitelist_import(ctx: click.Context) -> None: click.echo('Imported whitelist from inboxes') +def config_postgresql(config: Config) -> None: + config.pg_name = click.prompt( + 'What is the name of the database?', + default = config.pg_name + ) + + config.pg_host = click.prompt( + 'What IP address, hostname, or unix socket does the server listen on?', + default = config.pg_host, + ) + + config.pg_port = click.prompt( + 'What port does the server listen on?', + default = config.pg_port, + type = int + ) + + config.pg_user = click.prompt( + 'Which user will authenticate with the server?', + default = config.pg_user + ) + + config.pg_pass = click.prompt( + 'User password', + hide_input = True, + show_default = False, + default = config.pg_pass or "" + ) or None + def main() -> None: cli(prog_name='activityrelay')