Refactor pseudoID client event hotswap to occur in a single location

This commit is contained in:
Devon Hudson 2023-09-13 12:47:30 -06:00
parent bea73c765a
commit b593f29fce
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
14 changed files with 478 additions and 235 deletions

View file

@ -191,9 +191,16 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
sk = &skString sk = &skString
} }
} }
clientEvent, err := synctypes.ToClientEvent(ev, synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, sender.String(), sk)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed converting to ClientEvent")
continue
}
stateEvents = append( stateEvents = append(
stateEvents, stateEvents,
synctypes.ToClientEvent(ev, synctypes.FormatAll, sender.String(), sk, ev.Unsigned()), *clientEvent,
) )
} }
} }

View file

@ -184,7 +184,13 @@ func RedactEvent(ctx context.Context, redactionEvent, redactedEvent gomatrixserv
if err != nil { if err != nil {
return err return err
} }
redactedBecause := synctypes.ToClientEvent(redactionEvent, synctypes.FormatSync, senderID.String(), redactionEvent.StateKey(), redactionEvent.Unsigned()) clientEvent, err := synctypes.ToClientEvent(redactionEvent, synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return querier.QueryUserIDForSender(ctx, roomID, senderID)
}, senderID.String(), redactionEvent.StateKey())
if err != nil {
return err
}
redactedBecause := clientEvent
if err := redactedEvent.SetUnsignedField("redacted_because", redactedBecause); err != nil { if err := redactedEvent.SetUnsignedField("redacted_because", redactedBecause); err != nil {
return err return err
} }

View file

@ -142,8 +142,20 @@ func GetEvent(
sk = &skString sk = &skString
} }
} }
clientEvent, err := synctypes.ToClientEvent(events[0], synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, senderUserID.String(), sk)
if err != nil {
util.GetLogger(req.Context()).WithError(err).WithField("senderID", events[0].SenderID()).WithField("roomID", *roomID).Error("Failed converting to ClientEvent")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.Unknown("internal server error"),
}
}
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, senderUserID.String(), sk, events[0].Unsigned()), JSON: *clientEvent,
} }
} }

View file

@ -416,6 +416,17 @@ func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserv
start = *r.from start = *r.from
for _, ev := range filteredEvents {
if ev.Version() != gomatrixserverlib.RoomVersionPseudoIDs {
continue
}
if ev.Type() != spec.MRoomPowerLevels || !ev.StateKeyEquals("") {
continue
}
// TODO: update power levels
// TODO: same thing for /event endpoint?
}
return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}), start, end, nil }), start, end, nil

View file

@ -144,9 +144,17 @@ func Relations(
sk = &skString sk = &skString
} }
} }
clientEvent, err := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
}, sender.String(), sk)
if err != nil {
util.GetLogger(req.Context()).WithError(err).WithField("senderID", events[0].SenderID()).WithField("roomID", *roomID).Error("Failed converting to ClientEvent")
continue
}
res.Chunk = append( res.Chunk = append(
res.Chunk, res.Chunk,
synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender.String(), sk, event.Unsigned()), *clientEvent,
) )
} }

View file

@ -254,6 +254,15 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
sk = &skString sk = &skString
} }
} }
clientEvent, err := synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, sender.String(), sk)
if err != nil {
util.GetLogger(req.Context()).WithError(err).WithField("senderID", event.SenderID()).WithField("roomID", *validRoomID).Error("Failed converting to ClientEvent")
continue
}
results = append(results, Result{ results = append(results, Result{
Context: SearchContextResponse{ Context: SearchContextResponse{
Start: startToken.String(), Start: startToken.String(),
@ -267,7 +276,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
ProfileInfo: profileInfos, ProfileInfo: profileInfos,
}, },
Rank: eventScore[event.EventID()].Score, Rank: eventScore[event.EventID()].Score,
Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender.String(), sk, event.Unsigned()), Result: *clientEvent,
}) })
roomGroup := groups[event.RoomID()] roomGroup := groups[event.RoomID()]
roomGroup.Results = append(roomGroup.Results, event.EventID()) roomGroup.Results = append(roomGroup.Results, event.EventID())

View file

