ensure Date objects returned from db have a timezone

This commit is contained in:
Izalia Mae 2024-09-14 22:44:14 -04:00
parent 0e89b9bb11
commit 0cea1ff9e9
2 changed files with 26 additions and 13 deletions

View file

@ -8,7 +8,7 @@ from blib import Date
from bsql import Database, Row
from collections.abc import Callable, Iterator
from dataclasses import asdict, dataclass
from datetime import timedelta
from datetime import timedelta, timezone
from redis import Redis
from typing import TYPE_CHECKING, Any, TypedDict
@ -72,6 +72,9 @@ class Item:
def __post_init__(self) -> None:
self.updated = Date.parse(self.updated)
if self.updated.tzinfo is None:
self.updated = self.updated.replace(tzinfo = timezone.utc)
@classmethod
def from_data(cls: type[Item], *args: Any) -> Item:
@ -82,8 +85,7 @@ class Item:
def older_than(self, hours: int) -> bool:
delta = Date.new_utc() - self.updated
return (delta.total_seconds()) > hours * 3600
return self.updated + timedelta(hours = hours) < Date.new_utc()
def to_dict(self) -> dict[str, Any]:

View file

@ -4,6 +4,7 @@ from blib import Date
from bsql import Column, Row, Tables
from collections.abc import Callable
from copy import deepcopy
from datetime import timezone
from typing import TYPE_CHECKING, Any
from .config import ConfigData
@ -18,12 +19,15 @@ TABLES = Tables()
def deserialize_timestamp(value: Any) -> Date:
try:
return Date.parse(value)
date = Date.parse(value)
except ValueError:
pass
date = Date.fromisoformat(value)
return Date.fromisoformat(value)
if date.tzinfo is None:
date = date.replace(tzinfo = timezone.utc)
return date
@TABLES.add_row
@ -45,14 +49,16 @@ class Instance(Row):
followid: Column[str] = Column('followid', 'text')
software: Column[str] = Column('software', 'text')
accepted: Column[Date] = Column('accepted', 'boolean')
created: Column[Date] = Column('created', 'timestamp', nullable = False)
created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row
class Whitelist(Row):
domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = True)
created: Column[Date] = Column('created', 'timestamp', nullable = False)
created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row
@ -64,7 +70,8 @@ class DomainBan(Row):
'domain', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text')
created: Column[Date] = Column('created', 'timestamp', nullable = False)
created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row
@ -75,7 +82,8 @@ class SoftwareBan(Row):
name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text')
created: Column[Date] = Column('created', 'timestamp', nullable = False)
created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row
@ -87,7 +95,8 @@ class User(Row):
'username', 'text', primary_key = True, unique = True, nullable = False)
hash: Column[str] = Column('hash', 'text', nullable = False)
handle: Column[str] = Column('handle', 'text')
created: Column[Date] = Column('created', 'timestamp', nullable = False)
created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
@TABLES.add_row
@ -104,8 +113,10 @@ class App(Row):
token: Column[str | None] = Column('token', 'text')
auth_code: Column[str | None] = Column('auth_code', 'text')
user: Column[str | None] = Column('user', 'text')
created: Column[Date] = Column('created', 'timestamp', nullable = False)
accessed: Column[Date] = Column('accessed', 'timestamp', nullable = False)
created: Column[Date] = Column(
'created', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
accessed: Column[Date] = Column(
'accessed', 'timestamp', nullable = False, deserializer = deserialize_timestamp)
def get_api_data(self, include_token: bool = False) -> dict[str, Any]: