replace misc functions

This commit is contained in:
Izalia Mae 2024-09-18 07:13:44 -04:00
parent c482677c32
commit 0d24aea764
10 changed files with 36 additions and 93 deletions

View file

@ -12,7 +12,7 @@ from aiohttp.web import HTTPException, StaticResource
from aiohttp_swagger import setup_swagger from aiohttp_swagger import setup_swagger
from aputils.signer import Signer from aputils.signer import Signer
from base64 import b64encode from base64 import b64encode
from blib import HttpError from blib import File, HttpError, port_check
from bsql import Database from bsql import Database
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -27,7 +27,7 @@ from .config import Config
from .database import Connection, get_database from .database import Connection, get_database
from .database.schema import Instance from .database.schema import Instance
from .http_client import HttpClient from .http_client import HttpClient
from .misc import JSON_PATHS, TOKEN_PATHS, Message, Response, check_open_port, get_resource from .misc import JSON_PATHS, TOKEN_PATHS, Message, Response
from .template import Template from .template import Template
from .views import VIEWS from .views import VIEWS
from .views.api import handle_api_path from .views.api import handle_api_path
@ -90,7 +90,7 @@ class Application(web.Application):
setup_swagger( setup_swagger(
self, self,
ui_version = 3, ui_version = 3,
swagger_from_file = get_resource('data/swagger.yaml') swagger_from_file = File.from_resource('relay', 'data/swagger.yaml')
) )
@ -154,10 +154,12 @@ class Application(web.Application):
def register_static_routes(self) -> None: def register_static_routes(self) -> None:
if self['dev']: if self['dev']:
static = StaticResource('/static', get_resource('frontend/static')) static = StaticResource('/static', File.from_resource('relay', 'frontend/static'))
else: else:
static = CachedStaticResource('/static', get_resource('frontend/static')) static = CachedStaticResource(
'/static', Path(File.from_resource('relay', 'frontend/static'))
)
self.router.register_resource(static) self.router.register_resource(static)
@ -170,7 +172,7 @@ class Application(web.Application):
host = self.config.listen host = self.config.listen
port = self.config.port port = self.config.port
if not check_open_port(host, port): if port_check(port, '127.0.0.1' if host == '0.0.0.0' else host):
logging.error(f'A server is already running on {host}:{port}') logging.error(f'A server is already running on {host}:{port}')
return return

View file

@ -4,7 +4,7 @@ import json
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from blib import Date from blib import Date, convert_to_boolean
from bsql import Database, Row from bsql import Database, Row
from collections.abc import Callable, Iterator from collections.abc import Callable, Iterator
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
@ -13,7 +13,7 @@ from redis import Redis
from typing import TYPE_CHECKING, Any, TypedDict from typing import TYPE_CHECKING, Any, TypedDict
from .database import Connection, get_database from .database import Connection, get_database
from .misc import Message, boolean from .misc import Message
if TYPE_CHECKING: if TYPE_CHECKING:
from .application import Application from .application import Application
@ -26,7 +26,7 @@ BACKENDS: dict[str, type[Cache]] = {}
CONVERTERS: dict[str, tuple[SerializerCallback, DeserializerCallback]] = { CONVERTERS: dict[str, tuple[SerializerCallback, DeserializerCallback]] = {
'str': (str, str), 'str': (str, str),
'int': (str, int), 'int': (str, int),
'bool': (str, boolean), 'bool': (str, convert_to_boolean),
'json': (json.dumps, json.loads), 'json': (json.dumps, json.loads),
'message': (lambda x: x.to_json(), Message.parse) 'message': (lambda x: x.to_json(), Message.parse)
} }

View file

@ -2,13 +2,12 @@ import json
import os import os
import yaml import yaml
from blib import convert_to_boolean
from functools import cached_property from functools import cached_property
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from urllib.parse import urlparse from urllib.parse import urlparse
from .misc import boolean
class RelayConfig(dict[str, Any]): class RelayConfig(dict[str, Any]):
def __init__(self, path: str): def __init__(self, path: str):
@ -31,7 +30,7 @@ class RelayConfig(dict[str, Any]):
elif key == 'whitelist_enabled': elif key == 'whitelist_enabled':
if not isinstance(value, bool): if not isinstance(value, bool):
value = boolean(value) value = convert_to_boolean(value)
super().__setitem__(key, value) super().__setitem__(key, value)

