Check if we're in the room already before resetting latest events/state

This commit is contained in:
Neil Alexander 2020-05-19 13:27:45 +01:00
parent dd91d3d6f5
commit c5e3ac9537
8 changed files with 81 additions and 34 deletions

View file

@ -117,8 +117,12 @@ func (r *RoomserverInternalAPI) calculateAndSetState(
roomState := state.NewStateResolution(r.DB) roomState := state.NewStateResolution(r.DB)
if input.HasState { if input.HasState {
// TODO: Check here if we think we're in the room already. // Check here if we think we're in the room already.
stateAtEvent.Overwrite = true stateAtEvent.Overwrite = true
var joinEventNIDs []types.EventNID
if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, true, true); err == nil {
stateAtEvent.Overwrite = len(joinEventNIDs) == 0
}
// We've been told what the state at the event is so we don't need to calculate it. // We've been told what the state at the event is so we don't need to calculate it.
// Check that those state events are in the database and store the state. // Check that those state events are in the database and store the state.

View file

@ -267,7 +267,7 @@ func (r *RoomserverInternalAPI) QueryMembershipsForRoom(
var stateEntries []types.StateEntry var stateEntries []types.StateEntry
if stillInRoom { if stillInRoom {
var eventNIDs []types.EventNID var eventNIDs []types.EventNID
eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, request.JoinedOnly) eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, request.JoinedOnly, false)
if err != nil { if err != nil {
return err return err
} }
@ -592,7 +592,7 @@ func (r *RoomserverInternalAPI) isServerCurrentlyInRoom(ctx context.Context, ser
return false, err return false, err
} }
eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, true) eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, true, false)
if err != nil { if err != nil {
return false, err return false, err
} }

View file

@ -270,7 +270,7 @@ func joinEventsFromHistoryVisibility(
if err != nil { if err != nil {
return nil, err return nil, err
} }
joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, roomNID, true) joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, roomNID, true, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -85,7 +85,7 @@ type Database interface {
RemoveRoomAlias(ctx context.Context, alias string) error RemoveRoomAlias(ctx context.Context, alias string) error
MembershipUpdater(ctx context.Context, roomID, targetUserID string, targetLocal bool, roomVersion gomatrixserverlib.RoomVersion) (types.MembershipUpdater, error) MembershipUpdater(ctx context.Context, roomID, targetUserID string, targetLocal bool, roomVersion gomatrixserverlib.RoomVersion) (types.MembershipUpdater, error)
GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom bool, err error) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom bool, err error)
GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool) ([]types.EventNID, error) GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool) ([]types.EventNID, error)
EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error)
GetRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) GetRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error)
} }

View file

