Compare commits

..

No commits in common. "c482677c32c5637baecf1f3abb3edf0ebed4fa8e" and "0e89b9bb113cb5fada1a24442cecc9fd082488c1" have entirely different histories.

13 changed files with 72 additions and 86 deletions

View file

@ -18,7 +18,7 @@ dependencies = [
"aiohttp >= 3.9.5",
"aiohttp-swagger[performance] == 1.0.16",
"argon2-cffi == 23.1.0",
"barkshark-lib >= 0.2.3, < 0.3.0",
"barkshark-lib >= 0.2.2.post2, < 0.3.0",
"barkshark-sql >= 0.2.0, < 0.3.0",
"click == 8.1.2",
"hiredis == 2.3.2",

View file

@ -12,7 +12,6 @@ from aiohttp.web import HTTPException, StaticResource
from aiohttp_swagger import setup_swagger
from aputils.signer import Signer
from base64 import b64encode
from blib import HttpError
from bsql import Database
from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta
@ -27,7 +26,8 @@ from .config import Config
from .database import Connection, get_database
from .database.schema import Instance
from .http_client import HttpClient
from .misc import JSON_PATHS, TOKEN_PATHS, Message, Response, check_open_port, get_resource
from .misc import HttpError, Message, Response, check_open_port, get_resource
from .misc import JSON_PATHS, TOKEN_PATHS
from .template import Template
from .views import VIEWS
from .views.api import handle_api_path
@ -296,7 +296,7 @@ def format_error(request: web.Request, error: HttpError) -> Response:
app: Application = request.app # type: ignore[assignment]
if request.path.startswith(JSON_PATHS) or 'json' in request.headers.get('accept', ''):
return Response.new({'error': error.message}, error.status, ctype = 'json')
return Response.new({'error': error.body}, error.status, ctype = 'json')
else:
body = app.template.render('page/error.haml', request, e = error)
@ -338,21 +338,21 @@ async def handle_response_headers(
except HttpError as e:
resp = format_error(request, e)
except HTTPException as e:
if e.status == 404:
except HTTPException as ae:
if ae.status == 404:
try:
text = (e.text or "").split(":")[1].strip()
text = (ae.text or "").split(":")[1].strip()
except IndexError:
text = e.text or ""
text = ae.text or ""
resp = format_error(request, HttpError(e.status, text))
resp = format_error(request, HttpError(ae.status, text))
else:
raise
except Exception:
resp = format_error(request, HttpError(500, 'Internal server error'))
except Exception as e:
resp = format_error(request, HttpError(500, f'{type(e).__name__}: {str(e)}'))
traceback.print_exc()
resp.headers['Server'] = 'ActivityRelay'

View file

@ -8,7 +8,7 @@ from blib import Date
from bsql import Database, Row
from collections.abc import Callable, Iterator
from dataclasses import asdict, dataclass
from datetime import timedelta, timezone
from datetime import timedelta
from redis import Redis
from typing import TYPE_CHECKING, Any, TypedDict
@ -72,9 +72,6 @@ class Item:
def __post_init__(self) -> None:
self.updated = Date.parse(self.updated)
if self.updated.tzinfo is None:
self.updated = self.updated.replace(tzinfo = timezone.utc)
@classmethod
def from_data(cls: type[Item], *args: Any) -> Item:
@ -85,7 +82,8 @@ class Item:
def older_than(self, hours: int) -> bool:
return self.updated + timedelta(hours = hours) < Date.new_utc()
delta = Date.new_utc() - self.updated
return (delta.total_seconds()) > hours * 3600
def to_dict(self) -> dict[str, Any]:
@ -243,10 +241,11 @@ class SqlCache(Cache):
if self._db is None:
raise RuntimeError("Database has not been setup")
date = Date.new_utc() - timedelta(days = days)
limit = Date.new_utc() - timedelta(days = days)
params = {"limit": limit.timestamp()}
with self._db.session(True) as conn:
with conn.execute("DELETE FROM cache WHERE updated < :limit", {"limit": date}):
with conn.execute("DELETE FROM cache WHERE updated < :limit", params):
pass

View file

@ -4,7 +4,6 @@ from blib import Date
from bsql import Column, Row, Tables
from collections.abc import Callable
from copy import deepcopy
from datetime import timezone
from typing import TYPE_CHECKING, Any
from .config import ConfigData
@ -19,15 +18,12 @@ TABLES = Tables()
def deserialize_timestamp(value: Any) -> Date:
try:
date = Date.parse(value)
return Date.parse(value)
except ValueError:
date = Date.fromisoformat(value)
pass
if date.tzinfo is None:
date = date.replace(tzinfo = timezone.utc)
return date
return Date.fromisoformat(value)
@TABLES.add_row
@ -49,16 +45,14 @@ class Instance(Row):
followid: Column[str] = Column('followid', 'text')
software: Column[str] = Column('software', 'text')
accepted: Column[Date] = Column('accepted', 'boolean')
created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
created: Column[Date] = Column('created', 'timestamp', nullable = False)
@TABLES.add_row
class Whitelist(Row):
domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = True)
created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
created: Column[Date] = Column('created', 'timestamp', nullable = False)
@TABLES.add_row
@ -70,8 +64,7 @@ class DomainBan(Row):
'domain', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text')
created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
created: Column[Date] = Column('created', 'timestamp', nullable = False)
@TABLES.add_row
@ -82,8 +75,7 @@ class SoftwareBan(Row):
name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text')
created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
created: Column[Date] = Column('created', 'timestamp', nullable = False)
@TABLES.add_row
@ -95,8 +87,7 @@ class User(Row):
'username', 'text', primary_key = True, unique = True, nullable = False)
hash: Column[str] = Column('hash', 'text', nullable = False)
handle: Column[str] = Column('handle', 'text')
created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
created: Column[Date] = Column('created', 'timestamp', nullable = False)
@TABLES.add_row
@ -113,10 +104,8 @@ class App(Row):
token: Column[str | None] = Column('token', 'text')
auth_code: Column[str | None] = Column('auth_code', 'text')
user: Column[str | None] = Column('user', 'text')
created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
accessed: Column[Date] = Column(
'accessed', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
created: Column[Date] = Column('created', 'timestamp', nullable = False)
accessed: Column[Date] = Column('accessed', 'timestamp', nullable = False)
def get_api_data(self, include_token: bool = False) -> dict[str, Any]:

View file

@ -1,2 +0,0 @@
class EmptyBodyError(Exception):
pass

View file

@ -4,4 +4,4 @@
-block content
.section.error
.title << HTTP Error {{e.status}}
.body -> =e.message
.body -> =e.body

View file

@ -4,13 +4,12 @@ import json
from aiohttp import ClientSession, ClientTimeout, TCPConnector
from aputils import AlgorithmType, Nodeinfo, ObjectType, Signer, WellKnownNodeinfo
from blib import HttpError, JsonBase
from blib import JsonBase
from typing import TYPE_CHECKING, Any, TypeVar, overload
from . import __version__, logger as logging
from .cache import Cache
from .database.schema import Instance
from .errors import EmptyBodyError
from .misc import MIMETYPES, Message, get_app
if TYPE_CHECKING:
@ -108,7 +107,11 @@ class HttpClient:
if not self._session:
raise RuntimeError('Client not open')
url = url.split("#", 1)[0]
try:
url, _ = url.split('#', 1)
except ValueError:
pass
if not force:
try:
@ -133,14 +136,10 @@ class HttpClient:
data = await resp.text()
if resp.status not in (200, 202):
try:
error = json.loads(data)["error"]
except Exception:
error = data
raise HttpError(resp.status, error)
if resp.status != 200:
logging.verbose('Received error when requesting %s: %i', url, resp.status)
logging.debug(data)
return None
self.cache.set('request', url, data, 'str')
return data
@ -152,7 +151,7 @@ class HttpClient:
sign_headers: bool,
cls: None = None,
force: bool = False,
old_algo: bool = True) -> str | None: ...
old_algo: bool = True) -> None: ...
@overload
@ -169,7 +168,7 @@ class HttpClient:
sign_headers: bool,
cls: type[T] | None = None,
force: bool = False,
old_algo: bool = True) -> T | str | None:
old_algo: bool = True) -> T | None:
if cls is not None and not issubclass(cls, JsonBase):
raise TypeError('cls must be a sub-class of "blib.JsonBase"')
@ -178,12 +177,11 @@ class HttpClient:
if cls is not None:
if data is None:
# this shouldn't actually get raised, but keeping just in case
raise EmptyBodyError(f"GET {url}")
raise ValueError("Empty response")
return cls.parse(data)
return data
return None
async def post(self, url: str, data: Message | bytes, instance: Instance | None = None) -> None:
@ -220,12 +218,16 @@ class HttpClient:
logging.verbose('Sending "%s" to %s', mtype, url)
async with self._session.post(url, headers = headers, data = body) as resp:
if resp.status not in (200, 202):
raise HttpError(
resp.status,
await resp.text(),
headers = {k: v for k, v in resp.headers.items()}
)
# Not expecting a response, so just return
if resp.status in {200, 202}:
logging.verbose('Successfully sent "%s" to %s', mtype, url)
return
logging.error('Received error when pushing to %s: %i', url, resp.status)
logging.debug(await resp.read())
logging.debug("message: %s", body.decode("utf-8"))
logging.debug("headers: %s", json.dumps(headers, indent = 4))
return
async def fetch_nodeinfo(self, domain: str, force: bool = False) -> Nodeinfo:

View file

@ -224,9 +224,6 @@ def cli_db_maintenance(ctx: click.Context, fix_timestamps: bool) -> None:
with ctx.obj.database.session(True) as conn:
conn.fix_timestamps()
if ctx.obj.config.db_type == "postgres":
return
with ctx.obj.database.session(False) as conn:
with conn.execute("VACUUM"):
pass

View file

@ -134,6 +134,17 @@ def get_resource(path: str) -> Path:
return Path(str(pkgfiles('relay'))).joinpath(path)
class HttpError(Exception):
def __init__(self,
status: int,
body: str) -> None:
self.body: str = body
self.status: int = status
Exception.__init__(self, f"HTTP Error {status}: {body}")
class JsonEncoder(json.JSONEncoder):
def default(self, o: Any) -> str:
if isinstance(o, datetime):

View file

@ -2,13 +2,12 @@ import aputils
import traceback
from aiohttp.web import Request
from blib import HttpError
from .base import View, register_route
from .. import logger as logging
from ..database import schema
from ..misc import Message, Response
from ..misc import HttpError, Message, Response
from ..processors import run_processor
@ -94,19 +93,15 @@ class ActorView(View):
try:
self.actor = await self.client.get(self.signature.keyid, True, Message)
except HttpError:
except Exception:
# ld signatures aren't handled atm, so just ignore it
if self.message.type == 'Delete':
logging.verbose('Instance sent a delete which cannot be handled')
raise HttpError(202, '')
logging.verbose('Failed to fetch actor: %s', self.signature.keyid)
logging.verbose(f'Failed to fetch actor: {self.signature.keyid}')
raise HttpError(400, 'failed to fetch actor')
except Exception:
traceback.print_exc()
raise HttpError(500, 'unexpected error when fetching actor')
try:
self.signer = self.actor.signer

View file

@ -2,7 +2,7 @@ import traceback
from aiohttp.web import Request, middleware
from argon2.exceptions import VerifyMismatchError
from blib import HttpError, convert_to_boolean
from blib import convert_to_boolean
from collections.abc import Awaitable, Callable, Sequence
from urllib.parse import urlparse
@ -10,7 +10,7 @@ from .base import View, register_route
from .. import __version__
from ..database import ConfigData, schema
from ..misc import Message, Response, boolean
from ..misc import HttpError, Message, Response, boolean
DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob'
@ -324,7 +324,7 @@ class Inbox(View):
except Exception:
traceback.print_exc()
raise HttpError(500, 'Failed to fetch actor') from None
raise HttpError(500, 'Failed to fetch actor')
data['inbox'] = actor_data.shared_inbox
@ -396,7 +396,7 @@ class RequestView(View):
instance = conn.put_request_response(data['domain'], boolean(data['accept']))
except KeyError:
raise HttpError(404, 'Request not found') from None
raise HttpError(404, 'Request not found')
message = Message.new_response(
host = self.config.domain,

View file

@ -3,7 +3,6 @@ from __future__ import annotations
from aiohttp.abc import AbstractView
from aiohttp.hdrs import METH_ALL as METHODS
from aiohttp.web import Request
from blib import HttpError
from bsql import Database
from collections.abc import Awaitable, Callable, Generator, Sequence, Mapping
from functools import cached_property
@ -14,7 +13,7 @@ from ..cache import Cache
from ..config import Config
from ..database import Connection
from ..http_client import HttpClient
from ..misc import Response, get_app
from ..misc import HttpError, Response, get_app
if TYPE_CHECKING:
from typing import Self
@ -138,7 +137,7 @@ class View(AbstractView):
data[key] = post_data[key]
except KeyError as e:
raise HttpError(400, f'Missing {str(e)} pararmeter') from None
raise HttpError(400, f'Missing {str(e)} pararmeter')
for key in optional:
data[key] = post_data.get(key) # type: ignore[assignment]

View file

@ -5,7 +5,6 @@ import traceback
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
from blib import HttpError
from dataclasses import dataclass
from multiprocessing import Event, Process, Queue, Value
from multiprocessing.queues import Queue as QueueType
@ -95,9 +94,6 @@ class PushWorker(Process):
try:
await self.client.post(item.inbox, item.message, item.instance)
except HttpError as e:
logging.error('HTTP Error when pushing to %s: %i %s', item.inbox, e.status, e.message)
except AsyncTimeoutError:
logging.error('Timeout when pushing to %s', item.domain)