Compare commits

..

6 commits

Author SHA1 Message Date
Izalia Mae cc76256b12 Merge branch 'dev' into 'main'
Draft: version 0.3.3

See merge request pleroma/relay!59
2024-09-15 09:06:13 +00:00
Izalia Mae 7f0a1a4e08 replace HttpError with blib.HttpError 2024-09-15 04:54:37 -04:00
Izalia Mae ca70c1b293 change how HttpClient.get handles errors
* raise `blib.HttpError` on non 200 and 202 statuses
* use `EmptyBodyError` instead of `ValueError` for empty bodies
2024-09-15 04:35:00 -04:00
Izalia Mae 619b1d5560 update barkshark-lib to 0.2.3 2024-09-15 04:20:09 -04:00
Izalia Mae cf44b0dafe don't vacuum postgresql database 2024-09-14 22:44:51 -04:00
Izalia Mae 0cea1ff9e9 ensure Date objects returned from db have a timezone 2024-09-14 22:44:14 -04:00
12 changed files with 75 additions and 57 deletions

View file

@ -18,7 +18,7 @@ dependencies = [
"aiohttp >= 3.9.5", "aiohttp >= 3.9.5",
"aiohttp-swagger[performance] == 1.0.16", "aiohttp-swagger[performance] == 1.0.16",
"argon2-cffi == 23.1.0", "argon2-cffi == 23.1.0",
"barkshark-lib >= 0.2.2.post2, < 0.3.0", "barkshark-lib >= 0.2.3, < 0.3.0",
"barkshark-sql >= 0.2.0, < 0.3.0", "barkshark-sql >= 0.2.0, < 0.3.0",
"click == 8.1.2", "click == 8.1.2",
"hiredis == 2.3.2", "hiredis == 2.3.2",

View file

@ -12,6 +12,7 @@ from aiohttp.web import HTTPException, StaticResource
from aiohttp_swagger import setup_swagger from aiohttp_swagger import setup_swagger
from aputils.signer import Signer from aputils.signer import Signer
from base64 import b64encode from base64 import b64encode
from blib import HttpError
from bsql import Database from bsql import Database
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -26,8 +27,7 @@ from .config import Config
from .database import Connection, get_database from .database import Connection, get_database
from .database.schema import Instance from .database.schema import Instance
from .http_client import HttpClient from .http_client import HttpClient
from .misc import HttpError, Message, Response, check_open_port, get_resource from .misc import JSON_PATHS, TOKEN_PATHS, Message, Response, check_open_port, get_resource
from .misc import JSON_PATHS, TOKEN_PATHS
from .template import Template from .template import Template
from .views import VIEWS from .views import VIEWS
from .views.api import handle_api_path from .views.api import handle_api_path
@ -296,7 +296,7 @@ def format_error(request: web.Request, error: HttpError) -> Response:
app: Application = request.app # type: ignore[assignment] app: Application = request.app # type: ignore[assignment]
if request.path.startswith(JSON_PATHS) or 'json' in request.headers.get('accept', ''): if request.path.startswith(JSON_PATHS) or 'json' in request.headers.get('accept', ''):
return Response.new({'error': error.body}, error.status, ctype = 'json') return Response.new({'error': error.message}, error.status, ctype = 'json')
else: else:
body = app.template.render('page/error.haml', request, e = error) body = app.template.render('page/error.haml', request, e = error)
@ -338,21 +338,21 @@ async def handle_response_headers(
except HttpError as e: except HttpError as e:
resp = format_error(request, e) resp = format_error(request, e)
except HTTPException as ae: except HTTPException as e:
if ae.status == 404: if e.status == 404:
try: try:
text = (ae.text or "").split(":")[1].strip() text = (e.text or "").split(":")[1].strip()
except IndexError: except IndexError:
text = ae.text or "" text = e.text or ""
resp = format_error(request, HttpError(ae.status, text)) resp = format_error(request, HttpError(e.status, text))
else: else:
raise raise
except Exception as e: except Exception:
resp = format_error(request, HttpError(500, f'{type(e).__name__}: {str(e)}')) resp = format_error(request, HttpError(500, 'Internal server error'))
traceback.print_exc() traceback.print_exc()
resp.headers['Server'] = 'ActivityRelay' resp.headers['Server'] = 'ActivityRelay'

View file

@ -8,7 +8,7 @@ from blib import Date
from bsql import Database, Row from bsql import Database, Row
from collections.abc import Callable, Iterator from collections.abc import Callable, Iterator
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from datetime import timedelta from datetime import timedelta, timezone
from redis import Redis from redis import Redis
from typing import TYPE_CHECKING, Any, TypedDict from typing import TYPE_CHECKING, Any, TypedDict
@ -72,6 +72,9 @@ class Item:
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.updated = Date.parse(self.updated) self.updated = Date.parse(self.updated)
if self.updated.tzinfo is None:
self.updated = self.updated.replace(tzinfo = timezone.utc)
@classmethod @classmethod
def from_data(cls: type[Item], *args: Any) -> Item: def from_data(cls: type[Item], *args: Any) -> Item:
@ -82,8 +85,7 @@ class Item:
def older_than(self, hours: int) -> bool: def older_than(self, hours: int) -> bool:
delta = Date.new_utc() - self.updated return self.updated + timedelta(hours = hours) < Date.new_utc()
return (delta.total_seconds()) > hours * 3600
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:

View file

@ -4,6 +4,7 @@ from blib import Date
from bsql import Column, Row, Tables from bsql import Column, Row, Tables
from collections.abc import Callable from collections.abc import Callable
from copy import deepcopy from copy import deepcopy
from datetime import timezone
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from .config import ConfigData from .config import ConfigData
@ -18,12 +19,15 @@ TABLES = Tables()
def deserialize_timestamp(value: Any) -> Date: def deserialize_timestamp(value: Any) -> Date:
try: try:
return Date.parse(value) date = Date.parse(value)
except ValueError: except ValueError:
pass date = Date.fromisoformat(value)
return Date.fromisoformat(value) if date.tzinfo is None:
date = date.replace(tzinfo = timezone.utc)
return date
@TABLES.add_row @TABLES.add_row
@ -45,14 +49,16 @@ class Instance(Row):
followid: Column[str] = Column('followid', 'text') followid: Column[str] = Column('followid', 'text')
software: Column[str] = Column('software', 'text') software: Column[str] = Column('software', 'text')
accepted: Column[Date] = Column('accepted', 'boolean') accepted: Column[Date] = Column('accepted', 'boolean')
created: Column[Date] = Column('created', 'timestamp', nullable = False) created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row @TABLES.add_row
class Whitelist(Row): class Whitelist(Row):
domain: Column[str] = Column( domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = True) 'domain', 'text', primary_key = True, unique = True, nullable = True)
created: Column[Date] = Column('created', 'timestamp', nullable = False) created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row @TABLES.add_row
@ -64,7 +70,8 @@ class DomainBan(Row):
'domain', 'text', primary_key = True, unique = True, nullable = True) 'domain', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text') reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text') note: Column[str] = Column('note', 'text')
created: Column[Date] = Column('created', 'timestamp', nullable = False) created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row @TABLES.add_row
@ -75,7 +82,8 @@ class SoftwareBan(Row):
name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True) name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text') reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text') note: Column[str] = Column('note', 'text')
created: Column[Date] = Column('created', 'timestamp', nullable = False) created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row @TABLES.add_row
@ -87,7 +95,8 @@ class User(Row):
'username', 'text', primary_key = True, unique = True, nullable = False) 'username', 'text', primary_key = True, unique = True, nullable = False)
hash: Column[str] = Column('hash', 'text', nullable = False) hash: Column[str] = Column('hash', 'text', nullable = False)
handle: Column[str] = Column('handle', 'text') handle: Column[str] = Column('handle', 'text')
created: Column[Date] = Column('created', 'timestamp', nullable = False) created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row @TABLES.add_row
@ -104,8 +113,10 @@ class App(Row):
token: Column[str | None] = Column('token', 'text') token: Column[str | None] = Column('token', 'text')
auth_code: Column[str | None] = Column('auth_code', 'text') auth_code: Column[str | None] = Column('auth_code', 'text')
user: Column[str | None] = Column('user', 'text') user: Column[str | None] = Column('user', 'text')
created: Column[Date] = Column('created', 'timestamp', nullable = False) created: Column[Date] = Column(
accessed: Column[Date] = Column('accessed', 'timestamp', nullable = False) 'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
accessed: Column[Date] = Column(
'accessed', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
def get_api_data(self, include_token: bool = False) -> dict[str, Any]: def get_api_data(self, include_token: bool = False) -> dict[str, Any]:

2
relay/errors.py Normal file
View file

@ -0,0 +1,2 @@
class EmptyBodyError(Exception):
pass

View file

@ -4,4 +4,4 @@
-block content -block content
.section.error .section.error
.title << HTTP Error {{e.status}} .title << HTTP Error {{e.status}}
.body -> =e.body .body -> =e.message

View file

@ -4,12 +4,13 @@ import json
from aiohttp import ClientSession, ClientTimeout, TCPConnector from aiohttp import ClientSession, ClientTimeout, TCPConnector
from aputils import AlgorithmType, Nodeinfo, ObjectType, Signer, WellKnownNodeinfo from aputils import AlgorithmType, Nodeinfo, ObjectType, Signer, WellKnownNodeinfo
from blib import JsonBase from blib import HttpError, JsonBase
from typing import TYPE_CHECKING, Any, TypeVar, overload from typing import TYPE_CHECKING, Any, TypeVar, overload
from . import __version__, logger as logging from . import __version__, logger as logging
from .cache import Cache from .cache import Cache
from .database.schema import Instance from .database.schema import Instance
from .errors import EmptyBodyError
from .misc import MIMETYPES, Message, get_app from .misc import MIMETYPES, Message, get_app
if TYPE_CHECKING: if TYPE_CHECKING:
@ -107,11 +108,7 @@ class HttpClient:
if not self._session: if not self._session:
raise RuntimeError('Client not open') raise RuntimeError('Client not open')
try: url = url.split("#", 1)[0]
url, _ = url.split('#', 1)
except ValueError:
pass
if not force: if not force:
try: try:
@ -136,10 +133,17 @@ class HttpClient:
data = await resp.text() data = await resp.text()
if resp.status != 200: if resp.status not in (200, 202):
logging.verbose('Received error when requesting %s: %i', url, resp.status) logging.verbose('Received error when requesting %s: %i', url, resp.status)
logging.debug(data) logging.debug(data)
return None
try:
error = json.loads(data)["error"]
except Exception:
error = data
raise HttpError(resp.status, error)
self.cache.set('request', url, data, 'str') self.cache.set('request', url, data, 'str')
return data return data
@ -151,7 +155,7 @@ class HttpClient:
sign_headers: bool, sign_headers: bool,
cls: None = None, cls: None = None,
force: bool = False, force: bool = False,
old_algo: bool = True) -> None: ... old_algo: bool = True) -> str | None: ...
@overload @overload
@ -168,7 +172,7 @@ class HttpClient:
sign_headers: bool, sign_headers: bool,
cls: type[T] | None = None, cls: type[T] | None = None,
force: bool = False, force: bool = False,
old_algo: bool = True) -> T | None: old_algo: bool = True) -> T | str | None:
if cls is not None and not issubclass(cls, JsonBase): if cls is not None and not issubclass(cls, JsonBase):
raise TypeError('cls must be a sub-class of "blib.JsonBase"') raise TypeError('cls must be a sub-class of "blib.JsonBase"')
@ -177,11 +181,12 @@ class HttpClient:
if cls is not None: if cls is not None:
if data is None: if data is None:
raise ValueError("Empty response") # this shouldn't actually get raised, but keeping just in case
raise EmptyBodyError(f"GET {url}")
return cls.parse(data) return cls.parse(data)
return None return data
async def post(self, url: str, data: Message | bytes, instance: Instance | None = None) -> None: async def post(self, url: str, data: Message | bytes, instance: Instance | None = None) -> None:

View file

@ -224,6 +224,9 @@ def cli_db_maintenance(ctx: click.Context, fix_timestamps: bool) -> None:
with ctx.obj.database.session(True) as conn: with ctx.obj.database.session(True) as conn:
conn.fix_timestamps() conn.fix_timestamps()
if ctx.obj.config.db_type == "postgres":
return
with ctx.obj.database.session(False) as conn: with ctx.obj.database.session(False) as conn:
with conn.execute("VACUUM"): with conn.execute("VACUUM"):
pass pass

View file

@ -134,17 +134,6 @@ def get_resource(path: str) -> Path:
return Path(str(pkgfiles('relay'))).joinpath(path) return Path(str(pkgfiles('relay'))).joinpath(path)
class HttpError(Exception):
def __init__(self,
status: int,
body: str) -> None:
self.body: str = body
self.status: int = status
Exception.__init__(self, f"HTTP Error {status}: {body}")
class JsonEncoder(json.JSONEncoder): class JsonEncoder(json.JSONEncoder):
def default(self, o: Any) -> str: def default(self, o: Any) -> str:
if isinstance(o, datetime): if isinstance(o, datetime):

View file

@ -2,12 +2,13 @@ import aputils
import traceback import traceback
from aiohttp.web import Request from aiohttp.web import Request
from blib import HttpError
from .base import View, register_route from .base import View, register_route
from .. import logger as logging from .. import logger as logging
from ..database import schema from ..database import schema
from ..misc import HttpError, Message, Response from ..misc import Message, Response
from ..processors import run_processor from ..processors import run_processor
@ -93,15 +94,19 @@ class ActorView(View):
try: try:
self.actor = await self.client.get(self.signature.keyid, True, Message) self.actor = await self.client.get(self.signature.keyid, True, Message)
except Exception: except HttpError:
# ld signatures aren't handled atm, so just ignore it # ld signatures aren't handled atm, so just ignore it
if self.message.type == 'Delete': if self.message.type == 'Delete':
logging.verbose('Instance sent a delete which cannot be handled') logging.verbose('Instance sent a delete which cannot be handled')
raise HttpError(202, '') raise HttpError(202, '')
logging.verbose(f'Failed to fetch actor: {self.signature.keyid}') logging.verbose('Failed to fetch actor: %s', self.signature.keyid)
raise HttpError(400, 'failed to fetch actor') raise HttpError(400, 'failed to fetch actor')
except Exception:
traceback.print_exc()
raise HttpError(500, 'unexpected error when fetching actor')
try: try:
self.signer = self.actor.signer self.signer = self.actor.signer

View file

@ -2,7 +2,7 @@ import traceback
from aiohttp.web import Request, middleware from aiohttp.web import Request, middleware
from argon2.exceptions import VerifyMismatchError from argon2.exceptions import VerifyMismatchError
from blib import convert_to_boolean from blib import HttpError, convert_to_boolean
from collections.abc import Awaitable, Callable, Sequence from collections.abc import Awaitable, Callable, Sequence
from urllib.parse import urlparse from urllib.parse import urlparse
@ -10,7 +10,7 @@ from .base import View, register_route
from .. import __version__ from .. import __version__
from ..database import ConfigData, schema from ..database import ConfigData, schema
from ..misc import HttpError, Message, Response, boolean from ..misc import Message, Response, boolean
DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob' DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob'
@ -324,7 +324,7 @@ class Inbox(View):
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
raise HttpError(500, 'Failed to fetch actor') raise HttpError(500, 'Failed to fetch actor') from None
data['inbox'] = actor_data.shared_inbox data['inbox'] = actor_data.shared_inbox
@ -396,7 +396,7 @@ class RequestView(View):
instance = conn.put_request_response(data['domain'], boolean(data['accept'])) instance = conn.put_request_response(data['domain'], boolean(data['accept']))
except KeyError: except KeyError:
raise HttpError(404, 'Request not found') raise HttpError(404, 'Request not found') from None
message = Message.new_response( message = Message.new_response(
host = self.config.domain, host = self.config.domain,

View file

@ -3,6 +3,7 @@ from __future__ import annotations
from aiohttp.abc import AbstractView from aiohttp.abc import AbstractView
from aiohttp.hdrs import METH_ALL as METHODS from aiohttp.hdrs import METH_ALL as METHODS
from aiohttp.web import Request from aiohttp.web import Request
from blib import HttpError
from bsql import Database from bsql import Database
from collections.abc import Awaitable, Callable, Generator, Sequence, Mapping from collections.abc import Awaitable, Callable, Generator, Sequence, Mapping
from functools import cached_property from functools import cached_property
@ -13,7 +14,7 @@ from ..cache import Cache
from ..config import Config from ..config import Config
from ..database import Connection from ..database import Connection
from ..http_client import HttpClient from ..http_client import HttpClient
from ..misc import HttpError, Response, get_app from ..misc import Response, get_app
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Self from typing import Self
@ -137,7 +138,7 @@ class View(AbstractView):
data[key] = post_data[key] data[key] = post_data[key]
except KeyError as e: except KeyError as e:
raise HttpError(400, f'Missing {str(e)} pararmeter') raise HttpError(400, f'Missing {str(e)} pararmeter') from None
for key in optional: for key in optional:
data[key] = post_data.get(key) # type: ignore[assignment] data[key] = post_data.get(key) # type: ignore[assignment]