ToClientEvent now directly uses the provided userID for sender field

This commit is contained in:
Devon Hudson 2023-06-06 14:11:21 -06:00
parent c2aac0f19e
commit 34edfff85c
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
13 changed files with 112 additions and 77 deletions

View file

@ -140,11 +140,14 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
// use the result of the previous QueryLatestEventsAndState response // use the result of the previous QueryLatestEventsAndState response
// to find the state event, if provided. // to find the state event, if provided.
for _, ev := range stateRes.StateEvents { for _, ev := range stateRes.StateEvents {
sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID())
if err == nil && userID != nil {
sender = *userID
}
stateEvents = append( stateEvents = append(
stateEvents, stateEvents,
synctypes.ToClientEvent(ev, synctypes.FormatAll, func(roomID string, senderID string) (*spec.UserID, error) { synctypes.ToClientEvent(ev, synctypes.FormatAll, sender),
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}),
) )
} }
} else { } else {
@ -164,11 +167,14 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
} }
} }
for _, ev := range stateAfterRes.StateEvents { for _, ev := range stateAfterRes.StateEvents {
sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID())
if err == nil && userID != nil {
sender = *userID
}
stateEvents = append( stateEvents = append(
stateEvents, stateEvents,
synctypes.ToClientEvent(ev, synctypes.FormatAll, func(roomID string, senderID string) (*spec.UserID, error) { synctypes.ToClientEvent(ev, synctypes.FormatAll, sender),
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}),
) )
} }
} }
@ -338,10 +344,13 @@ func OnIncomingStateTypeRequest(
} }
} }
sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err == nil && userID != nil {
sender = *userID
}
stateEvent := stateEventInStateResp{ stateEvent := stateEventInStateResp{
ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID string, senderID string) (*spec.UserID, error) { ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll, sender),
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}),
} }
var res interface{} var res interface{}

View file

@ -388,9 +388,12 @@ func (r *Queryer) QueryMembershipsForRoom(
return fmt.Errorf("r.DB.Events: %w", err) return fmt.Errorf("r.DB.Events: %w", err)
} }
for _, event := range events { for _, event := range events {
clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { sender := spec.UserID{}
return r.DB.GetUserIDForSender(ctx, roomID, senderID) userID, queryErr := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
}) if queryErr == nil && userID != nil {
sender = *userID
}
clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender)
response.JoinEvents = append(response.JoinEvents, clientEvent) response.JoinEvents = append(response.JoinEvents, clientEvent)
} }
return nil return nil
@ -439,9 +442,12 @@ func (r *Queryer) QueryMembershipsForRoom(
} }
for _, event := range events { for _, event := range events {
clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { sender := spec.UserID{}
return r.DB.GetUserIDForSender(ctx, roomID, senderID) userID, err := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
}) if err == nil && userID != nil {
sender = *userID
}
clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender)
response.JoinEvents = append(response.JoinEvents, clientEvent) response.JoinEvents = append(response.JoinEvents, clientEvent)
} }

View file

@ -217,9 +217,12 @@ func Context(
} }
} }
ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { sender := spec.UserID{}
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) userID, err := rsAPI.QueryUserIDForSender(ctx, requestedEvent.RoomID(), requestedEvent.SenderID())
}) if err == nil && userID != nil {
sender = *userID
}
ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, sender)
response := ContextRespsonse{ response := ContextRespsonse{
Event: &ev, Event: &ev,
EventsAfter: eventsAfterClient, EventsAfter: eventsAfterClient,

View file

@ -101,10 +101,13 @@ func GetEvent(
} }
} }
sender := spec.UserID{}
senderUserID, err := rsAPI.QueryUserIDForSender(req.Context(), roomID, events[0].SenderID())
if err == nil && senderUserID != nil {
sender = *senderUserID
}
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender),
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
}),
} }
} }

View file

