Check if we're in the room already before resetting latest events/state

This commit is contained in:
Neil Alexander 2020-05-19 13:27:45 +01:00
parent dd91d3d6f5
commit c5e3ac9537
8 changed files with 81 additions and 34 deletions

View file

@ -117,8 +117,12 @@ func (r *RoomserverInternalAPI) calculateAndSetState(
roomState := state.NewStateResolution(r.DB)
if input.HasState {
// TODO: Check here if we think we're in the room already.
// Check here if we think we're in the room already.
stateAtEvent.Overwrite = true
var joinEventNIDs []types.EventNID
if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, true, true); err == nil {
stateAtEvent.Overwrite = len(joinEventNIDs) == 0
}
// We've been told what the state at the event is so we don't need to calculate it.
// Check that those state events are in the database and store the state.

View file

@ -267,7 +267,7 @@ func (r *RoomserverInternalAPI) QueryMembershipsForRoom(
var stateEntries []types.StateEntry
if stillInRoom {
var eventNIDs []types.EventNID
eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, request.JoinedOnly)
eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, request.JoinedOnly, false)
if err != nil {
return err
}
@ -592,7 +592,7 @@ func (r *RoomserverInternalAPI) isServerCurrentlyInRoom(ctx context.Context, ser
return false, err
}
eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, true)
eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, true, false)
if err != nil {
return false, err
}

View file

@ -270,7 +270,7 @@ func joinEventsFromHistoryVisibility(
if err != nil {
return nil, err
}
joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, roomNID, true)
joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, roomNID, true, false)
if err != nil {
return nil, err
}

View file

@ -85,7 +85,7 @@ type Database interface {
RemoveRoomAlias(ctx context.Context, alias string) error
MembershipUpdater(ctx context.Context, roomID, targetUserID string, targetLocal bool, roomVersion gomatrixserverlib.RoomVersion) (types.MembershipUpdater, error)
GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom bool, err error)
GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool) ([]types.EventNID, error)
GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool) ([]types.EventNID, error)
EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error)
GetRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error)
}

View file

