diff --git a/relay/database/connection.py b/relay/database/connection.py index 67df706..4e575f6 100644 --- a/relay/database/connection.py +++ b/relay/database/connection.py @@ -192,26 +192,29 @@ class Connection(SqlConnection): def put_user(self, username: str, password: str | None, handle: str | None = None) -> Row: if self.get_user(username): - data: dict[str, str | datetime | None] = { - 'username': username - } + data: dict[str, str] = {} if password: - data['password'] = password + data['hash'] = self.hasher.hash(password) if handle: - data['handler'] = handle + data['handle'] = handle - else: - if password is None: - raise ValueError('Password cannot be empty') + stmt = Update("users", data) + stmt.set_where("username", username) - data = { - 'username': username, - 'hash': self.hasher.hash(password), - 'handle': handle, - 'created': datetime.now(tz = timezone.utc) - } + with self.query(stmt) as cur: + return cur.one() + + if password is None: + raise ValueError('Password cannot be empty') + + data = { + 'username': username, + 'hash': self.hasher.hash(password), + 'handle': handle, + 'created': datetime.now(tz = timezone.utc) + } with self.run('put-user', data) as cur: return cur.one() # type: ignore