Compare commits

...

3 commits

Author SHA1 Message Date
Izalia Mae a6f1738b73 fix user updating 2024-03-16 06:11:47 -04:00
Izalia Mae ea0658e2ea don't set csp header on /api routes 2024-03-16 06:10:58 -04:00
Izalia Mae 5c210dc20f fix linting issues 2024-03-16 01:36:20 -04:00
6 changed files with 50 additions and 48 deletions

View file

@ -35,6 +35,21 @@ if typing.TYPE_CHECKING:
from .misc import Message, Response from .misc import Message, Response
def get_csp(request: web.Request) -> str:
data = [
"default-src 'none'",
f"script-src 'nonce-{request['hash']}'",
f"style-src 'self' 'nonce-{request['hash']}'",
"form-action 'self'",
"connect-src 'self'",
"img-src 'self'",
"object-src 'none'",
"frame-ancestors 'none'"
]
return '; '.join(data) + ';'
class Application(web.Application): class Application(web.Application):
DEFAULT: Application | None = None DEFAULT: Application | None = None
@ -127,21 +142,6 @@ class Application(web.Application):
return timedelta(seconds=uptime.seconds) return timedelta(seconds=uptime.seconds)
def get_csp(self, request: Request) -> str:
data = [
"default-src 'none'",
f"script-src 'nonce-{request['hash']}'",
f"style-src 'self' 'nonce-{request['hash']}'",
"form-action 'self'",
"connect-src 'self'",
"img-src 'self'",
"object-src 'none'",
"frame-ancestors 'none'"
]
return '; '.join(data) + ';'
def push_message(self, inbox: str, message: Message, instance: Row) -> None: def push_message(self, inbox: str, message: Message, instance: Row) -> None:
self['push_queue'].put((inbox, message, instance)) self['push_queue'].put((inbox, message, instance))
@ -240,7 +240,7 @@ class CachedStaticResource(StaticResource):
def __init__(self, prefix: str, path: Path): def __init__(self, prefix: str, path: Path):
StaticResource.__init__(self, prefix, path) StaticResource.__init__(self, prefix, path)
self.cache: dict[Path, bytes] = {} self.cache: dict[str, bytes] = {}
for filename in path.rglob('*'): for filename in path.rglob('*'):
if filename.is_dir(): if filename.is_dir():
@ -333,8 +333,8 @@ async def handle_response_headers(request: web.Request, handler: Callable) -> Re
resp.headers['Server'] = 'ActivityRelay' resp.headers['Server'] = 'ActivityRelay'
# Still have to figure out how csp headers work # Still have to figure out how csp headers work
if resp.content_type == 'text/html': if resp.content_type == 'text/html' and not request.path.startswith("/api"):
resp.headers['Content-Security-Policy'] = Application.DEFAULT.get_csp(request) resp.headers['Content-Security-Policy'] = get_csp(request)
if not request.app['dev'] and request.path.endswith(('.css', '.js')): if not request.app['dev'] and request.path.endswith(('.css', '.js')):
# cache for 2 weeks # cache for 2 weeks

View file

@ -192,26 +192,29 @@ class Connection(SqlConnection):
def put_user(self, username: str, password: str | None, handle: str | None = None) -> Row: def put_user(self, username: str, password: str | None, handle: str | None = None) -> Row:
if self.get_user(username): if self.get_user(username):
data = { data: dict[str, str] = {}
'username': username
}
if password: if password:
data['password'] = password data['hash'] = self.hasher.hash(password)
if handle: if handle:
data['handler'] = handle data['handle'] = handle
else: stmt = Update("users", data)
if password is None: stmt.set_where("username", username)
raise ValueError('Password cannot be empty')
data = { with self.query(stmt) as cur:
'username': username, return cur.one()
'hash': self.hasher.hash(password),
'handle': handle, if password is None:
'created': datetime.now(tz = timezone.utc) raise ValueError('Password cannot be empty')
}
data = {
'username': username,
'hash': self.hasher.hash(password),
'handle': handle,
'created': datetime.now(tz = timezone.utc)
}
with self.run('put-user', data) as cur: with self.run('put-user', data) as cur:
return cur.one() # type: ignore return cur.one() # type: ignore

