From c54aeabc90e9970bab7adde6aefc11d6b2f08de8 Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Sat, 14 Sep 2024 05:56:27 -0400 Subject: [PATCH] fix postgres support --- relay/data/statements.sql | 8 ++++++-- relay/database/connection.py | 16 ++++++++-------- relay/database/schema.py | 35 +++++++---------------------------- 3 files changed, 21 insertions(+), 38 deletions(-) diff --git a/relay/data/statements.sql b/relay/data/statements.sql index dde6a29..894bb40 100644 --- a/relay/data/statements.sql +++ b/relay/data/statements.sql @@ -40,7 +40,7 @@ WHERE domain = :value or inbox = :value or actor = :value; -- name: get-request -SELECT * FROM inboxes WHERE accepted = 0 and domain = :domain; +SELECT * FROM inboxes WHERE accepted = false and domain = :domain; -- name: get-user @@ -64,7 +64,7 @@ RETURNING *; -- name: del-user DELETE FROM users -WHERE username = :value or handle = :value; +WHERE username = :username or handle = :username; -- name: get-app @@ -91,6 +91,10 @@ DELETE FROM apps WHERE client_id = :id and client_secret = :secret and token = :token; +-- name: del-token-user +DELETE FROM apps WHERE "user" = :username; + + -- name: get-software-ban SELECT * FROM software_bans WHERE name = :name; diff --git a/relay/database/connection.py b/relay/database/connection.py index 1053294..c061d3b 100644 --- a/relay/database/connection.py +++ b/relay/database/connection.py @@ -138,7 +138,7 @@ class Connection(SqlConnection): def get_inboxes(self) -> Iterator[schema.Instance]: - return self.execute("SELECT * FROM inboxes WHERE accepted = 1").all(schema.Instance) + return self.execute("SELECT * FROM inboxes WHERE accepted = true").all(schema.Instance) # todo: check if software is different than stored row @@ -196,7 +196,7 @@ class Connection(SqlConnection): def get_requests(self) -> Iterator[schema.Instance]: - return self.execute('SELECT * FROM inboxes WHERE accepted = 0').all(schema.Instance) + return self.execute('SELECT * FROM inboxes WHERE accepted = false').all(schema.Instance) def put_request_response(self, domain: str, accepted: bool) -> schema.Instance: @@ -275,10 +275,10 @@ class Connection(SqlConnection): if (user := self.get_user(username)) is None: raise KeyError(username) - with self.run('del-user', {'value': user.username}): + with self.run('del-token-user', {'username': user.username}): pass - with self.run('del-token-user', {'username': user.username}): + with self.run('del-user', {'username': user.username}): pass @@ -315,8 +315,8 @@ class Connection(SqlConnection): 'website': website, 'client_id': secrets.token_hex(20), 'client_secret': secrets.token_hex(20), - 'created': Date.new_utc().timestamp(), - 'accessed': Date.new_utc().timestamp() + 'created': Date.new_utc(), + 'accessed': Date.new_utc() } with self.insert('apps', params) as cur: @@ -336,8 +336,8 @@ class Connection(SqlConnection): 'client_secret': secrets.token_hex(20), 'auth_code': None, 'token': secrets.token_hex(20), - 'created': Date.new_utc().timestamp(), - 'accessed': Date.new_utc().timestamp() + 'created': Date.new_utc(), + 'accessed': Date.new_utc() } with self.insert('apps', params) as cur: diff --git a/relay/database/schema.py b/relay/database/schema.py index a6016bb..7565346 100644 --- a/relay/database/schema.py +++ b/relay/database/schema.py @@ -45,20 +45,14 @@ class Instance(Row): followid: Column[str] = Column('followid', 'text') software: Column[str] = Column('software', 'text') accepted: Column[Date] = Column('accepted', 'boolean') - created: Column[Date] = Column( - 'created', 'timestamp', nullable = False, - deserializer = deserialize_timestamp, serializer = Date.timestamp - ) + created: Column[Date] = Column('created', 'timestamp', nullable = False) @TABLES.add_row class Whitelist(Row): domain: Column[str] = Column( 'domain', 'text', primary_key = True, unique = True, nullable = True) - created: Column[Date] = Column( - 'created', 'timestamp', nullable = False, - deserializer = deserialize_timestamp, serializer = Date.timestamp - ) + created: Column[Date] = Column('created', 'timestamp', nullable = False) @TABLES.add_row @@ -70,10 +64,7 @@ class DomainBan(Row): 'domain', 'text', primary_key = True, unique = True, nullable = True) reason: Column[str] = Column('reason', 'text') note: Column[str] = Column('note', 'text') - created: Column[Date] = Column( - 'created', 'timestamp', nullable = False, - deserializer = deserialize_timestamp, serializer = Date.timestamp - ) + created: Column[Date] = Column('created', 'timestamp', nullable = False) @TABLES.add_row @@ -84,10 +75,7 @@ class SoftwareBan(Row): name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True) reason: Column[str] = Column('reason', 'text') note: Column[str] = Column('note', 'text') - created: Column[Date] = Column( - 'created', 'timestamp', nullable = False, - deserializer = deserialize_timestamp, serializer = Date.timestamp - ) + created: Column[Date] = Column('created', 'timestamp', nullable = False) @TABLES.add_row @@ -99,10 +87,7 @@ class User(Row): 'username', 'text', primary_key = True, unique = True, nullable = False) hash: Column[str] = Column('hash', 'text', nullable = False) handle: Column[str] = Column('handle', 'text') - created: Column[Date] = Column( - 'created', 'timestamp', nullable = False, - deserializer = deserialize_timestamp, serializer = Date.timestamp - ) + created: Column[Date] = Column('created', 'timestamp', nullable = False) @TABLES.add_row @@ -119,14 +104,8 @@ class App(Row): token: Column[str | None] = Column('token', 'text') auth_code: Column[str | None] = Column('auth_code', 'text') user: Column[str | None] = Column('user', 'text') - created: Column[Date] = Column( - 'created', 'timestamp', nullable = False, - deserializer = deserialize_timestamp, serializer = Date.timestamp - ) - accessed: Column[Date] = Column( - 'accessed', 'timestamp', nullable = False, - deserializer = deserialize_timestamp, serializer = Date.timestamp - ) + created: Column[Date] = Column('created', 'timestamp', nullable = False) + accessed: Column[Date] = Column('accessed', 'timestamp', nullable = False) def get_api_data(self, include_token: bool = False) -> dict[str, Any]: