create get_resource method

This commit is contained in:
Izalia Mae 2024-02-22 17:54:15 -05:00
parent 097a53a539
commit 26c5c05320
3 changed files with 20 additions and 20 deletions

View file

@ -22,16 +22,10 @@ from .cache import get_cache
from .config import Config from .config import Config
from .database import get_database from .database import get_database
from .http_client import HttpClient from .http_client import HttpClient
from .misc import check_open_port from .misc import check_open_port, get_resource
from .views import VIEWS from .views import VIEWS
from .views.api import handle_api_path from .views.api import handle_api_path
try:
from importlib.resources import files as pkgfiles
except ImportError:
from importlib_resources import files as pkgfiles
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from tinysql import Database, Row from tinysql import Database, Row
from .cache import Cache from .cache import Cache
@ -75,7 +69,7 @@ class Application(web.Application):
setup_swagger(self, setup_swagger(self,
ui_version = 3, ui_version = 3,
swagger_from_file = pkgfiles('relay').joinpath('data', 'swagger.yaml') swagger_from_file = get_resource('data/swagger.yaml')
) )

View file

@ -8,12 +8,7 @@ from .connection import RELAY_SOFTWARE, Connection
from .schema import TABLES, VERSIONS, migrate_0 from .schema import TABLES, VERSIONS, migrate_0
from .. import logger as logging from .. import logger as logging
from ..misc import get_resource
try:
from importlib.resources import files as pkgfiles
except ImportError: # pylint: disable=duplicate-code
from importlib_resources import files as pkgfiles
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from .config import Config from .config import Config
@ -21,15 +16,15 @@ if typing.TYPE_CHECKING:
def get_database(config: Config, migrate: bool = True) -> bsql.Database: def get_database(config: Config, migrate: bool = True) -> bsql.Database:
options = { options = {
"connection_class": Connection, 'connection_class': Connection,
"pool_size": 5, 'pool_size': 5,
"tables": TABLES 'tables': TABLES
} }
if config.db_type == "sqlite": if config.db_type == 'sqlite':
db = bsql.Database.sqlite(config.sqlite_path, **options) db = bsql.Database.sqlite(config.sqlite_path, **options)
elif config.db_type == "postgres": elif config.db_type == 'postgres':
db = bsql.Database.postgresql( db = bsql.Database.postgresql(
config.pg_name, config.pg_name,
config.pg_host, config.pg_host,
@ -39,7 +34,7 @@ def get_database(config: Config, migrate: bool = True) -> bsql.Database:
**options **options
) )
db.load_prepared_statements(pkgfiles("relay").joinpath("data", "statements.sql")) db.load_prepared_statements(get_resource('data/statements.sql'))
db.connect() db.connect()
if not migrate: if not migrate:

View file

@ -10,7 +10,14 @@ from aputils.message import Message as ApMessage
from datetime import datetime from datetime import datetime
from uuid import uuid4 from uuid import uuid4
try:
from importlib.resources import files as pkgfiles
except ImportError:
from importlib_resources import files as pkgfiles
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from pathlib import Path
from typing import Any from typing import Any
from .application import Application from .application import Application
@ -75,6 +82,10 @@ def get_app() -> Application:
return Application.DEFAULT return Application.DEFAULT
def get_resource(path: str) -> Path:
return pkgfiles('relay').joinpath(path)
class JsonEncoder(json.JSONEncoder): class JsonEncoder(json.JSONEncoder):
def default(self, obj: Any) -> str: def default(self, obj: Any) -> str:
if isinstance(obj, datetime): if isinstance(obj, datetime):