@ -92,7 +92,11 @@ func (p *InviteStreamProvider) IncrementalSync(
if _, ok := req.IgnoredUsers.List[user.String()]; ok { if _, ok := req.IgnoredUsers.List[user.String()]; ok {
continue continue
} }
ir := types.NewInviteResponse(inviteEvent, user, sk, eventFormat) ir, err := types.NewInviteResponse(ctx, p.rsAPI, inviteEvent, user, sk, eventFormat)
if err != nil {
req.Log.WithError(err).Error("failed creating invite response")
continue
}
req.Response.Rooms.Invite[roomID] = ir req.Response.Rooms.Invite[roomID] = ir
} }

View file

@ -3,7 +3,6 @@ package streams
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"fmt" "fmt"
"time" "time"
@ -16,8 +15,6 @@ import (
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -359,23 +356,6 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
// Now that we've filtered the timeline, work out which state events are still // Now that we've filtered the timeline, work out which state events are still
// left. Anything that appears in the filtered timeline will be removed from the // left. Anything that appears in the filtered timeline will be removed from the
// "state" section and kept in "timeline". // "state" section and kept in "timeline".
// update the powerlevel event for timeline events
for i, ev := range events {
if ev.Version() != gomatrixserverlib.RoomVersionPseudoIDs {
continue
}
if ev.Type() != spec.MRoomPowerLevels || !ev.StateKeyEquals("") {
continue
}
var newEvent gomatrixserverlib.PDU
newEvent, err = p.updatePowerLevelEvent(ctx, ev, eventFormat)
if err != nil {
return r.From, err
}
events[i] = &rstypes.HeaderedEvent{PDU: newEvent}
}
sEvents := gomatrixserverlib.HeaderedReverseTopologicalOrdering( sEvents := gomatrixserverlib.HeaderedReverseTopologicalOrdering(
gomatrixserverlib.ToPDUs(removeDuplicates(delta.StateEvents, events)), gomatrixserverlib.ToPDUs(removeDuplicates(delta.StateEvents, events)),
gomatrixserverlib.TopologicalOrderByAuthEvents, gomatrixserverlib.TopologicalOrderByAuthEvents,
@ -390,15 +370,6 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
continue continue
} }
delta.StateEvents[i-skipped] = he delta.StateEvents[i-skipped] = he
// update the powerlevel event for state events
if ev.Version() == gomatrixserverlib.RoomVersionPseudoIDs && ev.Type() == spec.MRoomPowerLevels && ev.StateKeyEquals("") {
var newEvent gomatrixserverlib.PDU
newEvent, err = p.updatePowerLevelEvent(ctx, he, eventFormat)
if err != nil {
return r.From, err
}
delta.StateEvents[i-skipped] = &rstypes.HeaderedEvent{PDU: newEvent}
}
} }
delta.StateEvents = delta.StateEvents[:len(sEvents)-skipped] delta.StateEvents = delta.StateEvents[:len(sEvents)-skipped]
@ -468,81 +439,6 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
return latestPosition, nil return latestPosition, nil
} }
func (p *PDUStreamProvider) updatePowerLevelEvent(ctx context.Context, ev *rstypes.HeaderedEvent, eventFormat synctypes.ClientEventFormat) (gomatrixserverlib.PDU, error) {
pls, err := gomatrixserverlib.NewPowerLevelContentFromEvent(ev)
if err != nil {
return nil, err
}
newPls := make(map[string]int64)
var userID *spec.UserID
for user, level := range pls.Users {
validRoomID, _ := spec.NewRoomID(ev.RoomID())
if eventFormat != synctypes.FormatSyncFederation {
userID, err = p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(user))
if err != nil {
return nil, err
}
user = userID.String()
}
newPls[user] = level
}
var newPlBytes, newEv []byte
newPlBytes, err = json.Marshal(newPls)
if err != nil {
return nil, err
}
newEv, err = sjson.SetRawBytes(ev.JSON(), "content.users", newPlBytes)
if err != nil {
return nil, err
}
// do the same for prev content
prevContent := gjson.GetBytes(ev.JSON(), "unsigned.prev_content")
if !prevContent.Exists() {
var evNew gomatrixserverlib.PDU
evNew, err = gomatrixserverlib.MustGetRoomVersion(ev.Version()).NewEventFromTrustedJSON(newEv, false)
if err != nil {
return nil, err
}
return evNew, err
}
pls = gomatrixserverlib.PowerLevelContent{}
err = json.Unmarshal([]byte(prevContent.Raw), &pls)
if err != nil {
return nil, err
}
newPls = make(map[string]int64)
for user, level := range pls.Users {
validRoomID, _ := spec.NewRoomID(ev.RoomID())
if eventFormat != synctypes.FormatSyncFederation {
userID, err = p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(user))
if err != nil {
return nil, err
}
user = userID.String()
}
newPls[user] = level
}
newPlBytes, err = json.Marshal(newPls)
if err != nil {
return nil, err
}
newEv, err = sjson.SetRawBytes(newEv, "unsigned.prev_content.users", newPlBytes)
if err != nil {
return nil, err
}
var evNew gomatrixserverlib.PDU
evNew, err = gomatrixserverlib.MustGetRoomVersion(ev.Version()).NewEventFromTrustedJSONWithEventID(ev.EventID(), newEv, false)
if err != nil {
return nil, err
}
return evNew, err
}
// applyHistoryVisibilityFilter gets the current room state and supplies it to ApplyHistoryVisibilityFilter, to make // applyHistoryVisibilityFilter gets the current room state and supplies it to ApplyHistoryVisibilityFilter, to make
// sure we always return the required events in the timeline. // sure we always return the required events in the timeline.
func applyHistoryVisibilityFilter( func applyHistoryVisibilityFilter(
@ -692,35 +588,6 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
prevBatch.Decrement() prevBatch.Decrement()
} }
// Update powerlevel events for timeline events
for i, ev := range events {
if ev.Version() != gomatrixserverlib.RoomVersionPseudoIDs {
continue
}
if ev.Type() != spec.MRoomPowerLevels || !ev.StateKeyEquals("") {
continue
}
newEvent, err := p.updatePowerLevelEvent(ctx, ev, eventFormat)
if err != nil {
return nil, err
}
events[i] = &rstypes.HeaderedEvent{PDU: newEvent}
}
// Update powerlevel events for state events
for i, ev := range stateEvents {
if ev.Version() != gomatrixserverlib.RoomVersionPseudoIDs {
continue
}
if ev.Type() != spec.MRoomPowerLevels || !ev.StateKeyEquals("") {
continue
}
newEvent, err := p.updatePowerLevelEvent(ctx, ev, eventFormat)
if err != nil {
return nil, err
}
stateEvents[i] = &rstypes.HeaderedEvent{PDU: newEvent}
}
jr.Timeline.PrevBatch = prevBatch jr.Timeline.PrevBatch = prevBatch
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), eventFormat, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), eventFormat, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)

