ToClientEvent now directly uses the provided userID for sender field

This commit is contained in:
Devon Hudson 2023-06-06 14:11:21 -06:00
parent c2aac0f19e
commit 34edfff85c
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
13 changed files with 112 additions and 77 deletions

View file

@ -140,11 +140,14 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
// use the result of the previous QueryLatestEventsAndState response
// to find the state event, if provided.
for _, ev := range stateRes.StateEvents {
sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID())
if err == nil && userID != nil {
sender = *userID
}
stateEvents = append(
stateEvents,
synctypes.ToClientEvent(ev, synctypes.FormatAll, func(roomID string, senderID string) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}),
synctypes.ToClientEvent(ev, synctypes.FormatAll, sender),
)
}
} else {
@ -164,11 +167,14 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
}
}
for _, ev := range stateAfterRes.StateEvents {
sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID())
if err == nil && userID != nil {
sender = *userID
}
stateEvents = append(
stateEvents,
synctypes.ToClientEvent(ev, synctypes.FormatAll, func(roomID string, senderID string) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}),
synctypes.ToClientEvent(ev, synctypes.FormatAll, sender),
)
}
}
@ -338,10 +344,13 @@ func OnIncomingStateTypeRequest(
}
}
sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err == nil && userID != nil {
sender = *userID
}
stateEvent := stateEventInStateResp{
ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID string, senderID string) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}),
ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll, sender),
}
var res interface{}

View file

@ -388,9 +388,12 @@ func (r *Queryer) QueryMembershipsForRoom(
return fmt.Errorf("r.DB.Events: %w", err)
}
for _, event := range events {
clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
})
sender := spec.UserID{}
userID, queryErr := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
if queryErr == nil && userID != nil {
sender = *userID
}
clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender)
response.JoinEvents = append(response.JoinEvents, clientEvent)
}
return nil
@ -439,9 +442,12 @@ func (r *Queryer) QueryMembershipsForRoom(
}
for _, event := range events {
clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
})
sender := spec.UserID{}
userID, err := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err == nil && userID != nil {
sender = *userID
}
clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender)
response.JoinEvents = append(response.JoinEvents, clientEvent)
}

View file

@ -217,9 +217,12 @@ func Context(
}
}
ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(ctx, requestedEvent.RoomID(), requestedEvent.SenderID())
if err == nil && userID != nil {
sender = *userID
}
ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, sender)
response := ContextRespsonse{
Event: &ev,
EventsAfter: eventsAfterClient,

View file

@ -101,10 +101,13 @@ func GetEvent(
}
}
sender := spec.UserID{}
senderUserID, err := rsAPI.QueryUserIDForSender(req.Context(), roomID, events[0].SenderID())
if err == nil && senderUserID != nil {
sender = *senderUserID
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
}),
JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender),
}
}

View file

@ -114,11 +114,14 @@ func Relations(
// type if it was specified.
res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents))
for _, event := range filteredEvents {
sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), event.SenderID())
if err == nil && userID != nil {
sender = *userID
}
res.Chunk = append(
res.Chunk,
synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
}),
synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender),
)
}

View file

@ -205,17 +205,17 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
profileInfos := make(map[string]ProfileInfoResponse)
for _, ev := range append(eventsBefore, eventsAfter...) {
userID, err := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID())
if err != nil {
logrus.WithError(err).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile")
userID, queryErr := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID())
if queryErr != nil {
logrus.WithError(queryErr).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile")
continue
}
profile, ok := knownUsersProfiles[userID.String()]
if !ok {
stateEvent, err := snapshot.GetStateEvent(ctx, ev.RoomID(), spec.MRoomMember, ev.SenderID())
if err != nil {
logrus.WithError(err).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile")
stateEvent, stateErr := snapshot.GetStateEvent(ctx, ev.RoomID(), spec.MRoomMember, ev.SenderID())
if stateErr != nil {
logrus.WithError(stateErr).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile")
continue
}
if stateEvent == nil {
@ -230,6 +230,11 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
profileInfos[userID.String()] = profile
}
sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), event.SenderID())
if err == nil && userID != nil {
sender = *userID
}
results = append(results, Result{
Context: SearchContextResponse{
Start: startToken.String(),
@ -242,10 +247,8 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
}),
ProfileInfo: profileInfos,
},
Rank: eventScore[event.EventID()].Score,
Result: synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
}),
Rank: eventScore[event.EventID()].Score,
Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender),
})
roomGroup := groups[event.RoomID()]
roomGroup.Results = append(roomGroup.Results, event.EventID())

View file

@ -64,19 +64,17 @@ func (p *InviteStreamProvider) IncrementalSync(
}
for roomID, inviteEvent := range invites {
user := ""
user := spec.UserID{}
sender, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), inviteEvent.SenderID())
if err == nil {
user = sender.String()
if err == nil && sender != nil {
user = *sender
}
// skip ignored user events
if _, ok := req.IgnoredUsers.List[user]; ok {
if _, ok := req.IgnoredUsers.List[user.String()]; ok {
continue
}
ir := types.NewInviteResponse(inviteEvent, func(roomID, senderID string) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
ir := types.NewInviteResponse(inviteEvent, user)
req.Response.Rooms.Invite[roomID] = ir
}

View file

@ -50,21 +50,21 @@ func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat,
if se == nil {
continue // TODO: shouldn't happen?
}
evs = append(evs, ToClientEvent(se, format, userIDForSender))
sender := spec.UserID{}
userID, err := userIDForSender(se.RoomID(), se.SenderID())
if err == nil && userID != nil {
sender = *userID
}
evs = append(evs, ToClientEvent(se, format, sender))
}
return evs
}
// ToClientEvent converts a single server event to a client event.
func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, userIDForSender spec.UserIDForSender) ClientEvent {
user := ""
userID, err := userIDForSender(se.RoomID(), se.SenderID())
if err == nil {
user = userID.String()
}
func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID) ClientEvent {
ce := ClientEvent{
Content: spec.RawJSON(se.Content()),
Sender: user,
Sender: sender.String(),
Type: se.Type(),
StateKey: se.StateKey(),
Unsigned: spec.RawJSON(se.Unsigned()),

View file

@ -24,10 +24,6 @@ import (
"github.com/matrix-org/gomatrixserverlib/spec"
)
func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) {
return spec.NewUserID(senderID, true)
}
func TestToClientEvent(t *testing.T) { // nolint: gocyclo
ev, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionV1).NewEventFromTrustedJSON([]byte(`{
"type": "m.room.name",
@ -48,7 +44,11 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo
if err != nil {
t.Fatalf("failed to create Event: %s", err)
}
ce := ToClientEvent(ev, FormatAll, UserIDForSender)
userID, err := spec.NewUserID("@test:localhost", true)
if err != nil {
t.Fatalf("failed to create userID: %s", err)
}
ce := ToClientEvent(ev, FormatAll, *userID)
if ce.EventID != ev.EventID() {
t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID)
}
@ -67,13 +67,8 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo
if !bytes.Equal(ce.Unsigned, ev.Unsigned()) {
t.Errorf("ClientEvent.Unsigned: wanted %s, got %s", string(ev.Unsigned()), string(ce.Unsigned))
}
user := ""
userID, err := UserIDForSender("", ev.SenderID())
if err == nil {
user = userID.String()
}
if ce.Sender != user {
t.Errorf("ClientEvent.Sender: wanted %s, got %s", user, ce.Sender)
if ce.Sender != userID.String() {
t.Errorf("ClientEvent.Sender: wanted %s, got %s", userID.String(), ce.Sender)
}
j, err := json.Marshal(ce)
if err != nil {
@ -108,7 +103,11 @@ func TestToClientFormatSync(t *testing.T) {
if err != nil {
t.Fatalf("failed to create Event: %s", err)
}
ce := ToClientEvent(ev, FormatSync, UserIDForSender)
userID, err := spec.NewUserID("@test:localhost", true)
if err != nil {
t.Fatalf("failed to create userID: %s", err)
}
ce := ToClientEvent(ev, FormatSync, *userID)
if ce.RoomID != "" {
t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID)
}

View file

@ -539,7 +539,7 @@ type InviteResponse struct {
}
// NewInviteResponse creates an empty response with initialised arrays.
func NewInviteResponse(event *types.HeaderedEvent, userIDForSender spec.UserIDForSender) *InviteResponse {
func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID) *InviteResponse {
res := InviteResponse{}
res.InviteState.Events = []json.RawMessage{}
@ -552,7 +552,7 @@ func NewInviteResponse(event *types.HeaderedEvent, userIDForSender spec.UserIDFo
// Then we'll see if we can create a partial of the invite event itself.
// This is needed for clients to work out *who* sent the invite.
inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userIDForSender)
inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID)
inviteEvent.Unsigned = nil
if ev, err := json.Marshal(inviteEvent); err == nil {
res.InviteState.Events = append(res.InviteState.Events, ev)

View file

@ -61,7 +61,12 @@ func TestNewInviteResponse(t *testing.T) {
t.Fatal(err)
}
res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, UserIDForSender)
sender, err := spec.NewUserID("@neilalexander:matrix.org", true)
if err != nil {
t.Fatal(err)
}
res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender)
j, err := json.Marshal(res)
if err != nil {
t.Fatal(err)

View file

@ -301,9 +301,12 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst
switch {
case event.Type() == spec.MRoomMember:
cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) {
return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
sender := spec.UserID{}
userID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
if queryErr == nil && userID != nil {
sender = *userID
}
cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, sender)
var member *localMembership
member, err = newLocalMembership(&cevent)
if err != nil {
@ -531,14 +534,17 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype
return fmt.Errorf("s.localPushDevices: %w", err)
}
sender := spec.UserID{}
userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err == nil && userID != nil {
sender = *userID
}
n := &api.Notification{
Actions: actions,
// UNSPEC: the spec doesn't say this is a ClientEvent, but the
// fields seem to match. room_id should be missing, which
// matches the behaviour of FormatSync.
Event: synctypes.ToClientEvent(event, synctypes.FormatSync, func(roomID string, senderID string) (*spec.UserID, error) {
return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}),
Event: synctypes.ToClientEvent(event, synctypes.FormatSync, sender),
// TODO: this is per-device, but it's not part of the primary
// key. So inserting one notification per profile tag doesn't
// make sense. What is this supposed to be? Sytests require it

View file

@ -23,10 +23,6 @@ import (
userUtil "github.com/matrix-org/dendrite/userapi/util"
)
func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) {
return spec.NewUserID(senderID, true)
}
func TestNotifyUserCountsAsync(t *testing.T) {
alice := test.NewUser(t)
aliceLocalpart, serverName, err := gomatrixserverlib.SplitID('@', alice.ID)
@ -92,7 +88,7 @@ func TestNotifyUserCountsAsync(t *testing.T) {
}
// Prepare pusher with our test server URL
if err := db.UpsertPusher(ctx, api.Pusher{
if err = db.UpsertPusher(ctx, api.Pusher{
Kind: api.HTTPKind,
AppID: appID,
PushKey: pushKey,
@ -104,8 +100,12 @@ func TestNotifyUserCountsAsync(t *testing.T) {
}
// Insert a dummy event
sender, err := spec.NewUserID(alice.ID, true)
if err != nil {
t.Error(err)
}
if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{
Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, UserIDForSender),
Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, *sender),
}); err != nil {
t.Error(err)
}