View file

@ -1,6 +1,6 @@
import sqlite3 import sqlite3
from blib import Date from blib import Date, File
from bsql import Database from bsql import Database
from .config import THEMES, ConfigData from .config import THEMES, ConfigData
@ -9,7 +9,6 @@ from .schema import TABLES, VERSIONS, migrate_0
from .. import logger as logging from .. import logger as logging
from ..config import Config from ..config import Config
from ..misc import get_resource
sqlite3.register_adapter(Date, Date.timestamp) sqlite3.register_adapter(Date, Date.timestamp)
@ -37,7 +36,7 @@ def get_database(config: Config, migrate: bool = True) -> Database[Connection]:
**options **options
) )
db.load_prepared_statements(get_resource('data/statements.sql')) db.load_prepared_statements(File.from_resource('relay', 'data/statements.sql'))
db.connect() db.connect()
if not migrate: if not migrate:

View file

@ -2,13 +2,13 @@ from __future__ import annotations
# removing the above line turns annotations into types instead of str objects which messes with # removing the above line turns annotations into types instead of str objects which messes with
# `Field.type` # `Field.type`
from blib import convert_to_boolean
from bsql import Row from bsql import Row
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from dataclasses import Field, asdict, dataclass, fields from dataclasses import Field, asdict, dataclass, fields
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from .. import logger as logging from .. import logger as logging
from ..misc import boolean
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Self from typing import Self
@ -66,7 +66,7 @@ THEMES = {
CONFIG_CONVERT: dict[str, tuple[Callable[[Any], str], Callable[[str], Any]]] = { CONFIG_CONVERT: dict[str, tuple[Callable[[Any], str], Callable[[str], Any]]] = {
'str': (str, str), 'str': (str, str),
'int': (str, int), 'int': (str, int),
'bool': (str, boolean), 'bool': (str, convert_to_boolean),
'logging.LogLevel': (lambda x: x.name, logging.LogLevel.parse) 'logging.LogLevel': (lambda x: x.name, logging.LogLevel.parse)
} }

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import secrets import secrets
from argon2 import PasswordHasher from argon2 import PasswordHasher
from blib import Date from blib import Date, convert_to_boolean
from bsql import Connection as SqlConnection, Row, Update from bsql import Connection as SqlConnection, Row, Update
from collections.abc import Iterator from collections.abc import Iterator
from datetime import datetime, timezone from datetime import datetime, timezone
@ -17,7 +17,7 @@ from .config import (
) )
from .. import logger as logging from .. import logger as logging
from ..misc import Message, boolean, get_app from ..misc import Message, get_app
if TYPE_CHECKING: if TYPE_CHECKING:
from ..application import Application from ..application import Application
@ -111,7 +111,7 @@ class Connection(SqlConnection):
self.app['workers'].set_log_level(value) self.app['workers'].set_log_level(value)
elif key in {'approval-required', 'whitelist-enabled'}: elif key in {'approval-required', 'whitelist-enabled'}:
value = boolean(value) value = convert_to_boolean(value)
elif key == 'theme': elif key == 'theme':
if value not in THEMES: if value not in THEMES:

View file

@ -4,13 +4,10 @@ import aputils
import json import json
import os import os
import platform import platform
import socket
from aiohttp.web import Response as AiohttpResponse from aiohttp.web import Response as AiohttpResponse
from collections.abc import Sequence from collections.abc import Sequence
from datetime import datetime from datetime import datetime
from importlib.resources import files as pkgfiles
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypedDict, TypeVar from typing import TYPE_CHECKING, Any, TypedDict, TypeVar
from uuid import uuid4 from uuid import uuid4
@ -40,11 +37,6 @@ MIMETYPES = {
'webmanifest': 'application/manifest+json' 'webmanifest': 'application/manifest+json'
} }
NODEINFO_NS = {
'20': 'http://nodeinfo.diaspora.software/ns/schema/2.0',
'21': 'http://nodeinfo.diaspora.software/ns/schema/2.1'
}
ACTOR_FORMATS = { ACTOR_FORMATS = {
'mastodon': 'https://{domain}/actor', 'mastodon': 'https://{domain}/actor',
'akkoma': 'https://{domain}/relay', 'akkoma': 'https://{domain}/relay',
@ -84,43 +76,6 @@ TOKEN_PATHS: tuple[str, ...] = (
) )
def boolean(value: Any) -> bool:
if isinstance(value, str):
if value.lower() in {'on', 'y', 'yes', 'true', 'enable', 'enabled', '1'}:
return True
if value.lower() in {'off', 'n', 'no', 'false', 'disable', 'disabled', '0'}:
return False
raise TypeError(f'Cannot parse string "{value}" as a boolean')
if isinstance(value, int):
if value == 1:
return True
if value == 0:
return False
raise ValueError('Integer value must be 1 or 0')
if value is None:
return False
return bool(value)
def check_open_port(host: str, port: int) -> bool:
if host == '0.0.0.0':
host = '127.0.0.1'
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
return s.connect_ex((host, port)) != 0
except socket.error:
return False
def get_app() -> Application: def get_app() -> Application:
from .application import Application from .application import Application
@ -130,10 +85,6 @@ def get_app() -> Application:
return Application.DEFAULT return Application.DEFAULT
def get_resource(path: str) -> Path:
return Path(str(pkgfiles('relay'))).joinpath(path)
class JsonEncoder(json.JSONEncoder): class JsonEncoder(json.JSONEncoder):
def default(self, o: Any) -> str: def default(self, o: Any) -> str:
if isinstance(o, datetime): if isinstance(o, datetime):
@ -250,18 +201,6 @@ class Response(AiohttpResponse):
return cls(**kwargs) return cls(**kwargs)
@classmethod
def new_error(cls: type[Self],
status: int,
body: str | bytes | dict[str, Any],
ctype: str = 'text') -> Self:
if ctype == 'json':
body = {'error': body}
return cls.new(body=body, status=status, ctype=ctype)
@classmethod @classmethod
def new_redir(cls: type[Self], path: str, status: int = 307) -> Self: def new_redir(cls: type[Self], path: str, status: int = 307) -> Self:
body = f'Redirect to <a href="{path}">{path}</a>' body = f'Redirect to <a href="{path}">{path}</a>'

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import textwrap import textwrap
from aiohttp.web import Request from aiohttp.web import Request
from blib import File
from collections.abc import Callable from collections.abc import Callable
from hamlish_jinja import HamlishExtension from hamlish_jinja import HamlishExtension
from jinja2 import Environment, FileSystemLoader from jinja2 import Environment, FileSystemLoader
@ -13,7 +14,6 @@ from markdown import Markdown
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from . import __version__ from . import __version__
from .misc import get_resource
if TYPE_CHECKING: if TYPE_CHECKING:
from .application import Application from .application import Application
@ -33,7 +33,7 @@ class Template(Environment):
MarkdownExtension MarkdownExtension
], ],
loader = FileSystemLoader([ loader = FileSystemLoader([
get_resource('frontend'), File.from_resource('relay', 'frontend'),
app.config.path.parent.joinpath('template') app.config.path.parent.joinpath('template')
]) ])
) )

View file

@ -48,7 +48,7 @@ class ActorView(View):
# reject if actor is banned # reject if actor is banned
if conn.get_domain_ban(self.actor.domain): if conn.get_domain_ban(self.actor.domain):
logging.verbose('Ignored request from banned actor: %s', self.actor.id) logging.verbose('Ignored request from banned actor: %s', self.actor.id)
return Response.new_error(403, 'access denied', 'json') raise HttpError(403, 'access denied')
# reject if activity type isn't 'Follow' and the actor isn't following # reject if activity type isn't 'Follow' and the actor isn't following
if self.message.type != 'Follow' and not self.instance: if self.message.type != 'Follow' and not self.instance:
@ -57,7 +57,7 @@ class ActorView(View):
self.actor.id self.actor.id
) )
return Response.new_error(401, 'access denied', 'json') raise HttpError(401, 'access denied')
logging.debug('>> payload %s', self.message.to_json(4)) logging.debug('>> payload %s', self.message.to_json(4))
@ -78,7 +78,7 @@ class ActorView(View):
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
logging.verbose('Failed to parse inbox message') logging.verbose('Failed to parse message from actor: %s', self.signature.keyid)
raise HttpError(400, 'failed to parse message') raise HttpError(400, 'failed to parse message')
if message is None: if message is None:
@ -94,13 +94,14 @@ class ActorView(View):
try: try:
self.actor = await self.client.get(self.signature.keyid, True, Message) self.actor = await self.client.get(self.signature.keyid, True, Message)
except HttpError: except HttpError as e:
# ld signatures aren't handled atm, so just ignore it # ld signatures aren't handled atm, so just ignore it
if self.message.type == 'Delete': if self.message.type == 'Delete':
logging.verbose('Instance sent a delete which cannot be handled') logging.verbose('Instance sent a delete which cannot be handled')
raise HttpError(202, '') raise HttpError(202, '')
logging.verbose('Failed to fetch actor: %s', self.signature.keyid) logging.verbose('Failed to fetch actor: %s', self.signature.keyid)
logging.debug('HTTP Status %i: %s', e.status, e.message)
raise HttpError(400, 'failed to fetch actor') raise HttpError(400, 'failed to fetch actor')
except Exception: except Exception:
@ -162,10 +163,10 @@ class WebfingerView(View):
subject = request.query['resource'] subject = request.query['resource']
except KeyError: except KeyError:
return Response.new_error(400, 'missing "resource" query key', 'json') raise HttpError(400, 'missing "resource" query key')
if subject != f'acct:relay@{self.config.domain}': if subject != f'acct:relay@{self.config.domain}':
return Response.new_error(404, 'user not found', 'json') raise HttpError(404, 'user not found')
data = aputils.Webfinger.new( data = aputils.Webfinger.new(
handle = 'relay', handle = 'relay',

View file

@ -10,7 +10,7 @@ from .base import View, register_route
from .. import __version__ from .. import __version__
from ..database import ConfigData, schema from ..database import ConfigData, schema
from ..misc import Message, Response, boolean from ..misc import Message, Response
DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob' DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob'
@ -97,7 +97,7 @@ class OauthAuthorize(View):
with self.database.session(True) as conn: with self.database.session(True) as conn:
if (app := conn.get_app(data['client_id'], data['client_secret'])) is None: if (app := conn.get_app(data['client_id'], data['client_secret'])) is None:
return Response.new_error(404, 'Could not find app', 'json') raise HttpError(404, 'Could not find app')
if convert_to_boolean(data['response']): if convert_to_boolean(data['response']):
if app.token is not None: if app.token is not None:
@ -393,7 +393,10 @@ class RequestView(View):
try: try:
with self.database.session(True) as conn: with self.database.session(True) as conn:
instance = conn.put_request_response(data['domain'], boolean(data['accept'])) instance = conn.put_request_response(
data['domain'],
convert_to_boolean(data['accept'])
)
except KeyError: except KeyError:
raise HttpError(404, 'Request not found') from None raise HttpError(404, 'Request not found') from None
@ -402,7 +405,7 @@ class RequestView(View):
host = self.config.domain, host = self.config.domain,
actor = instance.actor, actor = instance.actor,
followid = instance.followid, followid = instance.followid,
accept = boolean(data['accept']) accept = convert_to_boolean(data['accept'])
) )
self.app.push_message(instance.inbox, message, instance) self.app.push_message(instance.inbox, message, instance)