View file

@ -22,6 +22,8 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
) )
// PrevEventRef represents a reference to a previous event in a state event upgrade // PrevEventRef represents a reference to a previous event in a state event upgrade
@ -78,88 +80,16 @@ func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat,
if se == nil { if se == nil {
continue // TODO: shouldn't happen? continue // TODO: shouldn't happen?
} }
if format == FormatSyncFederation { ev, err := ToClientEvent(se, format, userIDForSender, string(se.SenderID()), se.StateKey())
evs = append(evs, ToClientEvent(se, format, string(se.SenderID()), se.StateKey(), spec.RawJSON(se.Unsigned())))
continue
}
sender := spec.UserID{}
validRoomID, err := spec.NewRoomID(se.RoomID())
if err != nil { if err != nil {
logrus.Errorf("Failed converting event to ClientEvent: %s", err.Error())
continue continue
} }
userID, err := userIDForSender(*validRoomID, se.SenderID()) evs = append(evs, *ev)
if err == nil && userID != nil {
sender = *userID
}
sk := se.StateKey()
if sk != nil && *sk != "" {
skUserID, err := userIDForSender(*validRoomID, spec.SenderID(*sk))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
unsigned := se.Unsigned()
var prev PrevEventRef
if err := json.Unmarshal(se.Unsigned(), &prev); err == nil && prev.PrevSenderID != "" {
prevUserID, err := userIDForSender(*validRoomID, spec.SenderID(prev.PrevSenderID))
if err == nil && userID != nil {
prev.PrevSenderID = prevUserID.String()
} else {
errString := "userID unknown"
if err != nil {
errString = err.Error()
}
logrus.Warnf("Failed to find userID for prev_sender in ClientEvent: %s", errString)
// NOTE: Not much can be done here, so leave the previous value in place.
}
unsigned, err = json.Marshal(prev)
if err != nil {
logrus.Errorf("Failed to marshal unsigned content for ClientEvent: %s", err.Error())
continue
}
}
evs = append(evs, ToClientEvent(se, format, sender.String(), sk, spec.RawJSON(unsigned)))
} }
return evs return evs
} }
// ToClientEvent converts a single server event to a client event.
func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender string, stateKey *string, unsigned spec.RawJSON) ClientEvent {
ce := ClientEvent{
Content: spec.RawJSON(se.Content()),
Sender: sender,
Type: se.Type(),
StateKey: stateKey,
Unsigned: unsigned,
OriginServerTS: se.OriginServerTS(),
EventID: se.EventID(),
Redacts: se.Redacts(),
}
switch format {
case FormatAll:
ce.RoomID = se.RoomID()
case FormatSync:
case FormatSyncFederation:
ce.RoomID = se.RoomID()
ce.AuthEvents = se.AuthEventIDs()
ce.PrevEvents = se.PrevEventIDs()
ce.Depth = se.Depth()
// TODO: Set Signatures & Hashes fields
}
if format != FormatSyncFederation {
if se.Version() == gomatrixserverlib.RoomVersionPseudoIDs {
ce.SenderKey = se.SenderID()
}
}
return ce
}
// ToClientEvent converts a single server event to a client event. // ToClientEvent converts a single server event to a client event.
// It provides default logic for event.SenderID & event.StateKey -> userID conversions. // It provides default logic for event.SenderID & event.StateKey -> userID conversions.
func ToClientEventDefault(userIDQuery spec.UserIDForSender, event gomatrixserverlib.PDU) ClientEvent { func ToClientEventDefault(userIDQuery spec.UserIDForSender, event gomatrixserverlib.PDU) ClientEvent {
@ -181,7 +111,11 @@ func ToClientEventDefault(userIDQuery spec.UserIDForSender, event gomatrixserver
sk = &skString sk = &skString
} }
} }
return ToClientEvent(event, FormatAll, sender.String(), sk, event.Unsigned()) ev, err := ToClientEvent(event, FormatAll, userIDQuery, sender.String(), sk)
if err != nil {
return ClientEvent{}
}
return *ev
} }
// If provided state key is a user ID (state keys beginning with @ are reserved for this purpose) // If provided state key is a user ID (state keys beginning with @ are reserved for this purpose)
@ -211,3 +145,299 @@ func FromClientStateKey(roomID spec.RoomID, stateKey string, senderIDQuery spec.
return &stateKey, nil return &stateKey, nil
} }
} }
// ToClientEvent converts a single server event to a client event.
func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, userIDForSender spec.UserIDForSender, sender string, stateKey *string) (*ClientEvent, error) {
ce := ClientEvent{
Content: se.Content(),
Sender: sender,
Type: se.Type(),
StateKey: stateKey,
Unsigned: se.Unsigned(),
OriginServerTS: se.OriginServerTS(),
EventID: se.EventID(),
Redacts: se.Redacts(),
}
switch format {
case FormatAll:
ce.RoomID = se.RoomID()
case FormatSync:
case FormatSyncFederation:
ce.RoomID = se.RoomID()
ce.AuthEvents = se.AuthEventIDs()
ce.PrevEvents = se.PrevEventIDs()
ce.Depth = se.Depth()
// TODO: Set Signatures & Hashes fields
}
if format != FormatSyncFederation {
if se.Version() == gomatrixserverlib.RoomVersionPseudoIDs {
ce.SenderKey = se.SenderID()
validRoomID, err := spec.NewRoomID(se.RoomID())
if err != nil {
return nil, err
}
userID, err := userIDForSender(*validRoomID, se.SenderID())
if err == nil && userID != nil {
ce.Sender = userID.String()
}
sk := se.StateKey()
if sk != nil && *sk != "" {
skUserID, err := userIDForSender(*validRoomID, spec.SenderID(*sk))
if err == nil && skUserID != nil {
skString := skUserID.String()
ce.StateKey = &skString
}
}
var prev PrevEventRef
if err := json.Unmarshal(se.Unsigned(), &prev); err == nil && prev.PrevSenderID != "" {
prevUserID, err := userIDForSender(*validRoomID, spec.SenderID(prev.PrevSenderID))
if err == nil && userID != nil {
prev.PrevSenderID = prevUserID.String()
} else {
errString := "userID unknown"
if err != nil {
errString = err.Error()
}
logrus.Warnf("Failed to find userID for prev_sender in ClientEvent: %s", errString)
// NOTE: Not much can be done here, so leave the previous value in place.
}
ce.Unsigned, err = json.Marshal(prev)
if err != nil {
err = fmt.Errorf("Failed to marshal unsigned content for ClientEvent: %s", err.Error())
return nil, err
}
}
// TODO: Refactor all this stuff
if se.Type() == spec.MRoomCreate {
if creator := gjson.GetBytes(se.Content(), "creator"); creator.Exists() {
oldCreator := creator.Str
userID, err := userIDForSender(*validRoomID, spec.SenderID(oldCreator))
if err != nil {
err = fmt.Errorf("Failed to find userID for creator in ClientEvent: %s", err.Error())
return nil, err
}
if userID != nil {
var newCreatorBytes, newContent []byte
newCreatorBytes, err = json.Marshal(userID.String())
if err != nil {
err = fmt.Errorf("Failed to marshal new creator for ClientEvent: %s", err.Error())
return nil, err
}
newContent, err = sjson.SetRawBytes([]byte(se.Content()), "creator", newCreatorBytes)
if err != nil {
err = fmt.Errorf("Failed to set new creator for ClientEvent: %s", err.Error())
return nil, err
}
ce.Content = newContent
}
}
}
if se.Type() == spec.MRoomMember {
updatedEvent, err := updateInviteRoomState(userIDForSender, se, format)
if err != nil {
err = fmt.Errorf("Failed to update m.room.member event for ClientEvent: %s", err.Error())
return nil, err
}
ce.Unsigned = updatedEvent.Unsigned()
}
if se.Type() == spec.MRoomPowerLevels && se.StateKeyEquals("") {
se, err = updatePowerLevelEvent(userIDForSender, se, format)
if err != nil {
err = fmt.Errorf("Failed update power levels for ClientEvent: %s", err.Error())
return nil, err
}
ce.Content = se.Content()
ce.Unsigned = se.Unsigned()
}
}
}
return &ce, nil
}
func updateInviteRoomState(userIDForSender spec.UserIDForSender, ev gomatrixserverlib.PDU, eventFormat ClientEventFormat) (gomatrixserverlib.PDU, error) {
if inviteRoomState := gjson.GetBytes(ev.Unsigned(), "invite_room_state"); inviteRoomState.Exists() {
validRoomID, err := spec.NewRoomID(ev.RoomID())
if err != nil {
return nil, err
}
userID, err := userIDForSender(*validRoomID, ev.SenderID())
if err != nil || userID == nil {
if err != nil {
logrus.WithError(err).Error("userID is invalid")
}
return nil, err
}
newState, err := getUpdatedInviteRoomState(userIDForSender, inviteRoomState, ev, *userID, ev.StateKey(), eventFormat)
if err != nil {
return nil, err
}
var newEv []byte
newEv, err = sjson.SetRawBytes(ev.JSON(), "unsigned.invite_room_state", newState)
if err != nil {
return nil, err
}
return gomatrixserverlib.MustGetRoomVersion(ev.Version()).NewEventFromTrustedJSON(newEv, false)
}
return ev, nil
}
type InviteRoomStateEvent struct {
Content spec.RawJSON `json:"content"`
SenderID string `json:"sender"`
StateKey *string `json:"state_key"`
Type string `json:"type"`
}
func getUpdatedInviteRoomState(userIDForSender spec.UserIDForSender, inviteRoomState gjson.Result, event gomatrixserverlib.PDU, inviterUserID spec.UserID, stateKey *string, eventFormat ClientEventFormat) (spec.RawJSON, error) {
var res spec.RawJSON
inviteStateEvents := []InviteRoomStateEvent{}
err := json.Unmarshal([]byte(inviteRoomState.Raw), &inviteStateEvents)
if err != nil {
return nil, err
}
if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && eventFormat != FormatSyncFederation {
validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return nil, err
}
for i, ev := range inviteStateEvents {
userID, err := userIDForSender(*validRoomID, spec.SenderID(ev.SenderID))
if err != nil {
return nil, err
}
inviteStateEvents[i].SenderID = userID.String()
if ev.StateKey != nil && *ev.StateKey != "" {
userID, err := userIDForSender(*validRoomID, spec.SenderID(*ev.StateKey))
if err != nil {
return nil, err
}
if userID != nil {
user := userID.String()
inviteStateEvents[i].StateKey = &user
}
}
if creator := gjson.GetBytes(ev.Content, "creator"); creator.Exists() {
oldCreator := creator.Str
userID, err := userIDForSender(*validRoomID, spec.SenderID(oldCreator))
if err != nil {
return nil, err
}
if userID != nil {
var newCreatorBytes, newContent []byte
newCreatorBytes, err = json.Marshal(userID.String())
if err != nil {
return nil, err
}
newContent, err = sjson.SetRawBytes([]byte(ev.Content), "creator", newCreatorBytes)
if err != nil {
return nil, err
}
inviteStateEvents[i].Content = newContent
}
}
}
}
res, err = json.Marshal(inviteStateEvents)
if err != nil {
return nil, err
}
return res, nil
}
func updatePowerLevelEvent(userIDForSender spec.UserIDForSender, se gomatrixserverlib.PDU, eventFormat ClientEventFormat) (gomatrixserverlib.PDU, error) {
pls, err := gomatrixserverlib.NewPowerLevelContentFromEvent(se)
if err != nil {
return nil, err
}
newPls := make(map[string]int64)
var userID *spec.UserID
for user, level := range pls.Users {
validRoomID, _ := spec.NewRoomID(se.RoomID())
if eventFormat != FormatSyncFederation {
userID, err = userIDForSender(*validRoomID, spec.SenderID(user))
if err != nil {
return nil, err
}
user = userID.String()
}
newPls[user] = level
}
var newPlBytes, newEv []byte
newPlBytes, err = json.Marshal(newPls)
if err != nil {
return nil, err
}
newEv, err = sjson.SetRawBytes(se.JSON(), "content.users", newPlBytes)
if err != nil {
return nil, err
}
// do the same for prev content
prevContent := gjson.GetBytes(se.JSON(), "unsigned.prev_content")
if !prevContent.Exists() {
var evNew gomatrixserverlib.PDU
evNew, err = gomatrixserverlib.MustGetRoomVersion(se.Version()).NewEventFromTrustedJSON(newEv, false)
if err != nil {
return nil, err
}
return evNew, err
}
pls = gomatrixserverlib.PowerLevelContent{}
err = json.Unmarshal([]byte(prevContent.Raw), &pls)
if err != nil {
return nil, err
}
newPls = make(map[string]int64)
for user, level := range pls.Users {
validRoomID, _ := spec.NewRoomID(se.RoomID())
if eventFormat != FormatSyncFederation {
userID, err = userIDForSender(*validRoomID, spec.SenderID(user))
if err != nil {
return nil, err
}
user = userID.String()
}
newPls[user] = level
}
newPlBytes, err = json.Marshal(newPls)
if err != nil {
return nil, err
}
newEv, err = sjson.SetRawBytes(newEv, "unsigned.prev_content.users", newPlBytes)
if err != nil {
return nil, err
}
var evNew gomatrixserverlib.PDU
evNew, err = gomatrixserverlib.MustGetRoomVersion(se.Version()).NewEventFromTrustedJSONWithEventID(se.EventID(), newEv, false)
if err != nil {
return nil, err
}
return evNew, err
}

