fix postgres support

This commit is contained in:
Izalia Mae 2024-09-14 05:56:27 -04:00
parent 16fcea90f2
commit c54aeabc90
3 changed files with 21 additions and 38 deletions

View file

@ -40,7 +40,7 @@ WHERE domain = :value or inbox = :value or actor = :value;
-- name: get-request -- name: get-request
SELECT * FROM inboxes WHERE accepted = 0 and domain = :domain; SELECT * FROM inboxes WHERE accepted = false and domain = :domain;
-- name: get-user -- name: get-user
@ -64,7 +64,7 @@ RETURNING *;
-- name: del-user -- name: del-user
DELETE FROM users DELETE FROM users
WHERE username = :value or handle = :value; WHERE username = :username or handle = :username;
-- name: get-app -- name: get-app
@ -91,6 +91,10 @@ DELETE FROM apps
WHERE client_id = :id and client_secret = :secret and token = :token; WHERE client_id = :id and client_secret = :secret and token = :token;
-- name: del-token-user
DELETE FROM apps WHERE "user" = :username;
-- name: get-software-ban -- name: get-software-ban
SELECT * FROM software_bans WHERE name = :name; SELECT * FROM software_bans WHERE name = :name;

View file

@ -138,7 +138,7 @@ class Connection(SqlConnection):
def get_inboxes(self) -> Iterator[schema.Instance]: def get_inboxes(self) -> Iterator[schema.Instance]:
return self.execute("SELECT * FROM inboxes WHERE accepted = 1").all(schema.Instance) return self.execute("SELECT * FROM inboxes WHERE accepted = true").all(schema.Instance)
# todo: check if software is different than stored row # todo: check if software is different than stored row
@ -196,7 +196,7 @@ class Connection(SqlConnection):
def get_requests(self) -> Iterator[schema.Instance]: def get_requests(self) -> Iterator[schema.Instance]:
return self.execute('SELECT * FROM inboxes WHERE accepted = 0').all(schema.Instance) return self.execute('SELECT * FROM inboxes WHERE accepted = false').all(schema.Instance)
def put_request_response(self, domain: str, accepted: bool) -> schema.Instance: def put_request_response(self, domain: str, accepted: bool) -> schema.Instance:
@ -275,10 +275,10 @@ class Connection(SqlConnection):
if (user := self.get_user(username)) is None: if (user := self.get_user(username)) is None:
raise KeyError(username) raise KeyError(username)
with self.run('del-user', {'value': user.username}): with self.run('del-token-user', {'username': user.username}):
pass pass
with self.run('del-token-user', {'username': user.username}): with self.run('del-user', {'username': user.username}):
pass pass
@ -315,8 +315,8 @@ class Connection(SqlConnection):
'website': website, 'website': website,
'client_id': secrets.token_hex(20), 'client_id': secrets.token_hex(20),
'client_secret': secrets.token_hex(20), 'client_secret': secrets.token_hex(20),
'created': Date.new_utc().timestamp(), 'created': Date.new_utc(),
'accessed': Date.new_utc().timestamp() 'accessed': Date.new_utc()
} }
with self.insert('apps', params) as cur: with self.insert('apps', params) as cur:
@ -336,8 +336,8 @@ class Connection(SqlConnection):
'client_secret': secrets.token_hex(20), 'client_secret': secrets.token_hex(20),
'auth_code': None, 'auth_code': None,
'token': secrets.token_hex(20), 'token': secrets.token_hex(20),
'created': Date.new_utc().timestamp(), 'created': Date.new_utc(),
'accessed': Date.new_utc().timestamp() 'accessed': Date.new_utc()
} }
with self.insert('apps', params) as cur: with self.insert('apps', params) as cur:

View file

@ -45,20 +45,14 @@ class Instance(Row):
followid: Column[str] = Column('followid', 'text') followid: Column[str] = Column('followid', 'text')
software: Column[str] = Column('software', 'text') software: Column[str] = Column('software', 'text')
accepted: Column[Date] = Column('accepted', 'boolean') accepted: Column[Date] = Column('accepted', 'boolean')
created: Column[Date] = Column( created: Column[Date] = Column('created', 'timestamp', nullable = False)
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
@TABLES.add_row @TABLES.add_row
class Whitelist(Row): class Whitelist(Row):
domain: Column[str] = Column( domain: Column[str] = Column(
'domain', 'text', primary_key = True, unique = True, nullable = True) 'domain', 'text', primary_key = True, unique = True, nullable = True)
created: Column[Date] = Column( created: Column[Date] = Column('created', 'timestamp', nullable = False)
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
@TABLES.add_row @TABLES.add_row
@ -70,10 +64,7 @@ class DomainBan(Row):
'domain', 'text', primary_key = True, unique = True, nullable = True) 'domain', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text') reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text') note: Column[str] = Column('note', 'text')
created: Column[Date] = Column( created: Column[Date] = Column('created', 'timestamp', nullable = False)
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
@TABLES.add_row @TABLES.add_row
@ -84,10 +75,7 @@ class SoftwareBan(Row):
name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True) name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True)
reason: Column[str] = Column('reason', 'text') reason: Column[str] = Column('reason', 'text')
note: Column[str] = Column('note', 'text') note: Column[str] = Column('note', 'text')
created: Column[Date] = Column( created: Column[Date] = Column('created', 'timestamp', nullable = False)
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
@TABLES.add_row @TABLES.add_row
@ -99,10 +87,7 @@ class User(Row):
'username', 'text', primary_key = True, unique = True, nullable = False) 'username', 'text', primary_key = True, unique = True, nullable = False)
hash: Column[str] = Column('hash', 'text', nullable = False) hash: Column[str] = Column('hash', 'text', nullable = False)
handle: Column[str] = Column('handle', 'text') handle: Column[str] = Column('handle', 'text')
created: Column[Date] = Column( created: Column[Date] = Column('created', 'timestamp', nullable = False)
'created', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
@TABLES.add_row @TABLES.add_row
@ -119,14 +104,8 @@ class App(Row):
token: Column[str | None] = Column('token', 'text') token: Column[str | None] = Column('token', 'text')
auth_code: Column[str | None] = Column('auth_code', 'text') auth_code: Column[str | None] = Column('auth_code', 'text')
user: Column[str | None] = Column('user', 'text') user: Column[str | None] = Column('user', 'text')
created: Column[Date] = Column( created: Column[Date] = Column('created', 'timestamp', nullable = False)
'created', 'timestamp', nullable = False, accessed: Column[Date] = Column('accessed', 'timestamp', nullable = False)
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
accessed: Column[Date] = Column(
'accessed', 'timestamp', nullable = False,
deserializer = deserialize_timestamp, serializer = Date.timestamp
)
def get_api_data(self, include_token: bool = False) -> dict[str, Any]: def get_api_data(self, include_token: bool = False) -> dict[str, Any]: