replace misc functions

This commit is contained in:
Izalia Mae 2024-09-18 07:13:44 -04:00
parent c482677c32
commit 0d24aea764
10 changed files with 36 additions and 93 deletions

View file

@ -12,7 +12,7 @@ from aiohttp.web import HTTPException, StaticResource
from aiohttp_swagger import setup_swagger
from aputils.signer import Signer
from base64 import b64encode
from blib import HttpError
from blib import File, HttpError, port_check
from bsql import Database
from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta
@ -27,7 +27,7 @@ from .config import Config
from .database import Connection, get_database
from .database.schema import Instance
from .http_client import HttpClient
from .misc import JSON_PATHS, TOKEN_PATHS, Message, Response, check_open_port, get_resource
from .misc import JSON_PATHS, TOKEN_PATHS, Message, Response
from .template import Template
from .views import VIEWS
from .views.api import handle_api_path
@ -90,7 +90,7 @@ class Application(web.Application):
setup_swagger(
self,
ui_version = 3,
swagger_from_file = get_resource('data/swagger.yaml')
swagger_from_file = File.from_resource('relay', 'data/swagger.yaml')
)
@ -154,10 +154,12 @@ class Application(web.Application):
def register_static_routes(self) -> None:
if self['dev']:
static = StaticResource('/static', get_resource('frontend/static'))
static = StaticResource('/static', File.from_resource('relay', 'frontend/static'))
else:
static = CachedStaticResource('/static', get_resource('frontend/static'))
static = CachedStaticResource(
'/static', Path(File.from_resource('relay', 'frontend/static'))
)
self.router.register_resource(static)
@ -170,7 +172,7 @@ class Application(web.Application):
host = self.config.listen
port = self.config.port
if not check_open_port(host, port):
if port_check(port, '127.0.0.1' if host == '0.0.0.0' else host):
logging.error(f'A server is already running on {host}:{port}')
return

View file

@ -4,7 +4,7 @@ import json
import os
from abc import ABC, abstractmethod
from blib import Date
from blib import Date, convert_to_boolean
from bsql import Database, Row
from collections.abc import Callable, Iterator
from dataclasses import asdict, dataclass
@ -13,7 +13,7 @@ from redis import Redis
from typing import TYPE_CHECKING, Any, TypedDict
from .database import Connection, get_database
from .misc import Message, boolean
from .misc import Message
if TYPE_CHECKING:
from .application import Application
@ -26,7 +26,7 @@ BACKENDS: dict[str, type[Cache]] = {}
CONVERTERS: dict[str, tuple[SerializerCallback, DeserializerCallback]] = {
'str': (str, str),
'int': (str, int),
'bool': (str, boolean),
'bool': (str, convert_to_boolean),
'json': (json.dumps, json.loads),
'message': (lambda x: x.to_json(), Message.parse)
}

View file

@ -2,13 +2,12 @@ import json
import os
import yaml
from blib import convert_to_boolean
from functools import cached_property
from pathlib import Path
from typing import Any
from urllib.parse import urlparse
from .misc import boolean
class RelayConfig(dict[str, Any]):
def __init__(self, path: str):
@ -31,7 +30,7 @@ class RelayConfig(dict[str, Any]):
elif key == 'whitelist_enabled':
if not isinstance(value, bool):
value = boolean(value)
value = convert_to_boolean(value)
super().__setitem__(key, value)

View file

@ -1,6 +1,6 @@
import sqlite3
from blib import Date
from blib import Date, File
from bsql import Database
from .config import THEMES, ConfigData
@ -9,7 +9,6 @@ from .schema import TABLES, VERSIONS, migrate_0
from .. import logger as logging
from ..config import Config
from ..misc import get_resource
sqlite3.register_adapter(Date, Date.timestamp)
@ -37,7 +36,7 @@ def get_database(config: Config, migrate: bool = True) -> Database[Connection]:
**options
)
db.load_prepared_statements(get_resource('data/statements.sql'))
db.load_prepared_statements(File.from_resource('relay', 'data/statements.sql'))
db.connect()
if not migrate:

