diff --git a/relay/application.py b/relay/application.py index d463244..f3d9810 100644 --- a/relay/application.py +++ b/relay/application.py @@ -26,6 +26,7 @@ from .views.api import handle_api_path from .views.frontend import handle_frontend_path if typing.TYPE_CHECKING: + from collections.abc import Coroutine from tinysql import Database, Row from .cache import Cache from .misc import Message, Response @@ -264,7 +265,7 @@ async def handle_response_headers(request: web.Request, handler: Coroutine) -> R # if not request.app['dev'] and request.path.endswith(('.css', '.js')): # resp.headers['Cache-Control'] = 'public,max-age=2628000,immutable' -# + # else: # resp.headers['Cache-Control'] = 'no-store' diff --git a/relay/config.py b/relay/config.py index e61c99a..84faab1 100644 --- a/relay/config.py +++ b/relay/config.py @@ -52,7 +52,11 @@ if IS_DOCKER: class Config: def __init__(self, path: str, load: bool = False): - self.path = Config.get_config_dir() + if path: + self.path = Path(path).expanduser().resolve() + + else: + self.path = Config.get_config_dir() self.listen = None self.port = None diff --git a/relay/views/base.py b/relay/views/base.py index 65859a5..f568525 100644 --- a/relay/views/base.py +++ b/relay/views/base.py @@ -11,9 +11,10 @@ from json.decoder import JSONDecodeError from ..misc import Response if typing.TYPE_CHECKING: + from aiohttp.web import Request from collections.abc import Callable, Coroutine, Generator from bsql import Database - from typing import Self + from typing import Any, Self from ..application import Application from ..cache import Cache from ..config import Config diff --git a/relay/views/frontend.py b/relay/views/frontend.py index 8710413..9df5e45 100644 --- a/relay/views/frontend.py +++ b/relay/views/frontend.py @@ -4,6 +4,7 @@ import typing from aiohttp import web from argon2.exceptions import VerifyMismatchError +from urllib.parse import urlparse from .base import View, register_route @@ -13,8 +14,11 @@ from ..misc import ACTOR_FORMATS, Message, Response if typing.TYPE_CHECKING: from aiohttp.web import Request + from collections.abc import Coroutine +# pylint: disable=no-self-use + UNAUTH_ROUTES = { '/', '/login' @@ -147,22 +151,22 @@ class AdminInstances(View): async def post(self, request: Request) -> Response: - data = {key: value for key, value in (await request.post()).items()} + data = await request.post() - if not data['actor'] and not data['domain']: + if not data.get('actor') and not data.get('domain'): return await self.get(request, error = 'Missing actor and/or domain') - if not data['domain']: + if not data.get('domain'): data['domain'] = urlparse(data['actor']).netloc - if not data['software']: + if not data.get('software'): nodeinfo = await self.client.fetch_nodeinfo(data['domain']) data['software'] = nodeinfo.sw_name - if not data['actor'] and data['software'] in ACTOR_FORMATS: + if not data.get('actor') and data['software'] in ACTOR_FORMATS: data['actor'] = ACTOR_FORMATS[data['software']].format(domain = data['domain']) - if not data['inbox'] and data['actor']: + if not data.get('inbox') and data['actor']: actor = await self.client.get(data['actor'], sign_headers = True, loads = Message.parse) data['inbox'] = actor.shared_inbox @@ -176,7 +180,7 @@ class AdminInstances(View): class AdminInstancesDelete(View): async def get(self, request: Request, domain: str) -> Response: with self.database.session() as conn: - if not (conn.get_inbox(domain)): + if not conn.get_inbox(domain): return await AdminInstances(request).get(request, message = 'Instance not found') conn.del_inbox(domain) @@ -213,7 +217,7 @@ class AdminWhitelist(View): return await self.get(request, error = 'Missing domain') with self.database.session(True) as conn: - if (ban := conn.get_domain_whitelist(data['domain'])): + if conn.get_domain_whitelist(data['domain']): return await self.get(request, message = "Domain already in whitelist") conn.put_domain_whitelist(data['domain']) @@ -225,7 +229,7 @@ class AdminWhitelist(View): class AdminWhitlistDelete(View): async def get(self, request: Request, domain: str) -> Response: with self.database.session() as conn: - if not (conn.get_domain_whitelist(domain)): + if not conn.get_domain_whitelist(domain): msg = 'Whitelisted domain not found' return await AdminWhitelist.run("GET", request, message = msg) @@ -263,7 +267,7 @@ class AdminDomainBans(View): return await self.get(request, error = 'Missing domain') with self.database.session(True) as conn: - if (ban := conn.get_domain_ban(data['domain'])): + if conn.get_domain_ban(data['domain']): conn.update_domain_ban( data['domain'], data.get('reason'), @@ -284,7 +288,7 @@ class AdminDomainBans(View): class AdminDomainBansDelete(View): async def get(self, request: Request, domain: str) -> Response: with self.database.session() as conn: - if not (conn.get_domain_ban(domain)): + if not conn.get_domain_ban(domain): return await AdminDomainBans.run("GET", request, message = 'Domain ban not found') conn.del_domain_ban(domain) @@ -321,7 +325,7 @@ class AdminSoftwareBans(View): return await self.get(request, error = 'Missing name') with self.database.session(True) as conn: - if (ban := conn.get_software_ban(data['name'])): + if conn.get_software_ban(data['name']): conn.update_software_ban( data['name'], data.get('reason'), @@ -342,7 +346,7 @@ class AdminSoftwareBans(View): class AdminSoftwareBansDelete(View): async def get(self, request: Request, name: str) -> Response: with self.database.session() as conn: - if not (conn.get_software_ban(name)): + if not conn.get_software_ban(name): return await AdminSoftwareBans.run("GET", request, message = 'Software ban not found') conn.del_software_ban(name) @@ -375,9 +379,8 @@ class AdminUsers(View): async def post(self, request: Request) -> Response: data = await request.post() required_fields = {'username', 'password', 'password2'} - print(data) - if not all(map(data.get, required_fields)): + if not all(data.get(field) for field in required_fields): return await self.get(request, error = 'Missing username and/or password') if data['password'] != data['password2']: @@ -396,7 +399,7 @@ class AdminUsers(View): class AdminUsersDelete(View): async def get(self, request: Request, name: str) -> Response: with self.database.session() as conn: - if not (conn.get_user(name)): + if not conn.get_user(name): return await AdminUsers.run("GET", request, message = 'User not found') conn.del_user(name)