package storage import ( "database/sql" "fmt" "github.com/lib/pq" "github.com/matrix-org/dendrite/roomserver/types" ) const stateDataSchema = ` -- The state data map. -- Designed to give enough information to run the state resolution algorithm -- without hitting the database in the common case. -- TODO: Is it worth replacing the unique btree index with a covering index so -- that postgres could lookup the state using an index-only scan? -- The type and state_key are included in the index to make it easier to -- lookup a specific (type, state_key) pair for an event. It also makes it easy -- to read the state for a given state_block_nid ordered by (type, state_key) -- which in turn makes it easier to merge state data blocks. CREATE SEQUENCE IF NOT EXISTS state_block_nid_seq; CREATE TABLE IF NOT EXISTS state_block ( -- Local numeric ID for this state data. state_block_nid bigint NOT NULL, event_type_nid bigint NOT NULL, event_state_key_nid bigint NOT NULL, event_nid bigint NOT NULL, UNIQUE (state_block_nid, event_type_nid, event_state_key_nid) ); ` const insertStateDataSQL = "" + "INSERT INTO state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" + " VALUES ($1, $2, $3, $4)" const selectNextStateBlockNIDSQL = "" + "SELECT nextval('state_block_nid_seq')" // Bulk state lookup by numeric event ID. // Sort by the state_block_nid, event_type_nid, event_state_key_nid // This means that all the entries for a given state_block_nid will appear // together in the list and those entries will sorted by event_type_nid // and event_state_key_nid. This property makes it easier to merge two // state data blocks together. const bulkSelectStateDataEntriesSQL = "" + "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + " FROM state_block WHERE state_block_nid = ANY($1)" + " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" type stateBlockStatements struct { insertStateDataStmt *sql.Stmt selectNextStateBlockNIDStmt *sql.Stmt bulkSelectStateDataEntriesStmt *sql.Stmt } func (s *stateBlockStatements) prepare(db *sql.DB) (err error) { _, err = db.Exec(stateDataSchema) if err != nil { return } if s.insertStateDataStmt, err = db.Prepare(insertStateDataSQL); err != nil { return } if s.selectNextStateBlockNIDStmt, err = db.Prepare(selectNextStateBlockNIDSQL); err != nil { return } if s.bulkSelectStateDataEntriesStmt, err = db.Prepare(bulkSelectStateDataEntriesSQL); err != nil { return } return } func (s *stateBlockStatements) bulkInsertStateData(stateBlockNID types.StateBlockNID, entries []types.StateEntry) error { for _, entry := range entries { _, err := s.insertStateDataStmt.Exec( int64(stateBlockNID), int64(entry.EventTypeNID), int64(entry.EventStateKeyNID), int64(entry.EventNID), ) if err != nil { return err } } return nil } func (s *stateBlockStatements) selectNextStateBlockNID() (types.StateBlockNID, error) { var stateBlockNID int64 err := s.selectNextStateBlockNIDStmt.QueryRow().Scan(&stateBlockNID) return types.StateBlockNID(stateBlockNID), err } func (s *stateBlockStatements) bulkSelectStateDataEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) { nids := make([]int64, len(stateBlockNIDs)) for i := range stateBlockNIDs { nids[i] = int64(stateBlockNIDs[i]) } rows, err := s.bulkSelectStateDataEntriesStmt.Query(pq.Int64Array(nids)) if err != nil { return nil, err } defer rows.Close() results := make([]types.StateEntryList, len(stateBlockNIDs)) // current is a pointer to the StateEntryList to append the state entries to. var current *types.StateEntryList i := 0 for rows.Next() { var ( stateBlockNID int64 eventTypeNID int64 eventStateKeyNID int64 eventNID int64 entry types.StateEntry ) if err := rows.Scan( &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, ); err != nil { return nil, err } entry.EventTypeNID = types.EventTypeNID(eventTypeNID) entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) entry.EventNID = types.EventNID(eventNID) if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID { // The state entry row is for a different state data block to the current one. // So we start appending to the next entry in the list. current = &results[i] current.StateBlockNID = types.StateBlockNID(stateBlockNID) i++ } current.StateEntries = append(current.StateEntries, entry) } if i != len(stateBlockNIDs) { return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(stateBlockNIDs)) } return results, nil }