diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go index d7f0b40f8..c8b0e921b 100644 --- a/clientapi/routing/state.go +++ b/clientapi/routing/state.go @@ -191,9 +191,16 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a sk = &skString } } + clientEvent, err := synctypes.ToClientEvent(ev, synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }, sender.String(), sk) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("Failed converting to ClientEvent") + continue + } stateEvents = append( stateEvents, - synctypes.ToClientEvent(ev, synctypes.FormatAll, sender.String(), sk, ev.Unsigned()), + *clientEvent, ) } } diff --git a/internal/eventutil/events.go b/internal/eventutil/events.go index aa99e5860..5a0d53a6e 100644 --- a/internal/eventutil/events.go +++ b/internal/eventutil/events.go @@ -184,7 +184,13 @@ func RedactEvent(ctx context.Context, redactionEvent, redactedEvent gomatrixserv if err != nil { return err } - redactedBecause := synctypes.ToClientEvent(redactionEvent, synctypes.FormatSync, senderID.String(), redactionEvent.StateKey(), redactionEvent.Unsigned()) + clientEvent, err := synctypes.ToClientEvent(redactionEvent, synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return querier.QueryUserIDForSender(ctx, roomID, senderID) + }, senderID.String(), redactionEvent.StateKey()) + if err != nil { + return err + } + redactedBecause := clientEvent if err := redactedEvent.SetUnsignedField("redacted_because", redactedBecause); err != nil { return err } diff --git a/syncapi/routing/getevent.go b/syncapi/routing/getevent.go index bf0f9bf8c..e2879dd25 100644 --- a/syncapi/routing/getevent.go +++ b/syncapi/routing/getevent.go @@ -142,8 +142,20 @@ func GetEvent( sk = &skString } } + + clientEvent, err := synctypes.ToClientEvent(events[0], synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }, senderUserID.String(), sk) + if err != nil { + util.GetLogger(req.Context()).WithError(err).WithField("senderID", events[0].SenderID()).WithField("roomID", *roomID).Error("Failed converting to ClientEvent") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.Unknown("internal server error"), + } + } + return util.JSONResponse{ Code: http.StatusOK, - JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, senderUserID.String(), sk, events[0].Unsigned()), + JSON: *clientEvent, } } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 3333cb54d..4c1824ff2 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -416,6 +416,17 @@ func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserv start = *r.from + for _, ev := range filteredEvents { + if ev.Version() != gomatrixserverlib.RoomVersionPseudoIDs { + continue + } + if ev.Type() != spec.MRoomPowerLevels || !ev.StateKeyEquals("") { + continue + } + // TODO: update power levels + // TODO: same thing for /event endpoint? + } + return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), start, end, nil diff --git a/syncapi/routing/relations.go b/syncapi/routing/relations.go index b451a7e2e..418796f7c 100644 --- a/syncapi/routing/relations.go +++ b/syncapi/routing/relations.go @@ -144,9 +144,17 @@ func Relations( sk = &skString } } + + clientEvent, err := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) + }, sender.String(), sk) + if err != nil { + util.GetLogger(req.Context()).WithError(err).WithField("senderID", events[0].SenderID()).WithField("roomID", *roomID).Error("Failed converting to ClientEvent") + continue + } res.Chunk = append( res.Chunk, - synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender.String(), sk, event.Unsigned()), + *clientEvent, ) } diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go index 7d5c061b7..257d8ca06 100644 --- a/syncapi/routing/search.go +++ b/syncapi/routing/search.go @@ -254,6 +254,15 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts sk = &skString } } + + clientEvent, err := synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }, sender.String(), sk) + if err != nil { + util.GetLogger(req.Context()).WithError(err).WithField("senderID", event.SenderID()).WithField("roomID", *validRoomID).Error("Failed converting to ClientEvent") + continue + } + results = append(results, Result{ Context: SearchContextResponse{ Start: startToken.String(), @@ -267,7 +276,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts ProfileInfo: profileInfos, }, Rank: eventScore[event.EventID()].Score, - Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender.String(), sk, event.Unsigned()), + Result: *clientEvent, }) roomGroup := groups[event.RoomID()] roomGroup.Results = append(roomGroup.Results, event.EventID()) diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go index 1ce3346f4..b5d5f280a 100644 --- a/syncapi/streams/stream_invite.go +++ b/syncapi/streams/stream_invite.go @@ -92,7 +92,11 @@ func (p *InviteStreamProvider) IncrementalSync( if _, ok := req.IgnoredUsers.List[user.String()]; ok { continue } - ir := types.NewInviteResponse(inviteEvent, user, sk, eventFormat) + ir, err := types.NewInviteResponse(ctx, p.rsAPI, inviteEvent, user, sk, eventFormat) + if err != nil { + req.Log.WithError(err).Error("failed creating invite response") + continue + } req.Response.Rooms.Invite[roomID] = ir } diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index ee524f726..3abb0b3c6 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -3,7 +3,6 @@ package streams import ( "context" "database/sql" - "encoding/json" "fmt" "time" @@ -16,8 +15,6 @@ import ( "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib/spec" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/gomatrixserverlib" @@ -359,23 +356,6 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( // Now that we've filtered the timeline, work out which state events are still // left. Anything that appears in the filtered timeline will be removed from the // "state" section and kept in "timeline". - - // update the powerlevel event for timeline events - for i, ev := range events { - if ev.Version() != gomatrixserverlib.RoomVersionPseudoIDs { - continue - } - if ev.Type() != spec.MRoomPowerLevels || !ev.StateKeyEquals("") { - continue - } - var newEvent gomatrixserverlib.PDU - newEvent, err = p.updatePowerLevelEvent(ctx, ev, eventFormat) - if err != nil { - return r.From, err - } - events[i] = &rstypes.HeaderedEvent{PDU: newEvent} - } - sEvents := gomatrixserverlib.HeaderedReverseTopologicalOrdering( gomatrixserverlib.ToPDUs(removeDuplicates(delta.StateEvents, events)), gomatrixserverlib.TopologicalOrderByAuthEvents, @@ -390,15 +370,6 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( continue } delta.StateEvents[i-skipped] = he - // update the powerlevel event for state events - if ev.Version() == gomatrixserverlib.RoomVersionPseudoIDs && ev.Type() == spec.MRoomPowerLevels && ev.StateKeyEquals("") { - var newEvent gomatrixserverlib.PDU - newEvent, err = p.updatePowerLevelEvent(ctx, he, eventFormat) - if err != nil { - return r.From, err - } - delta.StateEvents[i-skipped] = &rstypes.HeaderedEvent{PDU: newEvent} - } } delta.StateEvents = delta.StateEvents[:len(sEvents)-skipped] @@ -468,81 +439,6 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( return latestPosition, nil } -func (p *PDUStreamProvider) updatePowerLevelEvent(ctx context.Context, ev *rstypes.HeaderedEvent, eventFormat synctypes.ClientEventFormat) (gomatrixserverlib.PDU, error) { - pls, err := gomatrixserverlib.NewPowerLevelContentFromEvent(ev) - if err != nil { - return nil, err - } - newPls := make(map[string]int64) - var userID *spec.UserID - for user, level := range pls.Users { - validRoomID, _ := spec.NewRoomID(ev.RoomID()) - if eventFormat != synctypes.FormatSyncFederation { - userID, err = p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(user)) - if err != nil { - return nil, err - } - user = userID.String() - } - newPls[user] = level - } - var newPlBytes, newEv []byte - newPlBytes, err = json.Marshal(newPls) - if err != nil { - return nil, err - } - newEv, err = sjson.SetRawBytes(ev.JSON(), "content.users", newPlBytes) - if err != nil { - return nil, err - } - - // do the same for prev content - prevContent := gjson.GetBytes(ev.JSON(), "unsigned.prev_content") - if !prevContent.Exists() { - var evNew gomatrixserverlib.PDU - evNew, err = gomatrixserverlib.MustGetRoomVersion(ev.Version()).NewEventFromTrustedJSON(newEv, false) - if err != nil { - return nil, err - } - - return evNew, err - } - pls = gomatrixserverlib.PowerLevelContent{} - err = json.Unmarshal([]byte(prevContent.Raw), &pls) - if err != nil { - return nil, err - } - - newPls = make(map[string]int64) - for user, level := range pls.Users { - validRoomID, _ := spec.NewRoomID(ev.RoomID()) - if eventFormat != synctypes.FormatSyncFederation { - userID, err = p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(user)) - if err != nil { - return nil, err - } - user = userID.String() - } - newPls[user] = level - } - newPlBytes, err = json.Marshal(newPls) - if err != nil { - return nil, err - } - newEv, err = sjson.SetRawBytes(newEv, "unsigned.prev_content.users", newPlBytes) - if err != nil { - return nil, err - } - - var evNew gomatrixserverlib.PDU - evNew, err = gomatrixserverlib.MustGetRoomVersion(ev.Version()).NewEventFromTrustedJSONWithEventID(ev.EventID(), newEv, false) - if err != nil { - return nil, err - } - - return evNew, err -} - // applyHistoryVisibilityFilter gets the current room state and supplies it to ApplyHistoryVisibilityFilter, to make // sure we always return the required events in the timeline. func applyHistoryVisibilityFilter( @@ -692,35 +588,6 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( prevBatch.Decrement() } - // Update powerlevel events for timeline events - for i, ev := range events { - if ev.Version() != gomatrixserverlib.RoomVersionPseudoIDs { - continue - } - if ev.Type() != spec.MRoomPowerLevels || !ev.StateKeyEquals("") { - continue - } - newEvent, err := p.updatePowerLevelEvent(ctx, ev, eventFormat) - if err != nil { - return nil, err - } - events[i] = &rstypes.HeaderedEvent{PDU: newEvent} - } - // Update powerlevel events for state events - for i, ev := range stateEvents { - if ev.Version() != gomatrixserverlib.RoomVersionPseudoIDs { - continue - } - if ev.Type() != spec.MRoomPowerLevels || !ev.StateKeyEquals("") { - continue - } - newEvent, err := p.updatePowerLevelEvent(ctx, ev, eventFormat) - if err != nil { - return nil, err - } - stateEvents[i] = &rstypes.HeaderedEvent{PDU: newEvent} - } - jr.Timeline.PrevBatch = prevBatch jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), eventFormat, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) diff --git a/syncapi/synctypes/clientevent.go b/syncapi/synctypes/clientevent.go index 7e5b1c1bc..0c332fb72 100644 --- a/syncapi/synctypes/clientevent.go +++ b/syncapi/synctypes/clientevent.go @@ -22,6 +22,8 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) // PrevEventRef represents a reference to a previous event in a state event upgrade @@ -78,88 +80,16 @@ func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat, if se == nil { continue // TODO: shouldn't happen? } - if format == FormatSyncFederation { - evs = append(evs, ToClientEvent(se, format, string(se.SenderID()), se.StateKey(), spec.RawJSON(se.Unsigned()))) - continue - } - - sender := spec.UserID{} - validRoomID, err := spec.NewRoomID(se.RoomID()) + ev, err := ToClientEvent(se, format, userIDForSender, string(se.SenderID()), se.StateKey()) if err != nil { + logrus.Errorf("Failed converting event to ClientEvent: %s", err.Error()) continue } - userID, err := userIDForSender(*validRoomID, se.SenderID()) - if err == nil && userID != nil { - sender = *userID - } - - sk := se.StateKey() - if sk != nil && *sk != "" { - skUserID, err := userIDForSender(*validRoomID, spec.SenderID(*sk)) - if err == nil && skUserID != nil { - skString := skUserID.String() - sk = &skString - } - } - - unsigned := se.Unsigned() - var prev PrevEventRef - if err := json.Unmarshal(se.Unsigned(), &prev); err == nil && prev.PrevSenderID != "" { - prevUserID, err := userIDForSender(*validRoomID, spec.SenderID(prev.PrevSenderID)) - if err == nil && userID != nil { - prev.PrevSenderID = prevUserID.String() - } else { - errString := "userID unknown" - if err != nil { - errString = err.Error() - } - logrus.Warnf("Failed to find userID for prev_sender in ClientEvent: %s", errString) - // NOTE: Not much can be done here, so leave the previous value in place. - } - unsigned, err = json.Marshal(prev) - if err != nil { - logrus.Errorf("Failed to marshal unsigned content for ClientEvent: %s", err.Error()) - continue - } - } - evs = append(evs, ToClientEvent(se, format, sender.String(), sk, spec.RawJSON(unsigned))) + evs = append(evs, *ev) } return evs } -// ToClientEvent converts a single server event to a client event. -func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender string, stateKey *string, unsigned spec.RawJSON) ClientEvent { - ce := ClientEvent{ - Content: spec.RawJSON(se.Content()), - Sender: sender, - Type: se.Type(), - StateKey: stateKey, - Unsigned: unsigned, - OriginServerTS: se.OriginServerTS(), - EventID: se.EventID(), - Redacts: se.Redacts(), - } - - switch format { - case FormatAll: - ce.RoomID = se.RoomID() - case FormatSync: - case FormatSyncFederation: - ce.RoomID = se.RoomID() - ce.AuthEvents = se.AuthEventIDs() - ce.PrevEvents = se.PrevEventIDs() - ce.Depth = se.Depth() - // TODO: Set Signatures & Hashes fields - } - - if format != FormatSyncFederation { - if se.Version() == gomatrixserverlib.RoomVersionPseudoIDs { - ce.SenderKey = se.SenderID() - } - } - return ce -} - // ToClientEvent converts a single server event to a client event. // It provides default logic for event.SenderID & event.StateKey -> userID conversions. func ToClientEventDefault(userIDQuery spec.UserIDForSender, event gomatrixserverlib.PDU) ClientEvent { @@ -181,7 +111,11 @@ func ToClientEventDefault(userIDQuery spec.UserIDForSender, event gomatrixserver sk = &skString } } - return ToClientEvent(event, FormatAll, sender.String(), sk, event.Unsigned()) + ev, err := ToClientEvent(event, FormatAll, userIDQuery, sender.String(), sk) + if err != nil { + return ClientEvent{} + } + return *ev } // If provided state key is a user ID (state keys beginning with @ are reserved for this purpose) @@ -211,3 +145,299 @@ func FromClientStateKey(roomID spec.RoomID, stateKey string, senderIDQuery spec. return &stateKey, nil } } + +// ToClientEvent converts a single server event to a client event. +func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, userIDForSender spec.UserIDForSender, sender string, stateKey *string) (*ClientEvent, error) { + ce := ClientEvent{ + Content: se.Content(), + Sender: sender, + Type: se.Type(), + StateKey: stateKey, + Unsigned: se.Unsigned(), + OriginServerTS: se.OriginServerTS(), + EventID: se.EventID(), + Redacts: se.Redacts(), + } + + switch format { + case FormatAll: + ce.RoomID = se.RoomID() + case FormatSync: + case FormatSyncFederation: + ce.RoomID = se.RoomID() + ce.AuthEvents = se.AuthEventIDs() + ce.PrevEvents = se.PrevEventIDs() + ce.Depth = se.Depth() + // TODO: Set Signatures & Hashes fields + } + + if format != FormatSyncFederation { + if se.Version() == gomatrixserverlib.RoomVersionPseudoIDs { + ce.SenderKey = se.SenderID() + + validRoomID, err := spec.NewRoomID(se.RoomID()) + if err != nil { + return nil, err + } + userID, err := userIDForSender(*validRoomID, se.SenderID()) + if err == nil && userID != nil { + ce.Sender = userID.String() + } + + sk := se.StateKey() + if sk != nil && *sk != "" { + skUserID, err := userIDForSender(*validRoomID, spec.SenderID(*sk)) + if err == nil && skUserID != nil { + skString := skUserID.String() + ce.StateKey = &skString + } + } + + var prev PrevEventRef + if err := json.Unmarshal(se.Unsigned(), &prev); err == nil && prev.PrevSenderID != "" { + prevUserID, err := userIDForSender(*validRoomID, spec.SenderID(prev.PrevSenderID)) + if err == nil && userID != nil { + prev.PrevSenderID = prevUserID.String() + } else { + errString := "userID unknown" + if err != nil { + errString = err.Error() + } + logrus.Warnf("Failed to find userID for prev_sender in ClientEvent: %s", errString) + // NOTE: Not much can be done here, so leave the previous value in place. + } + ce.Unsigned, err = json.Marshal(prev) + if err != nil { + err = fmt.Errorf("Failed to marshal unsigned content for ClientEvent: %s", err.Error()) + return nil, err + } + } + + // TODO: Refactor all this stuff + if se.Type() == spec.MRoomCreate { + if creator := gjson.GetBytes(se.Content(), "creator"); creator.Exists() { + oldCreator := creator.Str + userID, err := userIDForSender(*validRoomID, spec.SenderID(oldCreator)) + if err != nil { + err = fmt.Errorf("Failed to find userID for creator in ClientEvent: %s", err.Error()) + return nil, err + } + + if userID != nil { + var newCreatorBytes, newContent []byte + newCreatorBytes, err = json.Marshal(userID.String()) + if err != nil { + err = fmt.Errorf("Failed to marshal new creator for ClientEvent: %s", err.Error()) + return nil, err + } + newContent, err = sjson.SetRawBytes([]byte(se.Content()), "creator", newCreatorBytes) + if err != nil { + err = fmt.Errorf("Failed to set new creator for ClientEvent: %s", err.Error()) + return nil, err + } + + ce.Content = newContent + } + } + } + + if se.Type() == spec.MRoomMember { + updatedEvent, err := updateInviteRoomState(userIDForSender, se, format) + if err != nil { + err = fmt.Errorf("Failed to update m.room.member event for ClientEvent: %s", err.Error()) + return nil, err + } + ce.Unsigned = updatedEvent.Unsigned() + } + + if se.Type() == spec.MRoomPowerLevels && se.StateKeyEquals("") { + se, err = updatePowerLevelEvent(userIDForSender, se, format) + if err != nil { + err = fmt.Errorf("Failed update power levels for ClientEvent: %s", err.Error()) + return nil, err + } + ce.Content = se.Content() + ce.Unsigned = se.Unsigned() + } + } + } + + return &ce, nil +} + +func updateInviteRoomState(userIDForSender spec.UserIDForSender, ev gomatrixserverlib.PDU, eventFormat ClientEventFormat) (gomatrixserverlib.PDU, error) { + if inviteRoomState := gjson.GetBytes(ev.Unsigned(), "invite_room_state"); inviteRoomState.Exists() { + validRoomID, err := spec.NewRoomID(ev.RoomID()) + if err != nil { + return nil, err + } + userID, err := userIDForSender(*validRoomID, ev.SenderID()) + if err != nil || userID == nil { + if err != nil { + logrus.WithError(err).Error("userID is invalid") + } + return nil, err + } + + newState, err := getUpdatedInviteRoomState(userIDForSender, inviteRoomState, ev, *userID, ev.StateKey(), eventFormat) + if err != nil { + return nil, err + } + + var newEv []byte + newEv, err = sjson.SetRawBytes(ev.JSON(), "unsigned.invite_room_state", newState) + if err != nil { + return nil, err + } + + return gomatrixserverlib.MustGetRoomVersion(ev.Version()).NewEventFromTrustedJSON(newEv, false) + } + + return ev, nil +} + +type InviteRoomStateEvent struct { + Content spec.RawJSON `json:"content"` + SenderID string `json:"sender"` + StateKey *string `json:"state_key"` + Type string `json:"type"` +} + +func getUpdatedInviteRoomState(userIDForSender spec.UserIDForSender, inviteRoomState gjson.Result, event gomatrixserverlib.PDU, inviterUserID spec.UserID, stateKey *string, eventFormat ClientEventFormat) (spec.RawJSON, error) { + var res spec.RawJSON + inviteStateEvents := []InviteRoomStateEvent{} + err := json.Unmarshal([]byte(inviteRoomState.Raw), &inviteStateEvents) + if err != nil { + return nil, err + } + + if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && eventFormat != FormatSyncFederation { + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return nil, err + } + + for i, ev := range inviteStateEvents { + userID, err := userIDForSender(*validRoomID, spec.SenderID(ev.SenderID)) + if err != nil { + return nil, err + } + inviteStateEvents[i].SenderID = userID.String() + + if ev.StateKey != nil && *ev.StateKey != "" { + userID, err := userIDForSender(*validRoomID, spec.SenderID(*ev.StateKey)) + if err != nil { + return nil, err + } + if userID != nil { + user := userID.String() + inviteStateEvents[i].StateKey = &user + } + } + + if creator := gjson.GetBytes(ev.Content, "creator"); creator.Exists() { + oldCreator := creator.Str + userID, err := userIDForSender(*validRoomID, spec.SenderID(oldCreator)) + if err != nil { + return nil, err + } + + if userID != nil { + var newCreatorBytes, newContent []byte + newCreatorBytes, err = json.Marshal(userID.String()) + if err != nil { + return nil, err + } + newContent, err = sjson.SetRawBytes([]byte(ev.Content), "creator", newCreatorBytes) + if err != nil { + return nil, err + } + + inviteStateEvents[i].Content = newContent + } + } + } + } + + res, err = json.Marshal(inviteStateEvents) + if err != nil { + return nil, err + } + + return res, nil +} + +func updatePowerLevelEvent(userIDForSender spec.UserIDForSender, se gomatrixserverlib.PDU, eventFormat ClientEventFormat) (gomatrixserverlib.PDU, error) { + pls, err := gomatrixserverlib.NewPowerLevelContentFromEvent(se) + if err != nil { + return nil, err + } + newPls := make(map[string]int64) + var userID *spec.UserID + for user, level := range pls.Users { + validRoomID, _ := spec.NewRoomID(se.RoomID()) + if eventFormat != FormatSyncFederation { + userID, err = userIDForSender(*validRoomID, spec.SenderID(user)) + if err != nil { + return nil, err + } + user = userID.String() + } + newPls[user] = level + } + var newPlBytes, newEv []byte + newPlBytes, err = json.Marshal(newPls) + if err != nil { + return nil, err + } + newEv, err = sjson.SetRawBytes(se.JSON(), "content.users", newPlBytes) + if err != nil { + return nil, err + } + + // do the same for prev content + prevContent := gjson.GetBytes(se.JSON(), "unsigned.prev_content") + if !prevContent.Exists() { + var evNew gomatrixserverlib.PDU + evNew, err = gomatrixserverlib.MustGetRoomVersion(se.Version()).NewEventFromTrustedJSON(newEv, false) + if err != nil { + return nil, err + } + + return evNew, err + } + pls = gomatrixserverlib.PowerLevelContent{} + err = json.Unmarshal([]byte(prevContent.Raw), &pls) + if err != nil { + return nil, err + } + + newPls = make(map[string]int64) + for user, level := range pls.Users { + validRoomID, _ := spec.NewRoomID(se.RoomID()) + if eventFormat != FormatSyncFederation { + userID, err = userIDForSender(*validRoomID, spec.SenderID(user)) + if err != nil { + return nil, err + } + user = userID.String() + } + newPls[user] = level + } + newPlBytes, err = json.Marshal(newPls) + if err != nil { + return nil, err + } + newEv, err = sjson.SetRawBytes(newEv, "unsigned.prev_content.users", newPlBytes) + if err != nil { + return nil, err + } + + var evNew gomatrixserverlib.PDU + evNew, err = gomatrixserverlib.MustGetRoomVersion(se.Version()).NewEventFromTrustedJSONWithEventID(se.EventID(), newEv, false) + if err != nil { + return nil, err + } + + return evNew, err +} diff --git a/syncapi/synctypes/clientevent_test.go b/syncapi/synctypes/clientevent_test.go index 202c185f1..de4fd2888 100644 --- a/syncapi/synctypes/clientevent_test.go +++ b/syncapi/synctypes/clientevent_test.go @@ -17,6 +17,7 @@ package synctypes import ( "bytes" + "context" "encoding/json" "fmt" "reflect" @@ -26,6 +27,14 @@ import ( "github.com/matrix-org/gomatrixserverlib/spec" ) +func queryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + if senderID == "" { + return nil, nil + } + + return spec.NewUserID(string(senderID), true) +} + const testSenderID = "testSenderID" const testUserID = "@test:localhost" @@ -106,7 +115,12 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo t.Fatalf("failed to create userID: %s", err) } sk := "" - ce := ToClientEvent(ev, FormatAll, userID.String(), &sk, ev.Unsigned()) + ce, err := ToClientEvent(ev, FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return queryUserIDForSender(context.Background(), roomID, senderID) + }, userID.String(), &sk) + if err != nil { + t.Fatalf("failed to create ClientEvent: %s", err) + } verifyEventFields(t, EventFieldsToVerify{ @@ -166,7 +180,12 @@ func TestToClientFormatSync(t *testing.T) { t.Fatalf("failed to create userID: %s", err) } sk := "" - ce := ToClientEvent(ev, FormatSync, userID.String(), &sk, ev.Unsigned()) + ce, err := ToClientEvent(ev, FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return queryUserIDForSender(context.Background(), roomID, senderID) + }, userID.String(), &sk) + if err != nil { + t.Fatalf("failed to create ClientEvent: %s", err) + } if ce.RoomID != "" { t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID) } @@ -206,7 +225,12 @@ func TestToClientEventFormatSyncFederation(t *testing.T) { // nolint: gocyclo t.Fatalf("failed to create userID: %s", err) } sk := "" - ce := ToClientEvent(ev, FormatSyncFederation, userID.String(), &sk, ev.Unsigned()) + ce, err := ToClientEvent(ev, FormatSyncFederation, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return queryUserIDForSender(context.Background(), roomID, senderID) + }, userID.String(), &sk) + if err != nil { + t.Fatalf("failed to create ClientEvent: %s", err) + } verifyEventFields(t, EventFieldsToVerify{ diff --git a/syncapi/types/types.go b/syncapi/types/types.go index b90c128c3..6298128d8 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -15,6 +15,7 @@ package types import ( + "context" "encoding/json" "errors" "fmt" @@ -532,26 +533,52 @@ type InviteResponse struct { } // NewInviteResponse creates an empty response with initialised arrays. -func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID, stateKey *string, eventFormat synctypes.ClientEventFormat) *InviteResponse { - res := InviteResponse{} - res.InviteState.Events = []json.RawMessage{} +func NewInviteResponse(ctx context.Context, rsAPI api.QuerySenderIDAPI, event *types.HeaderedEvent, userID spec.UserID, stateKey *string, eventFormat synctypes.ClientEventFormat) (*InviteResponse, error) { + res, err := updateInviteRoomState(ctx, rsAPI, event, userID, stateKey, eventFormat) + if err != nil { + return nil, err + } + + return res, nil +} + +func updateInviteRoomState(ctx context.Context, rsAPI api.QuerySenderIDAPI, event *types.HeaderedEvent, inviterUserID spec.UserID, stateKey *string, eventFormat synctypes.ClientEventFormat) (*InviteResponse, error) { + inv := InviteResponse{} + inv.InviteState.Events = []json.RawMessage{} // First see if there's invite_room_state in the unsigned key of the invite. // If there is then unmarshal it into the response. This will contain the // partial room state such as join rules, room name etc. if inviteRoomState := gjson.GetBytes(event.Unsigned(), "invite_room_state"); inviteRoomState.Exists() { - _ = json.Unmarshal([]byte(inviteRoomState.Raw), &res.InviteState.Events) + err := json.Unmarshal([]byte(inviteRoomState.Raw), &inv.InviteState.Events) + if err != nil { + return nil, err + } + } + + // Clear unsigned so it doesn't have pseudoIDs converted during ToClientEvent + eventNoUnsigned, err := event.SetUnsigned(nil) + if err != nil { + return nil, err } // Then we'll see if we can create a partial of the invite event itself. // This is needed for clients to work out *who* sent the invite. - inviteEvent := synctypes.ToClientEvent(event.PDU, eventFormat, userID.String(), stateKey, event.Unsigned()) - inviteEvent.Unsigned = nil - if ev, err := json.Marshal(inviteEvent); err == nil { - res.InviteState.Events = append(res.InviteState.Events, ev) + inviteEvent, err := synctypes.ToClientEvent(eventNoUnsigned, eventFormat, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }, inviterUserID.String(), stateKey) + if err != nil { + return nil, err } - return &res + // Ensure unsigned field is empty so it isn't marshalled into the final JSON + inviteEvent.Unsigned = nil + + if ev, err := json.Marshal(*inviteEvent); err == nil { + inv.InviteState.Events = append(inv.InviteState.Events, ev) + } + + return &inv, nil } // LeaveResponse represents a /sync response for a room which is under the 'leave' key. diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go index a79b9fc5d..0d8ecce35 100644 --- a/syncapi/types/types_test.go +++ b/syncapi/types/types_test.go @@ -1,6 +1,7 @@ package types import ( + "context" "encoding/json" "reflect" "testing" @@ -11,8 +12,19 @@ import ( "github.com/matrix-org/gomatrixserverlib/spec" ) -func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) { - return spec.NewUserID(senderID, true) +type FakeRoomserverAPI struct{} + +func (f *FakeRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + if senderID == "" { + return nil, nil + } + + return spec.NewUserID(string(senderID), true) +} + +func (f *FakeRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (*spec.SenderID, error) { + sender := spec.SenderID(userID.String()) + return &sender, nil } func TestSyncTokens(t *testing.T) { @@ -72,14 +84,18 @@ func TestNewInviteResponse(t *testing.T) { skString := skUserID.String() sk := &skString - res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender, sk, synctypes.FormatSync) + rsAPI := FakeRoomserverAPI{} + res, err := NewInviteResponse(context.Background(), &rsAPI, &types.HeaderedEvent{PDU: ev}, *sender, sk, synctypes.FormatSync) + if err != nil { + t.Fatal(err) + } j, err := json.Marshal(res) if err != nil { t.Fatal(err) } if string(j) != expected { - t.Fatalf("Invite response didn't contain correct info") + t.Fatalf("Invite response didn't contain correct info, \nexpected: %s \ngot: %s", expected, string(j)) } } diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index 8863d258a..c74b76173 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -321,9 +321,14 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst return fmt.Errorf("queryUserIDForSender: userID unknown for %s", *sk) } } - cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, sender.String(), sk, event.Unsigned()) + cevent, err := synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }, sender.String(), sk) + if err != nil { + return err + } var member *localMembership - member, err = newLocalMembership(&cevent) + member, err = newLocalMembership(cevent) if err != nil { return fmt.Errorf("newLocalMembership: %w", err) } @@ -561,12 +566,18 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype sk = &skString } } + clientEvent, err := synctypes.ToClientEvent(event, synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }, sender.String(), sk) + if err != nil { + return err + } n := &api.Notification{ Actions: actions, // UNSPEC: the spec doesn't say this is a ClientEvent, but the // fields seem to match. room_id should be missing, which // matches the behaviour of FormatSync. - Event: synctypes.ToClientEvent(event, synctypes.FormatSync, sender.String(), sk, event.Unsigned()), + Event: *clientEvent, // TODO: this is per-device, but it's not part of the primary // key. So inserting one notification per profile tag doesn't // make sense. What is this supposed to be? Sytests require it diff --git a/userapi/util/notify_test.go b/userapi/util/notify_test.go index 27e77cf7a..8daf10ef0 100644 --- a/userapi/util/notify_test.go +++ b/userapi/util/notify_test.go @@ -23,6 +23,14 @@ import ( userUtil "github.com/matrix-org/dendrite/userapi/util" ) +func queryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + if senderID == "" { + return nil, nil + } + + return spec.NewUserID(string(senderID), true) +} + func TestNotifyUserCountsAsync(t *testing.T) { alice := test.NewUser(t) aliceLocalpart, serverName, err := gomatrixserverlib.SplitID('@', alice.ID) @@ -105,8 +113,11 @@ func TestNotifyUserCountsAsync(t *testing.T) { t.Error(err) } sk := "" + ev, err := synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return queryUserIDForSender(context.Background(), roomID, senderID) + }, sender.String(), &sk) if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{ - Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, sender.String(), &sk, dummyEvent.Unsigned()), + Event: *ev, }); err != nil { t.Error(err) }