diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go index 359a3cca3..d74002928 100644 --- a/clientapi/routing/state.go +++ b/clientapi/routing/state.go @@ -140,23 +140,11 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a // use the result of the previous QueryLatestEventsAndState response // to find the state event, if provided. for _, ev := range stateRes.StateEvents { - sender := spec.UserID{} - userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID()) - if err == nil && userID != nil { - sender = *userID - } - - sk := ev.StateKey() - if sk != nil && *sk != "" { - skUserID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey())) - if err == nil && skUserID != nil { - skString := skUserID.String() - sk = &skString - } - } stateEvents = append( stateEvents, - synctypes.ToClientEvent(ev, synctypes.FormatAll, sender, sk), + synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }, ev), ) } } else { @@ -362,22 +350,10 @@ func OnIncomingStateTypeRequest( } } - sender := spec.UserID{} - userID, err := rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) - if err == nil && userID != nil { - sender = *userID - } - - sk := event.StateKey() - if sk != nil && *sk != "" { - skUserID, err := rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey())) - if err == nil && skUserID != nil { - skString := skUserID.String() - sk = &skString - } - } stateEvent := stateEventInStateResp{ - ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll, sender, sk), + ClientEvent: synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }, event), } var res interface{} diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index a1dde341b..08af41c1e 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -388,21 +388,9 @@ func (r *Queryer) QueryMembershipsForRoom( return fmt.Errorf("r.DB.Events: %w", err) } for _, event := range events { - sender := spec.UserID{} - userID, queryErr := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) - if queryErr == nil && userID != nil { - sender = *userID - } - - sk := event.StateKey() - if sk != nil && *sk != "" { - skUserID, err := r.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey())) - if err == nil && skUserID != nil { - skString := skUserID.String() - sk = &skString - } - } - clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender, sk) + clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return r.QueryUserIDForSender(ctx, roomID, senderID) + }, event) response.JoinEvents = append(response.JoinEvents, clientEvent) } return nil @@ -451,21 +439,9 @@ func (r *Queryer) QueryMembershipsForRoom( } for _, event := range events { - sender := spec.UserID{} - userID, err := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) - if err == nil && userID != nil { - sender = *userID - } - - sk := event.StateKey() - if sk != nil && *sk != "" { - skUserID, err := r.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey())) - if err == nil && skUserID != nil { - skString := skUserID.String() - sk = &skString - } - } - clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender, sk) + clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return r.QueryUserIDForSender(ctx, roomID, senderID) + }, event) response.JoinEvents = append(response.JoinEvents, clientEvent) } diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index d402e468c..144baff1e 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -217,21 +217,9 @@ func Context( } } - sender := spec.UserID{} - userID, err := rsAPI.QueryUserIDForSender(ctx, requestedEvent.RoomID(), requestedEvent.SenderID()) - if err == nil && userID != nil { - sender = *userID - } - - sk := requestedEvent.StateKey() - if sk != nil && *sk != "" { - skUserID, err := rsAPI.QueryUserIDForSender(ctx, requestedEvent.RoomID(), spec.SenderID(*requestedEvent.StateKey())) - if err == nil && skUserID != nil { - skString := skUserID.String() - sk = &skString - } - } - ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, sender, sk) + ev := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }, requestedEvent) response := ContextRespsonse{ Event: &ev, EventsAfter: eventsAfterClient, diff --git a/syncapi/synctypes/clientevent.go b/syncapi/synctypes/clientevent.go index 06f52175b..358a0c971 100644 --- a/syncapi/synctypes/clientevent.go +++ b/syncapi/synctypes/clientevent.go @@ -86,3 +86,23 @@ func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender sp } return ce } + +// ToClientEvent converts a single server event to a client event. +// It provides default logic for event.SenderID & event.StateKey -> userID conversions. +func ToClientEventDefault(userIDQuery spec.UserIDForSender, event gomatrixserverlib.PDU) ClientEvent { + sender := spec.UserID{} + userID, err := userIDQuery(event.RoomID(), event.SenderID()) + if err == nil && userID != nil { + sender = *userID + } + + sk := event.StateKey() + if sk != nil && *sk != "" { + skUserID, err := userIDQuery(event.RoomID(), spec.SenderID(*event.StateKey())) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } + return ToClientEvent(event, FormatAll, sender, sk) +} diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index bf63c69dc..a293b0850 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -309,8 +309,8 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst sk := event.StateKey() if sk != nil && *sk != "" { - skUserID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey())) - if err == nil && skUserID != nil { + skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey())) + if queryErr == nil && skUserID != nil { skString := skUserID.String() sk = &skString } @@ -551,8 +551,8 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype sk := event.StateKey() if sk != nil && *sk != "" { - skUserID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey())) - if err == nil && skUserID != nil { + skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey())) + if queryErr == nil && skUserID != nil { skString := skUserID.String() sk = &skString }