mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-02-12 05:54:29 -06:00
Fill some gaps
This commit is contained in:
parent
e8f58acf03
commit
250a0ee946
|
@ -483,7 +483,7 @@ func (r *Inputer) calculateAndSetState(
|
|||
isRejected bool,
|
||||
) error {
|
||||
var err error
|
||||
roomState := state.NewStateResolution(r.DB, roomInfo)
|
||||
roomState := state.NewStateResolution(updater, roomInfo)
|
||||
|
||||
if input.HasState {
|
||||
// Check here if we think we're in the room already.
|
||||
|
|
|
@ -191,7 +191,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
|
|||
|
||||
func (u *latestEventsUpdater) latestState() error {
|
||||
var err error
|
||||
roomState := state.NewStateResolution(u.api.DB, u.roomInfo)
|
||||
roomState := state.NewStateResolution(u.updater, u.roomInfo)
|
||||
|
||||
// Work out if the state at the extremities has actually changed
|
||||
// or not. If they haven't then we won't bother doing all of the
|
||||
|
|
|
@ -22,7 +22,6 @@ import (
|
|||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
|
||||
|
@ -30,13 +29,25 @@ import (
|
|||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
type StateResolutionStorage interface {
|
||||
EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error)
|
||||
EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
|
||||
StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
|
||||
StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
|
||||
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
|
||||
StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
|
||||
StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
|
||||
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
|
||||
Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error)
|
||||
}
|
||||
|
||||
type StateResolution struct {
|
||||
db storage.Database
|
||||
db StateResolutionStorage
|
||||
roomInfo *types.RoomInfo
|
||||
events map[types.EventNID]*gomatrixserverlib.Event
|
||||
}
|
||||
|
||||
func NewStateResolution(db storage.Database, roomInfo *types.RoomInfo) StateResolution {
|
||||
func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo) StateResolution {
|
||||
return StateResolution{
|
||||
db: db,
|
||||
roomInfo: roomInfo,
|
||||
|
|
|
@ -111,9 +111,10 @@ func (s *eventStateKeyStatements) SelectEventStateKeyNID(
|
|||
}
|
||||
|
||||
func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
|
||||
ctx context.Context, eventStateKeys []string,
|
||||
ctx context.Context, txn *sql.Tx, eventStateKeys []string,
|
||||
) (map[string]types.EventStateKeyNID, error) {
|
||||
rows, err := s.bulkSelectEventStateKeyNIDStmt.QueryContext(
|
||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventStateKeyNIDStmt)
|
||||
rows, err := stmt.QueryContext(
|
||||
ctx, pq.StringArray(eventStateKeys),
|
||||
)
|
||||
if err != nil {
|
||||
|
@ -134,13 +135,14 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
|
|||
}
|
||||
|
||||
func (s *eventStateKeyStatements) BulkSelectEventStateKey(
|
||||
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
|
||||
ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID,
|
||||
) (map[types.EventStateKeyNID]string, error) {
|
||||
nIDs := make(pq.Int64Array, len(eventStateKeyNIDs))
|
||||
for i := range eventStateKeyNIDs {
|
||||
nIDs[i] = int64(eventStateKeyNIDs[i])
|
||||
}
|
||||
rows, err := s.bulkSelectEventStateKeyStmt.QueryContext(ctx, nIDs)
|
||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventStateKeyStmt)
|
||||
rows, err := stmt.QueryContext(ctx, nIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -133,9 +133,10 @@ func (s *eventTypeStatements) SelectEventTypeNID(
|
|||
}
|
||||
|
||||
func (s *eventTypeStatements) BulkSelectEventTypeNID(
|
||||
ctx context.Context, eventTypes []string,
|
||||
ctx context.Context, txn *sql.Tx, eventTypes []string,
|
||||
) (map[string]types.EventTypeNID, error) {
|
||||
rows, err := s.bulkSelectEventTypeNIDStmt.QueryContext(ctx, pq.StringArray(eventTypes))
|
||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventTypeNIDStmt)
|
||||
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventTypes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -62,7 +62,7 @@ func (d *Database) EventTypeNIDs(
|
|||
}
|
||||
}
|
||||
if len(remaining) > 0 {
|
||||
nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, remaining)
|
||||
nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, nil, remaining)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -77,7 +77,7 @@ func (d *Database) EventTypeNIDs(
|
|||
func (d *Database) EventStateKeys(
|
||||
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
|
||||
) (map[types.EventStateKeyNID]string, error) {
|
||||
return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, eventStateKeyNIDs)
|
||||
return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, eventStateKeyNIDs)
|
||||
}
|
||||
|
||||
func (d *Database) EventStateKeyNIDs(
|
||||
|
@ -93,7 +93,7 @@ func (d *Database) EventStateKeyNIDs(
|
|||
}
|
||||
}
|
||||
if len(remaining) > 0 {
|
||||
nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, remaining)
|
||||
nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, remaining)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -980,7 +980,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
|
|||
}
|
||||
// we don't bother failing the request if we get asked for event types we don't know about, as all that would result in is no matches which
|
||||
// isn't a failure.
|
||||
eventTypeNIDMap, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, eventTypes)
|
||||
eventTypeNIDMap, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, nil, eventTypes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetBulkStateContent: failed to map event type nids: %w", err)
|
||||
}
|
||||
|
@ -1000,7 +1000,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
|
|||
|
||||
}
|
||||
|
||||
eventStateKeyNIDMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, eventStateKeys)
|
||||
eventStateKeyNIDMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, eventStateKeys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetBulkStateContent: failed to map state key nids: %w", err)
|
||||
}
|
||||
|
@ -1076,7 +1076,7 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string)
|
|||
stateKeyNIDs[i] = nid
|
||||
i++
|
||||
}
|
||||
nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, stateKeyNIDs)
|
||||
nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, stateKeyNIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -112,15 +112,19 @@ func (s *eventStateKeyStatements) SelectEventStateKeyNID(
|
|||
}
|
||||
|
||||
func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
|
||||
ctx context.Context, eventStateKeys []string,
|
||||
ctx context.Context, txn *sql.Tx, eventStateKeys []string,
|
||||
) (map[string]types.EventStateKeyNID, error) {
|
||||
iEventStateKeys := make([]interface{}, len(eventStateKeys))
|
||||
for k, v := range eventStateKeys {
|
||||
iEventStateKeys[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectEventStateKeySQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeys)), 1)
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeys...)
|
||||
selectPrep, err := s.db.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, selectPrep)
|
||||
rows, err := stmt.QueryContext(ctx, iEventStateKeys...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -138,15 +142,19 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
|
|||
}
|
||||
|
||||
func (s *eventStateKeyStatements) BulkSelectEventStateKey(
|
||||
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
|
||||
ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID,
|
||||
) (map[types.EventStateKeyNID]string, error) {
|
||||
iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs))
|
||||
for k, v := range eventStateKeyNIDs {
|
||||
iEventStateKeyNIDs[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeyNIDs)), 1)
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...)
|
||||
selectPrep, err := s.db.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, selectPrep)
|
||||
rows, err := stmt.QueryContext(ctx, iEventStateKeyNIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -128,7 +128,7 @@ func (s *eventTypeStatements) SelectEventTypeNID(
|
|||
}
|
||||
|
||||
func (s *eventTypeStatements) BulkSelectEventTypeNID(
|
||||
ctx context.Context, eventTypes []string,
|
||||
ctx context.Context, txn *sql.Tx, eventTypes []string,
|
||||
) (map[string]types.EventTypeNID, error) {
|
||||
///////////////
|
||||
iEventTypes := make([]interface{}, len(eventTypes))
|
||||
|
@ -140,9 +140,10 @@ func (s *eventTypeStatements) BulkSelectEventTypeNID(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, selectPrep)
|
||||
///////////////
|
||||
|
||||
rows, err := selectPrep.QueryContext(ctx, iEventTypes...)
|
||||
rows, err := stmt.QueryContext(ctx, iEventTypes...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -24,14 +24,14 @@ type EventJSON interface {
|
|||
type EventTypes interface {
|
||||
InsertEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error)
|
||||
SelectEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error)
|
||||
BulkSelectEventTypeNID(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error)
|
||||
BulkSelectEventTypeNID(ctx context.Context, txn *sql.Tx, eventTypes []string) (map[string]types.EventTypeNID, error)
|
||||
}
|
||||
|
||||
type EventStateKeys interface {
|
||||
InsertEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error)
|
||||
SelectEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error)
|
||||
BulkSelectEventStateKeyNID(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
|
||||
BulkSelectEventStateKey(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error)
|
||||
BulkSelectEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
|
||||
BulkSelectEventStateKey(ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error)
|
||||
}
|
||||
|
||||
type Events interface {
|
||||
|
|
Loading…
Reference in a new issue