@ -62,7 +62,7 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
-- Local target is true if the target_nid refers to a local user rather than -- Local target is true if the target_nid refers to a local user rather than
-- a federated one. This is an optimisation for resetting state on federated -- a federated one. This is an optimisation for resetting state on federated
-- room joins. -- room joins.
local_target BOOLEAN NOT NULL DEFAULT false, target_local BOOLEAN NOT NULL DEFAULT false,
UNIQUE (room_nid, target_nid) UNIQUE (room_nid, target_nid)
); );
` `
@ -70,7 +70,7 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
// Insert a row in to membership table so that it can be locked by the // Insert a row in to membership table so that it can be locked by the
// SELECT FOR UPDATE // SELECT FOR UPDATE
const insertMembershipSQL = "" + const insertMembershipSQL = "" +
"INSERT INTO roomserver_membership (room_nid, target_nid, local_target)" + "INSERT INTO roomserver_membership (room_nid, target_nid, target_local)" +
" VALUES ($1, $2, $3)" + " VALUES ($1, $2, $3)" +
" ON CONFLICT DO NOTHING" " ON CONFLICT DO NOTHING"
@ -82,10 +82,20 @@ const selectMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" + "SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND membership_nid = $2" " WHERE room_nid = $1 AND membership_nid = $2"
const selectLocalMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND membership_nid = $2" +
" AND target_local = true"
const selectMembershipsFromRoomSQL = "" + const selectMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" + "SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1" " WHERE room_nid = $1"
const selectLocalMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1" +
" AND target_local = true"
const selectMembershipForUpdateSQL = "" + const selectMembershipForUpdateSQL = "" +
"SELECT membership_nid FROM roomserver_membership" + "SELECT membership_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND target_nid = $2 FOR UPDATE" " WHERE room_nid = $1 AND target_nid = $2 FOR UPDATE"
@ -95,12 +105,14 @@ const updateMembershipSQL = "" +
" WHERE room_nid = $1 AND target_nid = $2" " WHERE room_nid = $1 AND target_nid = $2"
type membershipStatements struct { type membershipStatements struct {
insertMembershipStmt *sql.Stmt insertMembershipStmt *sql.Stmt
selectMembershipForUpdateStmt *sql.Stmt selectMembershipForUpdateStmt *sql.Stmt
selectMembershipFromRoomAndTargetStmt *sql.Stmt selectMembershipFromRoomAndTargetStmt *sql.Stmt
selectMembershipsFromRoomAndMembershipStmt *sql.Stmt selectMembershipsFromRoomAndMembershipStmt *sql.Stmt
selectMembershipsFromRoomStmt *sql.Stmt selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt
updateMembershipStmt *sql.Stmt selectMembershipsFromRoomStmt *sql.Stmt
selectLocalMembershipsFromRoomStmt *sql.Stmt
updateMembershipStmt *sql.Stmt
} }
func (s *membershipStatements) prepare(db *sql.DB) (err error) { func (s *membershipStatements) prepare(db *sql.DB) (err error) {
@ -114,7 +126,9 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
{&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL}, {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL},
{&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL}, {&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL},
{&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL}, {&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL},
{&s.selectLocalMembershipsFromRoomAndMembershipStmt, selectLocalMembershipsFromRoomAndMembershipSQL},
{&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL}, {&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL},
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
{&s.updateMembershipStmt, updateMembershipSQL}, {&s.updateMembershipStmt, updateMembershipSQL},
}.prepare(db) }.prepare(db)
} }
@ -150,7 +164,7 @@ func (s *membershipStatements) selectMembershipFromRoomAndTarget(
} }
func (s *membershipStatements) selectMembershipsFromRoom( func (s *membershipStatements) selectMembershipsFromRoom(
ctx context.Context, roomNID types.RoomNID, ctx context.Context, roomNID types.RoomNID, localOnly bool,
) (eventNIDs []types.EventNID, err error) { ) (eventNIDs []types.EventNID, err error) {
rows, err := s.selectMembershipsFromRoomStmt.QueryContext(ctx, roomNID) rows, err := s.selectMembershipsFromRoomStmt.QueryContext(ctx, roomNID)
if err != nil { if err != nil {
@ -170,10 +184,14 @@ func (s *membershipStatements) selectMembershipsFromRoom(
func (s *membershipStatements) selectMembershipsFromRoomAndMembership( func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
ctx context.Context, ctx context.Context,
roomNID types.RoomNID, membership membershipState, roomNID types.RoomNID, membership membershipState, localOnly bool,
) (eventNIDs []types.EventNID, err error) { ) (eventNIDs []types.EventNID, err error) {
var rows *sql.Rows
stmt := s.selectMembershipsFromRoomAndMembershipStmt stmt := s.selectMembershipsFromRoomAndMembershipStmt
rows, err := stmt.QueryContext(ctx, roomNID, membership) if localOnly {
stmt = s.selectLocalMembershipsFromRoomAndMembershipStmt
}
rows, err = stmt.QueryContext(ctx, roomNID, membership)
if err != nil { if err != nil {
return return
} }

View file

@ -749,15 +749,15 @@ func (d *Database) GetMembership(
// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB // GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB
func (d *Database) GetMembershipEventNIDsForRoom( func (d *Database) GetMembershipEventNIDsForRoom(
ctx context.Context, roomNID types.RoomNID, joinOnly bool, ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
) ([]types.EventNID, error) { ) ([]types.EventNID, error) {
if joinOnly { if joinOnly {
return d.statements.selectMembershipsFromRoomAndMembership( return d.statements.selectMembershipsFromRoomAndMembership(
ctx, roomNID, membershipStateJoin, ctx, roomNID, membershipStateJoin, localOnly,
) )
} }
return d.statements.selectMembershipsFromRoom(ctx, roomNID) return d.statements.selectMembershipsFromRoom(ctx, roomNID, localOnly)
} }
// EventsFromIDs implements query.RoomserverQueryAPIEventDB // EventsFromIDs implements query.RoomserverQueryAPIEventDB

View file

@ -38,7 +38,7 @@ const membershipSchema = `
sender_nid INTEGER NOT NULL DEFAULT 0, sender_nid INTEGER NOT NULL DEFAULT 0,
membership_nid INTEGER NOT NULL DEFAULT 1, membership_nid INTEGER NOT NULL DEFAULT 1,
event_nid INTEGER NOT NULL DEFAULT 0, event_nid INTEGER NOT NULL DEFAULT 0,
local_target BOOLEAN NOT NULL DEFAULT false, target_local BOOLEAN NOT NULL DEFAULT false,
UNIQUE (room_nid, target_nid) UNIQUE (room_nid, target_nid)
); );
` `
@ -46,7 +46,7 @@ const membershipSchema = `
// Insert a row in to membership table so that it can be locked by the // Insert a row in to membership table so that it can be locked by the
// SELECT FOR UPDATE // SELECT FOR UPDATE
const insertMembershipSQL = "" + const insertMembershipSQL = "" +
"INSERT INTO roomserver_membership (room_nid, target_nid, local_target)" + "INSERT INTO roomserver_membership (room_nid, target_nid, target_local)" +
" VALUES ($1, $2, $3)" + " VALUES ($1, $2, $3)" +
" ON CONFLICT DO NOTHING" " ON CONFLICT DO NOTHING"
@ -58,10 +58,20 @@ const selectMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" + "SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND membership_nid = $2" " WHERE room_nid = $1 AND membership_nid = $2"
const selectLocalMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND membership_nid = $2" +
" AND target_local = true"
const selectMembershipsFromRoomSQL = "" + const selectMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" + "SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1" " WHERE room_nid = $1"
const selectLocalMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1" +
" AND target_local = true"
const selectMembershipForUpdateSQL = "" + const selectMembershipForUpdateSQL = "" +
"SELECT membership_nid FROM roomserver_membership" + "SELECT membership_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND target_nid = $2" " WHERE room_nid = $1 AND target_nid = $2"
@ -71,12 +81,14 @@ const updateMembershipSQL = "" +
" WHERE room_nid = $4 AND target_nid = $5" " WHERE room_nid = $4 AND target_nid = $5"
type membershipStatements struct { type membershipStatements struct {
insertMembershipStmt *sql.Stmt insertMembershipStmt *sql.Stmt
selectMembershipForUpdateStmt *sql.Stmt selectMembershipForUpdateStmt *sql.Stmt
selectMembershipFromRoomAndTargetStmt *sql.Stmt selectMembershipFromRoomAndTargetStmt *sql.Stmt
selectMembershipsFromRoomAndMembershipStmt *sql.Stmt selectMembershipsFromRoomAndMembershipStmt *sql.Stmt
selectMembershipsFromRoomStmt *sql.Stmt selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt
updateMembershipStmt *sql.Stmt selectMembershipsFromRoomStmt *sql.Stmt
selectLocalMembershipsFromRoomStmt *sql.Stmt
updateMembershipStmt *sql.Stmt
} }
func (s *membershipStatements) prepare(db *sql.DB) (err error) { func (s *membershipStatements) prepare(db *sql.DB) (err error) {
@ -90,7 +102,9 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
{&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL}, {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL},
{&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL}, {&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL},
{&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL}, {&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL},
{&s.selectLocalMembershipsFromRoomAndMembershipStmt, selectLocalMembershipsFromRoomAndMembershipSQL},
{&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL}, {&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL},
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
{&s.updateMembershipStmt, updateMembershipSQL}, {&s.updateMembershipStmt, updateMembershipSQL},
}.prepare(db) }.prepare(db)
} }
@ -129,9 +143,14 @@ func (s *membershipStatements) selectMembershipFromRoomAndTarget(
func (s *membershipStatements) selectMembershipsFromRoom( func (s *membershipStatements) selectMembershipsFromRoom(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, roomNID types.RoomNID, localOnly bool,
) (eventNIDs []types.EventNID, err error) { ) (eventNIDs []types.EventNID, err error) {
selectStmt := common.TxStmt(txn, s.selectMembershipsFromRoomStmt) var selectStmt *sql.Stmt
if localOnly {
selectStmt = common.TxStmt(txn, s.selectLocalMembershipsFromRoomStmt)
} else {
selectStmt = common.TxStmt(txn, s.selectMembershipsFromRoomStmt)
}
rows, err := selectStmt.QueryContext(ctx, roomNID) rows, err := selectStmt.QueryContext(ctx, roomNID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -147,11 +166,17 @@ func (s *membershipStatements) selectMembershipsFromRoom(
} }
return return
} }
func (s *membershipStatements) selectMembershipsFromRoomAndMembership( func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, membership membershipState, roomNID types.RoomNID, membership membershipState, localOnly bool,
) (eventNIDs []types.EventNID, err error) { ) (eventNIDs []types.EventNID, err error) {
stmt := common.TxStmt(txn, s.selectMembershipsFromRoomAndMembershipStmt) var stmt *sql.Stmt
if localOnly {
stmt = common.TxStmt(txn, s.selectLocalMembershipsFromRoomAndMembershipStmt)
} else {
stmt = common.TxStmt(txn, s.selectMembershipsFromRoomAndMembershipStmt)
}
rows, err := stmt.QueryContext(ctx, roomNID, membership) rows, err := stmt.QueryContext(ctx, roomNID, membership)
if err != nil { if err != nil {
return return

View file

@ -897,17 +897,17 @@ func (d *Database) GetMembership(
// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB // GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB
func (d *Database) GetMembershipEventNIDsForRoom( func (d *Database) GetMembershipEventNIDsForRoom(
ctx context.Context, roomNID types.RoomNID, joinOnly bool, ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
) (eventNIDs []types.EventNID, err error) { ) (eventNIDs []types.EventNID, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error { err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
if joinOnly { if joinOnly {
eventNIDs, err = d.statements.selectMembershipsFromRoomAndMembership( eventNIDs, err = d.statements.selectMembershipsFromRoomAndMembership(
ctx, txn, roomNID, membershipStateJoin, ctx, txn, roomNID, membershipStateJoin, localOnly,
) )
return nil return nil
} }
eventNIDs, err = d.statements.selectMembershipsFromRoom(ctx, txn, roomNID) eventNIDs, err = d.statements.selectMembershipsFromRoom(ctx, txn, roomNID, localOnly)
return nil return nil
}) })
return return