From 0cea1ff9e95fc78b05dcc61fa5511abb8374feec Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Sat, 14 Sep 2024 22:44:14 -0400 Subject: [PATCH] ensure Date objects returned from db have a timezone --- relay/cache.py | 8 +++++--- relay/database/schema.py | 31 +++++++++++++++++++++---------- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/relay/cache.py b/relay/cache.py index 1bf20d7..0c76b8e 100644 --- a/relay/cache.py +++ b/relay/cache.py @@ -8,7 +8,7 @@ from blib import Date from bsql import Database, Row from collections.abc import Callable, Iterator from dataclasses import asdict, dataclass -from datetime import timedelta +from datetime import timedelta, timezone from redis import Redis from typing import TYPE_CHECKING, Any, TypedDict @@ -72,6 +72,9 @@ class Item: def __post_init__(self) -> None: self.updated = Date.parse(self.updated) + if self.updated.tzinfo is None: + self.updated = self.updated.replace(tzinfo = timezone.utc) + @classmethod def from_data(cls: type[Item], *args: Any) -> Item: @@ -82,8 +85,7 @@ class Item: def older_than(self, hours: int) -> bool: - delta = Date.new_utc() - self.updated - return (delta.total_seconds()) > hours * 3600 + return self.updated + timedelta(hours = hours) < Date.new_utc() def to_dict(self) -> dict[str, Any]: diff --git a/relay/database/schema.py b/relay/database/schema.py index ca73c92..55ca608 100644 --- a/relay/database/schema.py +++ b/relay/database/schema.py @@ -4,6 +4,7 @@ from blib import Date from bsql import Column, Row, Tables from collections.abc import Callable from copy import deepcopy +from datetime import timezone from typing import TYPE_CHECKING, Any from .config import ConfigData @@ -18,12 +19,15 @@ TABLES = Tables() def deserialize_timestamp(value: Any) -> Date: try: - return Date.parse(value) + date = Date.parse(value) except ValueError: - pass + date = Date.fromisoformat(value) - return Date.fromisoformat(value) + if date.tzinfo is None: + date = date.replace(tzinfo = timezone.utc) + + return date @TABLES.add_row @@ -45,14 +49,16 @@ class Instance(Row): followid: Column[str] = Column('followid', 'text') software: Column[str] = Column('software', 'text') accepted: Column[Date] = Column('accepted', 'boolean') - created: Column[Date] = Column('created', 'timestamp', nullable = False) + created: Column[Date] = Column( + 'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp) @TABLES.add_row class Whitelist(Row): domain: Column[str] = Column( 'domain', 'text', primary_key = True, unique = True, nullable = True) - created: Column[Date] = Column('created', 'timestamp', nullable = False) + created: Column[Date] = Column( + 'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp) @TABLES.add_row @@ -64,7 +70,8 @@ class DomainBan(Row): 'domain', 'text', primary_key = True, unique = True, nullable = True) reason: Column[str] = Column('reason', 'text') note: Column[str] = Column('note', 'text') - created: Column[Date] = Column('created', 'timestamp', nullable = False) + created: Column[Date] = Column( + 'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp) @TABLES.add_row @@ -75,7 +82,8 @@ class SoftwareBan(Row): name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True) reason: Column[str] = Column('reason', 'text') note: Column[str] = Column('note', 'text') - created: Column[Date] = Column('created', 'timestamp', nullable = False) + created: Column[Date] = Column( + 'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp) @TABLES.add_row @@ -87,7 +95,8 @@ class User(Row): 'username', 'text', primary_key = True, unique = True, nullable = False) hash: Column[str] = Column('hash', 'text', nullable = False) handle: Column[str] = Column('handle', 'text') - created: Column[Date] = Column('created', 'timestamp', nullable = False) + created: Column[Date] = Column( + 'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp) @TABLES.add_row @@ -104,8 +113,10 @@ class App(Row): token: Column[str | None] = Column('token', 'text') auth_code: Column[str | None] = Column('auth_code', 'text') user: Column[str | None] = Column('user', 'text') - created: Column[Date] = Column('created', 'timestamp', nullable = False) - accessed: Column[Date] = Column('accessed', 'timestamp', nullable = False) + created: Column[Date] = Column( + 'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp) + accessed: Column[Date] = Column( + 'accessed', 'timestamp', nullable = False, deserializer = deserialize_timestamp) def get_api_data(self, include_token: bool = False) -> dict[str, Any]: