merge and resolve conflicts with main

This commit is contained in:
Tak Wai Wong 2022-06-01 09:34:29 -07:00
parent 63c169958d
commit ffb7581ef2
19 changed files with 202 additions and 60 deletions

View file

@ -1,5 +1,20 @@
# Changelog # Changelog
<<<<<<< HEAD
=======
## Dendrite 0.8.7 (2022-06-01)
### Features
* Support added for room version 10
### Fixes
* A number of state handling bugs have been fixed, which previously resulted in missing state events, unexpected state deletions, reverted memberships and unexpectedly rejected/soft-failed events in some specific cases
* Fixed destination queue performance issues as a result of missing indexes, which speeds up outbound federation considerably
* A bug which could cause the `/register` endpoint to return HTTP 500 has been fixed
>>>>>>> main
## Dendrite 0.8.6 (2022-05-26) ## Dendrite 0.8.6 (2022-05-26)
### Features ### Features

View file

@ -96,10 +96,9 @@ than features that massive deployments may be interested in (User Directory, Ope
This means Dendrite supports amongst others: This means Dendrite supports amongst others:
- Core room functionality (creating rooms, invites, auth rules) - Core room functionality (creating rooms, invites, auth rules)
- Full support for room versions 1 to 7 - Room versions 1 to 10 supported
- Experimental support for room versions 8 to 9
- Backfilling locally and via federation - Backfilling locally and via federation
- Accounts, Profiles and Devices - Accounts, profiles and devices
- Published room lists - Published room lists
- Typing - Typing
- Media APIs - Media APIs

View file

@ -89,6 +89,9 @@ func Setup(
"r0.4.0", "r0.4.0",
"r0.5.0", "r0.5.0",
"r0.6.1", "r0.6.1",
"v1.0",
"v1.1",
"v1.2",
}, UnstableFeatures: unstableFeatures}, }, UnstableFeatures: unstableFeatures},
} }
}), }),

View file

