Fix issue with DMs shown as normal rooms (#2776)

Fixes #2121, test added in
https://github.com/matrix-org/complement/pull/494
This commit is contained in:
Till 2022-10-07 16:00:12 +02:00 committed by GitHub
parent 8e231130e9
commit 1ca3f3efb5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 82 additions and 47 deletions

View file

@ -159,6 +159,7 @@ type PerformJoinRequest struct {
// The sorted list of servers to try. Servers will be tried sequentially, after de-duplication. // The sorted list of servers to try. Servers will be tried sequentially, after de-duplication.
ServerNames types.ServerNames `json:"server_names"` ServerNames types.ServerNames `json:"server_names"`
Content map[string]interface{} `json:"content"` Content map[string]interface{} `json:"content"`
Unsigned map[string]interface{} `json:"unsigned"`
} }
type PerformJoinResponse struct { type PerformJoinResponse struct {

View file

@ -7,14 +7,15 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/consumers"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/version"
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/consumers"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/version"
) )
// PerformLeaveRequest implements api.FederationInternalAPI // PerformLeaveRequest implements api.FederationInternalAPI
@ -95,6 +96,7 @@ func (r *FederationInternalAPI) PerformJoin(
request.Content, request.Content,
serverName, serverName,
supportedVersions, supportedVersions,
request.Unsigned,
); err != nil { ); err != nil {
logrus.WithError(err).WithFields(logrus.Fields{ logrus.WithError(err).WithFields(logrus.Fields{
"server_name": serverName, "server_name": serverName,
@ -139,6 +141,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
content map[string]interface{}, content map[string]interface{},
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
supportedVersions []gomatrixserverlib.RoomVersion, supportedVersions []gomatrixserverlib.RoomVersion,
unsigned map[string]interface{},
) error { ) error {
// Try to perform a make_join using the information supplied in the // Try to perform a make_join using the information supplied in the
// request. // request.
@ -267,6 +270,14 @@ func (r *FederationInternalAPI) performJoinUsingServer(
// If we successfully performed a send_join above then the other // If we successfully performed a send_join above then the other
// server now thinks we're a part of the room. Send the newly // server now thinks we're a part of the room. Send the newly
// returned state to the roomserver to update our local view. // returned state to the roomserver to update our local view.
if unsigned != nil {
event, err = event.SetUnsigned(unsigned)
if err != nil {
// non-fatal, log and continue
logrus.WithError(err).Errorf("Failed to set unsigned content")
}
}
if err = roomserverAPI.SendEventWithState( if err = roomserverAPI.SendEventWithState(
context.Background(), context.Background(),
r.rsAPI, r.rsAPI,

View file

@ -80,6 +80,7 @@ type PerformJoinRequest struct {
UserID string `json:"user_id"` UserID string `json:"user_id"`
Content map[string]interface{} `json:"content"` Content map[string]interface{} `json:"content"`
ServerNames []gomatrixserverlib.ServerName `json:"server_names"` ServerNames []gomatrixserverlib.ServerName `json:"server_names"`
Unsigned map[string]interface{} `json:"unsigned"`
} }
type PerformJoinResponse struct { type PerformJoinResponse struct {

View file

@ -7,6 +7,9 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/auth" "github.com/matrix-org/dendrite/roomserver/auth"
"github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/state"
@ -14,8 +17,6 @@ import (
"github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
) )
// TODO: temporary package which has helper functions used by both internal/perform packages. // TODO: temporary package which has helper functions used by both internal/perform packages.
@ -97,35 +98,35 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam
func IsInvitePending( func IsInvitePending(
ctx context.Context, db storage.Database, ctx context.Context, db storage.Database,
roomID, userID string, roomID, userID string,
) (bool, string, string, error) { ) (bool, string, string, *gomatrixserverlib.Event, error) {
// Look up the room NID for the supplied room ID. // Look up the room NID for the supplied room ID.
info, err := db.RoomInfo(ctx, roomID) info, err := db.RoomInfo(ctx, roomID)
if err != nil { if err != nil {
return false, "", "", fmt.Errorf("r.DB.RoomInfo: %w", err) return false, "", "", nil, fmt.Errorf("r.DB.RoomInfo: %w", err)
} }
if info == nil { if info == nil {
return false, "", "", fmt.Errorf("cannot get RoomInfo: unknown room ID %s", roomID) return false, "", "", nil, fmt.Errorf("cannot get RoomInfo: unknown room ID %s", roomID)
} }
// Look up the state key NID for the supplied user ID. // Look up the state key NID for the supplied user ID.
targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{userID}) targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{userID})
if err != nil { if err != nil {
return false, "", "", fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err) return false, "", "", nil, fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err)
} }
targetUserNID, targetUserFound := targetUserNIDs[userID] targetUserNID, targetUserFound := targetUserNIDs[userID]
if !targetUserFound { if !targetUserFound {
return false, "", "", fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs) return false, "", "", nil, fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs)
} }
// Let's see if we have an event active for the user in the room. If // Let's see if we have an event active for the user in the room. If
// we do then it will contain a server name that we can direct the // we do then it will contain a server name that we can direct the
// send_leave to. // send_leave to.
senderUserNIDs, eventIDs, err := db.GetInvitesForUser(ctx, info.RoomNID, targetUserNID) senderUserNIDs, eventIDs, eventJSON, err := db.GetInvitesForUser(ctx, info.RoomNID, targetUserNID)
if err != nil { if err != nil {
return false, "", "", fmt.Errorf("r.DB.GetInvitesForUser: %w", err) return false, "", "", nil, fmt.Errorf("r.DB.GetInvitesForUser: %w", err)
} }
if len(senderUserNIDs) == 0 { if len(senderUserNIDs) == 0 {
return false, "", "", nil return false, "", "", nil, nil
} }
userNIDToEventID := make(map[types.EventStateKeyNID]string) userNIDToEventID := make(map[types.EventStateKeyNID]string)
for i, nid := range senderUserNIDs { for i, nid := range senderUserNIDs {
@ -135,18 +136,20 @@ func IsInvitePending(
// Look up the user ID from the NID. // Look up the user ID from the NID.
senderUsers, err := db.EventStateKeys(ctx, senderUserNIDs) senderUsers, err := db.EventStateKeys(ctx, senderUserNIDs)
if err != nil { if err != nil {
return false, "", "", fmt.Errorf("r.DB.EventStateKeys: %w", err) return false, "", "", nil, fmt.Errorf("r.DB.EventStateKeys: %w", err)
} }
if len(senderUsers) == 0 { if len(senderUsers) == 0 {
return false, "", "", fmt.Errorf("no senderUsers") return false, "", "", nil, fmt.Errorf("no senderUsers")
} }
senderUser, senderUserFound := senderUsers[senderUserNIDs[0]] senderUser, senderUserFound := senderUsers[senderUserNIDs[0]]
if !senderUserFound { if !senderUserFound {
return false, "", "", fmt.Errorf("missing user for NID %d (%+v)", senderUserNIDs[0], senderUsers) return false, "", "", nil, fmt.Errorf("missing user for NID %d (%+v)", senderUserNIDs[0], senderUsers)
} }
return true, senderUser, userNIDToEventID[senderUserNIDs[0]], nil event, err := gomatrixserverlib.NewEventFromTrustedJSON(eventJSON, false, info.RoomVersion)
return true, senderUser, userNIDToEventID[senderUserNIDs[0]], event, err
} }
// GetMembershipsAtState filters the state events to // GetMembershipsAtState filters the state events to

View file

@ -22,6 +22,10 @@ import (
"time" "time"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
fsAPI "github.com/matrix-org/dendrite/federationapi/api" fsAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
@ -32,8 +36,6 @@ import (
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
) )
type Joiner struct { type Joiner struct {
@ -236,7 +238,7 @@ func (r *Joiner) performJoinRoomByID(
// Force a federated join if we're dealing with a pending invite // Force a federated join if we're dealing with a pending invite
// and we aren't in the room. // and we aren't in the room.
isInvitePending, inviteSender, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, req.UserID) isInvitePending, inviteSender, _, inviteEvent, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, req.UserID)
if err == nil && !serverInRoom && isInvitePending { if err == nil && !serverInRoom && isInvitePending {
_, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender) _, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender)
if ierr != nil { if ierr != nil {
@ -248,6 +250,17 @@ func (r *Joiner) performJoinRoomByID(
if inviterDomain != r.Cfg.Matrix.ServerName { if inviterDomain != r.Cfg.Matrix.ServerName {
req.ServerNames = append(req.ServerNames, inviterDomain) req.ServerNames = append(req.ServerNames, inviterDomain)
forceFederatedJoin = true forceFederatedJoin = true
memberEvent := gjson.Parse(string(inviteEvent.JSON()))
// only set unsigned if we've got a content.membership, which we _should_
if memberEvent.Get("content.membership").Exists() {
req.Unsigned = map[string]interface{}{
"prev_sender": memberEvent.Get("sender").Str,
"prev_content": map[string]interface{}{
"is_direct": memberEvent.Get("content.is_direct").Bool(),
"membership": memberEvent.Get("content.membership").Str,
},
}
}
} }
} }
@ -348,6 +361,7 @@ func (r *Joiner) performFederatedJoinRoomByID(
UserID: req.UserID, // the user ID joining the room UserID: req.UserID, // the user ID joining the room
ServerNames: req.ServerNames, // the server to try joining with ServerNames: req.ServerNames, // the server to try joining with
Content: req.Content, // the membership event content Content: req.Content, // the membership event content
Unsigned: req.Unsigned, // the unsigned event content, if any
} }
fedRes := fsAPI.PerformJoinResponse{} fedRes := fsAPI.PerformJoinResponse{}
r.FSAPI.PerformJoin(ctx, &fedReq, &fedRes) r.FSAPI.PerformJoin(ctx, &fedReq, &fedRes)

View file

@ -79,7 +79,7 @@ func (r *Leaver) performLeaveRoomByID(
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
// If there's an invite outstanding for the room then respond to // If there's an invite outstanding for the room then respond to
// that. // that.
isInvitePending, senderUser, eventID, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.UserID) isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.UserID)
if err == nil && isInvitePending { if err == nil && isInvitePending {
_, senderDomain, serr := gomatrixserverlib.SplitID('@', senderUser) _, senderDomain, serr := gomatrixserverlib.SplitID('@', senderUser)
if serr != nil { if serr != nil {

View file

@ -872,7 +872,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query
// but we don't specify an authorised via user, since the event auth // but we don't specify an authorised via user, since the event auth
// will allow the join anyway. // will allow the join anyway.
var pending bool var pending bool
if pending, _, _, err = helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.UserID); err != nil { if pending, _, _, _, err = helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.UserID); err != nil {
return fmt.Errorf("helpers.IsInvitePending: %w", err) return fmt.Errorf("helpers.IsInvitePending: %w", err)
} else if pending { } else if pending {
res.Allowed = true res.Allowed = true

View file

@ -17,10 +17,11 @@ package storage
import ( import (
"context" "context"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
) )
type Database interface { type Database interface {
@ -104,7 +105,7 @@ type Database interface {
// Look up the active invites targeting a user in a room and return the // Look up the active invites targeting a user in a room and return the
// numeric state key IDs for the user IDs who sent them along with the event IDs for the invites. // numeric state key IDs for the user IDs who sent them along with the event IDs for the invites.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
GetInvitesForUser(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (senderUserIDs []types.EventStateKeyNID, eventIDs []string, err error) GetInvitesForUser(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (senderUserIDs []types.EventStateKeyNID, eventIDs []string, inviteEventJSON []byte, err error)
// Save a given room alias with the room ID it refers to. // Save a given room alias with the room ID it refers to.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error

View file

@ -61,7 +61,7 @@ const insertInviteEventSQL = "" +
" ON CONFLICT DO NOTHING" " ON CONFLICT DO NOTHING"
const selectInviteActiveForUserInRoomSQL = "" + const selectInviteActiveForUserInRoomSQL = "" +
"SELECT invite_event_id, sender_nid FROM roomserver_invites" + "SELECT invite_event_id, sender_nid, invite_event_json FROM roomserver_invites" +
" WHERE target_nid = $1 AND room_nid = $2" + " WHERE target_nid = $1 AND room_nid = $2" +
" AND NOT retired" " AND NOT retired"
@ -141,25 +141,26 @@ func (s *inviteStatements) UpdateInviteRetired(
func (s *inviteStatements) SelectInviteActiveForUserInRoom( func (s *inviteStatements) SelectInviteActiveForUserInRoom(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
) ([]types.EventStateKeyNID, []string, error) { ) ([]types.EventStateKeyNID, []string, []byte, error) {
stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt) stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt)
rows, err := stmt.QueryContext( rows, err := stmt.QueryContext(
ctx, targetUserNID, roomNID, ctx, targetUserNID, roomNID,
) )
if err != nil { if err != nil {
return nil, nil, err return nil, nil, nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteActiveForUserInRoom: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectInviteActiveForUserInRoom: rows.close() failed")
var result []types.EventStateKeyNID var result []types.EventStateKeyNID
var eventIDs []string var eventIDs []string
var inviteEventID string var inviteEventID string
var senderUserNID int64 var senderUserNID int64
var eventJSON []byte
for rows.Next() { for rows.Next() {
if err := rows.Scan(&inviteEventID, &senderUserNID); err != nil { if err := rows.Scan(&inviteEventID, &senderUserNID, &eventJSON); err != nil {
return nil, nil, err return nil, nil, nil, err
} }
result = append(result, types.EventStateKeyNID(senderUserNID)) result = append(result, types.EventStateKeyNID(senderUserNID))
eventIDs = append(eventIDs, inviteEventID) eventIDs = append(eventIDs, inviteEventID)
} }
return result, eventIDs, rows.Err() return result, eventIDs, eventJSON, rows.Err()
} }

View file

@ -7,13 +7,14 @@ import (
"fmt" "fmt"
"sort" "sort"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/tidwall/gjson"
) )
// Ideally, when we have both events we should redact the event JSON and forget about the redaction, but we currently // Ideally, when we have both events we should redact the event JSON and forget about the redaction, but we currently
@ -445,7 +446,7 @@ func (d *Database) GetInvitesForUser(
ctx context.Context, ctx context.Context,
roomNID types.RoomNID, roomNID types.RoomNID,
targetUserNID types.EventStateKeyNID, targetUserNID types.EventStateKeyNID,
) (senderUserIDs []types.EventStateKeyNID, eventIDs []string, err error) { ) (senderUserIDs []types.EventStateKeyNID, eventIDs []string, inviteEventJSON []byte, err error) {
return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID) return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID)
} }

View file

@ -44,7 +44,7 @@ const insertInviteEventSQL = "" +
" ON CONFLICT DO NOTHING" " ON CONFLICT DO NOTHING"
const selectInviteActiveForUserInRoomSQL = "" + const selectInviteActiveForUserInRoomSQL = "" +
"SELECT invite_event_id, sender_nid FROM roomserver_invites" + "SELECT invite_event_id, sender_nid, invite_event_json FROM roomserver_invites" +
" WHERE target_nid = $1 AND room_nid = $2" + " WHERE target_nid = $1 AND room_nid = $2" +
" AND NOT retired" " AND NOT retired"
@ -136,25 +136,26 @@ func (s *inviteStatements) UpdateInviteRetired(
func (s *inviteStatements) SelectInviteActiveForUserInRoom( func (s *inviteStatements) SelectInviteActiveForUserInRoom(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
) ([]types.EventStateKeyNID, []string, error) { ) ([]types.EventStateKeyNID, []string, []byte, error) {
stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt) stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt)
rows, err := stmt.QueryContext( rows, err := stmt.QueryContext(
ctx, targetUserNID, roomNID, ctx, targetUserNID, roomNID,
) )
if err != nil { if err != nil {
return nil, nil, err return nil, nil, nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteActiveForUserInRoom: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectInviteActiveForUserInRoom: rows.close() failed")
var result []types.EventStateKeyNID var result []types.EventStateKeyNID
var eventIDs []string var eventIDs []string
var eventID string var eventID string
var senderUserNID int64 var senderUserNID int64
var eventJSON []byte
for rows.Next() { for rows.Next() {
if err := rows.Scan(&eventID, &senderUserNID); err != nil { if err := rows.Scan(&eventID, &senderUserNID, &eventJSON); err != nil {
return nil, nil, err return nil, nil, nil, err
} }
result = append(result, types.EventStateKeyNID(senderUserNID)) result = append(result, types.EventStateKeyNID(senderUserNID))
eventIDs = append(eventIDs, eventID) eventIDs = append(eventIDs, eventID)
} }
return result, eventIDs, nil return result, eventIDs, eventJSON, nil
} }

View file

@ -116,7 +116,7 @@ type Invites interface {
InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte) (bool, error) InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte) (bool, error)
UpdateInviteRetired(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) ([]string, error) UpdateInviteRetired(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) ([]string, error)
// SelectInviteActiveForUserInRoom returns a list of sender state key NIDs and invite event IDs matching those nids. // SelectInviteActiveForUserInRoom returns a list of sender state key NIDs and invite event IDs matching those nids.
SelectInviteActiveForUserInRoom(ctx context.Context, txn *sql.Tx, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, []string, error) SelectInviteActiveForUserInRoom(ctx context.Context, txn *sql.Tx, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, []string, []byte, error)
} }
type MembershipState int64 type MembershipState int64

View file

@ -4,6 +4,9 @@ import (
"context" "context"
"testing" "testing"
"github.com/matrix-org/util"
"github.com/stretchr/testify/assert"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/postgres" "github.com/matrix-org/dendrite/roomserver/storage/postgres"
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3" "github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
@ -11,8 +14,6 @@ import (
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
"github.com/matrix-org/util"
"github.com/stretchr/testify/assert"
) )
func mustCreateInviteTable(t *testing.T, dbType test.DBType) (tables.Invites, func()) { func mustCreateInviteTable(t *testing.T, dbType test.DBType) (tables.Invites, func()) {
@ -67,7 +68,7 @@ func TestInviteTable(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, newInvite) assert.True(t, newInvite)
stateKeyNIDs, eventIDs, err := tab.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID) stateKeyNIDs, eventIDs, _, err := tab.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, []string{eventID1, eventID2}, eventIDs) assert.Equal(t, []string{eventID1, eventID2}, eventIDs)
assert.Equal(t, []types.EventStateKeyNID{2, 2}, stateKeyNIDs) assert.Equal(t, []types.EventStateKeyNID{2, 2}, stateKeyNIDs)
@ -78,13 +79,13 @@ func TestInviteTable(t *testing.T) {
assert.Equal(t, []string{eventID1, eventID2}, retiredEventIDs) assert.Equal(t, []string{eventID1, eventID2}, retiredEventIDs)
// This should now be empty // This should now be empty
stateKeyNIDs, eventIDs, err = tab.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID) stateKeyNIDs, eventIDs, _, err = tab.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID)
assert.NoError(t, err) assert.NoError(t, err)
assert.Empty(t, eventIDs) assert.Empty(t, eventIDs)
assert.Empty(t, stateKeyNIDs) assert.Empty(t, stateKeyNIDs)
// Non-existent targetUserNID // Non-existent targetUserNID
stateKeyNIDs, eventIDs, err = tab.SelectInviteActiveForUserInRoom(ctx, nil, types.EventStateKeyNID(10), roomNID) stateKeyNIDs, eventIDs, _, err = tab.SelectInviteActiveForUserInRoom(ctx, nil, types.EventStateKeyNID(10), roomNID)
assert.NoError(t, err) assert.NoError(t, err)
assert.Empty(t, stateKeyNIDs) assert.Empty(t, stateKeyNIDs)
assert.Empty(t, eventIDs) assert.Empty(t, eventIDs)