mirror of
https://github.com/matrix-org/dendrite.git
synced 2024-11-22 14:21:55 -06:00
Finish implementing retiring invites (#1166)
* Pass retired invites to the syncapi with the event ID of the invite * Implement retire invite streaming * Update whitelist
This commit is contained in:
parent
c1d2382e6d
commit
4897beabee
|
@ -155,7 +155,7 @@ func (r *RoomserverInternalAPI) performJoinRoomByID(
|
||||||
// where we might think we know about a room in the following
|
// where we might think we know about a room in the following
|
||||||
// section but don't know the latest state as all of our users
|
// section but don't know the latest state as all of our users
|
||||||
// have left.
|
// have left.
|
||||||
isInvitePending, inviteSender, err := r.isInvitePending(ctx, req.RoomIDOrAlias, req.UserID)
|
isInvitePending, inviteSender, _, err := r.isInvitePending(ctx, req.RoomIDOrAlias, req.UserID)
|
||||||
if err == nil && isInvitePending {
|
if err == nil && isInvitePending {
|
||||||
// Check if there's an invite pending.
|
// Check if there's an invite pending.
|
||||||
_, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender)
|
_, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender)
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
fsAPI "github.com/matrix-org/dendrite/federationsender/api"
|
fsAPI "github.com/matrix-org/dendrite/federationsender/api"
|
||||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -38,9 +39,9 @@ func (r *RoomserverInternalAPI) performLeaveRoomByID(
|
||||||
) error {
|
) error {
|
||||||
// If there's an invite outstanding for the room then respond to
|
// If there's an invite outstanding for the room then respond to
|
||||||
// that.
|
// that.
|
||||||
isInvitePending, senderUser, err := r.isInvitePending(ctx, req.RoomID, req.UserID)
|
isInvitePending, senderUser, eventID, err := r.isInvitePending(ctx, req.RoomID, req.UserID)
|
||||||
if err == nil && isInvitePending {
|
if err == nil && isInvitePending {
|
||||||
return r.performRejectInvite(ctx, req, res, senderUser)
|
return r.performRejectInvite(ctx, req, res, senderUser, eventID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// There's no invite pending, so first of all we want to find out
|
// There's no invite pending, so first of all we want to find out
|
||||||
|
@ -134,7 +135,7 @@ func (r *RoomserverInternalAPI) performRejectInvite(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
req *api.PerformLeaveRequest,
|
req *api.PerformLeaveRequest,
|
||||||
res *api.PerformLeaveResponse, // nolint:unparam
|
res *api.PerformLeaveResponse, // nolint:unparam
|
||||||
senderUser string,
|
senderUser, eventID string,
|
||||||
) error {
|
) error {
|
||||||
_, domain, err := gomatrixserverlib.SplitID('@', senderUser)
|
_, domain, err := gomatrixserverlib.SplitID('@', senderUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -152,56 +153,68 @@ func (r *RoomserverInternalAPI) performRejectInvite(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Withdraw the invite, so that the sync API etc are
|
// Withdraw the invite, so that the sync API etc are
|
||||||
// notified that we rejected it.
|
// notified that we rejected it.
|
||||||
|
return r.WriteOutputEvents(req.RoomID, []api.OutputEvent{
|
||||||
return nil
|
{
|
||||||
|
Type: api.OutputTypeRetireInviteEvent,
|
||||||
|
RetireInviteEvent: &api.OutputRetireInviteEvent{
|
||||||
|
EventID: eventID,
|
||||||
|
Membership: "leave",
|
||||||
|
TargetUserID: req.UserID,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) isInvitePending(
|
func (r *RoomserverInternalAPI) isInvitePending(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
roomID, userID string,
|
roomID, userID string,
|
||||||
) (bool, string, error) {
|
) (bool, string, string, error) {
|
||||||
// Look up the room NID for the supplied room ID.
|
// Look up the room NID for the supplied room ID.
|
||||||
roomNID, err := r.DB.RoomNID(ctx, roomID)
|
roomNID, err := r.DB.RoomNID(ctx, roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, "", fmt.Errorf("r.DB.RoomNID: %w", err)
|
return false, "", "", fmt.Errorf("r.DB.RoomNID: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Look up the state key NID for the supplied user ID.
|
// Look up the state key NID for the supplied user ID.
|
||||||
targetUserNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{userID})
|
targetUserNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{userID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, "", fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err)
|
return false, "", "", fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err)
|
||||||
}
|
}
|
||||||
targetUserNID, targetUserFound := targetUserNIDs[userID]
|
targetUserNID, targetUserFound := targetUserNIDs[userID]
|
||||||
if !targetUserFound {
|
if !targetUserFound {
|
||||||
return false, "", fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs)
|
return false, "", "", fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Let's see if we have an event active for the user in the room. If
|
// Let's see if we have an event active for the user in the room. If
|
||||||
// we do then it will contain a server name that we can direct the
|
// we do then it will contain a server name that we can direct the
|
||||||
// send_leave to.
|
// send_leave to.
|
||||||
senderUserNIDs, err := r.DB.GetInvitesForUser(ctx, roomNID, targetUserNID)
|
senderUserNIDs, eventIDs, err := r.DB.GetInvitesForUser(ctx, roomNID, targetUserNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, "", fmt.Errorf("r.DB.GetInvitesForUser: %w", err)
|
return false, "", "", fmt.Errorf("r.DB.GetInvitesForUser: %w", err)
|
||||||
}
|
}
|
||||||
if len(senderUserNIDs) == 0 {
|
if len(senderUserNIDs) == 0 {
|
||||||
return false, "", nil
|
return false, "", "", nil
|
||||||
|
}
|
||||||
|
userNIDToEventID := make(map[types.EventStateKeyNID]string)
|
||||||
|
for i, nid := range senderUserNIDs {
|
||||||
|
userNIDToEventID[nid] = eventIDs[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Look up the user ID from the NID.
|
// Look up the user ID from the NID.
|
||||||
senderUsers, err := r.DB.EventStateKeys(ctx, senderUserNIDs)
|
senderUsers, err := r.DB.EventStateKeys(ctx, senderUserNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, "", fmt.Errorf("r.DB.EventStateKeys: %w", err)
|
return false, "", "", fmt.Errorf("r.DB.EventStateKeys: %w", err)
|
||||||
}
|
}
|
||||||
if len(senderUsers) == 0 {
|
if len(senderUsers) == 0 {
|
||||||
return false, "", fmt.Errorf("no senderUsers")
|
return false, "", "", fmt.Errorf("no senderUsers")
|
||||||
}
|
}
|
||||||
|
|
||||||
senderUser, senderUserFound := senderUsers[senderUserNIDs[0]]
|
senderUser, senderUserFound := senderUsers[senderUserNIDs[0]]
|
||||||
if !senderUserFound {
|
if !senderUserFound {
|
||||||
return false, "", fmt.Errorf("missing user for NID %d (%+v)", senderUserNIDs[0], senderUsers)
|
return false, "", "", fmt.Errorf("missing user for NID %d (%+v)", senderUserNIDs[0], senderUsers)
|
||||||
}
|
}
|
||||||
|
|
||||||
return true, senderUser, nil
|
return true, senderUser, userNIDToEventID[senderUserNIDs[0]], nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -102,9 +102,9 @@ type Database interface {
|
||||||
// Returns an error if there was a problem talking to the database.
|
// Returns an error if there was a problem talking to the database.
|
||||||
LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error)
|
LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error)
|
||||||
// Look up the active invites targeting a user in a room and return the
|
// Look up the active invites targeting a user in a room and return the
|
||||||
// numeric state key IDs for the user IDs who sent them.
|
// numeric state key IDs for the user IDs who sent them along with the event IDs for the invites.
|
||||||
// Returns an error if there was a problem talking to the database.
|
// Returns an error if there was a problem talking to the database.
|
||||||
GetInvitesForUser(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (senderUserIDs []types.EventStateKeyNID, err error)
|
GetInvitesForUser(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (senderUserIDs []types.EventStateKeyNID, eventIDs []string, err error)
|
||||||
// Save a given room alias with the room ID it refers to.
|
// Save a given room alias with the room ID it refers to.
|
||||||
// Returns an error if there was a problem talking to the database.
|
// Returns an error if there was a problem talking to the database.
|
||||||
SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error
|
SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error
|
||||||
|
|
|
@ -62,7 +62,7 @@ const insertInviteEventSQL = "" +
|
||||||
" ON CONFLICT DO NOTHING"
|
" ON CONFLICT DO NOTHING"
|
||||||
|
|
||||||
const selectInviteActiveForUserInRoomSQL = "" +
|
const selectInviteActiveForUserInRoomSQL = "" +
|
||||||
"SELECT sender_nid FROM roomserver_invites" +
|
"SELECT invite_event_id, sender_nid FROM roomserver_invites" +
|
||||||
" WHERE target_nid = $1 AND room_nid = $2" +
|
" WHERE target_nid = $1 AND room_nid = $2" +
|
||||||
" AND NOT retired"
|
" AND NOT retired"
|
||||||
|
|
||||||
|
@ -141,21 +141,24 @@ func (s *inviteStatements) UpdateInviteRetired(
|
||||||
func (s *inviteStatements) SelectInviteActiveForUserInRoom(
|
func (s *inviteStatements) SelectInviteActiveForUserInRoom(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
|
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
|
||||||
) ([]types.EventStateKeyNID, error) {
|
) ([]types.EventStateKeyNID, []string, error) {
|
||||||
rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext(
|
rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext(
|
||||||
ctx, targetUserNID, roomNID,
|
ctx, targetUserNID, roomNID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteActiveForUserInRoom: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteActiveForUserInRoom: rows.close() failed")
|
||||||
var result []types.EventStateKeyNID
|
var result []types.EventStateKeyNID
|
||||||
|
var eventIDs []string
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
|
var inviteEventID string
|
||||||
var senderUserNID int64
|
var senderUserNID int64
|
||||||
if err := rows.Scan(&senderUserNID); err != nil {
|
if err := rows.Scan(&inviteEventID, &senderUserNID); err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
result = append(result, types.EventStateKeyNID(senderUserNID))
|
result = append(result, types.EventStateKeyNID(senderUserNID))
|
||||||
|
eventIDs = append(eventIDs, inviteEventID)
|
||||||
}
|
}
|
||||||
return result, rows.Err()
|
return result, eventIDs, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
|
@ -265,7 +265,7 @@ func (d *Database) GetInvitesForUser(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
roomNID types.RoomNID,
|
roomNID types.RoomNID,
|
||||||
targetUserNID types.EventStateKeyNID,
|
targetUserNID types.EventStateKeyNID,
|
||||||
) (senderUserIDs []types.EventStateKeyNID, err error) {
|
) (senderUserIDs []types.EventStateKeyNID, eventIDs []string, err error) {
|
||||||
return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID)
|
return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -45,7 +45,7 @@ const insertInviteEventSQL = "" +
|
||||||
" ON CONFLICT DO NOTHING"
|
" ON CONFLICT DO NOTHING"
|
||||||
|
|
||||||
const selectInviteActiveForUserInRoomSQL = "" +
|
const selectInviteActiveForUserInRoomSQL = "" +
|
||||||
"SELECT sender_nid FROM roomserver_invites" +
|
"SELECT invite_event_id, sender_nid FROM roomserver_invites" +
|
||||||
" WHERE target_nid = $1 AND room_nid = $2" +
|
" WHERE target_nid = $1 AND room_nid = $2" +
|
||||||
" AND NOT retired"
|
" AND NOT retired"
|
||||||
|
|
||||||
|
@ -133,21 +133,24 @@ func (s *inviteStatements) UpdateInviteRetired(
|
||||||
func (s *inviteStatements) SelectInviteActiveForUserInRoom(
|
func (s *inviteStatements) SelectInviteActiveForUserInRoom(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
|
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
|
||||||
) ([]types.EventStateKeyNID, error) {
|
) ([]types.EventStateKeyNID, []string, error) {
|
||||||
rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext(
|
rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext(
|
||||||
ctx, targetUserNID, roomNID,
|
ctx, targetUserNID, roomNID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteActiveForUserInRoom: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteActiveForUserInRoom: rows.close() failed")
|
||||||
var result []types.EventStateKeyNID
|
var result []types.EventStateKeyNID
|
||||||
|
var eventIDs []string
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
|
var eventID string
|
||||||
var senderUserNID int64
|
var senderUserNID int64
|
||||||
if err := rows.Scan(&senderUserNID); err != nil {
|
if err := rows.Scan(&eventID, &senderUserNID); err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
result = append(result, types.EventStateKeyNID(senderUserNID))
|
result = append(result, types.EventStateKeyNID(senderUserNID))
|
||||||
|
eventIDs = append(eventIDs, eventID)
|
||||||
}
|
}
|
||||||
return result, nil
|
return result, eventIDs, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -100,8 +100,8 @@ type PreviousEvents interface {
|
||||||
type Invites interface {
|
type Invites interface {
|
||||||
InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte) (bool, error)
|
InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte) (bool, error)
|
||||||
UpdateInviteRetired(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) ([]string, error)
|
UpdateInviteRetired(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) ([]string, error)
|
||||||
// SelectInviteActiveForUserInRoom returns a list of sender state key NIDs
|
// SelectInviteActiveForUserInRoom returns a list of sender state key NIDs and invite event IDs matching those nids.
|
||||||
SelectInviteActiveForUserInRoom(ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, error)
|
SelectInviteActiveForUserInRoom(ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, []string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type MembershipState int64
|
type MembershipState int64
|
||||||
|
|
|
@ -157,7 +157,7 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent(
|
||||||
func (s *OutputRoomEventConsumer) onRetireInviteEvent(
|
func (s *OutputRoomEventConsumer) onRetireInviteEvent(
|
||||||
ctx context.Context, msg api.OutputRetireInviteEvent,
|
ctx context.Context, msg api.OutputRetireInviteEvent,
|
||||||
) error {
|
) error {
|
||||||
err := s.db.RetireInviteEvent(ctx, msg.EventID)
|
sp, err := s.db.RetireInviteEvent(ctx, msg.EventID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// panic rather than continue with an inconsistent database
|
// panic rather than continue with an inconsistent database
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
|
@ -166,8 +166,9 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent(
|
||||||
}).Panicf("roomserver output log: remove invite failure")
|
}).Panicf("roomserver output log: remove invite failure")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// TODO: Notify any active sync requests that the invite has been retired.
|
// Notify any active sync requests that the invite has been retired.
|
||||||
// s.notifier.OnNewEvent(nil, msg.TargetUserID, syncStreamPos)
|
// Invites share the same stream counter as PDUs
|
||||||
|
s.notifier.OnNewEvent(nil, "", []string{msg.TargetUserID}, types.NewStreamToken(sp, 0))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -78,9 +78,9 @@ type Database interface {
|
||||||
// If the invite was successfully stored this returns the stream ID it was stored at.
|
// If the invite was successfully stored this returns the stream ID it was stored at.
|
||||||
// Returns an error if there was a problem communicating with the database.
|
// Returns an error if there was a problem communicating with the database.
|
||||||
AddInviteEvent(ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent) (types.StreamPosition, error)
|
AddInviteEvent(ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent) (types.StreamPosition, error)
|
||||||
// RetireInviteEvent removes an old invite event from the database.
|
// RetireInviteEvent removes an old invite event from the database. Returns the new position of the retired invite.
|
||||||
// Returns an error if there was a problem communicating with the database.
|
// Returns an error if there was a problem communicating with the database.
|
||||||
RetireInviteEvent(ctx context.Context, inviteEventID string) error
|
RetireInviteEvent(ctx context.Context, inviteEventID string) (types.StreamPosition, error)
|
||||||
// SetTypingTimeoutCallback sets a callback function that is called right after
|
// SetTypingTimeoutCallback sets a callback function that is called right after
|
||||||
// a user is removed from the typing user list due to timeout.
|
// a user is removed from the typing user list due to timeout.
|
||||||
SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn)
|
SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn)
|
||||||
|
|
|
@ -33,7 +33,8 @@ CREATE TABLE IF NOT EXISTS syncapi_invite_events (
|
||||||
event_id TEXT NOT NULL,
|
event_id TEXT NOT NULL,
|
||||||
room_id TEXT NOT NULL,
|
room_id TEXT NOT NULL,
|
||||||
target_user_id TEXT NOT NULL,
|
target_user_id TEXT NOT NULL,
|
||||||
headered_event_json TEXT NOT NULL
|
headered_event_json TEXT NOT NULL,
|
||||||
|
deleted BOOL NOT NULL
|
||||||
);
|
);
|
||||||
|
|
||||||
-- For looking up the invites for a given user.
|
-- For looking up the invites for a given user.
|
||||||
|
@ -47,14 +48,14 @@ CREATE INDEX IF NOT EXISTS syncapi_invites_event_id_idx
|
||||||
|
|
||||||
const insertInviteEventSQL = "" +
|
const insertInviteEventSQL = "" +
|
||||||
"INSERT INTO syncapi_invite_events (" +
|
"INSERT INTO syncapi_invite_events (" +
|
||||||
" room_id, event_id, target_user_id, headered_event_json" +
|
" room_id, event_id, target_user_id, headered_event_json, deleted" +
|
||||||
") VALUES ($1, $2, $3, $4) RETURNING id"
|
") VALUES ($1, $2, $3, $4, FALSE) RETURNING id"
|
||||||
|
|
||||||
const deleteInviteEventSQL = "" +
|
const deleteInviteEventSQL = "" +
|
||||||
"DELETE FROM syncapi_invite_events WHERE event_id = $1"
|
"UPDATE syncapi_invite_events SET deleted=TRUE, id=nextval('syncapi_stream_id') WHERE event_id = $1 RETURNING id"
|
||||||
|
|
||||||
const selectInviteEventsInRangeSQL = "" +
|
const selectInviteEventsInRangeSQL = "" +
|
||||||
"SELECT room_id, headered_event_json FROM syncapi_invite_events" +
|
"SELECT room_id, headered_event_json, deleted FROM syncapi_invite_events" +
|
||||||
" WHERE target_user_id = $1 AND id > $2 AND id <= $3" +
|
" WHERE target_user_id = $1 AND id > $2 AND id <= $3" +
|
||||||
" ORDER BY id DESC"
|
" ORDER BY id DESC"
|
||||||
|
|
||||||
|
@ -110,40 +111,46 @@ func (s *inviteEventsStatements) InsertInviteEvent(
|
||||||
|
|
||||||
func (s *inviteEventsStatements) DeleteInviteEvent(
|
func (s *inviteEventsStatements) DeleteInviteEvent(
|
||||||
ctx context.Context, inviteEventID string,
|
ctx context.Context, inviteEventID string,
|
||||||
) error {
|
) (sp types.StreamPosition, err error) {
|
||||||
_, err := s.deleteInviteEventStmt.ExecContext(ctx, inviteEventID)
|
err = s.deleteInviteEventStmt.QueryRowContext(ctx, inviteEventID).Scan(&sp)
|
||||||
return err
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// selectInviteEventsInRange returns a map of room ID to invite event for the
|
// selectInviteEventsInRange returns a map of room ID to invite event for the
|
||||||
// active invites for the target user ID in the supplied range.
|
// active invites for the target user ID in the supplied range.
|
||||||
func (s *inviteEventsStatements) SelectInviteEventsInRange(
|
func (s *inviteEventsStatements) SelectInviteEventsInRange(
|
||||||
ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range,
|
ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range,
|
||||||
) (map[string]gomatrixserverlib.HeaderedEvent, error) {
|
) (map[string]gomatrixserverlib.HeaderedEvent, map[string]gomatrixserverlib.HeaderedEvent, error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt)
|
stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt)
|
||||||
rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High())
|
rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed")
|
||||||
result := map[string]gomatrixserverlib.HeaderedEvent{}
|
result := map[string]gomatrixserverlib.HeaderedEvent{}
|
||||||
|
retired := map[string]gomatrixserverlib.HeaderedEvent{}
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var (
|
var (
|
||||||
roomID string
|
roomID string
|
||||||
eventJSON []byte
|
eventJSON []byte
|
||||||
|
deleted bool
|
||||||
)
|
)
|
||||||
if err = rows.Scan(&roomID, &eventJSON); err != nil {
|
if err = rows.Scan(&roomID, &eventJSON, &deleted); err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var event gomatrixserverlib.HeaderedEvent
|
var event gomatrixserverlib.HeaderedEvent
|
||||||
if err := json.Unmarshal(eventJSON, &event); err != nil {
|
if err := json.Unmarshal(eventJSON, &event); err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
result[roomID] = event
|
if deleted {
|
||||||
|
retired[roomID] = event
|
||||||
|
} else {
|
||||||
|
result[roomID] = event
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return result, rows.Err()
|
return result, retired, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *inviteEventsStatements) SelectMaxInviteID(
|
func (s *inviteEventsStatements) SelectMaxInviteID(
|
||||||
|
|
|
@ -180,11 +180,8 @@ func (d *Database) AddInviteEvent(
|
||||||
// Returns an error if there was a problem communicating with the database.
|
// Returns an error if there was a problem communicating with the database.
|
||||||
func (d *Database) RetireInviteEvent(
|
func (d *Database) RetireInviteEvent(
|
||||||
ctx context.Context, inviteEventID string,
|
ctx context.Context, inviteEventID string,
|
||||||
) error {
|
) (types.StreamPosition, error) {
|
||||||
// TODO: Record that invite has been retired in a stream so that we can
|
return d.Invites.DeleteInviteEvent(ctx, inviteEventID)
|
||||||
// notify the user in an incremental sync.
|
|
||||||
err := d.Invites.DeleteInviteEvent(ctx, inviteEventID)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountDataInRange returns all account data for a given user inserted or
|
// GetAccountDataInRange returns all account data for a given user inserted or
|
||||||
|
@ -724,7 +721,7 @@ func (d *Database) addInvitesToResponse(
|
||||||
r types.Range,
|
r types.Range,
|
||||||
res *types.Response,
|
res *types.Response,
|
||||||
) error {
|
) error {
|
||||||
invites, err := d.Invites.SelectInviteEventsInRange(
|
invites, retiredInvites, err := d.Invites.SelectInviteEventsInRange(
|
||||||
ctx, txn, userID, r,
|
ctx, txn, userID, r,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -734,6 +731,10 @@ func (d *Database) addInvitesToResponse(
|
||||||
ir := types.NewInviteResponse(inviteEvent)
|
ir := types.NewInviteResponse(inviteEvent)
|
||||||
res.Rooms.Invite[roomID] = *ir
|
res.Rooms.Invite[roomID] = *ir
|
||||||
}
|
}
|
||||||
|
for roomID := range retiredInvites {
|
||||||
|
lr := types.NewLeaveResponse()
|
||||||
|
res.Rooms.Leave[roomID] = *lr
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,8 @@ CREATE TABLE IF NOT EXISTS syncapi_invite_events (
|
||||||
event_id TEXT NOT NULL,
|
event_id TEXT NOT NULL,
|
||||||
room_id TEXT NOT NULL,
|
room_id TEXT NOT NULL,
|
||||||
target_user_id TEXT NOT NULL,
|
target_user_id TEXT NOT NULL,
|
||||||
headered_event_json TEXT NOT NULL
|
headered_event_json TEXT NOT NULL,
|
||||||
|
deleted BOOL NOT NULL
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS syncapi_invites_target_user_id_idx ON syncapi_invite_events (target_user_id, id);
|
CREATE INDEX IF NOT EXISTS syncapi_invites_target_user_id_idx ON syncapi_invite_events (target_user_id, id);
|
||||||
|
@ -42,14 +43,14 @@ CREATE INDEX IF NOT EXISTS syncapi_invites_event_id_idx ON syncapi_invite_events
|
||||||
|
|
||||||
const insertInviteEventSQL = "" +
|
const insertInviteEventSQL = "" +
|
||||||
"INSERT INTO syncapi_invite_events" +
|
"INSERT INTO syncapi_invite_events" +
|
||||||
" (id, room_id, event_id, target_user_id, headered_event_json)" +
|
" (id, room_id, event_id, target_user_id, headered_event_json, deleted)" +
|
||||||
" VALUES ($1, $2, $3, $4, $5)"
|
" VALUES ($1, $2, $3, $4, $5, false)"
|
||||||
|
|
||||||
const deleteInviteEventSQL = "" +
|
const deleteInviteEventSQL = "" +
|
||||||
"DELETE FROM syncapi_invite_events WHERE event_id = $1"
|
"UPDATE syncapi_invite_events SET deleted=true, id=$1 WHERE event_id = $2"
|
||||||
|
|
||||||
const selectInviteEventsInRangeSQL = "" +
|
const selectInviteEventsInRangeSQL = "" +
|
||||||
"SELECT room_id, headered_event_json FROM syncapi_invite_events" +
|
"SELECT room_id, headered_event_json, deleted FROM syncapi_invite_events" +
|
||||||
" WHERE target_user_id = $1 AND id > $2 AND id <= $3" +
|
" WHERE target_user_id = $1 AND id > $2 AND id <= $3" +
|
||||||
" ORDER BY id DESC"
|
" ORDER BY id DESC"
|
||||||
|
|
||||||
|
@ -114,40 +115,49 @@ func (s *inviteEventsStatements) InsertInviteEvent(
|
||||||
|
|
||||||
func (s *inviteEventsStatements) DeleteInviteEvent(
|
func (s *inviteEventsStatements) DeleteInviteEvent(
|
||||||
ctx context.Context, inviteEventID string,
|
ctx context.Context, inviteEventID string,
|
||||||
) error {
|
) (types.StreamPosition, error) {
|
||||||
_, err := s.deleteInviteEventStmt.ExecContext(ctx, inviteEventID)
|
streamPos, err := s.streamIDStatements.nextStreamID(ctx, nil)
|
||||||
return err
|
if err != nil {
|
||||||
|
return streamPos, err
|
||||||
|
}
|
||||||
|
_, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID)
|
||||||
|
return streamPos, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// selectInviteEventsInRange returns a map of room ID to invite event for the
|
// selectInviteEventsInRange returns a map of room ID to invite event for the
|
||||||
// active invites for the target user ID in the supplied range.
|
// active invites for the target user ID in the supplied range.
|
||||||
func (s *inviteEventsStatements) SelectInviteEventsInRange(
|
func (s *inviteEventsStatements) SelectInviteEventsInRange(
|
||||||
ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range,
|
ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range,
|
||||||
) (map[string]gomatrixserverlib.HeaderedEvent, error) {
|
) (map[string]gomatrixserverlib.HeaderedEvent, map[string]gomatrixserverlib.HeaderedEvent, error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt)
|
stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt)
|
||||||
rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High())
|
rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed")
|
||||||
result := map[string]gomatrixserverlib.HeaderedEvent{}
|
result := map[string]gomatrixserverlib.HeaderedEvent{}
|
||||||
|
retired := map[string]gomatrixserverlib.HeaderedEvent{}
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var (
|
var (
|
||||||
roomID string
|
roomID string
|
||||||
eventJSON []byte
|
eventJSON []byte
|
||||||
|
deleted bool
|
||||||
)
|
)
|
||||||
if err = rows.Scan(&roomID, &eventJSON); err != nil {
|
if err = rows.Scan(&roomID, &eventJSON, &deleted); err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var event gomatrixserverlib.HeaderedEvent
|
var event gomatrixserverlib.HeaderedEvent
|
||||||
if err := json.Unmarshal(eventJSON, &event); err != nil {
|
if err := json.Unmarshal(eventJSON, &event); err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
if deleted {
|
||||||
|
retired[roomID] = event
|
||||||
|
} else {
|
||||||
|
result[roomID] = event
|
||||||
}
|
}
|
||||||
|
|
||||||
result[roomID] = event
|
|
||||||
}
|
}
|
||||||
return result, nil
|
return result, retired, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *inviteEventsStatements) SelectMaxInviteID(
|
func (s *inviteEventsStatements) SelectMaxInviteID(
|
||||||
|
|
|
@ -601,6 +601,83 @@ func TestSendToDeviceBehaviour(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestInviteBehaviour(t *testing.T) {
|
||||||
|
db := MustCreateDatabase(t)
|
||||||
|
inviteRoom1 := "!inviteRoom1:somewhere"
|
||||||
|
inviteEvent1 := MustCreateEvent(t, inviteRoom1, nil, &gomatrixserverlib.EventBuilder{
|
||||||
|
Content: []byte(fmt.Sprintf(`{"membership":"invite"}`)),
|
||||||
|
Type: "m.room.member",
|
||||||
|
StateKey: &testUserIDA,
|
||||||
|
Sender: "@inviteUser1:somewhere",
|
||||||
|
})
|
||||||
|
inviteRoom2 := "!inviteRoom2:somewhere"
|
||||||
|
inviteEvent2 := MustCreateEvent(t, inviteRoom2, nil, &gomatrixserverlib.EventBuilder{
|
||||||
|
Content: []byte(fmt.Sprintf(`{"membership":"invite"}`)),
|
||||||
|
Type: "m.room.member",
|
||||||
|
StateKey: &testUserIDA,
|
||||||
|
Sender: "@inviteUser2:somewhere",
|
||||||
|
})
|
||||||
|
for _, ev := range []gomatrixserverlib.HeaderedEvent{inviteEvent1, inviteEvent2} {
|
||||||
|
_, err := db.AddInviteEvent(ctx, ev)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to AddInviteEvent: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
latest, err := db.SyncPosition(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get SyncPosition: %s", err)
|
||||||
|
}
|
||||||
|
// both invite events should appear in a new sync
|
||||||
|
beforeRetireRes := types.NewResponse()
|
||||||
|
beforeRetireRes, err = db.IncrementalSync(ctx, beforeRetireRes, testUserDeviceA, types.NewStreamToken(0, 0), latest, 0, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IncrementalSync failed: %s", err)
|
||||||
|
}
|
||||||
|
assertInvitedToRooms(t, beforeRetireRes, []string{inviteRoom1, inviteRoom2})
|
||||||
|
|
||||||
|
// retire one event: a fresh sync should just return 1 invite room
|
||||||
|
if _, err = db.RetireInviteEvent(ctx, inviteEvent1.EventID()); err != nil {
|
||||||
|
t.Fatalf("Failed to RetireInviteEvent: %s", err)
|
||||||
|
}
|
||||||
|
latest, err = db.SyncPosition(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get SyncPosition: %s", err)
|
||||||
|
}
|
||||||
|
res := types.NewResponse()
|
||||||
|
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, types.NewStreamToken(0, 0), latest, 0, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IncrementalSync failed: %s", err)
|
||||||
|
}
|
||||||
|
assertInvitedToRooms(t, res, []string{inviteRoom2})
|
||||||
|
|
||||||
|
// a sync after we have received both invites should result in a leave for the retired room
|
||||||
|
beforeRetireTok, err := types.NewStreamTokenFromString(beforeRetireRes.NextBatch)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewStreamTokenFromString cannot parse next batch '%s' : %s", beforeRetireRes.NextBatch, err)
|
||||||
|
}
|
||||||
|
res = types.NewResponse()
|
||||||
|
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, beforeRetireTok, latest, 0, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IncrementalSync failed: %s", err)
|
||||||
|
}
|
||||||
|
assertInvitedToRooms(t, res, []string{})
|
||||||
|
if _, ok := res.Rooms.Leave[inviteRoom1]; !ok {
|
||||||
|
t.Fatalf("IncrementalSync: expected to see room left after it was retired but it wasn't")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertInvitedToRooms(t *testing.T, res *types.Response, roomIDs []string) {
|
||||||
|
t.Helper()
|
||||||
|
if len(res.Rooms.Invite) != len(roomIDs) {
|
||||||
|
t.Fatalf("got %d invited rooms, want %d", len(res.Rooms.Invite), len(roomIDs))
|
||||||
|
}
|
||||||
|
for _, roomID := range roomIDs {
|
||||||
|
if _, ok := res.Rooms.Invite[roomID]; !ok {
|
||||||
|
t.Fatalf("missing room ID %s", roomID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func assertEventsEqual(t *testing.T, msg string, checkRoomID bool, gots []gomatrixserverlib.ClientEvent, wants []gomatrixserverlib.HeaderedEvent) {
|
func assertEventsEqual(t *testing.T, msg string, checkRoomID bool, gots []gomatrixserverlib.ClientEvent, wants []gomatrixserverlib.HeaderedEvent) {
|
||||||
if len(gots) != len(wants) {
|
if len(gots) != len(wants) {
|
||||||
t.Fatalf("%s response returned %d events, want %d", msg, len(gots), len(wants))
|
t.Fatalf("%s response returned %d events, want %d", msg, len(gots), len(wants))
|
||||||
|
|
|
@ -32,9 +32,9 @@ type AccountData interface {
|
||||||
|
|
||||||
type Invites interface {
|
type Invites interface {
|
||||||
InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent) (streamPos types.StreamPosition, err error)
|
InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent) (streamPos types.StreamPosition, err error)
|
||||||
DeleteInviteEvent(ctx context.Context, inviteEventID string) error
|
DeleteInviteEvent(ctx context.Context, inviteEventID string) (types.StreamPosition, error)
|
||||||
// SelectInviteEventsInRange returns a map of room ID to invite events.
|
// SelectInviteEventsInRange returns a map of room ID to invite events.
|
||||||
SelectInviteEventsInRange(ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range) (map[string]gomatrixserverlib.HeaderedEvent, error)
|
SelectInviteEventsInRange(ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range) (invites map[string]gomatrixserverlib.HeaderedEvent, retired map[string]gomatrixserverlib.HeaderedEvent, err error)
|
||||||
SelectMaxInviteID(ctx context.Context, txn *sql.Tx) (id int64, err error)
|
SelectMaxInviteID(ctx context.Context, txn *sql.Tx) (id int64, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -290,10 +290,10 @@ type Response struct {
|
||||||
NextBatch string `json:"next_batch"`
|
NextBatch string `json:"next_batch"`
|
||||||
AccountData struct {
|
AccountData struct {
|
||||||
Events []gomatrixserverlib.ClientEvent `json:"events"`
|
Events []gomatrixserverlib.ClientEvent `json:"events"`
|
||||||
} `json:"account_data"`
|
} `json:"account_data,omitempty"`
|
||||||
Presence struct {
|
Presence struct {
|
||||||
Events []gomatrixserverlib.ClientEvent `json:"events"`
|
Events []gomatrixserverlib.ClientEvent `json:"events"`
|
||||||
} `json:"presence"`
|
} `json:"presence,omitempty"`
|
||||||
Rooms struct {
|
Rooms struct {
|
||||||
Join map[string]JoinResponse `json:"join"`
|
Join map[string]JoinResponse `json:"join"`
|
||||||
Invite map[string]InviteResponse `json:"invite"`
|
Invite map[string]InviteResponse `json:"invite"`
|
||||||
|
|
|
@ -378,3 +378,11 @@ Outbound federation correctly handles unsupported room versions
|
||||||
Remote users may not join unfederated rooms
|
Remote users may not join unfederated rooms
|
||||||
Guest users denied access over federation if guest access prohibited
|
Guest users denied access over federation if guest access prohibited
|
||||||
Non-numeric ports in server names are rejected
|
Non-numeric ports in server names are rejected
|
||||||
|
Invited user can reject invite over federation
|
||||||
|
Invited user can reject invite over federation for empty room
|
||||||
|
Can reject invites over federation for rooms with version 1
|
||||||
|
Can reject invites over federation for rooms with version 2
|
||||||
|
Can reject invites over federation for rooms with version 3
|
||||||
|
Can reject invites over federation for rooms with version 4
|
||||||
|
Can reject invites over federation for rooms with version 5
|
||||||
|
Can reject invites over federation for rooms with version 6
|
||||||
|
|
Loading…
Reference in a new issue