Optimize history visibility checks (#2848)
This optimizes history visibility checks by (mostly) avoiding database hits. Possibly solves https://github.com/matrix-org/dendrite/issues/2777 Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
This commit is contained in:
parent
0b21cb78aa
commit
2acc1d65fb
|
@ -5,6 +5,7 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
@ -159,7 +160,7 @@ func GetMembershipsAtState(
|
||||||
ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool,
|
ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool,
|
||||||
) ([]types.Event, error) {
|
) ([]types.Event, error) {
|
||||||
|
|
||||||
var eventNIDs []types.EventNID
|
var eventNIDs types.EventNIDs
|
||||||
for _, entry := range stateEntries {
|
for _, entry := range stateEntries {
|
||||||
// Filter the events to retrieve to only keep the membership events
|
// Filter the events to retrieve to only keep the membership events
|
||||||
if entry.EventTypeNID == types.MRoomMemberNID {
|
if entry.EventTypeNID == types.MRoomMemberNID {
|
||||||
|
@ -167,6 +168,14 @@ func GetMembershipsAtState(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// There are no events to get, don't bother asking the database
|
||||||
|
if len(eventNIDs) == 0 {
|
||||||
|
return []types.Event{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Sort(eventNIDs)
|
||||||
|
util.Unique(eventNIDs)
|
||||||
|
|
||||||
// Get all of the events in this state
|
// Get all of the events in this state
|
||||||
stateEvents, err := db.Events(ctx, eventNIDs)
|
stateEvents, err := db.Events(ctx, eventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -239,16 +239,42 @@ func (r *Queryer) QueryMembershipAtEvent(
|
||||||
return fmt.Errorf("unable to get state before event: %w", err)
|
return fmt.Errorf("unable to get state before event: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we only have one or less state entries, we can short circuit the below
|
||||||
|
// loop and avoid hitting the database
|
||||||
|
allStateEventNIDs := make(map[types.EventNID]types.StateEntry)
|
||||||
|
for _, eventID := range request.EventIDs {
|
||||||
|
stateEntry := stateEntries[eventID]
|
||||||
|
for _, s := range stateEntry {
|
||||||
|
allStateEventNIDs[s.EventNID] = s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var canShortCircuit bool
|
||||||
|
if len(allStateEventNIDs) <= 1 {
|
||||||
|
canShortCircuit = true
|
||||||
|
}
|
||||||
|
|
||||||
|
var memberships []types.Event
|
||||||
for _, eventID := range request.EventIDs {
|
for _, eventID := range request.EventIDs {
|
||||||
stateEntry, ok := stateEntries[eventID]
|
stateEntry, ok := stateEntries[eventID]
|
||||||
if !ok {
|
if !ok || len(stateEntry) == 0 {
|
||||||
response.Memberships[eventID] = []*gomatrixserverlib.HeaderedEvent{}
|
response.Memberships[eventID] = []*gomatrixserverlib.HeaderedEvent{}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
memberships, err := helpers.GetMembershipsAtState(ctx, r.DB, stateEntry, false)
|
|
||||||
|
// If we can short circuit, e.g. we only have 0 or 1 membership events, we only get the memberships
|
||||||
|
// once. If we have more than one membership event, we need to get the state for each state entry.
|
||||||
|
if canShortCircuit {
|
||||||
|
if len(memberships) == 0 {
|
||||||
|
memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntry, false)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntry, false)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to get memberships at state: %w", err)
|
return fmt.Errorf("unable to get memberships at state: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
res := make([]*gomatrixserverlib.HeaderedEvent, 0, len(memberships))
|
res := make([]*gomatrixserverlib.HeaderedEvent, 0, len(memberships))
|
||||||
|
|
||||||
for i := range memberships {
|
for i := range memberships {
|
||||||
|
|
|
@ -18,17 +18,17 @@ package state
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
"github.com/opentracing/opentracing-go"
|
"github.com/opentracing/opentracing-go"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
type StateResolutionStorage interface {
|
type StateResolutionStorage interface {
|
||||||
|
@ -37,6 +37,7 @@ type StateResolutionStorage interface {
|
||||||
StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
|
StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
|
||||||
StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
|
StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
|
||||||
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
|
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
|
||||||
|
BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error)
|
||||||
StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
|
StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
|
||||||
StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
|
StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
|
||||||
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
|
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
|
||||||
|
@ -130,21 +131,10 @@ func (v *StateResolution) LoadMembershipAtEvent(
|
||||||
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadMembershipAtEvent")
|
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadMembershipAtEvent")
|
||||||
defer span.Finish()
|
defer span.Finish()
|
||||||
|
|
||||||
// De-dupe snapshotNIDs
|
// Get a mapping from snapshotNID -> eventIDs
|
||||||
snapshotNIDMap := make(map[types.StateSnapshotNID][]string) // map from snapshot NID to eventIDs
|
snapshotNIDMap, err := v.db.BulkSelectSnapshotsFromEventIDs(ctx, eventIDs)
|
||||||
for i := range eventIDs {
|
if err != nil {
|
||||||
eventID := eventIDs[i]
|
return nil, err
|
||||||
snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID)
|
|
||||||
if err != nil && err != sql.ErrNoRows {
|
|
||||||
return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %w", eventID, err)
|
|
||||||
}
|
|
||||||
if snapshotNID == 0 {
|
|
||||||
// If we don't know a state snapshot for this event then we can't calculate
|
|
||||||
// memberships at the time of the event, so skip over it. This means that
|
|
||||||
// it isn't guaranteed that the response map will contain every single event.
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
snapshotNIDMap[snapshotNID] = append(snapshotNIDMap[snapshotNID], eventID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
snapshotNIDs := make([]types.StateSnapshotNID, 0, len(snapshotNIDMap))
|
snapshotNIDs := make([]types.StateSnapshotNID, 0, len(snapshotNIDMap))
|
||||||
|
@ -157,24 +147,45 @@ func (v *StateResolution) LoadMembershipAtEvent(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var wantStateBlocks []types.StateBlockNID
|
||||||
|
for _, x := range stateBlockNIDLists {
|
||||||
|
wantStateBlocks = append(wantStateBlocks, x.StateBlockNIDs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
stateEntryLists, err := v.db.StateEntriesForTuples(ctx, uniqueStateBlockNIDs(wantStateBlocks), []types.StateKeyTuple{
|
||||||
|
{
|
||||||
|
EventTypeNID: types.MRoomMemberNID,
|
||||||
|
EventStateKeyNID: stateKeyNID,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
stateBlockNIDsMap := stateBlockNIDListMap(stateBlockNIDLists)
|
||||||
|
stateEntriesMap := stateEntryListMap(stateEntryLists)
|
||||||
|
|
||||||
result := make(map[string][]types.StateEntry)
|
result := make(map[string][]types.StateEntry)
|
||||||
for _, stateBlockNIDList := range stateBlockNIDLists {
|
for _, stateBlockNIDList := range stateBlockNIDLists {
|
||||||
// Query the membership event for the user at the given stateblocks
|
stateBlockNIDs, ok := stateBlockNIDsMap.lookup(stateBlockNIDList.StateSnapshotNID)
|
||||||
stateEntryLists, err := v.db.StateEntriesForTuples(ctx, stateBlockNIDList.StateBlockNIDs, []types.StateKeyTuple{
|
if !ok {
|
||||||
{
|
// This should only get hit if the database is corrupt.
|
||||||
EventTypeNID: types.MRoomMemberNID,
|
// It should be impossible for an event to reference a NID that doesn't exist
|
||||||
EventStateKeyNID: stateKeyNID,
|
return nil, fmt.Errorf("corrupt DB: Missing state snapshot numeric ID %d", stateBlockNIDList.StateSnapshotNID)
|
||||||
},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
evIDs := snapshotNIDMap[stateBlockNIDList.StateSnapshotNID]
|
for _, stateBlockNID := range stateBlockNIDs {
|
||||||
|
entries, ok := stateEntriesMap.lookup(stateBlockNID)
|
||||||
|
if !ok {
|
||||||
|
// This should only get hit if the database is corrupt.
|
||||||
|
// It should be impossible for an event to reference a NID that doesn't exist
|
||||||
|
return nil, fmt.Errorf("corrupt DB: Missing state block numeric ID %d", stateBlockNID)
|
||||||
|
}
|
||||||
|
|
||||||
for _, evID := range evIDs {
|
evIDs := snapshotNIDMap[stateBlockNIDList.StateSnapshotNID]
|
||||||
for _, x := range stateEntryLists {
|
|
||||||
result[evID] = append(result[evID], x.StateEntries...)
|
for _, evID := range evIDs {
|
||||||
|
result[evID] = append(result[evID], entries...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -72,6 +72,7 @@ type Database interface {
|
||||||
Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error)
|
Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error)
|
||||||
// Look up snapshot NID for an event ID string
|
// Look up snapshot NID for an event ID string
|
||||||
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
|
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
|
||||||
|
BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error)
|
||||||
// Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error.
|
// Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error.
|
||||||
StoreEvent(
|
StoreEvent(
|
||||||
ctx context.Context, event *gomatrixserverlib.Event, authEventNIDs []types.EventNID,
|
ctx context.Context, event *gomatrixserverlib.Event, authEventNIDs []types.EventNID,
|
||||||
|
|
|
@ -22,11 +22,12 @@ import (
|
||||||
"sort"
|
"sort"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const eventsSchema = `
|
const eventsSchema = `
|
||||||
|
@ -80,6 +81,9 @@ const insertEventSQL = "" +
|
||||||
const selectEventSQL = "" +
|
const selectEventSQL = "" +
|
||||||
"SELECT event_nid, state_snapshot_nid FROM roomserver_events WHERE event_id = $1"
|
"SELECT event_nid, state_snapshot_nid FROM roomserver_events WHERE event_id = $1"
|
||||||
|
|
||||||
|
const bulkSelectSnapshotsForEventIDsSQL = "" +
|
||||||
|
"SELECT event_id, state_snapshot_nid FROM roomserver_events WHERE event_id = ANY($1)"
|
||||||
|
|
||||||
// Bulk lookup of events by string ID.
|
// Bulk lookup of events by string ID.
|
||||||
// Sort by the numeric IDs for event type and state key.
|
// Sort by the numeric IDs for event type and state key.
|
||||||
// This means we can use binary search to lookup entries by type and state key.
|
// This means we can use binary search to lookup entries by type and state key.
|
||||||
|
@ -150,6 +154,7 @@ const selectEventRejectedSQL = "" +
|
||||||
type eventStatements struct {
|
type eventStatements struct {
|
||||||
insertEventStmt *sql.Stmt
|
insertEventStmt *sql.Stmt
|
||||||
selectEventStmt *sql.Stmt
|
selectEventStmt *sql.Stmt
|
||||||
|
bulkSelectSnapshotsForEventIDsStmt *sql.Stmt
|
||||||
bulkSelectStateEventByIDStmt *sql.Stmt
|
bulkSelectStateEventByIDStmt *sql.Stmt
|
||||||
bulkSelectStateEventByIDExcludingRejectedStmt *sql.Stmt
|
bulkSelectStateEventByIDExcludingRejectedStmt *sql.Stmt
|
||||||
bulkSelectStateEventByNIDStmt *sql.Stmt
|
bulkSelectStateEventByNIDStmt *sql.Stmt
|
||||||
|
@ -179,6 +184,7 @@ func PrepareEventsTable(db *sql.DB) (tables.Events, error) {
|
||||||
return s, sqlutil.StatementList{
|
return s, sqlutil.StatementList{
|
||||||
{&s.insertEventStmt, insertEventSQL},
|
{&s.insertEventStmt, insertEventSQL},
|
||||||
{&s.selectEventStmt, selectEventSQL},
|
{&s.selectEventStmt, selectEventSQL},
|
||||||
|
{&s.bulkSelectSnapshotsForEventIDsStmt, bulkSelectSnapshotsForEventIDsSQL},
|
||||||
{&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL},
|
{&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL},
|
||||||
{&s.bulkSelectStateEventByIDExcludingRejectedStmt, bulkSelectStateEventByIDExcludingRejectedSQL},
|
{&s.bulkSelectStateEventByIDExcludingRejectedStmt, bulkSelectStateEventByIDExcludingRejectedSQL},
|
||||||
{&s.bulkSelectStateEventByNIDStmt, bulkSelectStateEventByNIDSQL},
|
{&s.bulkSelectStateEventByNIDStmt, bulkSelectStateEventByNIDSQL},
|
||||||
|
@ -230,6 +236,29 @@ func (s *eventStatements) SelectEvent(
|
||||||
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
|
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *eventStatements) BulkSelectSnapshotsFromEventIDs(
|
||||||
|
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||||
|
) (map[types.StateSnapshotNID][]string, error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.bulkSelectSnapshotsForEventIDsStmt)
|
||||||
|
|
||||||
|
rows, err := stmt.QueryContext(ctx, pq.Array(eventIDs))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var eventID string
|
||||||
|
var stateNID types.StateSnapshotNID
|
||||||
|
result := make(map[types.StateSnapshotNID][]string)
|
||||||
|
for rows.Next() {
|
||||||
|
if err := rows.Scan(&eventID, &stateNID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result[stateNID] = append(result[stateNID], eventID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
||||||
// If not excluding rejected events, and any of the requested events are missing from
|
// If not excluding rejected events, and any of the requested events are missing from
|
||||||
// the database it returns a types.MissingEventError. If excluding rejected events,
|
// the database it returns a types.MissingEventError. If excluding rejected events,
|
||||||
|
|
|
@ -5,8 +5,9 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RoomUpdater struct {
|
type RoomUpdater struct {
|
||||||
|
@ -186,6 +187,10 @@ func (u *RoomUpdater) EventIDs(
|
||||||
return u.d.EventsTable.BulkSelectEventID(ctx, u.txn, eventNIDs)
|
return u.d.EventsTable.BulkSelectEventID(ctx, u.txn, eventNIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *RoomUpdater) BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error) {
|
||||||
|
return u.d.EventsTable.BulkSelectSnapshotsFromEventIDs(ctx, u.txn, eventIDs)
|
||||||
|
}
|
||||||
|
|
||||||
func (u *RoomUpdater) StateAtEventIDs(
|
func (u *RoomUpdater) StateAtEventIDs(
|
||||||
ctx context.Context, eventIDs []string,
|
ctx context.Context, eventIDs []string,
|
||||||
) ([]types.StateAtEvent, error) {
|
) ([]types.StateAtEvent, error) {
|
||||||
|
|
|
@ -469,6 +469,23 @@ func (d *Database) events(
|
||||||
eventNIDs = append(eventNIDs, nid)
|
eventNIDs = append(eventNIDs, nid)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// If we don't need to get any events from the database, short circuit now
|
||||||
|
if len(eventNIDs) == 0 {
|
||||||
|
results := make([]types.Event, 0, len(inputEventNIDs))
|
||||||
|
for _, nid := range inputEventNIDs {
|
||||||
|
event, ok := events[nid]
|
||||||
|
if !ok || event == nil {
|
||||||
|
return nil, fmt.Errorf("event %d missing", nid)
|
||||||
|
}
|
||||||
|
results = append(results, types.Event{
|
||||||
|
EventNID: nid,
|
||||||
|
Event: event,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if !redactionsArePermanent {
|
||||||
|
d.applyRedactions(results)
|
||||||
|
}
|
||||||
|
}
|
||||||
eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, txn, eventNIDs)
|
eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, txn, eventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -534,6 +551,12 @@ func (d *Database) events(
|
||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Database) BulkSelectSnapshotsFromEventIDs(
|
||||||
|
ctx context.Context, eventIDs []string,
|
||||||
|
) (map[types.StateSnapshotNID][]string, error) {
|
||||||
|
return d.EventsTable.BulkSelectSnapshotsFromEventIDs(ctx, nil, eventIDs)
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Database) MembershipUpdater(
|
func (d *Database) MembershipUpdater(
|
||||||
ctx context.Context, roomID, targetUserID string,
|
ctx context.Context, roomID, targetUserID string,
|
||||||
targetLocal bool, roomVersion gomatrixserverlib.RoomVersion,
|
targetLocal bool, roomVersion gomatrixserverlib.RoomVersion,
|
||||||
|
|
|
@ -23,11 +23,12 @@ import (
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const eventsSchema = `
|
const eventsSchema = `
|
||||||
|
@ -57,6 +58,9 @@ const insertEventSQL = `
|
||||||
const selectEventSQL = "" +
|
const selectEventSQL = "" +
|
||||||
"SELECT event_nid, state_snapshot_nid FROM roomserver_events WHERE event_id = $1"
|
"SELECT event_nid, state_snapshot_nid FROM roomserver_events WHERE event_id = $1"
|
||||||
|
|
||||||
|
const bulkSelectSnapshotsForEventIDsSQL = "" +
|
||||||
|
"SELECT event_id, state_snapshot_nid FROM roomserver_events WHERE event_id IN ($1)"
|
||||||
|
|
||||||
// Bulk lookup of events by string ID.
|
// Bulk lookup of events by string ID.
|
||||||
// Sort by the numeric IDs for event type and state key.
|
// Sort by the numeric IDs for event type and state key.
|
||||||
// This means we can use binary search to lookup entries by type and state key.
|
// This means we can use binary search to lookup entries by type and state key.
|
||||||
|
@ -124,6 +128,7 @@ type eventStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
insertEventStmt *sql.Stmt
|
insertEventStmt *sql.Stmt
|
||||||
selectEventStmt *sql.Stmt
|
selectEventStmt *sql.Stmt
|
||||||
|
bulkSelectSnapshotsForEventIDsStmt *sql.Stmt
|
||||||
bulkSelectStateEventByIDStmt *sql.Stmt
|
bulkSelectStateEventByIDStmt *sql.Stmt
|
||||||
bulkSelectStateEventByIDExcludingRejectedStmt *sql.Stmt
|
bulkSelectStateEventByIDExcludingRejectedStmt *sql.Stmt
|
||||||
bulkSelectStateAtEventByIDStmt *sql.Stmt
|
bulkSelectStateAtEventByIDStmt *sql.Stmt
|
||||||
|
@ -153,6 +158,7 @@ func PrepareEventsTable(db *sql.DB) (tables.Events, error) {
|
||||||
return s, sqlutil.StatementList{
|
return s, sqlutil.StatementList{
|
||||||
{&s.insertEventStmt, insertEventSQL},
|
{&s.insertEventStmt, insertEventSQL},
|
||||||
{&s.selectEventStmt, selectEventSQL},
|
{&s.selectEventStmt, selectEventSQL},
|
||||||
|
{&s.bulkSelectSnapshotsForEventIDsStmt, bulkSelectSnapshotsForEventIDsSQL},
|
||||||
{&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL},
|
{&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL},
|
||||||
{&s.bulkSelectStateEventByIDExcludingRejectedStmt, bulkSelectStateEventByIDExcludingRejectedSQL},
|
{&s.bulkSelectStateEventByIDExcludingRejectedStmt, bulkSelectStateEventByIDExcludingRejectedSQL},
|
||||||
{&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL},
|
{&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL},
|
||||||
|
@ -203,6 +209,40 @@ func (s *eventStatements) SelectEvent(
|
||||||
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
|
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *eventStatements) BulkSelectSnapshotsFromEventIDs(
|
||||||
|
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||||
|
) (map[types.StateSnapshotNID][]string, error) {
|
||||||
|
qry := strings.Replace(bulkSelectSnapshotsForEventIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventIDs)), 1)
|
||||||
|
stmt, err := s.db.Prepare(qry)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, stmt, "BulkSelectSnapshotsFromEventIDs: stmt.close() failed")
|
||||||
|
|
||||||
|
params := make([]interface{}, len(eventIDs))
|
||||||
|
for i := range eventIDs {
|
||||||
|
params[i] = eventIDs[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := stmt.QueryContext(ctx, params...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "BulkSelectSnapshotsFromEventIDs: rows.close() failed")
|
||||||
|
|
||||||
|
var eventID string
|
||||||
|
var stateNID types.StateSnapshotNID
|
||||||
|
result := make(map[types.StateSnapshotNID][]string)
|
||||||
|
for rows.Next() {
|
||||||
|
if err := rows.Scan(&eventID, &stateNID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result[stateNID] = append(result[stateNID], eventID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
||||||
// If not excluding rejected events, and any of the requested events are missing from
|
// If not excluding rejected events, and any of the requested events are missing from
|
||||||
// the database it returns a types.MissingEventError. If excluding rejected events,
|
// the database it returns a types.MissingEventError. If excluding rejected events,
|
||||||
|
|
|
@ -44,6 +44,7 @@ type Events interface {
|
||||||
referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64, isRejected bool,
|
referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64, isRejected bool,
|
||||||
) (types.EventNID, types.StateSnapshotNID, error)
|
) (types.EventNID, types.StateSnapshotNID, error)
|
||||||
SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error)
|
SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error)
|
||||||
|
BulkSelectSnapshotsFromEventIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[types.StateSnapshotNID][]string, error)
|
||||||
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
||||||
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
||||||
BulkSelectStateEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string, excludeRejected bool) ([]types.StateEntry, error)
|
BulkSelectStateEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string, excludeRejected bool) ([]types.StateEntry, error)
|
||||||
|
|
Loading…
Reference in a new issue