@ -62,7 +62,7 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
-- Local target is true if the target_nid refers to a local user rather than
-- a federated one. This is an optimisation for resetting state on federated
-- room joins.
local_target BOOLEAN NOT NULL DEFAULT false,
target_local BOOLEAN NOT NULL DEFAULT false,
UNIQUE (room_nid, target_nid)
);
`
@ -70,7 +70,7 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
// Insert a row in to membership table so that it can be locked by the
// SELECT FOR UPDATE
const insertMembershipSQL = "" +
"INSERT INTO roomserver_membership (room_nid, target_nid, local_target)" +
"INSERT INTO roomserver_membership (room_nid, target_nid, target_local)" +
" VALUES ($1, $2, $3)" +
" ON CONFLICT DO NOTHING"
@ -82,10 +82,20 @@ const selectMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND membership_nid = $2"
const selectLocalMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND membership_nid = $2" +
" AND target_local = true"
const selectMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1"
const selectLocalMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1" +
" AND target_local = true"
const selectMembershipForUpdateSQL = "" +
"SELECT membership_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND target_nid = $2 FOR UPDATE"
@ -95,12 +105,14 @@ const updateMembershipSQL = "" +
" WHERE room_nid = $1 AND target_nid = $2"
type membershipStatements struct {
insertMembershipStmt *sql.Stmt
selectMembershipForUpdateStmt *sql.Stmt
selectMembershipFromRoomAndTargetStmt *sql.Stmt
selectMembershipsFromRoomAndMembershipStmt *sql.Stmt
selectMembershipsFromRoomStmt *sql.Stmt
updateMembershipStmt *sql.Stmt
insertMembershipStmt *sql.Stmt
selectMembershipForUpdateStmt *sql.Stmt
selectMembershipFromRoomAndTargetStmt *sql.Stmt
selectMembershipsFromRoomAndMembershipStmt *sql.Stmt
selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt
selectMembershipsFromRoomStmt *sql.Stmt
selectLocalMembershipsFromRoomStmt *sql.Stmt
updateMembershipStmt *sql.Stmt
}
func (s *membershipStatements) prepare(db *sql.DB) (err error) {
@ -114,7 +126,9 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
{&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL},
{&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL},
{&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL},
{&s.selectLocalMembershipsFromRoomAndMembershipStmt, selectLocalMembershipsFromRoomAndMembershipSQL},
{&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL},
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
{&s.updateMembershipStmt, updateMembershipSQL},
}.prepare(db)
}
@ -150,7 +164,7 @@ func (s *membershipStatements) selectMembershipFromRoomAndTarget(
}
func (s *membershipStatements) selectMembershipsFromRoom(
ctx context.Context, roomNID types.RoomNID,
ctx context.Context, roomNID types.RoomNID, localOnly bool,
) (eventNIDs []types.EventNID, err error) {
rows, err := s.selectMembershipsFromRoomStmt.QueryContext(ctx, roomNID)
if err != nil {
@ -170,10 +184,14 @@ func (s *membershipStatements) selectMembershipsFromRoom(
func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
ctx context.Context,
roomNID types.RoomNID, membership membershipState,
roomNID types.RoomNID, membership membershipState, localOnly bool,
) (eventNIDs []types.EventNID, err error) {
var rows *sql.Rows
stmt := s.selectMembershipsFromRoomAndMembershipStmt
rows, err := stmt.QueryContext(ctx, roomNID, membership)
if localOnly {
stmt = s.selectLocalMembershipsFromRoomAndMembershipStmt
}
rows, err = stmt.QueryContext(ctx, roomNID, membership)
if err != nil {
return
}

View file

@ -749,15 +749,15 @@ func (d *Database) GetMembership(
// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB
func (d *Database) GetMembershipEventNIDsForRoom(
ctx context.Context, roomNID types.RoomNID, joinOnly bool,
ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
) ([]types.EventNID, error) {
if joinOnly {
return d.statements.selectMembershipsFromRoomAndMembership(
ctx, roomNID, membershipStateJoin,
ctx, roomNID, membershipStateJoin, localOnly,
)
}
return d.statements.selectMembershipsFromRoom(ctx, roomNID)
return d.statements.selectMembershipsFromRoom(ctx, roomNID, localOnly)
}
// EventsFromIDs implements query.RoomserverQueryAPIEventDB

View file

@ -38,7 +38,7 @@ const membershipSchema = `
sender_nid INTEGER NOT NULL DEFAULT 0,
membership_nid INTEGER NOT NULL DEFAULT 1,
event_nid INTEGER NOT NULL DEFAULT 0,
local_target BOOLEAN NOT NULL DEFAULT false,
target_local BOOLEAN NOT NULL DEFAULT false,
UNIQUE (room_nid, target_nid)
);
`
@ -46,7 +46,7 @@ const membershipSchema = `
// Insert a row in to membership table so that it can be locked by the
// SELECT FOR UPDATE
const insertMembershipSQL = "" +
"INSERT INTO roomserver_membership (room_nid, target_nid, local_target)" +
"INSERT INTO roomserver_membership (room_nid, target_nid, target_local)" +
" VALUES ($1, $2, $3)" +
" ON CONFLICT DO NOTHING"
@ -58,10 +58,20 @@ const selectMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND membership_nid = $2"
const selectLocalMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND membership_nid = $2" +
" AND target_local = true"
const selectMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1"
const selectLocalMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1" +
" AND target_local = true"
const selectMembershipForUpdateSQL = "" +
"SELECT membership_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND target_nid = $2"
@ -71,12 +81,14 @@ const updateMembershipSQL = "" +
" WHERE room_nid = $4 AND target_nid = $5"
type membershipStatements struct {
insertMembershipStmt *sql.Stmt
selectMembershipForUpdateStmt *sql.Stmt
selectMembershipFromRoomAndTargetStmt *sql.Stmt
selectMembershipsFromRoomAndMembershipStmt *sql.Stmt
selectMembershipsFromRoomStmt *sql.Stmt
updateMembershipStmt *sql.Stmt
insertMembershipStmt *sql.Stmt
selectMembershipForUpdateStmt *sql.Stmt
selectMembershipFromRoomAndTargetStmt *sql.Stmt
selectMembershipsFromRoomAndMembershipStmt *sql.Stmt
selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt
selectMembershipsFromRoomStmt *sql.Stmt
selectLocalMembershipsFromRoomStmt *sql.Stmt
updateMembershipStmt *sql.Stmt
}
func (s *membershipStatements) prepare(db *sql.DB) (err error) {
@ -90,7 +102,9 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
{&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL},
{&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL},
{&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL},
{&s.selectLocalMembershipsFromRoomAndMembershipStmt, selectLocalMembershipsFromRoomAndMembershipSQL},
{&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL},
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
{&s.updateMembershipStmt, updateMembershipSQL},
}.prepare(db)
}
@ -129,9 +143,14 @@ func (s *membershipStatements) selectMembershipFromRoomAndTarget(
func (s *membershipStatements) selectMembershipsFromRoom(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID,
roomNID types.RoomNID, localOnly bool,
) (eventNIDs []types.EventNID, err error) {
selectStmt := common.TxStmt(txn, s.selectMembershipsFromRoomStmt)
var selectStmt *sql.Stmt
if localOnly {
selectStmt = common.TxStmt(txn, s.selectLocalMembershipsFromRoomStmt)
} else {
selectStmt = common.TxStmt(txn, s.selectMembershipsFromRoomStmt)
}
rows, err := selectStmt.QueryContext(ctx, roomNID)
if err != nil {
return nil, err
@ -147,11 +166,17 @@ func (s *membershipStatements) selectMembershipsFromRoom(
}
return
}
func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, membership membershipState,
roomNID types.RoomNID, membership membershipState, localOnly bool,
) (eventNIDs []types.EventNID, err error) {
stmt := common.TxStmt(txn, s.selectMembershipsFromRoomAndMembershipStmt)
var stmt *sql.Stmt
if localOnly {
stmt = common.TxStmt(txn, s.selectLocalMembershipsFromRoomAndMembershipStmt)
} else {
stmt = common.TxStmt(txn, s.selectMembershipsFromRoomAndMembershipStmt)
}
rows, err := stmt.QueryContext(ctx, roomNID, membership)
if err != nil {
return

View file

@ -897,17 +897,17 @@ func (d *Database) GetMembership(
// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB
func (d *Database) GetMembershipEventNIDsForRoom(
ctx context.Context, roomNID types.RoomNID, joinOnly bool,
ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
) (eventNIDs []types.EventNID, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
if joinOnly {
eventNIDs, err = d.statements.selectMembershipsFromRoomAndMembership(
ctx, txn, roomNID, membershipStateJoin,
ctx, txn, roomNID, membershipStateJoin, localOnly,
)
return nil
}
eventNIDs, err = d.statements.selectMembershipsFromRoom(ctx, txn, roomNID)
eventNIDs, err = d.statements.selectMembershipsFromRoom(ctx, txn, roomNID, localOnly)
return nil
})
return