sqlite: Fix invite table by using the global stream pos rather than one specific to invites

If we don't use the global, clients don't get notified about any invites
because the position is too low.
This commit is contained in:
Kegan Dougal 2020-02-18 18:02:53 +00:00
parent bc37106e14
commit 677dc175d0
3 changed files with 30 additions and 34 deletions

View file

@ -52,16 +52,18 @@ const selectInviteActiveForUserInRoomSQL = "" +
// However the matrix protocol doesn't give us a way to reliably identify the // However the matrix protocol doesn't give us a way to reliably identify the
// invites that were retired, so we are forced to retire all of them. // invites that were retired, so we are forced to retire all of them.
const updateInviteRetiredSQL = ` const updateInviteRetiredSQL = `
UPDATE roomserver_invites SET retired = TRUE UPDATE roomserver_invites SET retired = TRUE WHERE room_nid = $1 AND target_nid = $2 AND NOT retired
WHERE room_nid = $1 AND target_nid = $2 AND NOT retired; `
SELECT invite_event_id FROM roomserver_invites
WHERE rowid = last_insert_rowid(); const selectInvitesAboutToRetireSQL = `
SELECT invite_event_id FROM roomserver_invites WHERE room_nid = $1 AND target_nid = $2 AND NOT retired
` `
type inviteStatements struct { type inviteStatements struct {
insertInviteEventStmt *sql.Stmt insertInviteEventStmt *sql.Stmt
selectInviteActiveForUserInRoomStmt *sql.Stmt selectInviteActiveForUserInRoomStmt *sql.Stmt
updateInviteRetiredStmt *sql.Stmt updateInviteRetiredStmt *sql.Stmt
selectInvitesAboutToRetireStmt *sql.Stmt
} }
func (s *inviteStatements) prepare(db *sql.DB) (err error) { func (s *inviteStatements) prepare(db *sql.DB) (err error) {
@ -74,6 +76,7 @@ func (s *inviteStatements) prepare(db *sql.DB) (err error) {
{&s.insertInviteEventStmt, insertInviteEventSQL}, {&s.insertInviteEventStmt, insertInviteEventSQL},
{&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL}, {&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL},
{&s.updateInviteRetiredStmt, updateInviteRetiredSQL}, {&s.updateInviteRetiredStmt, updateInviteRetiredSQL},
{&s.selectInvitesAboutToRetireStmt, selectInvitesAboutToRetireSQL},
}.prepare(db) }.prepare(db)
} }
@ -102,7 +105,8 @@ func (s *inviteStatements) updateInviteRetired(
ctx context.Context, ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventIDs []string, err error) { ) (eventIDs []string, err error) {
stmt := common.TxStmt(txn, s.updateInviteRetiredStmt) // gather all the event IDs we will retire
stmt := txn.Stmt(s.selectInvitesAboutToRetireStmt)
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -115,6 +119,10 @@ func (s *inviteStatements) updateInviteRetired(
} }
eventIDs = append(eventIDs, inviteEventID) eventIDs = append(eventIDs, inviteEventID)
} }
// now retire the invites
stmt = txn.Stmt(s.updateInviteRetiredStmt)
_, err = stmt.ExecContext(ctx, roomNID, targetUserNID)
return return
} }

View file

