mirror of
https://github.com/matrix-org/dendrite.git
synced 2024-12-01 18:51:56 -06:00
Fix issue with DMs shown as normal rooms (#2776)
Fixes #2121, test added in https://github.com/matrix-org/complement/pull/494
This commit is contained in:
parent
8e231130e9
commit
1ca3f3efb5
|
@ -159,6 +159,7 @@ type PerformJoinRequest struct {
|
||||||
// The sorted list of servers to try. Servers will be tried sequentially, after de-duplication.
|
// The sorted list of servers to try. Servers will be tried sequentially, after de-duplication.
|
||||||
ServerNames types.ServerNames `json:"server_names"`
|
ServerNames types.ServerNames `json:"server_names"`
|
||||||
Content map[string]interface{} `json:"content"`
|
Content map[string]interface{} `json:"content"`
|
||||||
|
Unsigned map[string]interface{} `json:"unsigned"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type PerformJoinResponse struct {
|
type PerformJoinResponse struct {
|
||||||
|
|
|
@ -7,14 +7,15 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/federationapi/api"
|
|
||||||
"github.com/matrix-org/dendrite/federationapi/consumers"
|
|
||||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/version"
|
|
||||||
"github.com/matrix-org/gomatrix"
|
"github.com/matrix-org/gomatrix"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/federationapi/api"
|
||||||
|
"github.com/matrix-org/dendrite/federationapi/consumers"
|
||||||
|
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PerformLeaveRequest implements api.FederationInternalAPI
|
// PerformLeaveRequest implements api.FederationInternalAPI
|
||||||
|
@ -95,6 +96,7 @@ func (r *FederationInternalAPI) PerformJoin(
|
||||||
request.Content,
|
request.Content,
|
||||||
serverName,
|
serverName,
|
||||||
supportedVersions,
|
supportedVersions,
|
||||||
|
request.Unsigned,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
logrus.WithError(err).WithFields(logrus.Fields{
|
logrus.WithError(err).WithFields(logrus.Fields{
|
||||||
"server_name": serverName,
|
"server_name": serverName,
|
||||||
|
@ -139,6 +141,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
|
||||||
content map[string]interface{},
|
content map[string]interface{},
|
||||||
serverName gomatrixserverlib.ServerName,
|
serverName gomatrixserverlib.ServerName,
|
||||||
supportedVersions []gomatrixserverlib.RoomVersion,
|
supportedVersions []gomatrixserverlib.RoomVersion,
|
||||||
|
unsigned map[string]interface{},
|
||||||
) error {
|
) error {
|
||||||
// Try to perform a make_join using the information supplied in the
|
// Try to perform a make_join using the information supplied in the
|
||||||
// request.
|
// request.
|
||||||
|
@ -267,6 +270,14 @@ func (r *FederationInternalAPI) performJoinUsingServer(
|
||||||
// If we successfully performed a send_join above then the other
|
// If we successfully performed a send_join above then the other
|
||||||
// server now thinks we're a part of the room. Send the newly
|
// server now thinks we're a part of the room. Send the newly
|
||||||
// returned state to the roomserver to update our local view.
|
// returned state to the roomserver to update our local view.
|
||||||
|
if unsigned != nil {
|
||||||
|
event, err = event.SetUnsigned(unsigned)
|
||||||
|
if err != nil {
|
||||||
|
// non-fatal, log and continue
|
||||||
|
logrus.WithError(err).Errorf("Failed to set unsigned content")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err = roomserverAPI.SendEventWithState(
|
if err = roomserverAPI.SendEventWithState(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
r.rsAPI,
|
r.rsAPI,
|
||||||
|
|
|
@ -80,6 +80,7 @@ type PerformJoinRequest struct {
|
||||||
UserID string `json:"user_id"`
|
UserID string `json:"user_id"`
|
||||||
Content map[string]interface{} `json:"content"`
|
Content map[string]interface{} `json:"content"`
|
||||||
ServerNames []gomatrixserverlib.ServerName `json:"server_names"`
|
ServerNames []gomatrixserverlib.ServerName `json:"server_names"`
|
||||||
|
Unsigned map[string]interface{} `json:"unsigned"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type PerformJoinResponse struct {
|
type PerformJoinResponse struct {
|
||||||
|
|
|
@ -7,6 +7,9 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/roomserver/auth"
|
"github.com/matrix-org/dendrite/roomserver/auth"
|
||||||
"github.com/matrix-org/dendrite/roomserver/state"
|
"github.com/matrix-org/dendrite/roomserver/state"
|
||||||
|
@ -14,8 +17,6 @@ import (
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
"github.com/matrix-org/util"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: temporary package which has helper functions used by both internal/perform packages.
|
// TODO: temporary package which has helper functions used by both internal/perform packages.
|
||||||
|
@ -97,35 +98,35 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam
|
||||||
func IsInvitePending(
|
func IsInvitePending(
|
||||||
ctx context.Context, db storage.Database,
|
ctx context.Context, db storage.Database,
|
||||||
roomID, userID string,
|
roomID, userID string,
|
||||||
) (bool, string, string, error) {
|
) (bool, string, string, *gomatrixserverlib.Event, error) {
|
||||||
// Look up the room NID for the supplied room ID.
|
// Look up the room NID for the supplied room ID.
|
||||||
info, err := db.RoomInfo(ctx, roomID)
|
info, err := db.RoomInfo(ctx, roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, "", "", fmt.Errorf("r.DB.RoomInfo: %w", err)
|
return false, "", "", nil, fmt.Errorf("r.DB.RoomInfo: %w", err)
|
||||||
}
|
}
|
||||||
if info == nil {
|
if info == nil {
|
||||||
return false, "", "", fmt.Errorf("cannot get RoomInfo: unknown room ID %s", roomID)
|
return false, "", "", nil, fmt.Errorf("cannot get RoomInfo: unknown room ID %s", roomID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Look up the state key NID for the supplied user ID.
|
// Look up the state key NID for the supplied user ID.
|
||||||
targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{userID})
|
targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{userID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, "", "", fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err)
|
return false, "", "", nil, fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err)
|
||||||
}
|
}
|
||||||
targetUserNID, targetUserFound := targetUserNIDs[userID]
|
targetUserNID, targetUserFound := targetUserNIDs[userID]
|
||||||
if !targetUserFound {
|
if !targetUserFound {
|
||||||
return false, "", "", fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs)
|
return false, "", "", nil, fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Let's see if we have an event active for the user in the room. If
|
// Let's see if we have an event active for the user in the room. If
|
||||||
// we do then it will contain a server name that we can direct the
|
// we do then it will contain a server name that we can direct the
|
||||||
// send_leave to.
|
// send_leave to.
|
||||||
senderUserNIDs, eventIDs, err := db.GetInvitesForUser(ctx, info.RoomNID, targetUserNID)
|
senderUserNIDs, eventIDs, eventJSON, err := db.GetInvitesForUser(ctx, info.RoomNID, targetUserNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, "", "", fmt.Errorf("r.DB.GetInvitesForUser: %w", err)
|
return false, "", "", nil, fmt.Errorf("r.DB.GetInvitesForUser: %w", err)
|
||||||
}
|
}
|
||||||
if len(senderUserNIDs) == 0 {
|
if len(senderUserNIDs) == 0 {
|
||||||
return false, "", "", nil
|
return false, "", "", nil, nil
|
||||||
}
|
}
|
||||||
userNIDToEventID := make(map[types.EventStateKeyNID]string)
|
userNIDToEventID := make(map[types.EventStateKeyNID]string)
|
||||||
for i, nid := range senderUserNIDs {
|
for i, nid := range senderUserNIDs {
|
||||||
|
@ -135,18 +136,20 @@ func IsInvitePending(
|
||||||
// Look up the user ID from the NID.
|
// Look up the user ID from the NID.
|
||||||
senderUsers, err := db.EventStateKeys(ctx, senderUserNIDs)
|
senderUsers, err := db.EventStateKeys(ctx, senderUserNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, "", "", fmt.Errorf("r.DB.EventStateKeys: %w", err)
|
return false, "", "", nil, fmt.Errorf("r.DB.EventStateKeys: %w", err)
|
||||||
}
|
}
|
||||||
if len(senderUsers) == 0 {
|
if len(senderUsers) == 0 {
|
||||||
return false, "", "", fmt.Errorf("no senderUsers")
|
return false, "", "", nil, fmt.Errorf("no senderUsers")
|
||||||
}
|
}
|
||||||
|
|
||||||
senderUser, senderUserFound := senderUsers[senderUserNIDs[0]]
|
senderUser, senderUserFound := senderUsers[senderUserNIDs[0]]
|
||||||
if !senderUserFound {
|
if !senderUserFound {
|
||||||
return false, "", "", fmt.Errorf("missing user for NID %d (%+v)", senderUserNIDs[0], senderUsers)
|
return false, "", "", nil, fmt.Errorf("missing user for NID %d (%+v)", senderUserNIDs[0], senderUsers)
|
||||||
}
|
}
|
||||||
|
|
||||||
return true, senderUser, userNIDToEventID[senderUserNIDs[0]], nil
|
event, err := gomatrixserverlib.NewEventFromTrustedJSON(eventJSON, false, info.RoomVersion)
|
||||||
|
|
||||||
|
return true, senderUser, userNIDToEventID[senderUserNIDs[0]], event, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMembershipsAtState filters the state events to
|
// GetMembershipsAtState filters the state events to
|
||||||
|
|
|
@ -22,6 +22,10 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/getsentry/sentry-go"
|
"github.com/getsentry/sentry-go"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
|
||||||
fsAPI "github.com/matrix-org/dendrite/federationapi/api"
|
fsAPI "github.com/matrix-org/dendrite/federationapi/api"
|
||||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
@ -32,8 +36,6 @@ import (
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Joiner struct {
|
type Joiner struct {
|
||||||
|
@ -236,7 +238,7 @@ func (r *Joiner) performJoinRoomByID(
|
||||||
|
|
||||||
// Force a federated join if we're dealing with a pending invite
|
// Force a federated join if we're dealing with a pending invite
|
||||||
// and we aren't in the room.
|
// and we aren't in the room.
|
||||||
isInvitePending, inviteSender, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, req.UserID)
|
isInvitePending, inviteSender, _, inviteEvent, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, req.UserID)
|
||||||
if err == nil && !serverInRoom && isInvitePending {
|
if err == nil && !serverInRoom && isInvitePending {
|
||||||
_, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender)
|
_, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender)
|
||||||
if ierr != nil {
|
if ierr != nil {
|
||||||
|
@ -248,6 +250,17 @@ func (r *Joiner) performJoinRoomByID(
|
||||||
if inviterDomain != r.Cfg.Matrix.ServerName {
|
if inviterDomain != r.Cfg.Matrix.ServerName {
|
||||||
req.ServerNames = append(req.ServerNames, inviterDomain)
|
req.ServerNames = append(req.ServerNames, inviterDomain)
|
||||||
forceFederatedJoin = true
|
forceFederatedJoin = true
|
||||||
|
memberEvent := gjson.Parse(string(inviteEvent.JSON()))
|
||||||
|
// only set unsigned if we've got a content.membership, which we _should_
|
||||||
|
if memberEvent.Get("content.membership").Exists() {
|
||||||
|
req.Unsigned = map[string]interface{}{
|
||||||
|
"prev_sender": memberEvent.Get("sender").Str,
|
||||||
|
"prev_content": map[string]interface{}{
|
||||||
|
"is_direct": memberEvent.Get("content.is_direct").Bool(),
|
||||||
|
"membership": memberEvent.Get("content.membership").Str,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -348,6 +361,7 @@ func (r *Joiner) performFederatedJoinRoomByID(
|
||||||
UserID: req.UserID, // the user ID joining the room
|
UserID: req.UserID, // the user ID joining the room
|
||||||
ServerNames: req.ServerNames, // the server to try joining with
|
ServerNames: req.ServerNames, // the server to try joining with
|
||||||
Content: req.Content, // the membership event content
|
Content: req.Content, // the membership event content
|
||||||
|
Unsigned: req.Unsigned, // the unsigned event content, if any
|
||||||
}
|
}
|
||||||
fedRes := fsAPI.PerformJoinResponse{}
|
fedRes := fsAPI.PerformJoinResponse{}
|
||||||
r.FSAPI.PerformJoin(ctx, &fedReq, &fedRes)
|
r.FSAPI.PerformJoin(ctx, &fedReq, &fedRes)
|
||||||
|
|
|
@ -79,7 +79,7 @@ func (r *Leaver) performLeaveRoomByID(
|
||||||
) ([]api.OutputEvent, error) {
|
) ([]api.OutputEvent, error) {
|
||||||
// If there's an invite outstanding for the room then respond to
|
// If there's an invite outstanding for the room then respond to
|
||||||
// that.
|
// that.
|
||||||
isInvitePending, senderUser, eventID, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.UserID)
|
isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.UserID)
|
||||||
if err == nil && isInvitePending {
|
if err == nil && isInvitePending {
|
||||||
_, senderDomain, serr := gomatrixserverlib.SplitID('@', senderUser)
|
_, senderDomain, serr := gomatrixserverlib.SplitID('@', senderUser)
|
||||||
if serr != nil {
|
if serr != nil {
|
||||||
|
|
|
@ -872,7 +872,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query
|
||||||
// but we don't specify an authorised via user, since the event auth
|
// but we don't specify an authorised via user, since the event auth
|
||||||
// will allow the join anyway.
|
// will allow the join anyway.
|
||||||
var pending bool
|
var pending bool
|
||||||
if pending, _, _, err = helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.UserID); err != nil {
|
if pending, _, _, _, err = helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.UserID); err != nil {
|
||||||
return fmt.Errorf("helpers.IsInvitePending: %w", err)
|
return fmt.Errorf("helpers.IsInvitePending: %w", err)
|
||||||
} else if pending {
|
} else if pending {
|
||||||
res.Allowed = true
|
res.Allowed = true
|
||||||
|
|
|
@ -17,10 +17,11 @@ package storage
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Database interface {
|
type Database interface {
|
||||||
|
@ -104,7 +105,7 @@ type Database interface {
|
||||||
// Look up the active invites targeting a user in a room and return the
|
// Look up the active invites targeting a user in a room and return the
|
||||||
// numeric state key IDs for the user IDs who sent them along with the event IDs for the invites.
|
// numeric state key IDs for the user IDs who sent them along with the event IDs for the invites.
|
||||||
// Returns an error if there was a problem talking to the database.
|
// Returns an error if there was a problem talking to the database.
|
||||||
GetInvitesForUser(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (senderUserIDs []types.EventStateKeyNID, eventIDs []string, err error)
|
GetInvitesForUser(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (senderUserIDs []types.EventStateKeyNID, eventIDs []string, inviteEventJSON []byte, err error)
|
||||||
// Save a given room alias with the room ID it refers to.
|
// Save a given room alias with the room ID it refers to.
|
||||||
// Returns an error if there was a problem talking to the database.
|
// Returns an error if there was a problem talking to the database.
|
||||||
SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error
|
SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error
|
||||||
|
|
|
@ -61,7 +61,7 @@ const insertInviteEventSQL = "" +
|
||||||
" ON CONFLICT DO NOTHING"
|
" ON CONFLICT DO NOTHING"
|
||||||
|
|
||||||
const selectInviteActiveForUserInRoomSQL = "" +
|
const selectInviteActiveForUserInRoomSQL = "" +
|
||||||
"SELECT invite_event_id, sender_nid FROM roomserver_invites" +
|
"SELECT invite_event_id, sender_nid, invite_event_json FROM roomserver_invites" +
|
||||||
" WHERE target_nid = $1 AND room_nid = $2" +
|
" WHERE target_nid = $1 AND room_nid = $2" +
|
||||||
" AND NOT retired"
|
" AND NOT retired"
|
||||||
|
|
||||||
|
@ -141,25 +141,26 @@ func (s *inviteStatements) UpdateInviteRetired(
|
||||||
func (s *inviteStatements) SelectInviteActiveForUserInRoom(
|
func (s *inviteStatements) SelectInviteActiveForUserInRoom(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
|
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
|
||||||
) ([]types.EventStateKeyNID, []string, error) {
|
) ([]types.EventStateKeyNID, []string, []byte, error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt)
|
stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt)
|
||||||
rows, err := stmt.QueryContext(
|
rows, err := stmt.QueryContext(
|
||||||
ctx, targetUserNID, roomNID,
|
ctx, targetUserNID, roomNID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteActiveForUserInRoom: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteActiveForUserInRoom: rows.close() failed")
|
||||||
var result []types.EventStateKeyNID
|
var result []types.EventStateKeyNID
|
||||||
var eventIDs []string
|
var eventIDs []string
|
||||||
var inviteEventID string
|
var inviteEventID string
|
||||||
var senderUserNID int64
|
var senderUserNID int64
|
||||||
|
var eventJSON []byte
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
if err := rows.Scan(&inviteEventID, &senderUserNID); err != nil {
|
if err := rows.Scan(&inviteEventID, &senderUserNID, &eventJSON); err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
result = append(result, types.EventStateKeyNID(senderUserNID))
|
result = append(result, types.EventStateKeyNID(senderUserNID))
|
||||||
eventIDs = append(eventIDs, inviteEventID)
|
eventIDs = append(eventIDs, inviteEventID)
|
||||||
}
|
}
|
||||||
return result, eventIDs, rows.Err()
|
return result, eventIDs, eventJSON, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,13 +7,14 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/caching"
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
"github.com/matrix-org/util"
|
|
||||||
"github.com/tidwall/gjson"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Ideally, when we have both events we should redact the event JSON and forget about the redaction, but we currently
|
// Ideally, when we have both events we should redact the event JSON and forget about the redaction, but we currently
|
||||||
|
@ -445,7 +446,7 @@ func (d *Database) GetInvitesForUser(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
roomNID types.RoomNID,
|
roomNID types.RoomNID,
|
||||||
targetUserNID types.EventStateKeyNID,
|
targetUserNID types.EventStateKeyNID,
|
||||||
) (senderUserIDs []types.EventStateKeyNID, eventIDs []string, err error) {
|
) (senderUserIDs []types.EventStateKeyNID, eventIDs []string, inviteEventJSON []byte, err error) {
|
||||||
return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID)
|
return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -44,7 +44,7 @@ const insertInviteEventSQL = "" +
|
||||||
" ON CONFLICT DO NOTHING"
|
" ON CONFLICT DO NOTHING"
|
||||||
|
|
||||||
const selectInviteActiveForUserInRoomSQL = "" +
|
const selectInviteActiveForUserInRoomSQL = "" +
|
||||||
"SELECT invite_event_id, sender_nid FROM roomserver_invites" +
|
"SELECT invite_event_id, sender_nid, invite_event_json FROM roomserver_invites" +
|
||||||
" WHERE target_nid = $1 AND room_nid = $2" +
|
" WHERE target_nid = $1 AND room_nid = $2" +
|
||||||
" AND NOT retired"
|
" AND NOT retired"
|
||||||
|
|
||||||
|
@ -136,25 +136,26 @@ func (s *inviteStatements) UpdateInviteRetired(
|
||||||
func (s *inviteStatements) SelectInviteActiveForUserInRoom(
|
func (s *inviteStatements) SelectInviteActiveForUserInRoom(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
|
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
|
||||||
) ([]types.EventStateKeyNID, []string, error) {
|
) ([]types.EventStateKeyNID, []string, []byte, error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt)
|
stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt)
|
||||||
rows, err := stmt.QueryContext(
|
rows, err := stmt.QueryContext(
|
||||||
ctx, targetUserNID, roomNID,
|
ctx, targetUserNID, roomNID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteActiveForUserInRoom: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteActiveForUserInRoom: rows.close() failed")
|
||||||
var result []types.EventStateKeyNID
|
var result []types.EventStateKeyNID
|
||||||
var eventIDs []string
|
var eventIDs []string
|
||||||
var eventID string
|
var eventID string
|
||||||
var senderUserNID int64
|
var senderUserNID int64
|
||||||
|
var eventJSON []byte
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
if err := rows.Scan(&eventID, &senderUserNID); err != nil {
|
if err := rows.Scan(&eventID, &senderUserNID, &eventJSON); err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
result = append(result, types.EventStateKeyNID(senderUserNID))
|
result = append(result, types.EventStateKeyNID(senderUserNID))
|
||||||
eventIDs = append(eventIDs, eventID)
|
eventIDs = append(eventIDs, eventID)
|
||||||
}
|
}
|
||||||
return result, eventIDs, nil
|
return result, eventIDs, eventJSON, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -116,7 +116,7 @@ type Invites interface {
|
||||||
InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte) (bool, error)
|
InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte) (bool, error)
|
||||||
UpdateInviteRetired(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) ([]string, error)
|
UpdateInviteRetired(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) ([]string, error)
|
||||||
// SelectInviteActiveForUserInRoom returns a list of sender state key NIDs and invite event IDs matching those nids.
|
// SelectInviteActiveForUserInRoom returns a list of sender state key NIDs and invite event IDs matching those nids.
|
||||||
SelectInviteActiveForUserInRoom(ctx context.Context, txn *sql.Tx, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, []string, error)
|
SelectInviteActiveForUserInRoom(ctx context.Context, txn *sql.Tx, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, []string, []byte, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type MembershipState int64
|
type MembershipState int64
|
||||||
|
|
|
@ -4,6 +4,9 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/postgres"
|
"github.com/matrix-org/dendrite/roomserver/storage/postgres"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
|
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
|
||||||
|
@ -11,8 +14,6 @@ import (
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/dendrite/test"
|
"github.com/matrix-org/dendrite/test"
|
||||||
"github.com/matrix-org/util"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func mustCreateInviteTable(t *testing.T, dbType test.DBType) (tables.Invites, func()) {
|
func mustCreateInviteTable(t *testing.T, dbType test.DBType) (tables.Invites, func()) {
|
||||||
|
@ -67,7 +68,7 @@ func TestInviteTable(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.True(t, newInvite)
|
assert.True(t, newInvite)
|
||||||
|
|
||||||
stateKeyNIDs, eventIDs, err := tab.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID)
|
stateKeyNIDs, eventIDs, _, err := tab.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, []string{eventID1, eventID2}, eventIDs)
|
assert.Equal(t, []string{eventID1, eventID2}, eventIDs)
|
||||||
assert.Equal(t, []types.EventStateKeyNID{2, 2}, stateKeyNIDs)
|
assert.Equal(t, []types.EventStateKeyNID{2, 2}, stateKeyNIDs)
|
||||||
|
@ -78,13 +79,13 @@ func TestInviteTable(t *testing.T) {
|
||||||
assert.Equal(t, []string{eventID1, eventID2}, retiredEventIDs)
|
assert.Equal(t, []string{eventID1, eventID2}, retiredEventIDs)
|
||||||
|
|
||||||
// This should now be empty
|
// This should now be empty
|
||||||
stateKeyNIDs, eventIDs, err = tab.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID)
|
stateKeyNIDs, eventIDs, _, err = tab.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Empty(t, eventIDs)
|
assert.Empty(t, eventIDs)
|
||||||
assert.Empty(t, stateKeyNIDs)
|
assert.Empty(t, stateKeyNIDs)
|
||||||
|
|
||||||
// Non-existent targetUserNID
|
// Non-existent targetUserNID
|
||||||
stateKeyNIDs, eventIDs, err = tab.SelectInviteActiveForUserInRoom(ctx, nil, types.EventStateKeyNID(10), roomNID)
|
stateKeyNIDs, eventIDs, _, err = tab.SelectInviteActiveForUserInRoom(ctx, nil, types.EventStateKeyNID(10), roomNID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Empty(t, stateKeyNIDs)
|
assert.Empty(t, stateKeyNIDs)
|
||||||
assert.Empty(t, eventIDs)
|
assert.Empty(t, eventIDs)
|
||||||
|
|
Loading…
Reference in a new issue