View file

@ -100,8 +100,8 @@ def cli_run(dev: bool):
try: try:
while True: while True:
handler.proc.stdin.write(sys.stdin.read().encode('UTF-8')) handler.proc.stdin.write(sys.stdin.read().encode('UTF-8')) # type: ignore
handler.proc.stdin.flush() handler.proc.stdin.flush() # type: ignore
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
@ -121,12 +121,12 @@ class WatchHandler(PatternMatchingEventHandler):
PatternMatchingEventHandler.__init__(self) PatternMatchingEventHandler.__init__(self)
self.dev: bool = dev self.dev: bool = dev
self.proc = None self.proc: subprocess.Popen | None = None
self.last_restart = None self.last_restart: datetime | None = None
def kill_proc(self): def kill_proc(self):
if self.proc.poll() is not None: if not self.proc or self.proc.poll() is not None:
return return
logging.info(f'Terminating process {self.proc.pid}') logging.info(f'Terminating process {self.proc.pid}')

View file

@ -15,6 +15,7 @@ from ..misc import Message, Response, boolean, get_app
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from aiohttp.web import Request from aiohttp.web import Request
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import Any
PUBLIC_API_PATHS: Sequence[tuple[str, str]] = ( PUBLIC_API_PATHS: Sequence[tuple[str, str]] = (
@ -149,7 +150,7 @@ class Config(View):
if isinstance(data, Response): if isinstance(data, Response):
return data return data
data['key'] = data['key'].replace('-', '_'); data['key'] = data['key'].replace('-', '_')
if data['key'] not in ConfigData.USER_KEYS(): if data['key'] not in ConfigData.USER_KEYS():
return Response.new_error(400, 'Invalid key', 'json') return Response.new_error(400, 'Invalid key', 'json')
@ -255,7 +256,7 @@ class RequestView(View):
async def post(self, request: Request) -> Response: async def post(self, request: Request) -> Response:
data = await self.get_api_data(['domain', 'accept'], []) data: dict[str, Any] | Response = await self.get_api_data(['domain', 'accept'], [])
data['accept'] = boolean(data['accept']) data['accept'] = boolean(data['accept'])
try: try:
@ -430,7 +431,7 @@ class User(View):
async def patch(self, request: Request) -> Response: async def patch(self, request: Request) -> Response:
data = await self.get_api_data(['username'], ['password', ['handle']]) data = await self.get_api_data(['username'], ['password', 'handle'])
if isinstance(data, Response): if isinstance(data, Response):
return data return data

View file

@ -126,7 +126,7 @@ class View(AbstractView):
return Response.new_error(400, 'Invalid JSON data', 'json') return Response.new_error(400, 'Invalid JSON data', 'json')
else: else:
post_data = convert_data(self.request.query) # type: ignore post_data = convert_data(self.request.query)
data = {} data = {}

View file

@ -3,14 +3,12 @@ from __future__ import annotations
import typing import typing
from aiohttp import web from aiohttp import web
from argon2.exceptions import VerifyMismatchError
from urllib.parse import urlparse
from .base import View, register_route from .base import View, register_route
from ..database import THEMES, ConfigData from ..database import THEMES
from ..logger import LogLevel from ..logger import LogLevel
from ..misc import ACTOR_FORMATS, Message, Response, get_app from ..misc import Response, get_app
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from aiohttp.web import Request from aiohttp.web import Request
@ -40,9 +38,9 @@ async def handle_frontend_path(request: web.Request, handler: Callable) -> Respo
return Response.new('', 302, {'Location': '/'}) return Response.new('', 302, {'Location': '/'})
if not request['user'] and request.path.startswith('/admin'): if not request['user'] and request.path.startswith('/admin'):
response = Response.new('', 302, {'Location': f'/login?redir={request.path}'}) response = Response.new('', 302, {'Location': f'/login?redir={request.path}'})
response.del_cookie('user-token') response.del_cookie('user-token')
return response return response
response = await handler(request) response = await handler(request)