View file

@ -17,6 +17,7 @@ package synctypes
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"reflect" "reflect"
@ -26,6 +27,14 @@ import (
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
) )
func queryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
if senderID == "" {
return nil, nil
}
return spec.NewUserID(string(senderID), true)
}
const testSenderID = "testSenderID" const testSenderID = "testSenderID"
const testUserID = "@test:localhost" const testUserID = "@test:localhost"
@ -106,7 +115,12 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo
t.Fatalf("failed to create userID: %s", err) t.Fatalf("failed to create userID: %s", err)
} }
sk := "" sk := ""
ce := ToClientEvent(ev, FormatAll, userID.String(), &sk, ev.Unsigned()) ce, err := ToClientEvent(ev, FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return queryUserIDForSender(context.Background(), roomID, senderID)
}, userID.String(), &sk)
if err != nil {
t.Fatalf("failed to create ClientEvent: %s", err)
}
verifyEventFields(t, verifyEventFields(t,
EventFieldsToVerify{ EventFieldsToVerify{
@ -166,7 +180,12 @@ func TestToClientFormatSync(t *testing.T) {
t.Fatalf("failed to create userID: %s", err) t.Fatalf("failed to create userID: %s", err)
} }
sk := "" sk := ""
ce := ToClientEvent(ev, FormatSync, userID.String(), &sk, ev.Unsigned()) ce, err := ToClientEvent(ev, FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return queryUserIDForSender(context.Background(), roomID, senderID)
}, userID.String(), &sk)
if err != nil {
t.Fatalf("failed to create ClientEvent: %s", err)
}
if ce.RoomID != "" { if ce.RoomID != "" {
t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID) t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID)
} }
@ -206,7 +225,12 @@ func TestToClientEventFormatSyncFederation(t *testing.T) { // nolint: gocyclo
t.Fatalf("failed to create userID: %s", err) t.Fatalf("failed to create userID: %s", err)
} }
sk := "" sk := ""
ce := ToClientEvent(ev, FormatSyncFederation, userID.String(), &sk, ev.Unsigned()) ce, err := ToClientEvent(ev, FormatSyncFederation, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return queryUserIDForSender(context.Background(), roomID, senderID)
}, userID.String(), &sk)
if err != nil {
t.Fatalf("failed to create ClientEvent: %s", err)
}
verifyEventFields(t, verifyEventFields(t,
EventFieldsToVerify{ EventFieldsToVerify{

View file

@ -15,6 +15,7 @@
package types package types
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -532,26 +533,52 @@ type InviteResponse struct {
} }
// NewInviteResponse creates an empty response with initialised arrays. // NewInviteResponse creates an empty response with initialised arrays.
func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID, stateKey *string, eventFormat synctypes.ClientEventFormat) *InviteResponse { func NewInviteResponse(ctx context.Context, rsAPI api.QuerySenderIDAPI, event *types.HeaderedEvent, userID spec.UserID, stateKey *string, eventFormat synctypes.ClientEventFormat) (*InviteResponse, error) {
res := InviteResponse{} res, err := updateInviteRoomState(ctx, rsAPI, event, userID, stateKey, eventFormat)
res.InviteState.Events = []json.RawMessage{} if err != nil {
return nil, err
}
return res, nil
}
func updateInviteRoomState(ctx context.Context, rsAPI api.QuerySenderIDAPI, event *types.HeaderedEvent, inviterUserID spec.UserID, stateKey *string, eventFormat synctypes.ClientEventFormat) (*InviteResponse, error) {
inv := InviteResponse{}
inv.InviteState.Events = []json.RawMessage{}
// First see if there's invite_room_state in the unsigned key of the invite. // First see if there's invite_room_state in the unsigned key of the invite.
// If there is then unmarshal it into the response. This will contain the // If there is then unmarshal it into the response. This will contain the
// partial room state such as join rules, room name etc. // partial room state such as join rules, room name etc.
if inviteRoomState := gjson.GetBytes(event.Unsigned(), "invite_room_state"); inviteRoomState.Exists() { if inviteRoomState := gjson.GetBytes(event.Unsigned(), "invite_room_state"); inviteRoomState.Exists() {
_ = json.Unmarshal([]byte(inviteRoomState.Raw), &res.InviteState.Events) err := json.Unmarshal([]byte(inviteRoomState.Raw), &inv.InviteState.Events)
if err != nil {
return nil, err
}
}
// Clear unsigned so it doesn't have pseudoIDs converted during ToClientEvent
eventNoUnsigned, err := event.SetUnsigned(nil)
if err != nil {
return nil, err
} }
// Then we'll see if we can create a partial of the invite event itself. // Then we'll see if we can create a partial of the invite event itself.
// This is needed for clients to work out *who* sent the invite. // This is needed for clients to work out *who* sent the invite.
inviteEvent := synctypes.ToClientEvent(event.PDU, eventFormat, userID.String(), stateKey, event.Unsigned()) inviteEvent, err := synctypes.ToClientEvent(eventNoUnsigned, eventFormat, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
inviteEvent.Unsigned = nil return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
if ev, err := json.Marshal(inviteEvent); err == nil { }, inviterUserID.String(), stateKey)
res.InviteState.Events = append(res.InviteState.Events, ev) if err != nil {
return nil, err
} }
return &res // Ensure unsigned field is empty so it isn't marshalled into the final JSON
inviteEvent.Unsigned = nil
if ev, err := json.Marshal(*inviteEvent); err == nil {
inv.InviteState.Events = append(inv.InviteState.Events, ev)
}
return &inv, nil
} }
// LeaveResponse represents a /sync response for a room which is under the 'leave' key. // LeaveResponse represents a /sync response for a room which is under the 'leave' key.

View file

@ -1,6 +1,7 @@
package types package types
import ( import (
"context"
"encoding/json" "encoding/json"
"reflect" "reflect"
"testing" "testing"
@ -11,8 +12,19 @@ import (
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
) )
func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) { type FakeRoomserverAPI struct{}
return spec.NewUserID(senderID, true)
func (f *FakeRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
if senderID == "" {
return nil, nil
}
return spec.NewUserID(string(senderID), true)
}
func (f *FakeRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (*spec.SenderID, error) {
sender := spec.SenderID(userID.String())
return &sender, nil
} }
func TestSyncTokens(t *testing.T) { func TestSyncTokens(t *testing.T) {
@ -72,14 +84,18 @@ func TestNewInviteResponse(t *testing.T) {
skString := skUserID.String() skString := skUserID.String()
sk := &skString sk := &skString
res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender, sk, synctypes.FormatSync) rsAPI := FakeRoomserverAPI{}
res, err := NewInviteResponse(context.Background(), &rsAPI, &types.HeaderedEvent{PDU: ev}, *sender, sk, synctypes.FormatSync)
if err != nil {
t.Fatal(err)
}
j, err := json.Marshal(res) j, err := json.Marshal(res)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if string(j) != expected { if string(j) != expected {
t.Fatalf("Invite response didn't contain correct info") t.Fatalf("Invite response didn't contain correct info, \nexpected: %s \ngot: %s", expected, string(j))
} }
} }

View file

@ -321,9 +321,14 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst
return fmt.Errorf("queryUserIDForSender: userID unknown for %s", *sk) return fmt.Errorf("queryUserIDForSender: userID unknown for %s", *sk)
} }
} }
cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, sender.String(), sk, event.Unsigned()) cevent, err := synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, sender.String(), sk)
if err != nil {
return err
}
var member *localMembership var member *localMembership
member, err = newLocalMembership(&cevent) member, err = newLocalMembership(cevent)
if err != nil { if err != nil {
return fmt.Errorf("newLocalMembership: %w", err) return fmt.Errorf("newLocalMembership: %w", err)
} }
@ -561,12 +566,18 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype
sk = &skString sk = &skString
} }
} }
clientEvent, err := synctypes.ToClientEvent(event, synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, sender.String(), sk)
if err != nil {
return err
}
n := &api.Notification{ n := &api.Notification{
Actions: actions, Actions: actions,
// UNSPEC: the spec doesn't say this is a ClientEvent, but the // UNSPEC: the spec doesn't say this is a ClientEvent, but the
// fields seem to match. room_id should be missing, which // fields seem to match. room_id should be missing, which
// matches the behaviour of FormatSync. // matches the behaviour of FormatSync.
Event: synctypes.ToClientEvent(event, synctypes.FormatSync, sender.String(), sk, event.Unsigned()), Event: *clientEvent,
// TODO: this is per-device, but it's not part of the primary // TODO: this is per-device, but it's not part of the primary
// key. So inserting one notification per profile tag doesn't // key. So inserting one notification per profile tag doesn't
// make sense. What is this supposed to be? Sytests require it // make sense. What is this supposed to be? Sytests require it

View file

@ -23,6 +23,14 @@ import (
userUtil "github.com/matrix-org/dendrite/userapi/util" userUtil "github.com/matrix-org/dendrite/userapi/util"
) )
func queryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
if senderID == "" {
return nil, nil
}
return spec.NewUserID(string(senderID), true)
}
func TestNotifyUserCountsAsync(t *testing.T) { func TestNotifyUserCountsAsync(t *testing.T) {
alice := test.NewUser(t) alice := test.NewUser(t)
aliceLocalpart, serverName, err := gomatrixserverlib.SplitID('@', alice.ID) aliceLocalpart, serverName, err := gomatrixserverlib.SplitID('@', alice.ID)
@ -105,8 +113,11 @@ func TestNotifyUserCountsAsync(t *testing.T) {
t.Error(err) t.Error(err)
} }
sk := "" sk := ""
ev, err := synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return queryUserIDForSender(context.Background(), roomID, senderID)
}, sender.String(), &sk)
if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{ if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{
Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, sender.String(), &sk, dummyEvent.Unsigned()), Event: *ev,
}); err != nil { }); err != nil {
t.Error(err) t.Error(err)
} }