@ -114,11 +114,14 @@ func Relations(
// type if it was specified. // type if it was specified.
res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents)) res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents))
for _, event := range filteredEvents { for _, event := range filteredEvents {
sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), event.SenderID())
if err == nil && userID != nil {
sender = *userID
}
res.Chunk = append( res.Chunk = append(
res.Chunk, res.Chunk,
synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender),
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
}),
) )
} }

View file

@ -205,17 +205,17 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
profileInfos := make(map[string]ProfileInfoResponse) profileInfos := make(map[string]ProfileInfoResponse)
for _, ev := range append(eventsBefore, eventsAfter...) { for _, ev := range append(eventsBefore, eventsAfter...) {
userID, err := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID()) userID, queryErr := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID())
if err != nil { if queryErr != nil {
logrus.WithError(err).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile") logrus.WithError(queryErr).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile")
continue continue
} }
profile, ok := knownUsersProfiles[userID.String()] profile, ok := knownUsersProfiles[userID.String()]
if !ok { if !ok {
stateEvent, err := snapshot.GetStateEvent(ctx, ev.RoomID(), spec.MRoomMember, ev.SenderID()) stateEvent, stateErr := snapshot.GetStateEvent(ctx, ev.RoomID(), spec.MRoomMember, ev.SenderID())
if err != nil { if stateErr != nil {
logrus.WithError(err).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile") logrus.WithError(stateErr).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile")
continue continue
} }
if stateEvent == nil { if stateEvent == nil {
@ -230,6 +230,11 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
profileInfos[userID.String()] = profile profileInfos[userID.String()] = profile
} }
sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), event.SenderID())
if err == nil && userID != nil {
sender = *userID
}
results = append(results, Result{ results = append(results, Result{
Context: SearchContextResponse{ Context: SearchContextResponse{
Start: startToken.String(), Start: startToken.String(),
@ -243,9 +248,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
ProfileInfo: profileInfos, ProfileInfo: profileInfos,
}, },
Rank: eventScore[event.EventID()].Score, Rank: eventScore[event.EventID()].Score,
Result: synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender),
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
}),
}) })
roomGroup := groups[event.RoomID()] roomGroup := groups[event.RoomID()]
roomGroup.Results = append(roomGroup.Results, event.EventID()) roomGroup.Results = append(roomGroup.Results, event.EventID())

View file

@ -64,19 +64,17 @@ func (p *InviteStreamProvider) IncrementalSync(
} }
for roomID, inviteEvent := range invites { for roomID, inviteEvent := range invites {
user := "" user := spec.UserID{}
sender, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), inviteEvent.SenderID()) sender, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), inviteEvent.SenderID())
if err == nil { if err == nil && sender != nil {
user = sender.String() user = *sender
} }
// skip ignored user events // skip ignored user events
if _, ok := req.IgnoredUsers.List[user]; ok { if _, ok := req.IgnoredUsers.List[user.String()]; ok {
continue continue
} }
ir := types.NewInviteResponse(inviteEvent, func(roomID, senderID string) (*spec.UserID, error) { ir := types.NewInviteResponse(inviteEvent, user)
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
req.Response.Rooms.Invite[roomID] = ir req.Response.Rooms.Invite[roomID] = ir
} }

View file

@ -50,21 +50,21 @@ func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat,
if se == nil { if se == nil {
continue // TODO: shouldn't happen? continue // TODO: shouldn't happen?
} }
evs = append(evs, ToClientEvent(se, format, userIDForSender)) sender := spec.UserID{}
userID, err := userIDForSender(se.RoomID(), se.SenderID())
if err == nil && userID != nil {
sender = *userID
}
evs = append(evs, ToClientEvent(se, format, sender))
} }
return evs return evs
} }
// ToClientEvent converts a single server event to a client event. // ToClientEvent converts a single server event to a client event.
func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, userIDForSender spec.UserIDForSender) ClientEvent { func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID) ClientEvent {
user := ""
userID, err := userIDForSender(se.RoomID(), se.SenderID())
if err == nil {
user = userID.String()
}
ce := ClientEvent{ ce := ClientEvent{
Content: spec.RawJSON(se.Content()), Content: spec.RawJSON(se.Content()),
Sender: user, Sender: sender.String(),
Type: se.Type(), Type: se.Type(),
StateKey: se.StateKey(), StateKey: se.StateKey(),
Unsigned: spec.RawJSON(se.Unsigned()), Unsigned: spec.RawJSON(se.Unsigned()),

View file

@ -24,10 +24,6 @@ import (
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
) )
func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) {
return spec.NewUserID(senderID, true)
}
func TestToClientEvent(t *testing.T) { // nolint: gocyclo func TestToClientEvent(t *testing.T) { // nolint: gocyclo
ev, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionV1).NewEventFromTrustedJSON([]byte(`{ ev, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionV1).NewEventFromTrustedJSON([]byte(`{
"type": "m.room.name", "type": "m.room.name",
@ -48,7 +44,11 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo
if err != nil { if err != nil {
t.Fatalf("failed to create Event: %s", err) t.Fatalf("failed to create Event: %s", err)
} }
ce := ToClientEvent(ev, FormatAll, UserIDForSender) userID, err := spec.NewUserID("@test:localhost", true)
if err != nil {
t.Fatalf("failed to create userID: %s", err)
}
ce := ToClientEvent(ev, FormatAll, *userID)
if ce.EventID != ev.EventID() { if ce.EventID != ev.EventID() {
t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID) t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID)
} }
@ -67,13 +67,8 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo
if !bytes.Equal(ce.Unsigned, ev.Unsigned()) { if !bytes.Equal(ce.Unsigned, ev.Unsigned()) {
t.Errorf("ClientEvent.Unsigned: wanted %s, got %s", string(ev.Unsigned()), string(ce.Unsigned)) t.Errorf("ClientEvent.Unsigned: wanted %s, got %s", string(ev.Unsigned()), string(ce.Unsigned))
} }
user := "" if ce.Sender != userID.String() {
userID, err := UserIDForSender("", ev.SenderID()) t.Errorf("ClientEvent.Sender: wanted %s, got %s", userID.String(), ce.Sender)
if err == nil {
user = userID.String()
}
if ce.Sender != user {
t.Errorf("ClientEvent.Sender: wanted %s, got %s", user, ce.Sender)
} }
j, err := json.Marshal(ce) j, err := json.Marshal(ce)
if err != nil { if err != nil {
@ -108,7 +103,11 @@ func TestToClientFormatSync(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("failed to create Event: %s", err) t.Fatalf("failed to create Event: %s", err)
} }
ce := ToClientEvent(ev, FormatSync, UserIDForSender) userID, err := spec.NewUserID("@test:localhost", true)
if err != nil {
t.Fatalf("failed to create userID: %s", err)
}
ce := ToClientEvent(ev, FormatSync, *userID)
if ce.RoomID != "" { if ce.RoomID != "" {
t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID) t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID)
} }

View file

@ -539,7 +539,7 @@ type InviteResponse struct {
} }
// NewInviteResponse creates an empty response with initialised arrays. // NewInviteResponse creates an empty response with initialised arrays.
func NewInviteResponse(event *types.HeaderedEvent, userIDForSender spec.UserIDForSender) *InviteResponse { func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID) *InviteResponse {
res := InviteResponse{} res := InviteResponse{}
res.InviteState.Events = []json.RawMessage{} res.InviteState.Events = []json.RawMessage{}
@ -552,7 +552,7 @@ func NewInviteResponse(event *types.HeaderedEvent, userIDForSender spec.UserIDFo
// Then we'll see if we can create a partial of the invite event itself. // Then we'll see if we can create a partial of the invite event itself.
// This is needed for clients to work out *who* sent the invite. // This is needed for clients to work out *who* sent the invite.
inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userIDForSender) inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID)
inviteEvent.Unsigned = nil inviteEvent.Unsigned = nil
if ev, err := json.Marshal(inviteEvent); err == nil { if ev, err := json.Marshal(inviteEvent); err == nil {
res.InviteState.Events = append(res.InviteState.Events, ev) res.InviteState.Events = append(res.InviteState.Events, ev)

View file

@ -61,7 +61,12 @@ func TestNewInviteResponse(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, UserIDForSender) sender, err := spec.NewUserID("@neilalexander:matrix.org", true)
if err != nil {
t.Fatal(err)
}
res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender)
j, err := json.Marshal(res) j, err := json.Marshal(res)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View file

@ -301,9 +301,12 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst
switch { switch {
case event.Type() == spec.MRoomMember: case event.Type() == spec.MRoomMember:
cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { sender := spec.UserID{}
return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) userID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
}) if queryErr == nil && userID != nil {
sender = *userID
}
cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, sender)
var member *localMembership var member *localMembership
member, err = newLocalMembership(&cevent) member, err = newLocalMembership(&cevent)
if err != nil { if err != nil {
@ -531,14 +534,17 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype
return fmt.Errorf("s.localPushDevices: %w", err) return fmt.Errorf("s.localPushDevices: %w", err)
} }
sender := spec.UserID{}
userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err == nil && userID != nil {
sender = *userID
}
n := &api.Notification{ n := &api.Notification{
Actions: actions, Actions: actions,
// UNSPEC: the spec doesn't say this is a ClientEvent, but the // UNSPEC: the spec doesn't say this is a ClientEvent, but the
// fields seem to match. room_id should be missing, which // fields seem to match. room_id should be missing, which
// matches the behaviour of FormatSync. // matches the behaviour of FormatSync.
Event: synctypes.ToClientEvent(event, synctypes.FormatSync, func(roomID string, senderID string) (*spec.UserID, error) { Event: synctypes.ToClientEvent(event, synctypes.FormatSync, sender),
return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}),
// TODO: this is per-device, but it's not part of the primary // TODO: this is per-device, but it's not part of the primary
// key. So inserting one notification per profile tag doesn't // key. So inserting one notification per profile tag doesn't
// make sense. What is this supposed to be? Sytests require it // make sense. What is this supposed to be? Sytests require it

View file

@ -23,10 +23,6 @@ import (
userUtil "github.com/matrix-org/dendrite/userapi/util" userUtil "github.com/matrix-org/dendrite/userapi/util"
) )
func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) {
return spec.NewUserID(senderID, true)
}
func TestNotifyUserCountsAsync(t *testing.T) { func TestNotifyUserCountsAsync(t *testing.T) {
alice := test.NewUser(t) alice := test.NewUser(t)
aliceLocalpart, serverName, err := gomatrixserverlib.SplitID('@', alice.ID) aliceLocalpart, serverName, err := gomatrixserverlib.SplitID('@', alice.ID)
@ -92,7 +88,7 @@ func TestNotifyUserCountsAsync(t *testing.T) {
} }
// Prepare pusher with our test server URL // Prepare pusher with our test server URL
if err := db.UpsertPusher(ctx, api.Pusher{ if err = db.UpsertPusher(ctx, api.Pusher{
Kind: api.HTTPKind, Kind: api.HTTPKind,
AppID: appID, AppID: appID,
PushKey: pushKey, PushKey: pushKey,
@ -104,8 +100,12 @@ func TestNotifyUserCountsAsync(t *testing.T) {
} }
// Insert a dummy event // Insert a dummy event
sender, err := spec.NewUserID(alice.ID, true)
if err != nil {
t.Error(err)
}
if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{ if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{
Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, UserIDForSender), Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, *sender),
}); err != nil { }); err != nil {
t.Error(err) t.Error(err)
} }