diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go index 8567e617d..13f308998 100644 --- a/clientapi/routing/state.go +++ b/clientapi/routing/state.go @@ -140,11 +140,14 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a // use the result of the previous QueryLatestEventsAndState response // to find the state event, if provided. for _, ev := range stateRes.StateEvents { + sender := spec.UserID{} + userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID()) + if err == nil && userID != nil { + sender = *userID + } stateEvents = append( stateEvents, - synctypes.ToClientEvent(ev, synctypes.FormatAll, func(roomID string, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) - }), + synctypes.ToClientEvent(ev, synctypes.FormatAll, sender), ) } } else { @@ -164,11 +167,14 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a } } for _, ev := range stateAfterRes.StateEvents { + sender := spec.UserID{} + userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID()) + if err == nil && userID != nil { + sender = *userID + } stateEvents = append( stateEvents, - synctypes.ToClientEvent(ev, synctypes.FormatAll, func(roomID string, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) - }), + synctypes.ToClientEvent(ev, synctypes.FormatAll, sender), ) } } @@ -338,10 +344,13 @@ func OnIncomingStateTypeRequest( } } + sender := spec.UserID{} + userID, err := rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + if err == nil && userID != nil { + sender = *userID + } stateEvent := stateEventInStateResp{ - ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID string, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) - }), + ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll, sender), } var res interface{} diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 7329504db..707e95b2a 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -388,9 +388,12 @@ func (r *Queryer) QueryMembershipsForRoom( return fmt.Errorf("r.DB.Events: %w", err) } for _, event := range events { - clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomID, senderID) - }) + sender := spec.UserID{} + userID, queryErr := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + if queryErr == nil && userID != nil { + sender = *userID + } + clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender) response.JoinEvents = append(response.JoinEvents, clientEvent) } return nil @@ -439,9 +442,12 @@ func (r *Queryer) QueryMembershipsForRoom( } for _, event := range events { - clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomID, senderID) - }) + sender := spec.UserID{} + userID, err := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + if err == nil && userID != nil { + sender = *userID + } + clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender) response.JoinEvents = append(response.JoinEvents, clientEvent) } diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index a646622fc..27e99a357 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -217,9 +217,12 @@ func Context( } } - ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) - }) + sender := spec.UserID{} + userID, err := rsAPI.QueryUserIDForSender(ctx, requestedEvent.RoomID(), requestedEvent.SenderID()) + if err == nil && userID != nil { + sender = *userID + } + ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, sender) response := ContextRespsonse{ Event: &ev, EventsAfter: eventsAfterClient, diff --git a/syncapi/routing/getevent.go b/syncapi/routing/getevent.go index 0d47477aa..63df7e837 100644 --- a/syncapi/routing/getevent.go +++ b/syncapi/routing/getevent.go @@ -101,10 +101,13 @@ func GetEvent( } } + sender := spec.UserID{} + senderUserID, err := rsAPI.QueryUserIDForSender(req.Context(), roomID, events[0].SenderID()) + if err == nil && senderUserID != nil { + sender = *senderUserID + } return util.JSONResponse{ Code: http.StatusOK, - JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) - }), + JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender), } } diff --git a/syncapi/routing/relations.go b/syncapi/routing/relations.go index 30c02d293..f21c684c8 100644 --- a/syncapi/routing/relations.go +++ b/syncapi/routing/relations.go @@ -114,11 +114,14 @@ func Relations( // type if it was specified. res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents)) for _, event := range filteredEvents { + sender := spec.UserID{} + userID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), event.SenderID()) + if err == nil && userID != nil { + sender = *userID + } res.Chunk = append( res.Chunk, - synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) - }), + synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender), ) } diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go index e4e173df2..9cf3eabe2 100644 --- a/syncapi/routing/search.go +++ b/syncapi/routing/search.go @@ -205,17 +205,17 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts profileInfos := make(map[string]ProfileInfoResponse) for _, ev := range append(eventsBefore, eventsAfter...) { - userID, err := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID()) - if err != nil { - logrus.WithError(err).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile") + userID, queryErr := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID()) + if queryErr != nil { + logrus.WithError(queryErr).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile") continue } profile, ok := knownUsersProfiles[userID.String()] if !ok { - stateEvent, err := snapshot.GetStateEvent(ctx, ev.RoomID(), spec.MRoomMember, ev.SenderID()) - if err != nil { - logrus.WithError(err).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile") + stateEvent, stateErr := snapshot.GetStateEvent(ctx, ev.RoomID(), spec.MRoomMember, ev.SenderID()) + if stateErr != nil { + logrus.WithError(stateErr).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile") continue } if stateEvent == nil { @@ -230,6 +230,11 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts profileInfos[userID.String()] = profile } + sender := spec.UserID{} + userID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), event.SenderID()) + if err == nil && userID != nil { + sender = *userID + } results = append(results, Result{ Context: SearchContextResponse{ Start: startToken.String(), @@ -242,10 +247,8 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts }), ProfileInfo: profileInfos, }, - Rank: eventScore[event.EventID()].Score, - Result: synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) - }), + Rank: eventScore[event.EventID()].Score, + Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender), }) roomGroup := groups[event.RoomID()] roomGroup.Results = append(roomGroup.Results, event.EventID()) diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go index 0e62f70bf..a8b0a7b66 100644 --- a/syncapi/streams/stream_invite.go +++ b/syncapi/streams/stream_invite.go @@ -64,19 +64,17 @@ func (p *InviteStreamProvider) IncrementalSync( } for roomID, inviteEvent := range invites { - user := "" + user := spec.UserID{} sender, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), inviteEvent.SenderID()) - if err == nil { - user = sender.String() + if err == nil && sender != nil { + user = *sender } // skip ignored user events - if _, ok := req.IgnoredUsers.List[user]; ok { + if _, ok := req.IgnoredUsers.List[user.String()]; ok { continue } - ir := types.NewInviteResponse(inviteEvent, func(roomID, senderID string) (*spec.UserID, error) { - return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) - }) + ir := types.NewInviteResponse(inviteEvent, user) req.Response.Rooms.Invite[roomID] = ir } diff --git a/syncapi/synctypes/clientevent.go b/syncapi/synctypes/clientevent.go index 6ab12102d..66fb1d01f 100644 --- a/syncapi/synctypes/clientevent.go +++ b/syncapi/synctypes/clientevent.go @@ -50,21 +50,21 @@ func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat, if se == nil { continue // TODO: shouldn't happen? } - evs = append(evs, ToClientEvent(se, format, userIDForSender)) + sender := spec.UserID{} + userID, err := userIDForSender(se.RoomID(), se.SenderID()) + if err == nil && userID != nil { + sender = *userID + } + evs = append(evs, ToClientEvent(se, format, sender)) } return evs } // ToClientEvent converts a single server event to a client event. -func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, userIDForSender spec.UserIDForSender) ClientEvent { - user := "" - userID, err := userIDForSender(se.RoomID(), se.SenderID()) - if err == nil { - user = userID.String() - } +func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID) ClientEvent { ce := ClientEvent{ Content: spec.RawJSON(se.Content()), - Sender: user, + Sender: sender.String(), Type: se.Type(), StateKey: se.StateKey(), Unsigned: spec.RawJSON(se.Unsigned()), diff --git a/syncapi/synctypes/clientevent_test.go b/syncapi/synctypes/clientevent_test.go index b76b38c17..341795081 100644 --- a/syncapi/synctypes/clientevent_test.go +++ b/syncapi/synctypes/clientevent_test.go @@ -24,10 +24,6 @@ import ( "github.com/matrix-org/gomatrixserverlib/spec" ) -func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) { - return spec.NewUserID(senderID, true) -} - func TestToClientEvent(t *testing.T) { // nolint: gocyclo ev, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionV1).NewEventFromTrustedJSON([]byte(`{ "type": "m.room.name", @@ -48,7 +44,11 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo if err != nil { t.Fatalf("failed to create Event: %s", err) } - ce := ToClientEvent(ev, FormatAll, UserIDForSender) + userID, err := spec.NewUserID("@test:localhost", true) + if err != nil { + t.Fatalf("failed to create userID: %s", err) + } + ce := ToClientEvent(ev, FormatAll, *userID) if ce.EventID != ev.EventID() { t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID) } @@ -67,13 +67,8 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo if !bytes.Equal(ce.Unsigned, ev.Unsigned()) { t.Errorf("ClientEvent.Unsigned: wanted %s, got %s", string(ev.Unsigned()), string(ce.Unsigned)) } - user := "" - userID, err := UserIDForSender("", ev.SenderID()) - if err == nil { - user = userID.String() - } - if ce.Sender != user { - t.Errorf("ClientEvent.Sender: wanted %s, got %s", user, ce.Sender) + if ce.Sender != userID.String() { + t.Errorf("ClientEvent.Sender: wanted %s, got %s", userID.String(), ce.Sender) } j, err := json.Marshal(ce) if err != nil { @@ -108,7 +103,11 @@ func TestToClientFormatSync(t *testing.T) { if err != nil { t.Fatalf("failed to create Event: %s", err) } - ce := ToClientEvent(ev, FormatSync, UserIDForSender) + userID, err := spec.NewUserID("@test:localhost", true) + if err != nil { + t.Fatalf("failed to create userID: %s", err) + } + ce := ToClientEvent(ev, FormatSync, *userID) if ce.RoomID != "" { t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID) } diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 3b3bd7987..526a120d0 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -539,7 +539,7 @@ type InviteResponse struct { } // NewInviteResponse creates an empty response with initialised arrays. -func NewInviteResponse(event *types.HeaderedEvent, userIDForSender spec.UserIDForSender) *InviteResponse { +func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID) *InviteResponse { res := InviteResponse{} res.InviteState.Events = []json.RawMessage{} @@ -552,7 +552,7 @@ func NewInviteResponse(event *types.HeaderedEvent, userIDForSender spec.UserIDFo // Then we'll see if we can create a partial of the invite event itself. // This is needed for clients to work out *who* sent the invite. - inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userIDForSender) + inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID) inviteEvent.Unsigned = nil if ev, err := json.Marshal(inviteEvent); err == nil { res.InviteState.Events = append(res.InviteState.Events, ev) diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go index a93e9122f..a79ce5417 100644 --- a/syncapi/types/types_test.go +++ b/syncapi/types/types_test.go @@ -61,7 +61,12 @@ func TestNewInviteResponse(t *testing.T) { t.Fatal(err) } - res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, UserIDForSender) + sender, err := spec.NewUserID("@neilalexander:matrix.org", true) + if err != nil { + t.Fatal(err) + } + + res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender) j, err := json.Marshal(res) if err != nil { t.Fatal(err) diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index 3cb44f0fb..c025deee0 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -301,9 +301,12 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst switch { case event.Type() == spec.MRoomMember: - cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { - return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) - }) + sender := spec.UserID{} + userID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + if queryErr == nil && userID != nil { + sender = *userID + } + cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, sender) var member *localMembership member, err = newLocalMembership(&cevent) if err != nil { @@ -531,14 +534,17 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype return fmt.Errorf("s.localPushDevices: %w", err) } + sender := spec.UserID{} + userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + if err == nil && userID != nil { + sender = *userID + } n := &api.Notification{ Actions: actions, // UNSPEC: the spec doesn't say this is a ClientEvent, but the // fields seem to match. room_id should be missing, which // matches the behaviour of FormatSync. - Event: synctypes.ToClientEvent(event, synctypes.FormatSync, func(roomID string, senderID string) (*spec.UserID, error) { - return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) - }), + Event: synctypes.ToClientEvent(event, synctypes.FormatSync, sender), // TODO: this is per-device, but it's not part of the primary // key. So inserting one notification per profile tag doesn't // make sense. What is this supposed to be? Sytests require it diff --git a/userapi/util/notify_test.go b/userapi/util/notify_test.go index 86446b2db..27dd373c2 100644 --- a/userapi/util/notify_test.go +++ b/userapi/util/notify_test.go @@ -23,10 +23,6 @@ import ( userUtil "github.com/matrix-org/dendrite/userapi/util" ) -func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) { - return spec.NewUserID(senderID, true) -} - func TestNotifyUserCountsAsync(t *testing.T) { alice := test.NewUser(t) aliceLocalpart, serverName, err := gomatrixserverlib.SplitID('@', alice.ID) @@ -92,7 +88,7 @@ func TestNotifyUserCountsAsync(t *testing.T) { } // Prepare pusher with our test server URL - if err := db.UpsertPusher(ctx, api.Pusher{ + if err = db.UpsertPusher(ctx, api.Pusher{ Kind: api.HTTPKind, AppID: appID, PushKey: pushKey, @@ -104,8 +100,12 @@ func TestNotifyUserCountsAsync(t *testing.T) { } // Insert a dummy event + sender, err := spec.NewUserID(alice.ID, true) + if err != nil { + t.Error(err) + } if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{ - Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, UserIDForSender), + Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, *sender), }); err != nil { t.Error(err) }