sqlite: Fix invite table by using the global stream pos rather than one specific to invites

If we don't use the global, clients don't get notified about any invites
because the position is too low.
This commit is contained in:
Kegan Dougal 2020-02-18 18:02:53 +00:00
parent bc37106e14
commit 677dc175d0
3 changed files with 30 additions and 34 deletions

View file

@ -52,16 +52,18 @@ const selectInviteActiveForUserInRoomSQL = "" +
// However the matrix protocol doesn't give us a way to reliably identify the
// invites that were retired, so we are forced to retire all of them.
const updateInviteRetiredSQL = `
UPDATE roomserver_invites SET retired = TRUE
WHERE room_nid = $1 AND target_nid = $2 AND NOT retired;
SELECT invite_event_id FROM roomserver_invites
WHERE rowid = last_insert_rowid();
UPDATE roomserver_invites SET retired = TRUE WHERE room_nid = $1 AND target_nid = $2 AND NOT retired
`
const selectInvitesAboutToRetireSQL = `
SELECT invite_event_id FROM roomserver_invites WHERE room_nid = $1 AND target_nid = $2 AND NOT retired
`
type inviteStatements struct {
insertInviteEventStmt *sql.Stmt
selectInviteActiveForUserInRoomStmt *sql.Stmt
updateInviteRetiredStmt *sql.Stmt
selectInvitesAboutToRetireStmt *sql.Stmt
}
func (s *inviteStatements) prepare(db *sql.DB) (err error) {
@ -74,6 +76,7 @@ func (s *inviteStatements) prepare(db *sql.DB) (err error) {
{&s.insertInviteEventStmt, insertInviteEventSQL},
{&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL},
{&s.updateInviteRetiredStmt, updateInviteRetiredSQL},
{&s.selectInvitesAboutToRetireStmt, selectInvitesAboutToRetireSQL},
}.prepare(db)
}
@ -102,7 +105,8 @@ func (s *inviteStatements) updateInviteRetired(
ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventIDs []string, err error) {
stmt := common.TxStmt(txn, s.updateInviteRetiredStmt)
// gather all the event IDs we will retire
stmt := txn.Stmt(s.selectInvitesAboutToRetireStmt)
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
if err != nil {
return nil, err
@ -115,6 +119,10 @@ func (s *inviteStatements) updateInviteRetired(
}
eventIDs = append(eventIDs, inviteEventID)
}
// now retire the invites
stmt = txn.Stmt(s.updateInviteRetiredStmt)
_, err = stmt.ExecContext(ctx, roomNID, targetUserNID)
return
}

View file

@ -26,7 +26,7 @@ import (
const inviteEventsSchema = `
CREATE TABLE IF NOT EXISTS syncapi_invite_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
id INTEGER PRIMARY KEY,
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
target_user_id TEXT NOT NULL,
@ -39,8 +39,8 @@ CREATE INDEX IF NOT EXISTS syncapi_invites_event_id_idx ON syncapi_invite_events
const insertInviteEventSQL = "" +
"INSERT INTO syncapi_invite_events" +
" (room_id, event_id, target_user_id, event_json)" +
" VALUES ($1, $2, $3, $4)"
" (id, room_id, event_id, target_user_id, event_json)" +
" VALUES ($1, $2, $3, $4, $5)"
const deleteInviteEventSQL = "" +
"DELETE FROM syncapi_invite_events WHERE event_id = $1"
@ -83,25 +83,16 @@ func (s *inviteEventsStatements) prepare(db *sql.DB, streamID *streamIDStatement
}
func (s *inviteEventsStatements) insertInviteEvent(
ctx context.Context, inviteEvent gomatrixserverlib.Event,
) (streamPos types.StreamPosition, err error) {
var res sql.Result
res, err = s.insertInviteEventStmt.ExecContext(
ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.Event, streamPos types.StreamPosition,
) (err error) {
_, err = txn.Stmt(s.insertInviteEventStmt).ExecContext(
ctx,
streamPos,
inviteEvent.RoomID(),
inviteEvent.EventID(),
*inviteEvent.StateKey(),
inviteEvent.JSON(),
)
if err != nil {
return
}
var rowID int64
rowID, err = res.LastInsertId()
if err != nil {
return
}
streamPos = types.StreamPosition(rowID)
return
}

View file

@ -193,24 +193,20 @@ func (d *SyncServerDatasource) WriteEvent(
ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync,
)
if err != nil {
fmt.Println("d.events.insertEvent:", err)
return err
}
pduPosition = pos
if err = d.topology.insertEventInTopology(ctx, txn, ev); err != nil {
fmt.Println("d.topology.insertEventInTopology:", err)
return err
}
if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil {
fmt.Println("d.handleBackwardExtremities:", err)
return err
}
if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 {
// Nothing to do, the event may have just been a message event.
fmt.Println("nothing to do")
return nil
}
@ -625,18 +621,15 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
if err != nil {
return
}
fmt.Println("Joined rooms:", joinedRoomIDs)
stateFilterPart := gomatrixserverlib.DefaultStateFilter() // TODO: use filter provided in request
// Build up a /sync response. Add joined rooms.
for _, roomID := range joinedRoomIDs {
fmt.Println("WE'RE ON", roomID)
var stateEvents []gomatrixserverlib.Event
stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID, &stateFilterPart)
if err != nil {
fmt.Println("d.roomstate.selectCurrentState:", err)
return
}
//fmt.Println("State events:", stateEvents)
@ -648,7 +641,6 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
numRecentEventsPerRoom, true, true,
)
if err != nil {
fmt.Println("d.events.selectRecentEvents:", err)
return
}
//fmt.Println("Recent stream events:", recentStreamEvents)
@ -658,10 +650,9 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
var backwardTopologyPos types.StreamPosition
backwardTopologyPos, err = d.topology.selectPositionInTopology(ctx, txn, recentStreamEvents[0].EventID())
if err != nil {
fmt.Println("d.topology.selectPositionInTopology:", err)
return nil, types.PaginationToken{}, []string{}, err
}
fmt.Println("Backward topology position:", backwardTopologyPos)
if backwardTopologyPos-1 <= 0 {
backwardTopologyPos = types.StreamPosition(1)
} else {
@ -683,7 +674,6 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
}
if err = d.addInvitesToResponse(ctx, txn, userID, 0, toPos.PDUPosition, res); err != nil {
fmt.Println("d.addInvitesToResponse:", err)
return
}
@ -764,8 +754,15 @@ func (d *SyncServerDatasource) UpsertAccountData(
// Returns an error if there was a problem communicating with the database.
func (d *SyncServerDatasource) AddInviteEvent(
ctx context.Context, inviteEvent gomatrixserverlib.Event,
) (types.StreamPosition, error) {
return d.invites.insertInviteEvent(ctx, inviteEvent)
) (streamPos types.StreamPosition, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
streamPos, err = d.streamID.nextStreamID(ctx, txn)
if err != nil {
return err
}
return d.invites.insertInviteEvent(ctx, txn, inviteEvent, streamPos)
})
return
}
// RetireInviteEvent removes an old invite event from the database.