Compare commits

..

No commits in common. "c482677c32c5637baecf1f3abb3edf0ebed4fa8e" and "0e89b9bb113cb5fada1a24442cecc9fd082488c1" have entirely different histories.

13 changed files with 72 additions and 86 deletions

View file

@ -18,7 +18,7 @@ dependencies = [
"aiohttp >= 3.9.5", "aiohttp >= 3.9.5",
"aiohttp-swagger[performance] == 1.0.16", "aiohttp-swagger[performance] == 1.0.16",
"argon2-cffi == 23.1.0", "argon2-cffi == 23.1.0",
"barkshark-lib >= 0.2.3, < 0.3.0", "barkshark-lib >= 0.2.2.post2, < 0.3.0",
"barkshark-sql >= 0.2.0, < 0.3.0", "barkshark-sql >= 0.2.0, < 0.3.0",
"click == 8.1.2", "click == 8.1.2",
"hiredis == 2.3.2", "hiredis == 2.3.2",

View file

@ -12,7 +12,6 @@ from aiohttp.web import HTTPException, StaticResource
from aiohttp_swagger import setup_swagger from aiohttp_swagger import setup_swagger
from aputils.signer import Signer from aputils.signer import Signer
from base64 import b64encode from base64 import b64encode
from blib import HttpError
from bsql import Database from bsql import Database
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -27,7 +26,8 @@ from .config import Config
from .database import Connection, get_database from .database import Connection, get_database
from .database.schema import Instance from .database.schema import Instance
from .http_client import HttpClient from .http_client import HttpClient
from .misc import JSON_PATHS, TOKEN_PATHS, Message, Response, check_open_port, get_resource from .misc import HttpError, Message, Response, check_open_port, get_resource
from .misc import JSON_PATHS, TOKEN_PATHS
from .template import Template from .template import Template
from .views import VIEWS from .views import VIEWS
from .views.api import handle_api_path from .views.api import handle_api_path
@ -296,7 +296,7 @@ def format_error(request: web.Request, error: HttpError) -> Response:
app: Application = request.app # type: ignore[assignment] app: Application = request.app # type: ignore[assignment]
if request.path.startswith(JSON_PATHS) or 'json' in request.headers.get('accept', ''): if request.path.startswith(JSON_PATHS) or 'json' in request.headers.get('accept', ''):
return Response.new({'error': error.message}, error.status, ctype = 'json') return Response.new({'error': error.body}, error.status, ctype = 'json')
else: else:
body = app.template.render('page/error.haml', request, e = error) body = app.template.render('page/error.haml', request, e = error)
@ -338,21 +338,21 @@ async def handle_response_headers(
except HttpError as e: except HttpError as e:
resp = format_error(request, e) resp = format_error(request, e)
except HTTPException as e: except HTTPException as ae:
if e.status == 404: if ae.status == 404:
try: try:
text = (e.text or "").split(":")[1].strip() text = (ae.text or "").split(":")[1].strip()
except IndexError: except IndexError:
text = e.text or "" text = ae.text or ""
resp = format_error(request, HttpError(e.status, text)) resp = format_error(request, HttpError(ae.status, text))
else: else:
raise raise
except Exception: except Exception as e:
resp = format_error(request, HttpError(500, 'Internal server error')) resp = format_error(request, HttpError(500, f'{type(e).__name__}: {str(e)}'))
traceback.print_exc() traceback.print_exc()
resp.headers['Server'] = 'ActivityRelay' resp.headers['Server'] = 'ActivityRelay'

View file

@ -8,7 +8,7 @@ from blib import Date
from bsql import Database, Row from bsql import Database, Row
from collections.abc import Callable, Iterator from collections.abc import Callable, Iterator
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from datetime import timedelta, timezone from datetime import timedelta
from redis import Redis from redis import Redis
from typing import TYPE_CHECKING, Any, TypedDict from typing import TYPE_CHECKING, Any, TypedDict
@ -72,9 +72,6 @@ class Item:
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.updated = Date.parse(self.updated) self.updated = Date.parse(self.updated)
if self.updated.tzinfo is None:
self.updated = self.updated.replace(tzinfo = timezone.utc)
@classmethod @classmethod
def from_data(cls: type[Item], *args: Any) -> Item: def from_data(cls: type[Item], *args: Any) -> Item:
@ -85,7 +82,8 @@ class Item:
def older_than(self, hours: int) -> bool: def older_than(self, hours: int) -> bool:
return self.updated + timedelta(hours = hours) < Date.new_utc() delta = Date.new_utc() - self.updated
return (delta.total_seconds()) > hours * 3600
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
@ -243,10 +241,11 @@ class SqlCache(Cache):
if self._db is None: if self._db is None:
raise RuntimeError("Database has not been setup") raise RuntimeError("Database has not been setup")
date = Date.new_utc() - timedelta(days = days) limit = Date.new_utc() - timedelta(days = days)
params = {"limit": limit.timestamp()}
with self._db.session(True) as conn: with self._db.session(True) as conn:
with conn.execute("DELETE FROM cache WHERE updated < :limit", {"limit": date}): with conn.execute("DELETE FROM cache WHERE updated < :limit", params):
pass pass

View file

@ -4,7 +4,6 @@ from blib import Date
from bsql import Column, Row, Tables from bsql import Column, Row, Tables
from collections.abc import Callable from collections.abc import Callable
from copy import deepcopy from copy import deepcopy
from datetime import timezone
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from .config import ConfigData from .config import ConfigData
@ -19,15 +18,12 @@ TABLES = Tables()
def deserialize_timestamp(value: Any) -> Date: def deserialize_timestamp(value: Any) -> Date:
try: try:
date = Date.parse(value) return Date.parse(value)
except ValueError: except ValueError:
date = Date.fromisoformat(value) pass
if date.tzinfo is None: return Date.fromisoformat(value)
date = date.replace(tzinfo = timezone.utc)
return date
@TABLES.add_row @TABLES.add_row
@ -49,16 +45,14 @@ class Instance(Row):
followid: Column[str] = Column('followid', 'text') followid: Column[str] = Column('followid', 'text')
software: Column[str] = Column('software', 'text') software: Column[str] = Column('software', 'text')
accepted: Column[Date] = Column('accepted', 'boolean') accepted: Column[Date] = Column('accepted', 'boolean')
created: Column[Date] = Column( created: Column[Date] = Column('created', 'timestamp', nullable = False)
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row @TABLES.add_row
class Whitelist(Row): class Whitelist(Row):
domain: Column[str] = Column( domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = True) 'domain', 'text', primary_key = True, unique = True, nullable = True)
created: Column[Date] = Column( created: Column[Date] = Column('created', 'timestamp', nullable = False)
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row @TABLES.add_row
@ -70,8 +64,7 @@ class DomainBan(Row):
'domain', 'text', primary_key = True, unique = True, nullable = True) 'domain', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text') reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text') note: Column[str] = Column('note', 'text')
created: Column[Date] = Column( created: Column[Date] = Column('created', 'timestamp', nullable = False)
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row @TABLES.add_row
@ -82,8 +75,7 @@ class SoftwareBan(Row):
name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True) name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text') reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text') note: Column[str] = Column('note', 'text')
created: Column[Date] = Column( created: Column[Date] = Column('created', 'timestamp', nullable = False)
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row @TABLES.add_row
@ -95,8 +87,7 @@ class User(Row):
'username', 'text', primary_key = True, unique = True, nullable = False) 'username', 'text', primary_key = True, unique = True, nullable = False)
hash: Column[str] = Column('hash', 'text', nullable = False) hash: Column[str] = Column('hash', 'text', nullable = False)
handle: Column[str] = Column('handle', 'text') handle: Column[str] = Column('handle', 'text')
created: Column[Date] = Column( created: Column[Date] = Column('created', 'timestamp', nullable = False)
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row @TABLES.add_row
@ -113,10 +104,8 @@ class App(Row):
token: Column[str | None] = Column('token', 'text') token: Column[str | None] = Column('token', 'text')
auth_code: Column[str | None] = Column('auth_code', 'text') auth_code: Column[str | None] = Column('auth_code', 'text')
user: Column[str | None] = Column('user', 'text') user: Column[str | None] = Column('user', 'text')
created: Column[Date] = Column( created: Column[Date] = Column('created', 'timestamp', nullable = False)
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp) accessed: Column[Date] = Column('accessed', 'timestamp', nullable = False)
accessed: Column[Date] = Column(
'accessed', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
def get_api_data(self, include_token: bool = False) -> dict[str, Any]: def get_api_data(self, include_token: bool = False) -> dict[str, Any]:

View file

@ -1,2 +0,0 @@
class EmptyBodyError(Exception):
pass

View file

@ -4,4 +4,4 @@
-block content -block content
.section.error .section.error
.title << HTTP Error {{e.status}} .title << HTTP Error {{e.status}}
.body -> =e.message .body -> =e.body

View file

@ -4,13 +4,12 @@ import json
from aiohttp import ClientSession, ClientTimeout, TCPConnector from aiohttp import ClientSession, ClientTimeout, TCPConnector
from aputils import AlgorithmType, Nodeinfo, ObjectType, Signer, WellKnownNodeinfo from aputils import AlgorithmType, Nodeinfo, ObjectType, Signer, WellKnownNodeinfo
from blib import HttpError, JsonBase from blib import JsonBase
from typing import TYPE_CHECKING, Any, TypeVar, overload from typing import TYPE_CHECKING, Any, TypeVar, overload
from . import __version__, logger as logging from . import __version__, logger as logging
from .cache import Cache from .cache import Cache
from .database.schema import Instance from .database.schema import Instance
from .errors import EmptyBodyError
from .misc import MIMETYPES, Message, get_app from .misc import MIMETYPES, Message, get_app
if TYPE_CHECKING: if TYPE_CHECKING:
@ -108,7 +107,11 @@ class HttpClient:
if not self._session: if not self._session:
raise RuntimeError('Client not open') raise RuntimeError('Client not open')
url = url.split("#", 1)[0] try:
url, _ = url.split('#', 1)
except ValueError:
pass
if not force: if not force:
try: try:
@ -133,14 +136,10 @@ class HttpClient:
data = await resp.text() data = await resp.text()
if resp.status not in (200, 202): if resp.status != 200:
try: logging.verbose('Received error when requesting %s: %i', url, resp.status)
error = json.loads(data)["error"] logging.debug(data)
return None
except Exception:
error = data
raise HttpError(resp.status, error)
self.cache.set('request', url, data, 'str') self.cache.set('request', url, data, 'str')
return data return data
@ -152,7 +151,7 @@ class HttpClient:
sign_headers: bool, sign_headers: bool,
cls: None = None, cls: None = None,
force: bool = False, force: bool = False,
old_algo: bool = True) -> str | None: ... old_algo: bool = True) -> None: ...
@overload @overload
@ -169,7 +168,7 @@ class HttpClient:
sign_headers: bool, sign_headers: bool,
cls: type[T] | None = None, cls: type[T] | None = None,
force: bool = False, force: bool = False,
old_algo: bool = True) -> T | str | None: old_algo: bool = True) -> T | None:
if cls is not None and not issubclass(cls, JsonBase): if cls is not None and not issubclass(cls, JsonBase):
raise TypeError('cls must be a sub-class of "blib.JsonBase"') raise TypeError('cls must be a sub-class of "blib.JsonBase"')
@ -178,12 +177,11 @@ class HttpClient:
if cls is not None: if cls is not None:
if data is None: if data is None:
# this shouldn't actually get raised, but keeping just in case raise ValueError("Empty response")
raise EmptyBodyError(f"GET {url}")
return cls.parse(data) return cls.parse(data)
return data return None
async def post(self, url: str, data: Message | bytes, instance: Instance | None = None) -> None: async def post(self, url: str, data: Message | bytes, instance: Instance | None = None) -> None:
@ -220,12 +218,16 @@ class HttpClient:
logging.verbose('Sending "%s" to %s', mtype, url) logging.verbose('Sending "%s" to %s', mtype, url)
async with self._session.post(url, headers = headers, data = body) as resp: async with self._session.post(url, headers = headers, data = body) as resp:
if resp.status not in (200, 202): # Not expecting a response, so just return
raise HttpError( if resp.status in {200, 202}:
resp.status, logging.verbose('Successfully sent "%s" to %s', mtype, url)
await resp.text(), return
headers = {k: v for k, v in resp.headers.items()}
) logging.error('Received error when pushing to %s: %i', url, resp.status)
logging.debug(await resp.read())
logging.debug("message: %s", body.decode("utf-8"))
logging.debug("headers: %s", json.dumps(headers, indent = 4))
return
async def fetch_nodeinfo(self, domain: str, force: bool = False) -> Nodeinfo: async def fetch_nodeinfo(self, domain: str, force: bool = False) -> Nodeinfo:

