From 28690b72f92229b5ae2c944e439ccb4116a5d850 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Wed, 1 Feb 2023 14:01:08 +0100 Subject: [PATCH] Make SQLite work again --- roomserver/internal/query/query.go | 104 ++++++++---------- syncapi/internal/history_visibility.go | 2 +- .../postgres/output_room_events_table.go | 6 +- .../sqlite3/output_room_events_table.go | 92 +++++++++------- 4 files changed, 105 insertions(+), 99 deletions(-) diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 177c1352d..b960e0fa8 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -226,6 +227,7 @@ func (r *Queryer) QueryMembershipAtEvent( return fmt.Errorf("no roomInfo found") } + // get the users stateKeyNID stateKeyNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{request.UserID}) if err != nil { return fmt.Errorf("unable to get stateKeyNIDs for %s: %w", request.UserID, err) @@ -238,74 +240,62 @@ func (r *Queryer) QueryMembershipAtEvent( switch err { case nil: return nil - //case tables.OptimisationNotSupportedError: // fallthrough + case tables.OptimisationNotSupportedError: // fallthrough default: return err } - /* - // get the users stateKeyNID - stateKeyNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{request.UserID}) - if err != nil { - return fmt.Errorf("unable to get stateKeyNIDs for %s: %w", request.UserID, err) + response.Memberships = make(map[string]*gomatrixserverlib.HeaderedEvent) + stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID]) + if err != nil { + return fmt.Errorf("unable to get state before event: %w", err) + } + + // If we only have one or less state entries, we can short circuit the below + // loop and avoid hitting the database + allStateEventNIDs := make(map[types.EventNID]types.StateEntry) + for _, eventID := range request.EventIDs { + stateEntry := stateEntries[eventID] + for _, s := range stateEntry { + allStateEventNIDs[s.EventNID] = s } - if _, ok := stateKeyNIDs[request.UserID]; !ok { - return fmt.Errorf("requested stateKeyNID for %s was not found", request.UserID) + } + + var canShortCircuit bool + if len(allStateEventNIDs) <= 1 { + canShortCircuit = true + } + + var memberships []types.Event + for _, eventID := range request.EventIDs { + stateEntry, ok := stateEntries[eventID] + if !ok || len(stateEntry) == 0 { + response.Memberships[eventID] = nil + continue } - stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID]) - if err != nil { - return fmt.Errorf("unable to get state before event: %w", err) - } - - // If we only have one or less state entries, we can short circuit the below - // loop and avoid hitting the database - allStateEventNIDs := make(map[types.EventNID]types.StateEntry) - for _, eventID := range request.EventIDs { - stateEntry := stateEntries[eventID] - for _, s := range stateEntry { - allStateEventNIDs[s.EventNID] = s - } - } - - var canShortCircuit bool - if len(allStateEventNIDs) <= 1 { - canShortCircuit = true - } - - var memberships []types.Event - for _, eventID := range request.EventIDs { - stateEntry, ok := stateEntries[eventID] - if !ok || len(stateEntry) == 0 { - response.Memberships[eventID] = nil - continue - } - - // If we can short circuit, e.g. we only have 0 or 1 membership events, we only get the memberships - // once. If we have more than one membership event, we need to get the state for each state entry. - if canShortCircuit { - if len(memberships) == 0 { - memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntry, false) - } - } else { + // If we can short circuit, e.g. we only have 0 or 1 membership events, we only get the memberships + // once. If we have more than one membership event, we need to get the state for each state entry. + if canShortCircuit { + if len(memberships) == 0 { memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntry, false) } - if err != nil { - return fmt.Errorf("unable to get memberships at state: %w", err) - } - - res := make([]*gomatrixserverlib.HeaderedEvent, 0, len(memberships)) - - for i := range memberships { - ev := memberships[i] - if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(request.UserID) { - res = append(res, ev.Headered(roomVersion)) - } - } - response.Memberships[eventID] = res + } else { + memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntry, false) + } + if err != nil { + return fmt.Errorf("unable to get memberships at state: %w", err) } - return nil*/ + for i := range memberships { + ev := memberships[i] + if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(request.UserID) { + response.Memberships[eventID] = ev.Event.Headered(info.RoomVersion) + } + } + } + + return nil } // QueryMembershipsForRoom implements api.RoomserverInternalAPI diff --git a/syncapi/internal/history_visibility.go b/syncapi/internal/history_visibility.go index 7b5fa29b9..2b4ebaa41 100644 --- a/syncapi/internal/history_visibility.go +++ b/syncapi/internal/history_visibility.go @@ -200,7 +200,7 @@ func visibilityForEvents( visibility: event.Visibility, } ev, ok := membershipResp.Memberships[eventID] - if !ok { + if !ok || ev == nil { result[eventID] = vis continue } diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index 49dffade6..31b4753a1 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -407,7 +407,11 @@ func (s *outputRoomEventsStatements) InsertEvent( // selectRecentEvents returns the most recent events in the given room, up to a maximum of 'limit'. // If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude // from sync. -func (s *outputRoomEventsStatements) SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomIDs []string, ra types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) (map[string]types.RecentEvents, error) { +func (s *outputRoomEventsStatements) SelectRecentEvents( + ctx context.Context, txn *sql.Tx, + roomIDs []string, ra types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, + chronologicalOrder bool, onlySyncEvents bool, +) (map[string]types.RecentEvents, error) { var stmt *sql.Stmt if onlySyncEvents { stmt = sqlutil.TxStmt(txn, s.selectRecentEventsForSyncStmt) diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index dabba0402..23bc68a41 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -366,7 +366,11 @@ func (s *outputRoomEventsStatements) InsertEvent( return streamPos, err } -func (s *outputRoomEventsStatements) SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomIDs []string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) (map[string]types.RecentEvents, error) { +func (s *outputRoomEventsStatements) SelectRecentEvents( + ctx context.Context, txn *sql.Tx, + roomIDs []string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, + chronologicalOrder bool, onlySyncEvents bool, +) (map[string]types.RecentEvents, error) { var query string if onlySyncEvents { query = selectRecentEventsForSyncSQL @@ -374,47 +378,55 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(ctx context.Context, txn query = selectRecentEventsSQL } - stmt, params, err := prepareWithFilters( - s.db, txn, query, - []interface{}{ - roomIDs, r.Low(), r.High(), - }, - eventFilter.Senders, eventFilter.NotSenders, - eventFilter.Types, eventFilter.NotTypes, - nil, eventFilter.ContainsURL, eventFilter.Limit+1, FilterOrderDesc, - ) - if err != nil { - return nil, fmt.Errorf("s.prepareWithFilters: %w", err) - } - defer internal.CloseAndLogIfError(ctx, stmt, "selectRecentEvents: stmt.close() failed") - - rows, err := stmt.QueryContext(ctx, params...) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectRecentEvents: rows.close() failed") - events, err := rowsToStreamEvents(rows) - if err != nil { - return nil, err - } - if chronologicalOrder { - // The events need to be returned from oldest to latest, which isn't - // necessary the way the SQL query returns them, so a sort is necessary to - // ensure the events are in the right order in the slice. - sort.SliceStable(events, func(i int, j int) bool { - return events[i].StreamPosition < events[j].StreamPosition - }) - } - // we queried for 1 more than the limit, so if we returned one more mark limited=true - if len(events) > eventFilter.Limit { - // re-slice the extra (oldest) event out: in chronological order this is the first entry, else the last. - if chronologicalOrder { - events = events[1:] - } else { - events = events[:len(events)-1] + result := make(map[string]types.RecentEvents, len(roomIDs)) + for _, roomID := range roomIDs { + stmt, params, err := prepareWithFilters( + s.db, txn, query, + []interface{}{ + roomID, r.Low(), r.High(), + }, + eventFilter.Senders, eventFilter.NotSenders, + eventFilter.Types, eventFilter.NotTypes, + nil, eventFilter.ContainsURL, eventFilter.Limit+1, FilterOrderDesc, + ) + if err != nil { + return nil, fmt.Errorf("s.prepareWithFilters: %w", err) } + defer internal.CloseAndLogIfError(ctx, stmt, "selectRecentEvents: stmt.close() failed") + + rows, err := stmt.QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRecentEvents: rows.close() failed") + events, err := rowsToStreamEvents(rows) + if err != nil { + return nil, err + } + if chronologicalOrder { + // The events need to be returned from oldest to latest, which isn't + // necessary the way the SQL query returns them, so a sort is necessary to + // ensure the events are in the right order in the slice. + sort.SliceStable(events, func(i int, j int) bool { + return events[i].StreamPosition < events[j].StreamPosition + }) + } + res := types.RecentEvents{} + // we queried for 1 more than the limit, so if we returned one more mark limited=true + if len(events) > eventFilter.Limit { + res.Limited = true + // re-slice the extra (oldest) event out: in chronological order this is the first entry, else the last. + if chronologicalOrder { + events = events[1:] + } else { + events = events[:len(events)-1] + } + } + res.Events = events + result[roomID] = res } - return map[string]types.RecentEvents{}, nil + + return result, nil } func (s *outputRoomEventsStatements) SelectEarlyEvents(