diff --git a/userapi/storage/accounts/postgres/key_backup_version_table.go b/userapi/storage/accounts/postgres/key_backup_version_table.go index 0abe7b291..51a462b32 100644 --- a/userapi/storage/accounts/postgres/key_backup_version_table.go +++ b/userapi/storage/accounts/postgres/key_backup_version_table.go @@ -56,7 +56,7 @@ const selectKeyBackupSQL = "" + "SELECT algorithm, auth_data, etag, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2" const selectLatestVersionSQL = "" + - "SELECT COALESCE(MAX(version),0) FROM account_e2e_room_keys_versions WHERE user_id = $1" + "SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1" type keyBackupVersionStatements struct { insertKeyBackupStmt *sql.Stmt @@ -146,12 +146,19 @@ func (s *keyBackupVersionStatements) selectKeyBackup( ) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { var versionInt int64 if version == "" { - err = txn.Stmt(s.selectLatestVersionStmt).QueryRowContext(ctx, userID).Scan(&versionInt) + var v *int64 // allows nulls + if err = txn.Stmt(s.selectLatestVersionStmt).QueryRowContext(ctx, userID).Scan(&v); err != nil { + return + } + if v == nil { + err = sql.ErrNoRows + return + } + versionInt = *v } else { - versionInt, err = strconv.ParseInt(version, 10, 64) - } - if err != nil { - return + if versionInt, err = strconv.ParseInt(version, 10, 64); err != nil { + return + } } versionResult = strconv.FormatInt(versionInt, 10) var deletedInt int diff --git a/userapi/storage/accounts/sqlite3/key_backup_version_table.go b/userapi/storage/accounts/sqlite3/key_backup_version_table.go index b0e074839..a9e7bf5db 100644 --- a/userapi/storage/accounts/sqlite3/key_backup_version_table.go +++ b/userapi/storage/accounts/sqlite3/key_backup_version_table.go @@ -54,7 +54,7 @@ const selectKeyBackupSQL = "" + "SELECT algorithm, auth_data, etag, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2" const selectLatestVersionSQL = "" + - "SELECT IFNULL(MAX(version),0) FROM account_e2e_room_keys_versions WHERE user_id = $1" + "SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1" type keyBackupVersionStatements struct { insertKeyBackupStmt *sql.Stmt @@ -144,12 +144,19 @@ func (s *keyBackupVersionStatements) selectKeyBackup( ) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { var versionInt int64 if version == "" { - err = txn.Stmt(s.selectLatestVersionStmt).QueryRowContext(ctx, userID).Scan(&versionInt) + var v *int64 // allows nulls + if err = txn.Stmt(s.selectLatestVersionStmt).QueryRowContext(ctx, userID).Scan(&v); err != nil { + return + } + if v == nil { + err = sql.ErrNoRows + return + } + versionInt = *v } else { - versionInt, err = strconv.ParseInt(version, 10, 64) - } - if err != nil { - return + if versionInt, err = strconv.ParseInt(version, 10, 64); err != nil { + return + } } versionResult = strconv.FormatInt(versionInt, 10) var deletedInt int