View file

@ -224,9 +224,6 @@ def cli_db_maintenance(ctx: click.Context, fix_timestamps: bool) -> None:
with ctx.obj.database.session(True) as conn: with ctx.obj.database.session(True) as conn:
conn.fix_timestamps() conn.fix_timestamps()
if ctx.obj.config.db_type == "postgres":
return
with ctx.obj.database.session(False) as conn: with ctx.obj.database.session(False) as conn:
with conn.execute("VACUUM"): with conn.execute("VACUUM"):
pass pass

View file

@ -134,6 +134,17 @@ def get_resource(path: str) -> Path:
return Path(str(pkgfiles('relay'))).joinpath(path) return Path(str(pkgfiles('relay'))).joinpath(path)
class HttpError(Exception):
def __init__(self,
status: int,
body: str) -> None:
self.body: str = body
self.status: int = status
Exception.__init__(self, f"HTTP Error {status}: {body}")
class JsonEncoder(json.JSONEncoder): class JsonEncoder(json.JSONEncoder):
def default(self, o: Any) -> str: def default(self, o: Any) -> str:
if isinstance(o, datetime): if isinstance(o, datetime):

View file

@ -2,13 +2,12 @@ import aputils
import traceback import traceback
from aiohttp.web import Request from aiohttp.web import Request
from blib import HttpError
from .base import View, register_route from .base import View, register_route
from .. import logger as logging from .. import logger as logging
from ..database import schema from ..database import schema
from ..misc import Message, Response from ..misc import HttpError, Message, Response
from ..processors import run_processor from ..processors import run_processor
@ -94,19 +93,15 @@ class ActorView(View):
try: try:
self.actor = await self.client.get(self.signature.keyid, True, Message) self.actor = await self.client.get(self.signature.keyid, True, Message)
except HttpError: except Exception:
# ld signatures aren't handled atm, so just ignore it # ld signatures aren't handled atm, so just ignore it
if self.message.type == 'Delete': if self.message.type == 'Delete':
logging.verbose('Instance sent a delete which cannot be handled') logging.verbose('Instance sent a delete which cannot be handled')
raise HttpError(202, '') raise HttpError(202, '')
logging.verbose('Failed to fetch actor: %s', self.signature.keyid) logging.verbose(f'Failed to fetch actor: {self.signature.keyid}')
raise HttpError(400, 'failed to fetch actor') raise HttpError(400, 'failed to fetch actor')
except Exception:
traceback.print_exc()
raise HttpError(500, 'unexpected error when fetching actor')
try: try:
self.signer = self.actor.signer self.signer = self.actor.signer

View file

@ -2,7 +2,7 @@ import traceback
from aiohttp.web import Request, middleware from aiohttp.web import Request, middleware
from argon2.exceptions import VerifyMismatchError from argon2.exceptions import VerifyMismatchError
from blib import HttpError, convert_to_boolean from blib import convert_to_boolean
from collections.abc import Awaitable, Callable, Sequence from collections.abc import Awaitable, Callable, Sequence
from urllib.parse import urlparse from urllib.parse import urlparse
@ -10,7 +10,7 @@ from .base import View, register_route
from .. import __version__ from .. import __version__
from ..database import ConfigData, schema from ..database import ConfigData, schema
from ..misc import Message, Response, boolean from ..misc import HttpError, Message, Response, boolean
DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob' DEFAULT_REDIRECT: str = 'urn:ietf:wg:oauth:2.0:oob'
@ -324,7 +324,7 @@ class Inbox(View):
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
raise HttpError(500, 'Failed to fetch actor') from None raise HttpError(500, 'Failed to fetch actor')
data['inbox'] = actor_data.shared_inbox data['inbox'] = actor_data.shared_inbox
@ -396,7 +396,7 @@ class RequestView(View):
instance = conn.put_request_response(data['domain'], boolean(data['accept'])) instance = conn.put_request_response(data['domain'], boolean(data['accept']))
except KeyError: except KeyError:
raise HttpError(404, 'Request not found') from None raise HttpError(404, 'Request not found')
message = Message.new_response( message = Message.new_response(
host = self.config.domain, host = self.config.domain,

View file

@ -3,7 +3,6 @@ from __future__ import annotations
from aiohttp.abc import AbstractView from aiohttp.abc import AbstractView
from aiohttp.hdrs import METH_ALL as METHODS from aiohttp.hdrs import METH_ALL as METHODS
from aiohttp.web import Request from aiohttp.web import Request
from blib import HttpError
from bsql import Database from bsql import Database
from collections.abc import Awaitable, Callable, Generator, Sequence, Mapping from collections.abc import Awaitable, Callable, Generator, Sequence, Mapping
from functools import cached_property from functools import cached_property
@ -14,7 +13,7 @@ from ..cache import Cache
from ..config import Config from ..config import Config
from ..database import Connection from ..database import Connection
from ..http_client import HttpClient from ..http_client import HttpClient
from ..misc import Response, get_app from ..misc import HttpError, Response, get_app
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Self from typing import Self
@ -138,7 +137,7 @@ class View(AbstractView):
data[key] = post_data[key] data[key] = post_data[key]
except KeyError as e: except KeyError as e:
raise HttpError(400, f'Missing {str(e)} pararmeter') from None raise HttpError(400, f'Missing {str(e)} pararmeter')
for key in optional: for key in optional:
data[key] = post_data.get(key) # type: ignore[assignment] data[key] = post_data.get(key) # type: ignore[assignment]

View file

@ -5,7 +5,6 @@ import traceback
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
from asyncio.exceptions import TimeoutError as AsyncTimeoutError from asyncio.exceptions import TimeoutError as AsyncTimeoutError
from blib import HttpError
from dataclasses import dataclass from dataclasses import dataclass
from multiprocessing import Event, Process, Queue, Value from multiprocessing import Event, Process, Queue, Value
from multiprocessing.queues import Queue as QueueType from multiprocessing.queues import Queue as QueueType
@ -95,9 +94,6 @@ class PushWorker(Process):
try: try:
await self.client.post(item.inbox, item.message, item.instance) await self.client.post(item.inbox, item.message, item.instance)
except HttpError as e:
logging.error('HTTP Error when pushing to %s: %i %s', item.inbox, e.status, e.message)
except AsyncTimeoutError: except AsyncTimeoutError:
logging.error('Timeout when pushing to %s', item.domain) logging.error('Timeout when pushing to %s', item.domain)