convert current room state table

This commit is contained in:
Kegan Dougal 2020-05-14 14:40:18 +01:00
parent 170aecdd40
commit c850081815
2 changed files with 35 additions and 58 deletions

View file

@ -22,6 +22,7 @@ import (
"strings" "strings"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -91,35 +92,37 @@ type currentRoomStateStatements struct {
selectStateEventStmt *sql.Stmt selectStateEventStmt *sql.Stmt
} }
func (s *currentRoomStateStatements) prepare(db *sql.DB, streamID *streamIDStatements) (err error) { func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) {
s.streamIDStatements = streamID s := &currentRoomStateStatements{
_, err = db.Exec(currentRoomStateSchema) streamIDStatements: streamID,
}
_, err := db.Exec(currentRoomStateSchema)
if err != nil { if err != nil {
return return nil, err
} }
if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil { if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil {
return return nil, err
} }
if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil { if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil {
return return nil, err
} }
if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil {
return return nil, err
} }
if s.selectCurrentStateStmt, err = db.Prepare(selectCurrentStateSQL); err != nil { if s.selectCurrentStateStmt, err = db.Prepare(selectCurrentStateSQL); err != nil {
return return nil, err
} }
if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil {
return return nil, err
} }
if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil { if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil {
return return nil, err
} }
return return s, nil
} }
// JoinedMemberLists returns a map of room ID to a list of joined user IDs. // JoinedMemberLists returns a map of room ID to a list of joined user IDs.
func (s *currentRoomStateStatements) selectJoinedUsers( func (s *currentRoomStateStatements) SelectJoinedUsers(
ctx context.Context, ctx context.Context,
) (map[string][]string, error) { ) (map[string][]string, error) {
rows, err := s.selectJoinedUsersStmt.QueryContext(ctx) rows, err := s.selectJoinedUsersStmt.QueryContext(ctx)
@ -143,7 +146,7 @@ func (s *currentRoomStateStatements) selectJoinedUsers(
} }
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
func (s *currentRoomStateStatements) selectRoomIDsWithMembership( func (s *currentRoomStateStatements) SelectRoomIDsWithMembership(
ctx context.Context, ctx context.Context,
txn *sql.Tx, txn *sql.Tx,
userID string, userID string,
@ -168,7 +171,7 @@ func (s *currentRoomStateStatements) selectRoomIDsWithMembership(
} }
// CurrentState returns all the current state events for the given room. // CurrentState returns all the current state events for the given room.
func (s *currentRoomStateStatements) selectCurrentState( func (s *currentRoomStateStatements) SelectCurrentState(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
stateFilterPart *gomatrixserverlib.StateFilter, stateFilterPart *gomatrixserverlib.StateFilter,
) ([]gomatrixserverlib.HeaderedEvent, error) { ) ([]gomatrixserverlib.HeaderedEvent, error) {
@ -189,7 +192,7 @@ func (s *currentRoomStateStatements) selectCurrentState(
return rowsToEvents(rows) return rowsToEvents(rows)
} }
func (s *currentRoomStateStatements) deleteRoomStateByEventID( func (s *currentRoomStateStatements) DeleteRoomStateByEventID(
ctx context.Context, txn *sql.Tx, eventID string, ctx context.Context, txn *sql.Tx, eventID string,
) error { ) error {
stmt := common.TxStmt(txn, s.deleteRoomStateByEventIDStmt) stmt := common.TxStmt(txn, s.deleteRoomStateByEventIDStmt)
@ -197,7 +200,7 @@ func (s *currentRoomStateStatements) deleteRoomStateByEventID(
return err return err
} }
func (s *currentRoomStateStatements) upsertRoomState( func (s *currentRoomStateStatements) UpsertRoomState(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
event gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition, event gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition,
) error { ) error {
@ -231,7 +234,7 @@ func (s *currentRoomStateStatements) upsertRoomState(
return err return err
} }
func (s *currentRoomStateStatements) selectEventsWithEventIDs( func (s *currentRoomStateStatements) SelectEventsWithEventIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string, ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]types.StreamEvent, error) { ) ([]types.StreamEvent, error) {
iEventIDs := make([]interface{}, len(eventIDs)) iEventIDs := make([]interface{}, len(eventIDs))
@ -264,7 +267,7 @@ func rowsToEvents(rows *sql.Rows) ([]gomatrixserverlib.HeaderedEvent, error) {
return result, nil return result, nil
} }
func (s *currentRoomStateStatements) selectStateEvent( func (s *currentRoomStateStatements) SelectStateEvent(
ctx context.Context, roomID, evType, stateKey string, ctx context.Context, roomID, evType, stateKey string,
) (*gomatrixserverlib.HeaderedEvent, error) { ) (*gomatrixserverlib.HeaderedEvent, error) {
stmt := s.selectStateEventStmt stmt := s.selectStateEventStmt

View file

@ -54,9 +54,8 @@ type SyncServerDatasource struct {
shared.Database shared.Database
db *sql.DB db *sql.DB
common.PartitionOffsetStatements common.PartitionOffsetStatements
streamID streamIDStatements streamID streamIDStatements
roomstate currentRoomStateStatements topology outputRoomEventsTopologyStatements
topology outputRoomEventsTopologyStatements
} }
// NewSyncServerDatasource creates a new sync server database // NewSyncServerDatasource creates a new sync server database
@ -99,7 +98,8 @@ func (d *SyncServerDatasource) prepare() (err error) {
if err != nil { if err != nil {
return err return err
} }
if err = d.roomstate.prepare(d.db, &d.streamID); err != nil { roomState, err := NewSqliteCurrentRoomStateTable(d.db, &d.streamID)
if err != nil {
return err return err
} }
invites, err := NewSqliteInvitesTable(d.db, &d.streamID) invites, err := NewSqliteInvitesTable(d.db, &d.streamID)
@ -119,16 +119,12 @@ func (d *SyncServerDatasource) prepare() (err error) {
AccountData: accountData, AccountData: accountData,
OutputEvents: events, OutputEvents: events,
BackwardExtremities: bwExtrem, BackwardExtremities: bwExtrem,
CurrentRoomState: roomState,
EDUCache: cache.New(), EDUCache: cache.New(),
} }
return nil return nil
} }
// AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs.
func (d *SyncServerDatasource) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) {
return d.roomstate.selectJoinedUsers(ctx)
}
// handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of // handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of
// the events listed in the event's 'prev_events'. This function also updates the backwards extremities table // the events listed in the event's 'prev_events'. This function also updates the backwards extremities table
// to account for the fact that the given event is no longer a backwards extremity, but may be marked as such. // to account for the fact that the given event is no longer a backwards extremity, but may be marked as such.
@ -210,7 +206,7 @@ func (d *SyncServerDatasource) updateRoomState(
) error { ) error {
// remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add. // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add.
for _, eventID := range removedEventIDs { for _, eventID := range removedEventIDs {
if err := d.roomstate.deleteRoomStateByEventID(ctx, txn, eventID); err != nil { if err := d.Database.CurrentRoomState.DeleteRoomStateByEventID(ctx, txn, eventID); err != nil {
return err return err
} }
} }
@ -228,7 +224,7 @@ func (d *SyncServerDatasource) updateRoomState(
} }
membership = &value membership = &value
} }
if err := d.roomstate.upsertRoomState(ctx, txn, event, membership, pduPosition); err != nil { if err := d.Database.CurrentRoomState.UpsertRoomState(ctx, txn, event, membership, pduPosition); err != nil {
return err return err
} }
} }
@ -249,28 +245,6 @@ func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (tok types.Stre
return return
} }
// GetStateEvent returns the Matrix state event of a given type for a given room with a given state key
// If no event could be found, returns nil
// If there was an issue during the retrieval, returns an error
func (d *SyncServerDatasource) GetStateEvent(
ctx context.Context, roomID, evType, stateKey string,
) (*gomatrixserverlib.HeaderedEvent, error) {
return d.roomstate.selectStateEvent(ctx, roomID, evType, stateKey)
}
// GetStateEventsForRoom fetches the state events for a given room.
// Returns an empty slice if no state events could be found for this room.
// Returns an error if there was an issue with the retrieval.
func (d *SyncServerDatasource) GetStateEventsForRoom(
ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter,
) (stateEvents []gomatrixserverlib.HeaderedEvent, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID, stateFilterPart)
return err
})
return
}
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the
// given extremities and limit. // given extremities and limit.
func (d *SyncServerDatasource) GetEventsInTopologicalRange( func (d *SyncServerDatasource) GetEventsInTopologicalRange(
@ -530,7 +504,7 @@ func (d *SyncServerDatasource) IncrementalSync(
ctx, device, fromPos.PDUPosition(), toPos.PDUPosition(), numRecentEventsPerRoom, wantFullState, res, ctx, device, fromPos.PDUPosition(), toPos.PDUPosition(), numRecentEventsPerRoom, wantFullState, res,
) )
} else { } else {
joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership( joinedRoomIDs, err = d.Database.CurrentRoomState.SelectRoomIDsWithMembership(
ctx, nil, device.UserID, gomatrixserverlib.Join, ctx, nil, device.UserID, gomatrixserverlib.Join,
) )
} }
@ -585,7 +559,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
res = types.NewResponse(*toPos) res = types.NewResponse(*toPos)
// Extract room state and recent events for all rooms the user is joined to. // Extract room state and recent events for all rooms the user is joined to.
joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) joinedRoomIDs, err = d.Database.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)
if err != nil { if err != nil {
return return
} }
@ -595,7 +569,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
// Build up a /sync response. Add joined rooms. // Build up a /sync response. Add joined rooms.
for _, roomID := range joinedRoomIDs { for _, roomID := range joinedRoomIDs {
var stateEvents []gomatrixserverlib.HeaderedEvent var stateEvents []gomatrixserverlib.HeaderedEvent
stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID, &stateFilterPart) stateEvents, err = d.Database.CurrentRoomState.SelectCurrentState(ctx, txn, roomID, &stateFilterPart)
if err != nil { if err != nil {
return return
} }
@ -849,7 +823,7 @@ func (d *SyncServerDatasource) fetchMissingStateEvents(
// If they are missing from the events table then they should be state // If they are missing from the events table then they should be state
// events that we received from outside the main event stream. // events that we received from outside the main event stream.
// These should be in the room state table. // These should be in the room state table.
stateEvents, err := d.roomstate.selectEventsWithEventIDs(ctx, txn, missing) stateEvents, err := d.Database.CurrentRoomState.SelectEventsWithEventIDs(ctx, txn, missing)
if err != nil { if err != nil {
return nil, err return nil, err
@ -921,7 +895,7 @@ func (d *SyncServerDatasource) getStateDeltas(
} }
// Add in currently joined rooms // Add in currently joined rooms
joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) joinedRoomIDs, err := d.Database.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -945,7 +919,7 @@ func (d *SyncServerDatasource) getStateDeltasForFullStateSync(
fromPos, toPos types.StreamPosition, userID string, fromPos, toPos types.StreamPosition, userID string,
stateFilterPart *gomatrixserverlib.StateFilter, stateFilterPart *gomatrixserverlib.StateFilter,
) ([]stateDelta, []string, error) { ) ([]stateDelta, []string, error) {
joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) joinedRoomIDs, err := d.Database.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -1000,7 +974,7 @@ func (d *SyncServerDatasource) currentStateStreamEventsForRoom(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
stateFilterPart *gomatrixserverlib.StateFilter, stateFilterPart *gomatrixserverlib.StateFilter,
) ([]types.StreamEvent, error) { ) ([]types.StreamEvent, error) {
allState, err := d.roomstate.selectCurrentState(ctx, txn, roomID, stateFilterPart) allState, err := d.Database.CurrentRoomState.SelectCurrentState(ctx, txn, roomID, stateFilterPart)
if err != nil { if err != nil {
return nil, err return nil, err
} }