diff --git a/CHANGES.md b/CHANGES.md index c5e4f7f8c..9c2496428 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,20 @@ # Changelog +<<<<<<< HEAD +======= +## Dendrite 0.8.7 (2022-06-01) + +### Features + +* Support added for room version 10 + +### Fixes + +* A number of state handling bugs have been fixed, which previously resulted in missing state events, unexpected state deletions, reverted memberships and unexpectedly rejected/soft-failed events in some specific cases +* Fixed destination queue performance issues as a result of missing indexes, which speeds up outbound federation considerably +* A bug which could cause the `/register` endpoint to return HTTP 500 has been fixed + +>>>>>>> main ## Dendrite 0.8.6 (2022-05-26) ### Features diff --git a/README.md b/README.md index e8b7bd0e2..7c22b3692 100644 --- a/README.md +++ b/README.md @@ -96,10 +96,9 @@ than features that massive deployments may be interested in (User Directory, Ope This means Dendrite supports amongst others: - Core room functionality (creating rooms, invites, auth rules) -- Full support for room versions 1 to 7 -- Experimental support for room versions 8 to 9 +- Room versions 1 to 10 supported - Backfilling locally and via federation -- Accounts, Profiles and Devices +- Accounts, profiles and devices - Published room lists - Typing - Media APIs diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 749036c24..bfd17633a 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -89,6 +89,9 @@ func Setup( "r0.4.0", "r0.5.0", "r0.6.1", + "v1.0", + "v1.1", + "v1.2", }, UnstableFeatures: unstableFeatures}, } }), diff --git a/cmd/resolve-state/main.go b/cmd/resolve-state/main.go index 3a9f3ce4f..6ed6ebdb8 100644 --- a/cmd/resolve-state/main.go +++ b/cmd/resolve-state/main.go @@ -4,9 +4,12 @@ import ( "context" "flag" "fmt" + "sort" "strconv" + "strings" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup" @@ -57,25 +60,23 @@ func main() { panic(err) } - blockNIDs, err := roomserverDB.StateBlockNIDs(ctx, snapshotNIDs) - if err != nil { - panic(err) - } + stateres := state.NewStateResolution(roomserverDB, &types.RoomInfo{ + RoomVersion: gomatrixserverlib.RoomVersion(*roomVersion), + }) - var stateEntries []types.StateEntryList - for _, list := range blockNIDs { - entries, err2 := roomserverDB.StateEntries(ctx, list.StateBlockNIDs) - if err2 != nil { - panic(err2) + var stateEntries []types.StateEntry + for _, snapshotNID := range snapshotNIDs { + var entries []types.StateEntry + entries, err = stateres.LoadStateAtSnapshot(ctx, snapshotNID) + if err != nil { + panic(err) } stateEntries = append(stateEntries, entries...) } var eventNIDs []types.EventNID for _, entry := range stateEntries { - for _, e := range entry.StateEntries { - eventNIDs = append(eventNIDs, e.EventNID) - } + eventNIDs = append(eventNIDs, entry.EventNID) } fmt.Println("Fetching", len(eventNIDs), "state events") @@ -110,7 +111,8 @@ func main() { } fmt.Println("Resolving state") - resolved, err := gomatrixserverlib.ResolveConflicts( + var resolved Events + resolved, err = gomatrixserverlib.ResolveConflicts( gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, @@ -120,6 +122,7 @@ func main() { } fmt.Println("Resolved state contains", len(resolved), "events") + sort.Sort(resolved) filteringEventType := *filterType count := 0 for _, event := range resolved { @@ -135,3 +138,25 @@ func main() { fmt.Println() fmt.Println("Returned", count, "state events after filtering") } + +type Events []*gomatrixserverlib.Event + +func (e Events) Len() int { + return len(e) +} + +func (e Events) Swap(i, j int) { + e[i], e[j] = e[j], e[i] +} + +func (e Events) Less(i, j int) bool { + typeDelta := strings.Compare(e[i].Type(), e[j].Type()) + if typeDelta < 0 { + return true + } + if typeDelta > 0 { + return false + } + stateKeyDelta := strings.Compare(*e[i].StateKey(), *e[j].StateKey()) + return stateKeyDelta < 0 +} diff --git a/federationapi/storage/postgres/queue_edus_table.go b/federationapi/storage/postgres/queue_edus_table.go index 6cac489bf..1fedf0ef1 100644 --- a/federationapi/storage/postgres/queue_edus_table.go +++ b/federationapi/storage/postgres/queue_edus_table.go @@ -36,6 +36,10 @@ CREATE TABLE IF NOT EXISTS federationsender_queue_edus ( CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx ON federationsender_queue_edus (json_nid, server_name); +CREATE INDEX IF NOT EXISTS federationsender_queue_edus_nid_idx + ON federationsender_queue_edus (json_nid); +CREATE INDEX IF NOT EXISTS federationsender_queue_edus_server_name_idx + ON federationsender_queue_edus (server_name); ` const insertQueueEDUSQL = "" + diff --git a/federationapi/storage/postgres/queue_json_table.go b/federationapi/storage/postgres/queue_json_table.go index 853073741..e33074182 100644 --- a/federationapi/storage/postgres/queue_json_table.go +++ b/federationapi/storage/postgres/queue_json_table.go @@ -33,6 +33,9 @@ CREATE TABLE IF NOT EXISTS federationsender_queue_json ( -- The JSON body. Text so that we preserve UTF-8. json_body TEXT NOT NULL ); + +CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_json_json_nid_idx + ON federationsender_queue_json (json_nid); ` const insertJSONSQL = "" + diff --git a/federationapi/storage/postgres/queue_pdus_table.go b/federationapi/storage/postgres/queue_pdus_table.go index f9a477483..38ac5a6eb 100644 --- a/federationapi/storage/postgres/queue_pdus_table.go +++ b/federationapi/storage/postgres/queue_pdus_table.go @@ -36,6 +36,10 @@ CREATE TABLE IF NOT EXISTS federationsender_queue_pdus ( CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_pdus_pdus_json_nid_idx ON federationsender_queue_pdus (json_nid, server_name); +CREATE INDEX IF NOT EXISTS federationsender_queue_pdus_json_nid_idx + ON federationsender_queue_pdus (json_nid); +CREATE INDEX IF NOT EXISTS federationsender_queue_pdus_server_name_idx + ON federationsender_queue_pdus (server_name); ` const insertQueuePDUSQL = "" + diff --git a/federationapi/storage/sqlite3/queue_edus_table.go b/federationapi/storage/sqlite3/queue_edus_table.go index a6d609508..f4c84f094 100644 --- a/federationapi/storage/sqlite3/queue_edus_table.go +++ b/federationapi/storage/sqlite3/queue_edus_table.go @@ -37,6 +37,10 @@ CREATE TABLE IF NOT EXISTS federationsender_queue_edus ( CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx ON federationsender_queue_edus (json_nid, server_name); +CREATE INDEX IF NOT EXISTS federationsender_queue_edus_nid_idx + ON federationsender_queue_edus (json_nid); +CREATE INDEX IF NOT EXISTS federationsender_queue_edus_server_name_idx + ON federationsender_queue_edus (server_name); ` const insertQueueEDUSQL = "" + diff --git a/federationapi/storage/sqlite3/queue_json_table.go b/federationapi/storage/sqlite3/queue_json_table.go index 3e3f60f63..fe5896c7f 100644 --- a/federationapi/storage/sqlite3/queue_json_table.go +++ b/federationapi/storage/sqlite3/queue_json_table.go @@ -35,6 +35,9 @@ CREATE TABLE IF NOT EXISTS federationsender_queue_json ( -- The JSON body. Text so that we preserve UTF-8. json_body TEXT NOT NULL ); + +CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_json_json_nid_idx + ON federationsender_queue_json (json_nid); ` const insertJSONSQL = "" + diff --git a/federationapi/storage/sqlite3/queue_pdus_table.go b/federationapi/storage/sqlite3/queue_pdus_table.go index e0fdbda5f..e818585a5 100644 --- a/federationapi/storage/sqlite3/queue_pdus_table.go +++ b/federationapi/storage/sqlite3/queue_pdus_table.go @@ -38,6 +38,10 @@ CREATE TABLE IF NOT EXISTS federationsender_queue_pdus ( CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_pdus_pdus_json_nid_idx ON federationsender_queue_pdus (json_nid, server_name); +CREATE INDEX IF NOT EXISTS federationsender_queue_pdus_json_nid_idx + ON federationsender_queue_pdus (json_nid); +CREATE INDEX IF NOT EXISTS federationsender_queue_pdus_server_name_idx + ON federationsender_queue_pdus (server_name); ` const insertQueuePDUSQL = "" + diff --git a/go.mod b/go.mod index a732f679a..bc77fcf53 100644 --- a/go.mod +++ b/go.mod @@ -35,7 +35,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220526140030-dcfbb70ff32d + github.com/matrix-org/gomatrixserverlib v0.0.0-20220531163017-35e1cabf12ee github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.13 diff --git a/go.sum b/go.sum index 5453b0c5e..dab1f0090 100644 --- a/go.sum +++ b/go.sum @@ -542,8 +542,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220526140030-dcfbb70ff32d h1:IwyG/58rFn0/ugD0A/IdSIo7D/oLJ4+k3NznlYhzyHs= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220526140030-dcfbb70ff32d/go.mod h1:jX38yp3SSLJNftBg3PXU1ayd0PCLIiDHQ4xAc9DIixk= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220531163017-35e1cabf12ee h1:56sxEWrwB3eOmwjP2S4JsrQf29uBUaf+8WrbQJmjaGE= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220531163017-35e1cabf12ee/go.mod h1:jX38yp3SSLJNftBg3PXU1ayd0PCLIiDHQ4xAc9DIixk= github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48 h1:W0sjjC6yjskHX4mb0nk3p0fXAlbU5bAFUFeEtlrPASE= github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48/go.mod h1:ulJzsVOTssIVp1j/m5eI//4VpAGDkMt5NrRuAVX7wpc= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= diff --git a/internal/version.go b/internal/version.go index 0957b4545..2543ec90c 100644 --- a/internal/version.go +++ b/internal/version.go @@ -17,7 +17,7 @@ var build string const ( VersionMajor = 0 VersionMinor = 8 - VersionPatch = 6 + VersionPatch = 7 VersionTag = "" // example: "rc1" ) diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index 23f3e1a67..acbcd5b8f 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -374,7 +374,7 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam // fetch stale device lists userIDs, err := u.db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{serverName}) if err != nil { - logger.WithError(err).Error("failed to load stale device lists") + logger.WithError(err).Error("Failed to load stale device lists") return waitTime, true } failCount := 0 @@ -399,7 +399,7 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam } } else { waitTime = time.Hour - logger.WithError(err).WithField("user_id", userID).Warn("GetUserDevices returned unknown error type") + logger.WithError(err).WithField("user_id", userID).Debug("GetUserDevices returned unknown error type") } continue } @@ -422,12 +422,12 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam } err = u.updateDeviceList(&res) if err != nil { - logger.WithError(err).WithField("user_id", userID).Error("fetched device list but failed to store/emit it") + logger.WithError(err).WithField("user_id", userID).Error("Fetched device list but failed to store/emit it") failCount += 1 } } if failCount > 0 { - logger.WithField("total", len(userIDs)).WithField("failed", failCount).WithField("wait", waitTime).Error("failed to query device keys for some users") + logger.WithField("total", len(userIDs)).WithField("failed", failCount).WithField("wait", waitTime).Warn("Failed to query device keys for some users") } for _, userID := range userIDs { // always clear the channel to unblock Update calls regardless of success/failure diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index a1b094871..59b3fcb12 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -206,7 +206,7 @@ func (u *latestEventsUpdater) latestState() error { // Work out if the state at the extremities has actually changed // or not. If they haven't then we won't bother doing all of the // hard work. - if u.event.StateKey() == nil { + if !u.stateAtEvent.IsStateEvent() { stateChanged := false oldStateNIDs := make([]types.StateSnapshotNID, 0, len(u.oldLatest)) newStateNIDs := make([]types.StateSnapshotNID, 0, len(u.latest)) diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 187b996cd..95abdcb36 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -39,6 +39,7 @@ type StateResolutionStorage interface { StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) + EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) } type StateResolution struct { @@ -659,15 +660,13 @@ func (v *StateResolution) calculateStateAfterManyEvents( } // Collect all the entries with the same type and key together. - // We don't care about the order here because the conflict resolution - // algorithm doesn't depend on the order of the prev events. - // Remove duplicate entires. + // This is done so findDuplicateStateKeys can work in groups. + // We remove duplicates (same type, state key and event NID) too. combined = combined[:util.SortAndUnique(stateEntrySorter(combined))] // Find the conflicts - conflicts := findDuplicateStateKeys(combined) - - if len(conflicts) > 0 { + if conflicts := findDuplicateStateKeys(combined); len(conflicts) > 0 { + conflictMap := stateEntryMap(conflicts) conflictLength = len(conflicts) // 5) There are conflicting state events, for each conflict workout @@ -676,7 +675,7 @@ func (v *StateResolution) calculateStateAfterManyEvents( // Work out which entries aren't conflicted. var notConflicted []types.StateEntry for _, entry := range combined { - if _, ok := stateEntryMap(conflicts).lookup(entry.StateKeyTuple); !ok { + if _, ok := conflictMap.lookup(entry.StateKeyTuple); !ok { notConflicted = append(notConflicted, entry) } } @@ -689,7 +688,7 @@ func (v *StateResolution) calculateStateAfterManyEvents( return } algorithm = "full_state_with_conflicts" - state = resolved[:util.SortAndUnique(stateEntrySorter(resolved))] + state = resolved } else { algorithm = "full_state_no_conflicts" // 6) There weren't any conflicts @@ -818,39 +817,19 @@ func (v *StateResolution) resolveConflictsV2( authDifference := make([]*gomatrixserverlib.Event, 0, estimate) // For each conflicted event, let's try and get the needed auth events. - neededStateKeys := make([]string, 16) - authEntries := make([]types.StateEntry, 16) for _, conflictedEvent := range conflictedEvents { // Work out which auth events we need to load. key := conflictedEvent.EventID() - needed := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{conflictedEvent}) - - // Find the numeric IDs for the necessary state keys. - neededStateKeys = neededStateKeys[:0] - neededStateKeys = append(neededStateKeys, needed.Member...) - neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...) - stateKeyNIDMap, err := v.db.EventStateKeyNIDs(ctx, neededStateKeys) - if err != nil { - return nil, err - } - - // Load the necessary auth events. - tuplesNeeded := v.stateKeyTuplesNeeded(stateKeyNIDMap, needed) - authEntries = authEntries[:0] - for _, tuple := range tuplesNeeded { - if eventNID, ok := stateEntryMap(notConflicted).lookup(tuple); ok { - authEntries = append(authEntries, types.StateEntry{ - StateKeyTuple: tuple, - EventNID: eventNID, - }) - } - } // Store the newly found auth events in the auth set for this event. - authSets[key], _, err = v.loadStateEvents(ctx, authEntries) + var authEventMap map[string]types.StateEntry + authSets[key], authEventMap, err = v.loadAuthEvents(ctx, conflictedEvent) if err != nil { return nil, err } + for k, v := range authEventMap { + eventIDMap[k] = v + } // Only add auth events into the authEvents slice once, otherwise the // check for the auth difference can become expensive and produce @@ -909,7 +888,7 @@ func (v *StateResolution) resolveConflictsV2( for _, resolvedEvent := range resolvedEvents { entry, ok := eventIDMap[resolvedEvent.EventID()] if !ok { - panic(fmt.Errorf("missing state entry for event ID %q", resolvedEvent.EventID())) + return nil, fmt.Errorf("missing state entry for event ID %q", resolvedEvent.EventID()) } notConflicted = append(notConflicted, entry) } @@ -996,6 +975,84 @@ func (v *StateResolution) loadStateEvents( return result, eventIDMap, nil } +// loadAuthEvents loads all of the auth events for a given event recursively, +// along with a map that contains state entries for all of the auth events. +func (v *StateResolution) loadAuthEvents( + ctx context.Context, event *gomatrixserverlib.Event, +) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) { + eventMap := map[string]struct{}{} + var lookup []string + var authEvents []types.Event + queue := event.AuthEventIDs() + for i := 0; i < len(queue); i++ { + lookup = lookup[:0] + for _, authEventID := range queue { + if _, ok := eventMap[authEventID]; ok { + continue + } + lookup = append(lookup, authEventID) + } + if len(lookup) == 0 { + break + } + events, err := v.db.EventsFromIDs(ctx, lookup) + if err != nil { + return nil, nil, fmt.Errorf("v.db.EventsFromIDs: %w", err) + } + add := map[string]struct{}{} + for _, event := range events { + eventMap[event.EventID()] = struct{}{} + authEvents = append(authEvents, event) + for _, authEventID := range event.AuthEventIDs() { + if _, ok := eventMap[authEventID]; ok { + continue + } + add[authEventID] = struct{}{} + } + for authEventID := range add { + queue = append(queue, authEventID) + } + } + } + authEventTypes := map[string]struct{}{} + authEventStateKeys := map[string]struct{}{} + for _, authEvent := range authEvents { + authEventTypes[authEvent.Type()] = struct{}{} + authEventStateKeys[*authEvent.StateKey()] = struct{}{} + } + lookupAuthEventTypes := make([]string, 0, len(authEventTypes)) + lookupAuthEventStateKeys := make([]string, 0, len(authEventStateKeys)) + for eventType := range authEventTypes { + lookupAuthEventTypes = append(lookupAuthEventTypes, eventType) + } + for eventStateKey := range authEventStateKeys { + lookupAuthEventStateKeys = append(lookupAuthEventStateKeys, eventStateKey) + } + eventTypes, err := v.db.EventTypeNIDs(ctx, lookupAuthEventTypes) + if err != nil { + return nil, nil, fmt.Errorf("v.db.EventTypeNIDs: %w", err) + } + eventStateKeys, err := v.db.EventStateKeyNIDs(ctx, lookupAuthEventStateKeys) + if err != nil { + return nil, nil, fmt.Errorf("v.db.EventStateKeyNIDs: %w", err) + } + stateEntryMap := map[string]types.StateEntry{} + for _, authEvent := range authEvents { + stateEntryMap[authEvent.EventID()] = types.StateEntry{ + EventNID: authEvent.EventNID, + StateKeyTuple: types.StateKeyTuple{ + EventTypeNID: eventTypes[authEvent.Type()], + EventStateKeyNID: eventStateKeys[*authEvent.StateKey()], + }, + } + } + nakedEvents := make([]*gomatrixserverlib.Event, 0, len(authEvents)) + for _, authEvent := range authEvents { + nakedEvents = append(nakedEvents, authEvent.Event) + } + return nakedEvents, stateEntryMap, nil +} + // findDuplicateStateKeys finds the state entries where the state key tuple appears more than once in a sorted list. // Returns a sorted list of those state entries. func findDuplicateStateKeys(a []types.StateEntry) []types.StateEntry { diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index d4a2ee3b9..8f4e011bf 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -192,6 +192,10 @@ func (u *RoomUpdater) StateAtEventIDs( return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs) } +func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { + return u.d.eventsFromIDs(ctx, u.txn, eventIDs, false) +} + func (u *RoomUpdater) UnsentEventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { return u.d.eventsFromIDs(ctx, u.txn, eventIDs, true) } diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go index f86812f17..e3cab56ee 100644 --- a/userapi/storage/postgres/accounts_table.go +++ b/userapi/storage/postgres/accounts_table.go @@ -65,7 +65,7 @@ const selectPasswordHashSQL = "" + "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE" const selectNewNumericLocalpartSQL = "" + - "SELECT COALESCE(MAX(localpart::integer), 0) FROM account_accounts WHERE localpart ~ '^[0-9]*$'" + "SELECT COALESCE(MAX(localpart::bigint), 0) FROM account_accounts WHERE localpart ~ '^[0-9]{1,}$'" type accountsStatements struct { insertAccountStmt *sql.Stmt diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index 5bee880d3..a26097338 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -124,6 +124,23 @@ func Test_Accounts(t *testing.T) { _, err = db.GetAccountByLocalpart(ctx, "unusename") assert.Error(t, err, "expected an error for non existent localpart") + + // create an empty localpart; this should never happen, but is required to test getting a numeric localpart + // if there's already a user without a localpart in the database + _, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeUser) + assert.NoError(t, err) + + // test getting a numeric localpart, with an existing user without a localpart + _, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeGuest) + assert.NoError(t, err) + + // Create a user with a high numeric localpart, out of range for the Postgres integer (2147483647) type + _, err = db.CreateAccount(ctx, "2147483650", "", "", api.AccountTypeUser) + assert.NoError(t, err) + + // Now try to create a new guest user + _, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeGuest) + assert.NoError(t, err) }) }