diff --git a/relay/database/__init__.py b/relay/database/__init__.py index 90d744d..b4dad31 100644 --- a/relay/database/__init__.py +++ b/relay/database/__init__.py @@ -30,18 +30,22 @@ def get_database(state: State, migrate: bool = True) -> Database[Connection]: db: Database[Connection] - if state.config.db_type == "sqlite": - db = Database.sqlite(state.config.sqlite_path, **options) + match state.config.db_type: + case "sqlite" | "sqlite3": + db = Database.sqlite(state.config.sqlite_path, **options) - elif state.config.db_type == "postgres": - db = Database.postgresql( - state.config.pg_name, - state.config.pg_host, - state.config.pg_port, - state.config.pg_user, - state.config.pg_pass, - **options - ) + case "postgres" | "postgresql": + db = Database.postgresql( + state.config.pg_name, + state.config.pg_host, + state.config.pg_port, + state.config.pg_user, + state.config.pg_pass, + **options + ) + + case _: + raise RuntimeError(f"Invalid database backend: {state.config.db_type}") db.load_prepared_statements(File.from_resource("relay", "data/statements.sql")) db.connect()