Fill some gaps

This commit is contained in:
Neil Alexander 2022-02-02 14:12:55 +00:00
parent e8f58acf03
commit 250a0ee946
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
9 changed files with 51 additions and 28 deletions

View file

@ -483,7 +483,7 @@ func (r *Inputer) calculateAndSetState(
isRejected bool,
) error {
var err error
roomState := state.NewStateResolution(r.DB, roomInfo)
roomState := state.NewStateResolution(updater, roomInfo)
if input.HasState {
// Check here if we think we're in the room already.

View file

@ -191,7 +191,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
func (u *latestEventsUpdater) latestState() error {
var err error
roomState := state.NewStateResolution(u.api.DB, u.roomInfo)
roomState := state.NewStateResolution(u.updater, u.roomInfo)
// Work out if the state at the extremities has actually changed
// or not. If they haven't then we won't bother doing all of the

View file

@ -22,7 +22,6 @@ import (
"sort"
"time"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
@ -30,13 +29,25 @@ import (
"github.com/matrix-org/gomatrixserverlib"
)
type StateResolutionStorage interface {
EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error)
EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error)
}
type StateResolution struct {
db storage.Database
db StateResolutionStorage
roomInfo *types.RoomInfo
events map[types.EventNID]*gomatrixserverlib.Event
}
func NewStateResolution(db storage.Database, roomInfo *types.RoomInfo) StateResolution {
func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo) StateResolution {
return StateResolution{
db: db,
roomInfo: roomInfo,

View file

@ -111,9 +111,10 @@ func (s *eventStateKeyStatements) SelectEventStateKeyNID(
}
func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
ctx context.Context, eventStateKeys []string,
ctx context.Context, txn *sql.Tx, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) {
rows, err := s.bulkSelectEventStateKeyNIDStmt.QueryContext(
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventStateKeyNIDStmt)
rows, err := stmt.QueryContext(
ctx, pq.StringArray(eventStateKeys),
)
if err != nil {
@ -134,13 +135,14 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
}
func (s *eventStateKeyStatements) BulkSelectEventStateKey(
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) {
nIDs := make(pq.Int64Array, len(eventStateKeyNIDs))
for i := range eventStateKeyNIDs {
nIDs[i] = int64(eventStateKeyNIDs[i])
}
rows, err := s.bulkSelectEventStateKeyStmt.QueryContext(ctx, nIDs)
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventStateKeyStmt)
rows, err := stmt.QueryContext(ctx, nIDs)
if err != nil {
return nil, err
}

View file

@ -133,9 +133,10 @@ func (s *eventTypeStatements) SelectEventTypeNID(
}
func (s *eventTypeStatements) BulkSelectEventTypeNID(
ctx context.Context, eventTypes []string,
ctx context.Context, txn *sql.Tx, eventTypes []string,
) (map[string]types.EventTypeNID, error) {
rows, err := s.bulkSelectEventTypeNIDStmt.QueryContext(ctx, pq.StringArray(eventTypes))
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventTypeNIDStmt)
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventTypes))
if err != nil {
return nil, err
}

View file

@ -62,7 +62,7 @@ func (d *Database) EventTypeNIDs(
}
}
if len(remaining) > 0 {
nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, remaining)
nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, nil, remaining)
if err != nil {
return nil, err
}
@ -77,7 +77,7 @@ func (d *Database) EventTypeNIDs(
func (d *Database) EventStateKeys(
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) {
return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, eventStateKeyNIDs)
return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, eventStateKeyNIDs)
}
func (d *Database) EventStateKeyNIDs(
@ -93,7 +93,7 @@ func (d *Database) EventStateKeyNIDs(
}
}
if len(remaining) > 0 {
nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, remaining)
nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, remaining)
if err != nil {
return nil, err
}
@ -980,7 +980,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
}
// we don't bother failing the request if we get asked for event types we don't know about, as all that would result in is no matches which
// isn't a failure.
eventTypeNIDMap, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, eventTypes)
eventTypeNIDMap, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, nil, eventTypes)
if err != nil {
return nil, fmt.Errorf("GetBulkStateContent: failed to map event type nids: %w", err)
}
@ -1000,7 +1000,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
}
eventStateKeyNIDMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, eventStateKeys)
eventStateKeyNIDMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, eventStateKeys)
if err != nil {
return nil, fmt.Errorf("GetBulkStateContent: failed to map state key nids: %w", err)
}
@ -1076,7 +1076,7 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string)
stateKeyNIDs[i] = nid
i++
}
nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, stateKeyNIDs)
nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, stateKeyNIDs)
if err != nil {
return nil, err
}

View file

@ -112,15 +112,19 @@ func (s *eventStateKeyStatements) SelectEventStateKeyNID(
}
func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
ctx context.Context, eventStateKeys []string,
ctx context.Context, txn *sql.Tx, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) {
iEventStateKeys := make([]interface{}, len(eventStateKeys))
for k, v := range eventStateKeys {
iEventStateKeys[k] = v
}
selectOrig := strings.Replace(bulkSelectEventStateKeySQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeys)), 1)
rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeys...)
selectPrep, err := s.db.Prepare(selectOrig)
if err != nil {
return nil, err
}
stmt := sqlutil.TxStmt(txn, selectPrep)
rows, err := stmt.QueryContext(ctx, iEventStateKeys...)
if err != nil {
return nil, err
}
@ -138,15 +142,19 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
}
func (s *eventStateKeyStatements) BulkSelectEventStateKey(
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) {
iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs))
for k, v := range eventStateKeyNIDs {
iEventStateKeyNIDs[k] = v
}
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeyNIDs)), 1)
rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...)
selectPrep, err := s.db.Prepare(selectOrig)
if err != nil {
return nil, err
}
stmt := sqlutil.TxStmt(txn, selectPrep)
rows, err := stmt.QueryContext(ctx, iEventStateKeyNIDs...)
if err != nil {
return nil, err
}

View file

@ -128,7 +128,7 @@ func (s *eventTypeStatements) SelectEventTypeNID(
}
func (s *eventTypeStatements) BulkSelectEventTypeNID(
ctx context.Context, eventTypes []string,
ctx context.Context, txn *sql.Tx, eventTypes []string,
) (map[string]types.EventTypeNID, error) {
///////////////
iEventTypes := make([]interface{}, len(eventTypes))
@ -140,9 +140,10 @@ func (s *eventTypeStatements) BulkSelectEventTypeNID(
if err != nil {
return nil, err
}
stmt := sqlutil.TxStmt(txn, selectPrep)
///////////////
rows, err := selectPrep.QueryContext(ctx, iEventTypes...)
rows, err := stmt.QueryContext(ctx, iEventTypes...)
if err != nil {
return nil, err
}

View file

@ -24,14 +24,14 @@ type EventJSON interface {
type EventTypes interface {
InsertEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error)
SelectEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error)
BulkSelectEventTypeNID(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error)
BulkSelectEventTypeNID(ctx context.Context, txn *sql.Tx, eventTypes []string) (map[string]types.EventTypeNID, error)
}
type EventStateKeys interface {
InsertEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error)
SelectEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error)
BulkSelectEventStateKeyNID(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
BulkSelectEventStateKey(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error)
BulkSelectEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
BulkSelectEventStateKey(ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error)
}
type Events interface {