@ -4,9 +4,12 @@ import (
"context" "context"
"flag" "flag"
"fmt" "fmt"
"sort"
"strconv" "strconv"
"strings"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup"
@ -57,25 +60,23 @@ func main() {
panic(err) panic(err)
} }
blockNIDs, err := roomserverDB.StateBlockNIDs(ctx, snapshotNIDs) stateres := state.NewStateResolution(roomserverDB, &types.RoomInfo{
if err != nil { RoomVersion: gomatrixserverlib.RoomVersion(*roomVersion),
panic(err) })
}
var stateEntries []types.StateEntryList var stateEntries []types.StateEntry
for _, list := range blockNIDs { for _, snapshotNID := range snapshotNIDs {
entries, err2 := roomserverDB.StateEntries(ctx, list.StateBlockNIDs) var entries []types.StateEntry
if err2 != nil { entries, err = stateres.LoadStateAtSnapshot(ctx, snapshotNID)
panic(err2) if err != nil {
panic(err)
} }
stateEntries = append(stateEntries, entries...) stateEntries = append(stateEntries, entries...)
} }
var eventNIDs []types.EventNID var eventNIDs []types.EventNID
for _, entry := range stateEntries { for _, entry := range stateEntries {
for _, e := range entry.StateEntries { eventNIDs = append(eventNIDs, entry.EventNID)
eventNIDs = append(eventNIDs, e.EventNID)
}
} }
fmt.Println("Fetching", len(eventNIDs), "state events") fmt.Println("Fetching", len(eventNIDs), "state events")
@ -110,7 +111,8 @@ func main() {
} }
fmt.Println("Resolving state") fmt.Println("Resolving state")
resolved, err := gomatrixserverlib.ResolveConflicts( var resolved Events
resolved, err = gomatrixserverlib.ResolveConflicts(
gomatrixserverlib.RoomVersion(*roomVersion), gomatrixserverlib.RoomVersion(*roomVersion),
events, events,
authEvents, authEvents,
@ -120,6 +122,7 @@ func main() {
} }
fmt.Println("Resolved state contains", len(resolved), "events") fmt.Println("Resolved state contains", len(resolved), "events")
sort.Sort(resolved)
filteringEventType := *filterType filteringEventType := *filterType
count := 0 count := 0
for _, event := range resolved { for _, event := range resolved {
@ -135,3 +138,25 @@ func main() {
fmt.Println() fmt.Println()
fmt.Println("Returned", count, "state events after filtering") fmt.Println("Returned", count, "state events after filtering")
} }
type Events []*gomatrixserverlib.Event
func (e Events) Len() int {
return len(e)
}
func (e Events) Swap(i, j int) {
e[i], e[j] = e[j], e[i]
}
func (e Events) Less(i, j int) bool {
typeDelta := strings.Compare(e[i].Type(), e[j].Type())
if typeDelta < 0 {
return true
}
if typeDelta > 0 {
return false
}
stateKeyDelta := strings.Compare(*e[i].StateKey(), *e[j].StateKey())
return stateKeyDelta < 0
}

View file

@ -36,6 +36,10 @@ CREATE TABLE IF NOT EXISTS federationsender_queue_edus (
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx
ON federationsender_queue_edus (json_nid, server_name); ON federationsender_queue_edus (json_nid, server_name);
CREATE INDEX IF NOT EXISTS federationsender_queue_edus_nid_idx
ON federationsender_queue_edus (json_nid);
CREATE INDEX IF NOT EXISTS federationsender_queue_edus_server_name_idx
ON federationsender_queue_edus (server_name);
` `
const insertQueueEDUSQL = "" + const insertQueueEDUSQL = "" +

View file

@ -33,6 +33,9 @@ CREATE TABLE IF NOT EXISTS federationsender_queue_json (
-- The JSON body. Text so that we preserve UTF-8. -- The JSON body. Text so that we preserve UTF-8.
json_body TEXT NOT NULL json_body TEXT NOT NULL
); );
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_json_json_nid_idx
ON federationsender_queue_json (json_nid);
` `
const insertJSONSQL = "" + const insertJSONSQL = "" +

View file

@ -36,6 +36,10 @@ CREATE TABLE IF NOT EXISTS federationsender_queue_pdus (
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_pdus_pdus_json_nid_idx CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_pdus_pdus_json_nid_idx
ON federationsender_queue_pdus (json_nid, server_name); ON federationsender_queue_pdus (json_nid, server_name);
CREATE INDEX IF NOT EXISTS federationsender_queue_pdus_json_nid_idx
ON federationsender_queue_pdus (json_nid);
CREATE INDEX IF NOT EXISTS federationsender_queue_pdus_server_name_idx
ON federationsender_queue_pdus (server_name);
` `
const insertQueuePDUSQL = "" + const insertQueuePDUSQL = "" +

View file

@ -37,6 +37,10 @@ CREATE TABLE IF NOT EXISTS federationsender_queue_edus (
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx
ON federationsender_queue_edus (json_nid, server_name); ON federationsender_queue_edus (json_nid, server_name);
CREATE INDEX IF NOT EXISTS federationsender_queue_edus_nid_idx
ON federationsender_queue_edus (json_nid);
CREATE INDEX IF NOT EXISTS federationsender_queue_edus_server_name_idx
ON federationsender_queue_edus (server_name);
` `
const insertQueueEDUSQL = "" + const insertQueueEDUSQL = "" +

View file

@ -35,6 +35,9 @@ CREATE TABLE IF NOT EXISTS federationsender_queue_json (
-- The JSON body. Text so that we preserve UTF-8. -- The JSON body. Text so that we preserve UTF-8.
json_body TEXT NOT NULL json_body TEXT NOT NULL
); );
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_json_json_nid_idx
ON federationsender_queue_json (json_nid);
` `
const insertJSONSQL = "" + const insertJSONSQL = "" +

View file

@ -38,6 +38,10 @@ CREATE TABLE IF NOT EXISTS federationsender_queue_pdus (
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_pdus_pdus_json_nid_idx CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_pdus_pdus_json_nid_idx
ON federationsender_queue_pdus (json_nid, server_name); ON federationsender_queue_pdus (json_nid, server_name);
CREATE INDEX IF NOT EXISTS federationsender_queue_pdus_json_nid_idx
ON federationsender_queue_pdus (json_nid);
CREATE INDEX IF NOT EXISTS federationsender_queue_pdus_server_name_idx
ON federationsender_queue_pdus (server_name);
` `
const insertQueuePDUSQL = "" + const insertQueuePDUSQL = "" +

2
go.mod
View file

@ -35,7 +35,7 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
github.com/matrix-org/gomatrixserverlib v0.0.0-20220526140030-dcfbb70ff32d github.com/matrix-org/gomatrixserverlib v0.0.0-20220531163017-35e1cabf12ee
github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48 github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.13 github.com/mattn/go-sqlite3 v1.14.13

4
go.sum
View file

@ -542,8 +542,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220526140030-dcfbb70ff32d h1:IwyG/58rFn0/ugD0A/IdSIo7D/oLJ4+k3NznlYhzyHs= github.com/matrix-org/gomatrixserverlib v0.0.0-20220531163017-35e1cabf12ee h1:56sxEWrwB3eOmwjP2S4JsrQf29uBUaf+8WrbQJmjaGE=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220526140030-dcfbb70ff32d/go.mod h1:jX38yp3SSLJNftBg3PXU1ayd0PCLIiDHQ4xAc9DIixk= github.com/matrix-org/gomatrixserverlib v0.0.0-20220531163017-35e1cabf12ee/go.mod h1:jX38yp3SSLJNftBg3PXU1ayd0PCLIiDHQ4xAc9DIixk=
github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48 h1:W0sjjC6yjskHX4mb0nk3p0fXAlbU5bAFUFeEtlrPASE= github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48 h1:W0sjjC6yjskHX4mb0nk3p0fXAlbU5bAFUFeEtlrPASE=
github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48/go.mod h1:ulJzsVOTssIVp1j/m5eI//4VpAGDkMt5NrRuAVX7wpc= github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48/go.mod h1:ulJzsVOTssIVp1j/m5eI//4VpAGDkMt5NrRuAVX7wpc=
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=

View file

@ -17,7 +17,7 @@ var build string
const ( const (
VersionMajor = 0 VersionMajor = 0
VersionMinor = 8 VersionMinor = 8
VersionPatch = 6 VersionPatch = 7
VersionTag = "" // example: "rc1" VersionTag = "" // example: "rc1"
) )

View file

@ -374,7 +374,7 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam
// fetch stale device lists // fetch stale device lists
userIDs, err := u.db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{serverName}) userIDs, err := u.db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{serverName})
if err != nil { if err != nil {
logger.WithError(err).Error("failed to load stale device lists") logger.WithError(err).Error("Failed to load stale device lists")
return waitTime, true return waitTime, true
} }
failCount := 0 failCount := 0
@ -399,7 +399,7 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam
} }
} else { } else {
waitTime = time.Hour waitTime = time.Hour
logger.WithError(err).WithField("user_id", userID).Warn("GetUserDevices returned unknown error type") logger.WithError(err).WithField("user_id", userID).Debug("GetUserDevices returned unknown error type")
} }
continue continue
} }
@ -422,12 +422,12 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam
} }
err = u.updateDeviceList(&res) err = u.updateDeviceList(&res)
if err != nil { if err != nil {
logger.WithError(err).WithField("user_id", userID).Error("fetched device list but failed to store/emit it") logger.WithError(err).WithField("user_id", userID).Error("Fetched device list but failed to store/emit it")
failCount += 1 failCount += 1
} }
} }
if failCount > 0 { if failCount > 0 {
logger.WithField("total", len(userIDs)).WithField("failed", failCount).WithField("wait", waitTime).Error("failed to query device keys for some users") logger.WithField("total", len(userIDs)).WithField("failed", failCount).WithField("wait", waitTime).Warn("Failed to query device keys for some users")
} }
for _, userID := range userIDs { for _, userID := range userIDs {
// always clear the channel to unblock Update calls regardless of success/failure // always clear the channel to unblock Update calls regardless of success/failure

View file

@ -206,7 +206,7 @@ func (u *latestEventsUpdater) latestState() error {
// Work out if the state at the extremities has actually changed // Work out if the state at the extremities has actually changed
// or not. If they haven't then we won't bother doing all of the // or not. If they haven't then we won't bother doing all of the
// hard work. // hard work.
if u.event.StateKey() == nil { if !u.stateAtEvent.IsStateEvent() {
stateChanged := false stateChanged := false
oldStateNIDs := make([]types.StateSnapshotNID, 0, len(u.oldLatest)) oldStateNIDs := make([]types.StateSnapshotNID, 0, len(u.oldLatest))
newStateNIDs := make([]types.StateSnapshotNID, 0, len(u.latest)) newStateNIDs := make([]types.StateSnapshotNID, 0, len(u.latest))

View file

@ -39,6 +39,7 @@ type StateResolutionStorage interface {
StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error)
EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error)
} }
type StateResolution struct { type StateResolution struct {
@ -659,15 +660,13 @@ func (v *StateResolution) calculateStateAfterManyEvents(
} }
// Collect all the entries with the same type and key together. // Collect all the entries with the same type and key together.
// We don't care about the order here because the conflict resolution // This is done so findDuplicateStateKeys can work in groups.
// algorithm doesn't depend on the order of the prev events. // We remove duplicates (same type, state key and event NID) too.
// Remove duplicate entires.
combined = combined[:util.SortAndUnique(stateEntrySorter(combined))] combined = combined[:util.SortAndUnique(stateEntrySorter(combined))]
// Find the conflicts // Find the conflicts
conflicts := findDuplicateStateKeys(combined) if conflicts := findDuplicateStateKeys(combined); len(conflicts) > 0 {
conflictMap := stateEntryMap(conflicts)
if len(conflicts) > 0 {
conflictLength = len(conflicts) conflictLength = len(conflicts)
// 5) There are conflicting state events, for each conflict workout // 5) There are conflicting state events, for each conflict workout
@ -676,7 +675,7 @@ func (v *StateResolution) calculateStateAfterManyEvents(
// Work out which entries aren't conflicted. // Work out which entries aren't conflicted.
var notConflicted []types.StateEntry var notConflicted []types.StateEntry
for _, entry := range combined { for _, entry := range combined {
if _, ok := stateEntryMap(conflicts).lookup(entry.StateKeyTuple); !ok { if _, ok := conflictMap.lookup(entry.StateKeyTuple); !ok {
notConflicted = append(notConflicted, entry) notConflicted = append(notConflicted, entry)
} }
} }
@ -689,7 +688,7 @@ func (v *StateResolution) calculateStateAfterManyEvents(
return return
} }
algorithm = "full_state_with_conflicts" algorithm = "full_state_with_conflicts"
state = resolved[:util.SortAndUnique(stateEntrySorter(resolved))] state = resolved
} else { } else {
algorithm = "full_state_no_conflicts" algorithm = "full_state_no_conflicts"
// 6) There weren't any conflicts // 6) There weren't any conflicts
@ -818,39 +817,19 @@ func (v *StateResolution) resolveConflictsV2(
authDifference := make([]*gomatrixserverlib.Event, 0, estimate) authDifference := make([]*gomatrixserverlib.Event, 0, estimate)
// For each conflicted event, let's try and get the needed auth events. // For each conflicted event, let's try and get the needed auth events.
neededStateKeys := make([]string, 16)
authEntries := make([]types.StateEntry, 16)
for _, conflictedEvent := range conflictedEvents { for _, conflictedEvent := range conflictedEvents {
// Work out which auth events we need to load. // Work out which auth events we need to load.
key := conflictedEvent.EventID() key := conflictedEvent.EventID()
needed := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{conflictedEvent})
// Find the numeric IDs for the necessary state keys.
neededStateKeys = neededStateKeys[:0]
neededStateKeys = append(neededStateKeys, needed.Member...)
neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...)
stateKeyNIDMap, err := v.db.EventStateKeyNIDs(ctx, neededStateKeys)
if err != nil {
return nil, err
}
// Load the necessary auth events.
tuplesNeeded := v.stateKeyTuplesNeeded(stateKeyNIDMap, needed)
authEntries = authEntries[:0]
for _, tuple := range tuplesNeeded {
if eventNID, ok := stateEntryMap(notConflicted).lookup(tuple); ok {
authEntries = append(authEntries, types.StateEntry{
StateKeyTuple: tuple,
EventNID: eventNID,
})
}
}
// Store the newly found auth events in the auth set for this event. // Store the newly found auth events in the auth set for this event.
authSets[key], _, err = v.loadStateEvents(ctx, authEntries) var authEventMap map[string]types.StateEntry
authSets[key], authEventMap, err = v.loadAuthEvents(ctx, conflictedEvent)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for k, v := range authEventMap {
eventIDMap[k] = v
}
// Only add auth events into the authEvents slice once, otherwise the // Only add auth events into the authEvents slice once, otherwise the
// check for the auth difference can become expensive and produce // check for the auth difference can become expensive and produce
@ -909,7 +888,7 @@ func (v *StateResolution) resolveConflictsV2(
for _, resolvedEvent := range resolvedEvents { for _, resolvedEvent := range resolvedEvents {
entry, ok := eventIDMap[resolvedEvent.EventID()] entry, ok := eventIDMap[resolvedEvent.EventID()]
if !ok { if !ok {
panic(fmt.Errorf("missing state entry for event ID %q", resolvedEvent.EventID())) return nil, fmt.Errorf("missing state entry for event ID %q", resolvedEvent.EventID())
} }
notConflicted = append(notConflicted, entry) notConflicted = append(notConflicted, entry)
} }
@ -996,6 +975,84 @@ func (v *StateResolution) loadStateEvents(
return result, eventIDMap, nil return result, eventIDMap, nil
} }
// loadAuthEvents loads all of the auth events for a given event recursively,
// along with a map that contains state entries for all of the auth events.
func (v *StateResolution) loadAuthEvents(
ctx context.Context, event *gomatrixserverlib.Event,
) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) {
eventMap := map[string]struct{}{}
var lookup []string
var authEvents []types.Event
queue := event.AuthEventIDs()
for i := 0; i < len(queue); i++ {
lookup = lookup[:0]
for _, authEventID := range queue {
if _, ok := eventMap[authEventID]; ok {
continue
}
lookup = append(lookup, authEventID)
}
if len(lookup) == 0 {
break
}
events, err := v.db.EventsFromIDs(ctx, lookup)
if err != nil {
return nil, nil, fmt.Errorf("v.db.EventsFromIDs: %w", err)
}
add := map[string]struct{}{}
for _, event := range events {
eventMap[event.EventID()] = struct{}{}
authEvents = append(authEvents, event)
for _, authEventID := range event.AuthEventIDs() {
if _, ok := eventMap[authEventID]; ok {
continue
}
add[authEventID] = struct{}{}
}
for authEventID := range add {
queue = append(queue, authEventID)
}
}
}
authEventTypes := map[string]struct{}{}
authEventStateKeys := map[string]struct{}{}
for _, authEvent := range authEvents {
authEventTypes[authEvent.Type()] = struct{}{}
authEventStateKeys[*authEvent.StateKey()] = struct{}{}
}
lookupAuthEventTypes := make([]string, 0, len(authEventTypes))
lookupAuthEventStateKeys := make([]string, 0, len(authEventStateKeys))
for eventType := range authEventTypes {
lookupAuthEventTypes = append(lookupAuthEventTypes, eventType)
}
for eventStateKey := range authEventStateKeys {
lookupAuthEventStateKeys = append(lookupAuthEventStateKeys, eventStateKey)
}
eventTypes, err := v.db.EventTypeNIDs(ctx, lookupAuthEventTypes)
if err != nil {
return nil, nil, fmt.Errorf("v.db.EventTypeNIDs: %w", err)
}
eventStateKeys, err := v.db.EventStateKeyNIDs(ctx, lookupAuthEventStateKeys)
if err != nil {
return nil, nil, fmt.Errorf("v.db.EventStateKeyNIDs: %w", err)
}
stateEntryMap := map[string]types.StateEntry{}
for _, authEvent := range authEvents {
stateEntryMap[authEvent.EventID()] = types.StateEntry{
EventNID: authEvent.EventNID,
StateKeyTuple: types.StateKeyTuple{
EventTypeNID: eventTypes[authEvent.Type()],
EventStateKeyNID: eventStateKeys[*authEvent.StateKey()],
},
}
}
nakedEvents := make([]*gomatrixserverlib.Event, 0, len(authEvents))
for _, authEvent := range authEvents {
nakedEvents = append(nakedEvents, authEvent.Event)
}
return nakedEvents, stateEntryMap, nil
}
// findDuplicateStateKeys finds the state entries where the state key tuple appears more than once in a sorted list. // findDuplicateStateKeys finds the state entries where the state key tuple appears more than once in a sorted list.
// Returns a sorted list of those state entries. // Returns a sorted list of those state entries.
func findDuplicateStateKeys(a []types.StateEntry) []types.StateEntry { func findDuplicateStateKeys(a []types.StateEntry) []types.StateEntry {

View file

@ -192,6 +192,10 @@ func (u *RoomUpdater) StateAtEventIDs(
return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs) return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs)
} }
func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
return u.d.eventsFromIDs(ctx, u.txn, eventIDs, false)
}
func (u *RoomUpdater) UnsentEventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { func (u *RoomUpdater) UnsentEventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
return u.d.eventsFromIDs(ctx, u.txn, eventIDs, true) return u.d.eventsFromIDs(ctx, u.txn, eventIDs, true)
} }

View file

@ -65,7 +65,7 @@ const selectPasswordHashSQL = "" +
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE" "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
const selectNewNumericLocalpartSQL = "" + const selectNewNumericLocalpartSQL = "" +
"SELECT COALESCE(MAX(localpart::integer), 0) FROM account_accounts WHERE localpart ~ '^[0-9]*$'" "SELECT COALESCE(MAX(localpart::bigint), 0) FROM account_accounts WHERE localpart ~ '^[0-9]{1,}$'"
type accountsStatements struct { type accountsStatements struct {
insertAccountStmt *sql.Stmt insertAccountStmt *sql.Stmt

View file

@ -124,6 +124,23 @@ func Test_Accounts(t *testing.T) {
_, err = db.GetAccountByLocalpart(ctx, "unusename") _, err = db.GetAccountByLocalpart(ctx, "unusename")
assert.Error(t, err, "expected an error for non existent localpart") assert.Error(t, err, "expected an error for non existent localpart")
// create an empty localpart; this should never happen, but is required to test getting a numeric localpart
// if there's already a user without a localpart in the database
_, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeUser)
assert.NoError(t, err)
// test getting a numeric localpart, with an existing user without a localpart
_, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeGuest)
assert.NoError(t, err)
// Create a user with a high numeric localpart, out of range for the Postgres integer (2147483647) type
_, err = db.CreateAccount(ctx, "2147483650", "", "", api.AccountTypeUser)
assert.NoError(t, err)
// Now try to create a new guest user
_, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeGuest)
assert.NoError(t, err)
}) })
} }