View file

@ -2,13 +2,13 @@ from __future__ import annotations
# removing the above line turns annotations into types instead of str objects which messes with
# `Field.type`
from blib import convert_to_boolean
from bsql import Row
from collections.abc import Callable, Sequence
from dataclasses import Field, asdict, dataclass, fields
from typing import TYPE_CHECKING, Any
from .. import logger as logging
from ..misc import boolean
if TYPE_CHECKING:
from typing import Self
@ -66,7 +66,7 @@ THEMES = {
CONFIG_CONVERT: dict[str, tuple[Callable[[Any], str], Callable[[str], Any]]] = {
'str': (str, str),
'int': (str, int),
'bool': (str, boolean),
'bool': (str, convert_to_boolean),
'logging.LogLevel': (lambda x: x.name, logging.LogLevel.parse)
}

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import secrets
from argon2 import PasswordHasher
from blib import Date
from blib import Date, convert_to_boolean
from bsql import Connection as SqlConnection, Row, Update
from collections.abc import Iterator
from datetime import datetime, timezone
@ -17,7 +17,7 @@ from .config import (
)
from .. import logger as logging
from ..misc import Message, boolean, get_app
from ..misc import Message, get_app
if TYPE_CHECKING:
from ..application import Application
@ -111,7 +111,7 @@ class Connection(SqlConnection):
self.app['workers'].set_log_level(value)
elif key in {'approval-required', 'whitelist-enabled'}:
value = boolean(value)
value = convert_to_boolean(value)
elif key == 'theme':
if value not in THEMES:

View file

@ -4,13 +4,10 @@ import aputils
import json
import os
import platform
import socket
from aiohttp.web import Response as AiohttpResponse
from collections.abc import Sequence
from datetime import datetime
from importlib.resources import files as pkgfiles
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypedDict, TypeVar
from uuid import uuid4
@ -40,11 +37,6 @@ MIMETYPES = {
'webmanifest': 'application/manifest+json'
}
NODEINFO_NS = {
'20': 'http://nodeinfo.diaspora.software/ns/schema/2.0',
'21': 'http://nodeinfo.diaspora.software/ns/schema/2.1'
}
ACTOR_FORMATS = {
'mastodon': 'https://{domain}/actor',
'akkoma': 'https://{domain}/relay',
@ -84,43 +76,6 @@ TOKEN_PATHS: tuple[str, ...] = (
)
def boolean(value: Any) -> bool:
if isinstance(value, str):
if value.lower() in {'on', 'y', 'yes', 'true', 'enable', 'enabled', '1'}:
return True
if value.lower() in {'off', 'n', 'no', 'false', 'disable', 'disabled', '0'}:
return False
raise TypeError(f'Cannot parse string "{value}" as a boolean')
if isinstance(value, int):
if value == 1:
return True
if value == 0:
return False
raise ValueError('Integer value must be 1 or 0')
if value is None:
return False
return bool(value)
def check_open_port(host: str, port: int) -> bool:
if host == '0.0.0.0':
host = '127.0.0.1'
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
return s.connect_ex((host, port)) != 0
except socket.error:
return False
def get_app() -> Application:
from .application import Application
@ -130,10 +85,6 @@ def get_app() -> Application:
return Application.DEFAULT
def get_resource(path: str) -> Path:
return Path(str(pkgfiles('relay'))).joinpath(path)
class JsonEncoder(json.JSONEncoder):
def default(self, o: Any) -> str:
if isinstance(o, datetime):
@ -250,18 +201,6 @@ class Response(AiohttpResponse):
return cls(**kwargs)
@classmethod
def new_error(cls: type[Self],
status: int,
body: str | bytes | dict[str, Any],
ctype: str = 'text') -> Self:
if ctype == 'json':
body = {'error': body}
return cls.new(body=body, status=status, ctype=ctype)
@classmethod
def new_redir(cls: type[Self], path: str, status: int = 307) -> Self:
body = f'Redirect to <a href="{path}">{path}</a>'

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import textwrap
from aiohttp.web import Request
from blib import File
from collections.abc import Callable
from hamlish_jinja import HamlishExtension
from jinja2 import Environment, FileSystemLoader
@ -13,7 +14,6 @@ from markdown import Markdown
from typing import TYPE_CHECKING, Any
from . import __version__
from .misc import get_resource
if TYPE_CHECKING:
from .application import Application
@ -33,7 +33,7 @@ class Template(Environment):
MarkdownExtension
],
loader = FileSystemLoader([
get_resource('frontend'),
File.from_resource('relay', 'frontend'),
app.config.path.parent.joinpath('template')
])
)

View file

@ -48,7 +48,7 @@ class ActorView(View):
# reject if actor is banned
if conn.get_domain_ban(self.actor.domain):
logging.verbose('Ignored request from banned actor: %s', self.actor.id)
return Response.new_error(403, 'access denied', 'json')
raise HttpError(403, 'access denied')
# reject if activity type isn't 'Follow' and the actor isn't following
if self.message.type != 'Follow' and not self.instance:
@ -57,7 +57,7 @@ class ActorView(View):
self.actor.id
)
return Response.new_error(401, 'access denied', 'json')
raise HttpError(401, 'access denied')
logging.debug('>> payload %s', self.message.to_json(4))
@ -78,7 +78,7 @@ class ActorView(View):
except Exception:
traceback.print_exc()
logging.verbose('Failed to parse inbox message')
logging.verbose('Failed to parse message from actor: %s', self.signature.keyid)
raise HttpError(400, 'failed to parse message')
if message is None:
@ -94,13 +94,14 @@ class ActorView(View):
try:
self.actor = await self.client.get(self.signature.keyid, True, Message)
except HttpError:
except HttpError as e:
# ld signatures aren't handled atm, so just ignore it
if self.message.type == 'Delete':
logging.verbose('Instance sent a delete which cannot be handled')
raise HttpError(202, '')
logging.verbose('Failed to fetch actor: %s', self.signature.keyid)
logging.debug('HTTP Status %i: %s', e.status, e.message)
raise HttpError(400, 'failed to fetch actor')
except Exception:
@ -162,10 +163,10 @@ class WebfingerView(View):
subject = request.query['resource']
except KeyError:
return Response.new_error(400, 'missing "resource" query key', 'json')
raise HttpError(400, 'missing "resource" query key')
if subject != f'acct:relay@{self.config.domain}':
return Response.new_error(404, 'user not found', 'json')
raise HttpError(404, 'user not found')
data = aputils.Webfinger.new(
handle = 'relay',

View file

@ -10,7 +10,7 @@ from .base import View, register_route
from .. import __version__
from ..database import ConfigData, schema
from ..misc import Message, Response, boolean
from ..misc import Message, Response
DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob'
@ -97,7 +97,7 @@ class OauthAuthorize(View):
with self.database.session(True) as conn:
if (app := conn.get_app(data['client_id'], data['client_secret'])) is None:
return Response.new_error(404, 'Could not find app', 'json')
raise HttpError(404, 'Could not find app')
if convert_to_boolean(data['response']):
if app.token is not None:
@ -393,7 +393,10 @@ class RequestView(View):
try:
with self.database.session(True) as conn:
instance = conn.put_request_response(data['domain'], boolean(data['accept']))
instance = conn.put_request_response(
data['domain'],
convert_to_boolean(data['accept'])
)
except KeyError:
raise HttpError(404, 'Request not found') from None
@ -402,7 +405,7 @@ class RequestView(View):
host = self.config.domain,
actor = instance.actor,
followid = instance.followid,
accept = boolean(data['accept'])
accept = convert_to_boolean(data['accept'])
)
self.app.push_message(instance.inbox, message, instance)