diff --git a/relay/application.py b/relay/application.py index cf9611b..12c60b9 100644 --- a/relay/application.py +++ b/relay/application.py @@ -22,16 +22,10 @@ from .cache import get_cache from .config import Config from .database import get_database from .http_client import HttpClient -from .misc import check_open_port +from .misc import check_open_port, get_resource from .views import VIEWS from .views.api import handle_api_path -try: - from importlib.resources import files as pkgfiles - -except ImportError: - from importlib_resources import files as pkgfiles - if typing.TYPE_CHECKING: from tinysql import Database, Row from .cache import Cache @@ -75,7 +69,7 @@ class Application(web.Application): setup_swagger(self, ui_version = 3, - swagger_from_file = pkgfiles('relay').joinpath('data', 'swagger.yaml') + swagger_from_file = get_resource('data/swagger.yaml') ) diff --git a/relay/database/__init__.py b/relay/database/__init__.py index c7e9a1f..a55f2c8 100644 --- a/relay/database/__init__.py +++ b/relay/database/__init__.py @@ -8,12 +8,7 @@ from .connection import RELAY_SOFTWARE, Connection from .schema import TABLES, VERSIONS, migrate_0 from .. import logger as logging - -try: - from importlib.resources import files as pkgfiles - -except ImportError: # pylint: disable=duplicate-code - from importlib_resources import files as pkgfiles +from ..misc import get_resource if typing.TYPE_CHECKING: from .config import Config @@ -21,15 +16,15 @@ if typing.TYPE_CHECKING: def get_database(config: Config, migrate: bool = True) -> bsql.Database: options = { - "connection_class": Connection, - "pool_size": 5, - "tables": TABLES + 'connection_class': Connection, + 'pool_size': 5, + 'tables': TABLES } - if config.db_type == "sqlite": + if config.db_type == 'sqlite': db = bsql.Database.sqlite(config.sqlite_path, **options) - elif config.db_type == "postgres": + elif config.db_type == 'postgres': db = bsql.Database.postgresql( config.pg_name, config.pg_host, @@ -39,7 +34,7 @@ def get_database(config: Config, migrate: bool = True) -> bsql.Database: **options ) - db.load_prepared_statements(pkgfiles("relay").joinpath("data", "statements.sql")) + db.load_prepared_statements(get_resource('data/statements.sql')) db.connect() if not migrate: diff --git a/relay/misc.py b/relay/misc.py index 25c0a1e..62d4643 100644 --- a/relay/misc.py +++ b/relay/misc.py @@ -10,7 +10,14 @@ from aputils.message import Message as ApMessage from datetime import datetime from uuid import uuid4 +try: + from importlib.resources import files as pkgfiles + +except ImportError: + from importlib_resources import files as pkgfiles + if typing.TYPE_CHECKING: + from pathlib import Path from typing import Any from .application import Application @@ -75,6 +82,10 @@ def get_app() -> Application: return Application.DEFAULT +def get_resource(path: str) -> Path: + return pkgfiles('relay').joinpath(path) + + class JsonEncoder(json.JSONEncoder): def default(self, obj: Any) -> str: if isinstance(obj, datetime):