Finish implementing retiring invites (#1166)

* Pass retired invites to the syncapi with the event ID of the invite

* Implement retire invite streaming

* Update whitelist
This commit is contained in:
Kegsay 2020-06-26 11:07:52 +01:00 committed by GitHub
parent c1d2382e6d
commit 4897beabee
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 204 additions and 81 deletions

View file

@ -155,7 +155,7 @@ func (r *RoomserverInternalAPI) performJoinRoomByID(
// where we might think we know about a room in the following // where we might think we know about a room in the following
// section but don't know the latest state as all of our users // section but don't know the latest state as all of our users
// have left. // have left.
isInvitePending, inviteSender, err := r.isInvitePending(ctx, req.RoomIDOrAlias, req.UserID) isInvitePending, inviteSender, _, err := r.isInvitePending(ctx, req.RoomIDOrAlias, req.UserID)
if err == nil && isInvitePending { if err == nil && isInvitePending {
// Check if there's an invite pending. // Check if there's an invite pending.
_, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender) _, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender)

View file

@ -9,6 +9,7 @@ import (
fsAPI "github.com/matrix-org/dendrite/federationsender/api" fsAPI "github.com/matrix-org/dendrite/federationsender/api"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -38,9 +39,9 @@ func (r *RoomserverInternalAPI) performLeaveRoomByID(
) error { ) error {
// If there's an invite outstanding for the room then respond to // If there's an invite outstanding for the room then respond to
// that. // that.
isInvitePending, senderUser, err := r.isInvitePending(ctx, req.RoomID, req.UserID) isInvitePending, senderUser, eventID, err := r.isInvitePending(ctx, req.RoomID, req.UserID)
if err == nil && isInvitePending { if err == nil && isInvitePending {
return r.performRejectInvite(ctx, req, res, senderUser) return r.performRejectInvite(ctx, req, res, senderUser, eventID)
} }
// There's no invite pending, so first of all we want to find out // There's no invite pending, so first of all we want to find out
@ -134,7 +135,7 @@ func (r *RoomserverInternalAPI) performRejectInvite(
ctx context.Context, ctx context.Context,
req *api.PerformLeaveRequest, req *api.PerformLeaveRequest,
res *api.PerformLeaveResponse, // nolint:unparam res *api.PerformLeaveResponse, // nolint:unparam
senderUser string, senderUser, eventID string,
) error { ) error {
_, domain, err := gomatrixserverlib.SplitID('@', senderUser) _, domain, err := gomatrixserverlib.SplitID('@', senderUser)
if err != nil { if err != nil {
@ -152,56 +153,68 @@ func (r *RoomserverInternalAPI) performRejectInvite(
return err return err
} }
// TODO: Withdraw the invite, so that the sync API etc are // Withdraw the invite, so that the sync API etc are
// notified that we rejected it. // notified that we rejected it.
return r.WriteOutputEvents(req.RoomID, []api.OutputEvent{
return nil {
Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &api.OutputRetireInviteEvent{
EventID: eventID,
Membership: "leave",
TargetUserID: req.UserID,
},
},
})
} }
func (r *RoomserverInternalAPI) isInvitePending( func (r *RoomserverInternalAPI) isInvitePending(
ctx context.Context, ctx context.Context,
roomID, userID string, roomID, userID string,
) (bool, string, error) { ) (bool, string, string, error) {
// Look up the room NID for the supplied room ID. // Look up the room NID for the supplied room ID.
roomNID, err := r.DB.RoomNID(ctx, roomID) roomNID, err := r.DB.RoomNID(ctx, roomID)
if err != nil { if err != nil {
return false, "", fmt.Errorf("r.DB.RoomNID: %w", err) return false, "", "", fmt.Errorf("r.DB.RoomNID: %w", err)
} }
// Look up the state key NID for the supplied user ID. // Look up the state key NID for the supplied user ID.
targetUserNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{userID}) targetUserNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{userID})
if err != nil { if err != nil {
return false, "", fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err) return false, "", "", fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err)
} }
targetUserNID, targetUserFound := targetUserNIDs[userID] targetUserNID, targetUserFound := targetUserNIDs[userID]
if !targetUserFound { if !targetUserFound {
return false, "", fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs) return false, "", "", fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs)
} }
// Let's see if we have an event active for the user in the room. If // Let's see if we have an event active for the user in the room. If
// we do then it will contain a server name that we can direct the // we do then it will contain a server name that we can direct the
// send_leave to. // send_leave to.
senderUserNIDs, err := r.DB.GetInvitesForUser(ctx, roomNID, targetUserNID) senderUserNIDs, eventIDs, err := r.DB.GetInvitesForUser(ctx, roomNID, targetUserNID)
if err != nil { if err != nil {
return false, "", fmt.Errorf("r.DB.GetInvitesForUser: %w", err) return false, "", "", fmt.Errorf("r.DB.GetInvitesForUser: %w", err)
} }
if len(senderUserNIDs) == 0 { if len(senderUserNIDs) == 0 {
return false, "", nil return false, "", "", nil
}
userNIDToEventID := make(map[types.EventStateKeyNID]string)
for i, nid := range senderUserNIDs {
userNIDToEventID[nid] = eventIDs[i]
} }
// Look up the user ID from the NID. // Look up the user ID from the NID.
senderUsers, err := r.DB.EventStateKeys(ctx, senderUserNIDs) senderUsers, err := r.DB.EventStateKeys(ctx, senderUserNIDs)
if err != nil { if err != nil {
return false, "", fmt.Errorf("r.DB.EventStateKeys: %w", err) return false, "", "", fmt.Errorf("r.DB.EventStateKeys: %w", err)
} }
if len(senderUsers) == 0 { if len(senderUsers) == 0 {
return false, "", fmt.Errorf("no senderUsers") return false, "", "", fmt.Errorf("no senderUsers")
} }
senderUser, senderUserFound := senderUsers[senderUserNIDs[0]] senderUser, senderUserFound := senderUsers[senderUserNIDs[0]]
if !senderUserFound { if !senderUserFound {
return false, "", fmt.Errorf("missing user for NID %d (%+v)", senderUserNIDs[0], senderUsers) return false, "", "", fmt.Errorf("missing user for NID %d (%+v)", senderUserNIDs[0], senderUsers)
} }
return true, senderUser, nil return true, senderUser, userNIDToEventID[senderUserNIDs[0]], nil
} }

View file

@ -102,9 +102,9 @@ type Database interface {
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error)
// Look up the active invites targeting a user in a room and return the // Look up the active invites targeting a user in a room and return the
// numeric state key IDs for the user IDs who sent them. // numeric state key IDs for the user IDs who sent them along with the event IDs for the invites.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
GetInvitesForUser(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (senderUserIDs []types.EventStateKeyNID, err error) GetInvitesForUser(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (senderUserIDs []types.EventStateKeyNID, eventIDs []string, err error)
// Save a given room alias with the room ID it refers to. // Save a given room alias with the room ID it refers to.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error

View file

@ -62,7 +62,7 @@ const insertInviteEventSQL = "" +
" ON CONFLICT DO NOTHING" " ON CONFLICT DO NOTHING"
const selectInviteActiveForUserInRoomSQL = "" + const selectInviteActiveForUserInRoomSQL = "" +
"SELECT sender_nid FROM roomserver_invites" + "SELECT invite_event_id, sender_nid FROM roomserver_invites" +
" WHERE target_nid = $1 AND room_nid = $2" + " WHERE target_nid = $1 AND room_nid = $2" +
" AND NOT retired" " AND NOT retired"
@ -141,21 +141,24 @@ func (s *inviteStatements) UpdateInviteRetired(
func (s *inviteStatements) SelectInviteActiveForUserInRoom( func (s *inviteStatements) SelectInviteActiveForUserInRoom(
ctx context.Context, ctx context.Context,
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
) ([]types.EventStateKeyNID, error) { ) ([]types.EventStateKeyNID, []string, error) {
rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext( rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext(
ctx, targetUserNID, roomNID, ctx, targetUserNID, roomNID,
) )
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteActiveForUserInRoom: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectInviteActiveForUserInRoom: rows.close() failed")
var result []types.EventStateKeyNID var result []types.EventStateKeyNID
var eventIDs []string
for rows.Next() { for rows.Next() {
var inviteEventID string
var senderUserNID int64 var senderUserNID int64
if err := rows.Scan(&senderUserNID); err != nil { if err := rows.Scan(&inviteEventID, &senderUserNID); err != nil {
return nil, err return nil, nil, err
} }
result = append(result, types.EventStateKeyNID(senderUserNID)) result = append(result, types.EventStateKeyNID(senderUserNID))
eventIDs = append(eventIDs, inviteEventID)
} }
return result, rows.Err() return result, eventIDs, rows.Err()
} }

View file

@ -265,7 +265,7 @@ func (d *Database) GetInvitesForUser(
ctx context.Context, ctx context.Context,
roomNID types.RoomNID, roomNID types.RoomNID,
targetUserNID types.EventStateKeyNID, targetUserNID types.EventStateKeyNID,
) (senderUserIDs []types.EventStateKeyNID, err error) { ) (senderUserIDs []types.EventStateKeyNID, eventIDs []string, err error) {
return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID)
} }

View file

@ -45,7 +45,7 @@ const insertInviteEventSQL = "" +
" ON CONFLICT DO NOTHING" " ON CONFLICT DO NOTHING"
const selectInviteActiveForUserInRoomSQL = "" + const selectInviteActiveForUserInRoomSQL = "" +
"SELECT sender_nid FROM roomserver_invites" + "SELECT invite_event_id, sender_nid FROM roomserver_invites" +
" WHERE target_nid = $1 AND room_nid = $2" + " WHERE target_nid = $1 AND room_nid = $2" +
" AND NOT retired" " AND NOT retired"
@ -133,21 +133,24 @@ func (s *inviteStatements) UpdateInviteRetired(
func (s *inviteStatements) SelectInviteActiveForUserInRoom( func (s *inviteStatements) SelectInviteActiveForUserInRoom(
ctx context.Context, ctx context.Context,
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
) ([]types.EventStateKeyNID, error) { ) ([]types.EventStateKeyNID, []string, error) {
rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext( rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext(
ctx, targetUserNID, roomNID, ctx, targetUserNID, roomNID,
) )
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteActiveForUserInRoom: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectInviteActiveForUserInRoom: rows.close() failed")
var result []types.EventStateKeyNID var result []types.EventStateKeyNID
var eventIDs []string
for rows.Next() { for rows.Next() {
var eventID string
var senderUserNID int64 var senderUserNID int64
if err := rows.Scan(&senderUserNID); err != nil { if err := rows.Scan(&eventID, &senderUserNID); err != nil {
return nil, err return nil, nil, err
} }
result = append(result, types.EventStateKeyNID(senderUserNID)) result = append(result, types.EventStateKeyNID(senderUserNID))
eventIDs = append(eventIDs, eventID)
} }
return result, nil return result, eventIDs, nil
} }

View file

@ -100,8 +100,8 @@ type PreviousEvents interface {
type Invites interface { type Invites interface {
InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte) (bool, error) InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte) (bool, error)
UpdateInviteRetired(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) ([]string, error) UpdateInviteRetired(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) ([]string, error)
// SelectInviteActiveForUserInRoom returns a list of sender state key NIDs // SelectInviteActiveForUserInRoom returns a list of sender state key NIDs and invite event IDs matching those nids.
SelectInviteActiveForUserInRoom(ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, error) SelectInviteActiveForUserInRoom(ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, []string, error)
} }
type MembershipState int64 type MembershipState int64

View file

@ -157,7 +157,7 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent(
func (s *OutputRoomEventConsumer) onRetireInviteEvent( func (s *OutputRoomEventConsumer) onRetireInviteEvent(
ctx context.Context, msg api.OutputRetireInviteEvent, ctx context.Context, msg api.OutputRetireInviteEvent,
) error { ) error {
err := s.db.RetireInviteEvent(ctx, msg.EventID) sp, err := s.db.RetireInviteEvent(ctx, msg.EventID)
if err != nil { if err != nil {
// panic rather than continue with an inconsistent database // panic rather than continue with an inconsistent database
log.WithFields(log.Fields{ log.WithFields(log.Fields{
@ -166,8 +166,9 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent(
}).Panicf("roomserver output log: remove invite failure") }).Panicf("roomserver output log: remove invite failure")
return nil return nil
} }
// TODO: Notify any active sync requests that the invite has been retired. // Notify any active sync requests that the invite has been retired.
// s.notifier.OnNewEvent(nil, msg.TargetUserID, syncStreamPos) // Invites share the same stream counter as PDUs
s.notifier.OnNewEvent(nil, "", []string{msg.TargetUserID}, types.NewStreamToken(sp, 0))
return nil return nil
} }

View file

@ -78,9 +78,9 @@ type Database interface {
// If the invite was successfully stored this returns the stream ID it was stored at. // If the invite was successfully stored this returns the stream ID it was stored at.
// Returns an error if there was a problem communicating with the database. // Returns an error if there was a problem communicating with the database.
AddInviteEvent(ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent) (types.StreamPosition, error) AddInviteEvent(ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent) (types.StreamPosition, error)
// RetireInviteEvent removes an old invite event from the database. // RetireInviteEvent removes an old invite event from the database. Returns the new position of the retired invite.
// Returns an error if there was a problem communicating with the database. // Returns an error if there was a problem communicating with the database.
RetireInviteEvent(ctx context.Context, inviteEventID string) error RetireInviteEvent(ctx context.Context, inviteEventID string) (types.StreamPosition, error)
// SetTypingTimeoutCallback sets a callback function that is called right after // SetTypingTimeoutCallback sets a callback function that is called right after
// a user is removed from the typing user list due to timeout. // a user is removed from the typing user list due to timeout.
SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn)

View file

@ -33,7 +33,8 @@ CREATE TABLE IF NOT EXISTS syncapi_invite_events (
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
target_user_id TEXT NOT NULL, target_user_id TEXT NOT NULL,
headered_event_json TEXT NOT NULL headered_event_json TEXT NOT NULL,
deleted BOOL NOT NULL
); );
-- For looking up the invites for a given user. -- For looking up the invites for a given user.
@ -47,14 +48,14 @@ CREATE INDEX IF NOT EXISTS syncapi_invites_event_id_idx
const insertInviteEventSQL = "" + const insertInviteEventSQL = "" +
"INSERT INTO syncapi_invite_events (" + "INSERT INTO syncapi_invite_events (" +
" room_id, event_id, target_user_id, headered_event_json" + " room_id, event_id, target_user_id, headered_event_json, deleted" +
") VALUES ($1, $2, $3, $4) RETURNING id" ") VALUES ($1, $2, $3, $4, FALSE) RETURNING id"
const deleteInviteEventSQL = "" + const deleteInviteEventSQL = "" +
"DELETE FROM syncapi_invite_events WHERE event_id = $1" "UPDATE syncapi_invite_events SET deleted=TRUE, id=nextval('syncapi_stream_id') WHERE event_id = $1 RETURNING id"
const selectInviteEventsInRangeSQL = "" + const selectInviteEventsInRangeSQL = "" +
"SELECT room_id, headered_event_json FROM syncapi_invite_events" + "SELECT room_id, headered_event_json, deleted FROM syncapi_invite_events" +
" WHERE target_user_id = $1 AND id > $2 AND id <= $3" + " WHERE target_user_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id DESC" " ORDER BY id DESC"
@ -110,40 +111,46 @@ func (s *inviteEventsStatements) InsertInviteEvent(
func (s *inviteEventsStatements) DeleteInviteEvent( func (s *inviteEventsStatements) DeleteInviteEvent(
ctx context.Context, inviteEventID string, ctx context.Context, inviteEventID string,
) error { ) (sp types.StreamPosition, err error) {
_, err := s.deleteInviteEventStmt.ExecContext(ctx, inviteEventID) err = s.deleteInviteEventStmt.QueryRowContext(ctx, inviteEventID).Scan(&sp)
return err return
} }
// selectInviteEventsInRange returns a map of room ID to invite event for the // selectInviteEventsInRange returns a map of room ID to invite event for the
// active invites for the target user ID in the supplied range. // active invites for the target user ID in the supplied range.
func (s *inviteEventsStatements) SelectInviteEventsInRange( func (s *inviteEventsStatements) SelectInviteEventsInRange(
ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range, ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range,
) (map[string]gomatrixserverlib.HeaderedEvent, error) { ) (map[string]gomatrixserverlib.HeaderedEvent, map[string]gomatrixserverlib.HeaderedEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt) stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt)
rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High()) rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High())
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed")
result := map[string]gomatrixserverlib.HeaderedEvent{} result := map[string]gomatrixserverlib.HeaderedEvent{}
retired := map[string]gomatrixserverlib.HeaderedEvent{}
for rows.Next() { for rows.Next() {
var ( var (
roomID string roomID string
eventJSON []byte eventJSON []byte
deleted bool
) )
if err = rows.Scan(&roomID, &eventJSON); err != nil { if err = rows.Scan(&roomID, &eventJSON, &deleted); err != nil {
return nil, err return nil, nil, err
} }
var event gomatrixserverlib.HeaderedEvent var event gomatrixserverlib.HeaderedEvent
if err := json.Unmarshal(eventJSON, &event); err != nil { if err := json.Unmarshal(eventJSON, &event); err != nil {
return nil, err return nil, nil, err
} }
if deleted {
retired[roomID] = event
} else {
result[roomID] = event result[roomID] = event
} }
return result, rows.Err() }
return result, retired, rows.Err()
} }
func (s *inviteEventsStatements) SelectMaxInviteID( func (s *inviteEventsStatements) SelectMaxInviteID(

View file

@ -180,11 +180,8 @@ func (d *Database) AddInviteEvent(
// Returns an error if there was a problem communicating with the database. // Returns an error if there was a problem communicating with the database.
func (d *Database) RetireInviteEvent( func (d *Database) RetireInviteEvent(
ctx context.Context, inviteEventID string, ctx context.Context, inviteEventID string,
) error { ) (types.StreamPosition, error) {
// TODO: Record that invite has been retired in a stream so that we can return d.Invites.DeleteInviteEvent(ctx, inviteEventID)
// notify the user in an incremental sync.
err := d.Invites.DeleteInviteEvent(ctx, inviteEventID)
return err
} }
// GetAccountDataInRange returns all account data for a given user inserted or // GetAccountDataInRange returns all account data for a given user inserted or
@ -724,7 +721,7 @@ func (d *Database) addInvitesToResponse(
r types.Range, r types.Range,
res *types.Response, res *types.Response,
) error { ) error {
invites, err := d.Invites.SelectInviteEventsInRange( invites, retiredInvites, err := d.Invites.SelectInviteEventsInRange(
ctx, txn, userID, r, ctx, txn, userID, r,
) )
if err != nil { if err != nil {
@ -734,6 +731,10 @@ func (d *Database) addInvitesToResponse(
ir := types.NewInviteResponse(inviteEvent) ir := types.NewInviteResponse(inviteEvent)
res.Rooms.Invite[roomID] = *ir res.Rooms.Invite[roomID] = *ir
} }
for roomID := range retiredInvites {
lr := types.NewLeaveResponse()
res.Rooms.Leave[roomID] = *lr
}
return nil return nil
} }

View file

@ -33,7 +33,8 @@ CREATE TABLE IF NOT EXISTS syncapi_invite_events (
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
target_user_id TEXT NOT NULL, target_user_id TEXT NOT NULL,
headered_event_json TEXT NOT NULL headered_event_json TEXT NOT NULL,
deleted BOOL NOT NULL
); );
CREATE INDEX IF NOT EXISTS syncapi_invites_target_user_id_idx ON syncapi_invite_events (target_user_id, id); CREATE INDEX IF NOT EXISTS syncapi_invites_target_user_id_idx ON syncapi_invite_events (target_user_id, id);
@ -42,14 +43,14 @@ CREATE INDEX IF NOT EXISTS syncapi_invites_event_id_idx ON syncapi_invite_events
const insertInviteEventSQL = "" + const insertInviteEventSQL = "" +
"INSERT INTO syncapi_invite_events" + "INSERT INTO syncapi_invite_events" +
" (id, room_id, event_id, target_user_id, headered_event_json)" + " (id, room_id, event_id, target_user_id, headered_event_json, deleted)" +
" VALUES ($1, $2, $3, $4, $5)" " VALUES ($1, $2, $3, $4, $5, false)"
const deleteInviteEventSQL = "" + const deleteInviteEventSQL = "" +
"DELETE FROM syncapi_invite_events WHERE event_id = $1" "UPDATE syncapi_invite_events SET deleted=true, id=$1 WHERE event_id = $2"
const selectInviteEventsInRangeSQL = "" + const selectInviteEventsInRangeSQL = "" +
"SELECT room_id, headered_event_json FROM syncapi_invite_events" + "SELECT room_id, headered_event_json, deleted FROM syncapi_invite_events" +
" WHERE target_user_id = $1 AND id > $2 AND id <= $3" + " WHERE target_user_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id DESC" " ORDER BY id DESC"
@ -114,40 +115,49 @@ func (s *inviteEventsStatements) InsertInviteEvent(
func (s *inviteEventsStatements) DeleteInviteEvent( func (s *inviteEventsStatements) DeleteInviteEvent(
ctx context.Context, inviteEventID string, ctx context.Context, inviteEventID string,
) error { ) (types.StreamPosition, error) {
_, err := s.deleteInviteEventStmt.ExecContext(ctx, inviteEventID) streamPos, err := s.streamIDStatements.nextStreamID(ctx, nil)
return err if err != nil {
return streamPos, err
}
_, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID)
return streamPos, err
} }
// selectInviteEventsInRange returns a map of room ID to invite event for the // selectInviteEventsInRange returns a map of room ID to invite event for the
// active invites for the target user ID in the supplied range. // active invites for the target user ID in the supplied range.
func (s *inviteEventsStatements) SelectInviteEventsInRange( func (s *inviteEventsStatements) SelectInviteEventsInRange(
ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range, ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range,
) (map[string]gomatrixserverlib.HeaderedEvent, error) { ) (map[string]gomatrixserverlib.HeaderedEvent, map[string]gomatrixserverlib.HeaderedEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt) stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt)
rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High()) rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High())
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed")
result := map[string]gomatrixserverlib.HeaderedEvent{} result := map[string]gomatrixserverlib.HeaderedEvent{}
retired := map[string]gomatrixserverlib.HeaderedEvent{}
for rows.Next() { for rows.Next() {
var ( var (
roomID string roomID string
eventJSON []byte eventJSON []byte
deleted bool
) )
if err = rows.Scan(&roomID, &eventJSON); err != nil { if err = rows.Scan(&roomID, &eventJSON, &deleted); err != nil {
return nil, err return nil, nil, err
} }
var event gomatrixserverlib.HeaderedEvent var event gomatrixserverlib.HeaderedEvent
if err := json.Unmarshal(eventJSON, &event); err != nil { if err := json.Unmarshal(eventJSON, &event); err != nil {
return nil, err return nil, nil, err
} }
if deleted {
retired[roomID] = event
} else {
result[roomID] = event result[roomID] = event
} }
return result, nil }
return result, retired, nil
} }
func (s *inviteEventsStatements) SelectMaxInviteID( func (s *inviteEventsStatements) SelectMaxInviteID(

View file

@ -601,6 +601,83 @@ func TestSendToDeviceBehaviour(t *testing.T) {
} }
} }
func TestInviteBehaviour(t *testing.T) {
db := MustCreateDatabase(t)
inviteRoom1 := "!inviteRoom1:somewhere"
inviteEvent1 := MustCreateEvent(t, inviteRoom1, nil, &gomatrixserverlib.EventBuilder{
Content: []byte(fmt.Sprintf(`{"membership":"invite"}`)),
Type: "m.room.member",
StateKey: &testUserIDA,
Sender: "@inviteUser1:somewhere",
})
inviteRoom2 := "!inviteRoom2:somewhere"
inviteEvent2 := MustCreateEvent(t, inviteRoom2, nil, &gomatrixserverlib.EventBuilder{
Content: []byte(fmt.Sprintf(`{"membership":"invite"}`)),
Type: "m.room.member",
StateKey: &testUserIDA,
Sender: "@inviteUser2:somewhere",
})
for _, ev := range []gomatrixserverlib.HeaderedEvent{inviteEvent1, inviteEvent2} {
_, err := db.AddInviteEvent(ctx, ev)
if err != nil {
t.Fatalf("Failed to AddInviteEvent: %s", err)
}
}
latest, err := db.SyncPosition(ctx)
if err != nil {
t.Fatalf("failed to get SyncPosition: %s", err)
}
// both invite events should appear in a new sync
beforeRetireRes := types.NewResponse()
beforeRetireRes, err = db.IncrementalSync(ctx, beforeRetireRes, testUserDeviceA, types.NewStreamToken(0, 0), latest, 0, false)
if err != nil {
t.Fatalf("IncrementalSync failed: %s", err)
}
assertInvitedToRooms(t, beforeRetireRes, []string{inviteRoom1, inviteRoom2})
// retire one event: a fresh sync should just return 1 invite room
if _, err = db.RetireInviteEvent(ctx, inviteEvent1.EventID()); err != nil {
t.Fatalf("Failed to RetireInviteEvent: %s", err)
}
latest, err = db.SyncPosition(ctx)
if err != nil {
t.Fatalf("failed to get SyncPosition: %s", err)
}
res := types.NewResponse()
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, types.NewStreamToken(0, 0), latest, 0, false)
if err != nil {
t.Fatalf("IncrementalSync failed: %s", err)
}
assertInvitedToRooms(t, res, []string{inviteRoom2})
// a sync after we have received both invites should result in a leave for the retired room
beforeRetireTok, err := types.NewStreamTokenFromString(beforeRetireRes.NextBatch)
if err != nil {
t.Fatalf("NewStreamTokenFromString cannot parse next batch '%s' : %s", beforeRetireRes.NextBatch, err)
}
res = types.NewResponse()
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, beforeRetireTok, latest, 0, false)
if err != nil {
t.Fatalf("IncrementalSync failed: %s", err)
}
assertInvitedToRooms(t, res, []string{})
if _, ok := res.Rooms.Leave[inviteRoom1]; !ok {
t.Fatalf("IncrementalSync: expected to see room left after it was retired but it wasn't")
}
}
func assertInvitedToRooms(t *testing.T, res *types.Response, roomIDs []string) {
t.Helper()
if len(res.Rooms.Invite) != len(roomIDs) {
t.Fatalf("got %d invited rooms, want %d", len(res.Rooms.Invite), len(roomIDs))
}
for _, roomID := range roomIDs {
if _, ok := res.Rooms.Invite[roomID]; !ok {
t.Fatalf("missing room ID %s", roomID)
}
}
}
func assertEventsEqual(t *testing.T, msg string, checkRoomID bool, gots []gomatrixserverlib.ClientEvent, wants []gomatrixserverlib.HeaderedEvent) { func assertEventsEqual(t *testing.T, msg string, checkRoomID bool, gots []gomatrixserverlib.ClientEvent, wants []gomatrixserverlib.HeaderedEvent) {
if len(gots) != len(wants) { if len(gots) != len(wants) {
t.Fatalf("%s response returned %d events, want %d", msg, len(gots), len(wants)) t.Fatalf("%s response returned %d events, want %d", msg, len(gots), len(wants))

View file

@ -32,9 +32,9 @@ type AccountData interface {
type Invites interface { type Invites interface {
InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent) (streamPos types.StreamPosition, err error) InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent) (streamPos types.StreamPosition, err error)
DeleteInviteEvent(ctx context.Context, inviteEventID string) error DeleteInviteEvent(ctx context.Context, inviteEventID string) (types.StreamPosition, error)
// SelectInviteEventsInRange returns a map of room ID to invite events. // SelectInviteEventsInRange returns a map of room ID to invite events.
SelectInviteEventsInRange(ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range) (map[string]gomatrixserverlib.HeaderedEvent, error) SelectInviteEventsInRange(ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range) (invites map[string]gomatrixserverlib.HeaderedEvent, retired map[string]gomatrixserverlib.HeaderedEvent, err error)
SelectMaxInviteID(ctx context.Context, txn *sql.Tx) (id int64, err error) SelectMaxInviteID(ctx context.Context, txn *sql.Tx) (id int64, err error)
} }

View file

@ -290,10 +290,10 @@ type Response struct {
NextBatch string `json:"next_batch"` NextBatch string `json:"next_batch"`
AccountData struct { AccountData struct {
Events []gomatrixserverlib.ClientEvent `json:"events"` Events []gomatrixserverlib.ClientEvent `json:"events"`
} `json:"account_data"` } `json:"account_data,omitempty"`
Presence struct { Presence struct {
Events []gomatrixserverlib.ClientEvent `json:"events"` Events []gomatrixserverlib.ClientEvent `json:"events"`
} `json:"presence"` } `json:"presence,omitempty"`
Rooms struct { Rooms struct {
Join map[string]JoinResponse `json:"join"` Join map[string]JoinResponse `json:"join"`
Invite map[string]InviteResponse `json:"invite"` Invite map[string]InviteResponse `json:"invite"`

View file

@ -378,3 +378,11 @@ Outbound federation correctly handles unsupported room versions
Remote users may not join unfederated rooms Remote users may not join unfederated rooms
Guest users denied access over federation if guest access prohibited Guest users denied access over federation if guest access prohibited
Non-numeric ports in server names are rejected Non-numeric ports in server names are rejected
Invited user can reject invite over federation
Invited user can reject invite over federation for empty room
Can reject invites over federation for rooms with version 1
Can reject invites over federation for rooms with version 2
Can reject invites over federation for rooms with version 3
Can reject invites over federation for rooms with version 4
Can reject invites over federation for rooms with version 5
Can reject invites over federation for rooms with version 6