@ -26,7 +26,7 @@ import (
const inviteEventsSchema = ` const inviteEventsSchema = `
CREATE TABLE IF NOT EXISTS syncapi_invite_events ( CREATE TABLE IF NOT EXISTS syncapi_invite_events (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY,
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
target_user_id TEXT NOT NULL, target_user_id TEXT NOT NULL,
@ -39,8 +39,8 @@ CREATE INDEX IF NOT EXISTS syncapi_invites_event_id_idx ON syncapi_invite_events
const insertInviteEventSQL = "" + const insertInviteEventSQL = "" +
"INSERT INTO syncapi_invite_events" + "INSERT INTO syncapi_invite_events" +
" (room_id, event_id, target_user_id, event_json)" + " (id, room_id, event_id, target_user_id, event_json)" +
" VALUES ($1, $2, $3, $4)" " VALUES ($1, $2, $3, $4, $5)"
const deleteInviteEventSQL = "" + const deleteInviteEventSQL = "" +
"DELETE FROM syncapi_invite_events WHERE event_id = $1" "DELETE FROM syncapi_invite_events WHERE event_id = $1"
@ -83,25 +83,16 @@ func (s *inviteEventsStatements) prepare(db *sql.DB, streamID *streamIDStatement
} }
func (s *inviteEventsStatements) insertInviteEvent( func (s *inviteEventsStatements) insertInviteEvent(
ctx context.Context, inviteEvent gomatrixserverlib.Event, ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.Event, streamPos types.StreamPosition,
) (streamPos types.StreamPosition, err error) { ) (err error) {
var res sql.Result _, err = txn.Stmt(s.insertInviteEventStmt).ExecContext(
res, err = s.insertInviteEventStmt.ExecContext(
ctx, ctx,
streamPos,
inviteEvent.RoomID(), inviteEvent.RoomID(),
inviteEvent.EventID(), inviteEvent.EventID(),
*inviteEvent.StateKey(), *inviteEvent.StateKey(),
inviteEvent.JSON(), inviteEvent.JSON(),
) )
if err != nil {
return
}
var rowID int64
rowID, err = res.LastInsertId()
if err != nil {
return
}
streamPos = types.StreamPosition(rowID)
return return
} }

View file

@ -193,24 +193,20 @@ func (d *SyncServerDatasource) WriteEvent(
ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync,
) )
if err != nil { if err != nil {
fmt.Println("d.events.insertEvent:", err)
return err return err
} }
pduPosition = pos pduPosition = pos
if err = d.topology.insertEventInTopology(ctx, txn, ev); err != nil { if err = d.topology.insertEventInTopology(ctx, txn, ev); err != nil {
fmt.Println("d.topology.insertEventInTopology:", err)
return err return err
} }
if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil { if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil {
fmt.Println("d.handleBackwardExtremities:", err)
return err return err
} }
if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 { if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 {
// Nothing to do, the event may have just been a message event. // Nothing to do, the event may have just been a message event.
fmt.Println("nothing to do")
return nil return nil
} }
@ -625,18 +621,15 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
if err != nil { if err != nil {
return return
} }
fmt.Println("Joined rooms:", joinedRoomIDs)
stateFilterPart := gomatrixserverlib.DefaultStateFilter() // TODO: use filter provided in request stateFilterPart := gomatrixserverlib.DefaultStateFilter() // TODO: use filter provided in request
// Build up a /sync response. Add joined rooms. // Build up a /sync response. Add joined rooms.
for _, roomID := range joinedRoomIDs { for _, roomID := range joinedRoomIDs {
fmt.Println("WE'RE ON", roomID)
var stateEvents []gomatrixserverlib.Event var stateEvents []gomatrixserverlib.Event
stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID, &stateFilterPart) stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID, &stateFilterPart)
if err != nil { if err != nil {
fmt.Println("d.roomstate.selectCurrentState:", err)
return return
} }
//fmt.Println("State events:", stateEvents) //fmt.Println("State events:", stateEvents)
@ -648,7 +641,6 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
numRecentEventsPerRoom, true, true, numRecentEventsPerRoom, true, true,
) )
if err != nil { if err != nil {
fmt.Println("d.events.selectRecentEvents:", err)
return return
} }
//fmt.Println("Recent stream events:", recentStreamEvents) //fmt.Println("Recent stream events:", recentStreamEvents)
@ -658,10 +650,9 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
var backwardTopologyPos types.StreamPosition var backwardTopologyPos types.StreamPosition
backwardTopologyPos, err = d.topology.selectPositionInTopology(ctx, txn, recentStreamEvents[0].EventID()) backwardTopologyPos, err = d.topology.selectPositionInTopology(ctx, txn, recentStreamEvents[0].EventID())
if err != nil { if err != nil {
fmt.Println("d.topology.selectPositionInTopology:", err)
return nil, types.PaginationToken{}, []string{}, err return nil, types.PaginationToken{}, []string{}, err
} }
fmt.Println("Backward topology position:", backwardTopologyPos)
if backwardTopologyPos-1 <= 0 { if backwardTopologyPos-1 <= 0 {
backwardTopologyPos = types.StreamPosition(1) backwardTopologyPos = types.StreamPosition(1)
} else { } else {
@ -683,7 +674,6 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
} }
if err = d.addInvitesToResponse(ctx, txn, userID, 0, toPos.PDUPosition, res); err != nil { if err = d.addInvitesToResponse(ctx, txn, userID, 0, toPos.PDUPosition, res); err != nil {
fmt.Println("d.addInvitesToResponse:", err)
return return
} }
@ -764,8 +754,15 @@ func (d *SyncServerDatasource) UpsertAccountData(
// Returns an error if there was a problem communicating with the database. // Returns an error if there was a problem communicating with the database.
func (d *SyncServerDatasource) AddInviteEvent( func (d *SyncServerDatasource) AddInviteEvent(
ctx context.Context, inviteEvent gomatrixserverlib.Event, ctx context.Context, inviteEvent gomatrixserverlib.Event,
) (types.StreamPosition, error) { ) (streamPos types.StreamPosition, err error) {
return d.invites.insertInviteEvent(ctx, inviteEvent) err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
streamPos, err = d.streamID.nextStreamID(ctx, txn)
if err != nil {
return err
}
return d.invites.insertInviteEvent(ctx, txn, inviteEvent, streamPos)
})
return
} }
// RetireInviteEvent removes an old invite event from the database. // RetireInviteEvent removes an old invite event from the database.