Inject UserID into state_key field of ClientEvent

This commit is contained in:
Devon Hudson 2023-06-07 16:24:58 -06:00
parent 8ea1a11105
commit cd308012df
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
15 changed files with 150 additions and 24 deletions

View file

@ -145,9 +145,18 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
if err == nil && userID != nil {
sender = *userID
}
sk := ev.StateKey()
if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
stateEvents = append(
stateEvents,
synctypes.ToClientEvent(ev, synctypes.FormatAll, sender),
synctypes.ToClientEvent(ev, synctypes.FormatAll, sender, sk),
)
}
} else {
@ -172,9 +181,18 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
if err == nil && userID != nil {
sender = *userID
}
sk := ev.StateKey()
if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
stateEvents = append(
stateEvents,
synctypes.ToClientEvent(ev, synctypes.FormatAll, sender),
synctypes.ToClientEvent(ev, synctypes.FormatAll, sender, sk),
)
}
}
@ -349,8 +367,17 @@ func OnIncomingStateTypeRequest(
if err == nil && userID != nil {
sender = *userID
}
sk := event.StateKey()
if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
stateEvent := stateEventInStateResp{
ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll, sender),
ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll, sender, sk),
}
var res interface{}

2
go.mod
View file

@ -22,7 +22,7 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d
github.com/matrix-org/gomatrixserverlib v0.0.0-20230607195007-f30c42b17b85
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
github.com/mattn/go-sqlite3 v1.14.16

4
go.sum
View file

@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d h1:MjL8SXRzhO61aXDFL+gA3Bx1SicqLGL9gCWXDv8jkD8=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230607195007-f30c42b17b85 h1:WkmHgdQvpPBxGo/huuCQVR9FWmdkKOcAuQHZBmcHK3g=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230607195007-f30c42b17b85/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ=
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=

View file

@ -393,7 +393,16 @@ func (r *Queryer) QueryMembershipsForRoom(
if queryErr == nil && userID != nil {
sender = *userID
}
clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender)
sk := event.StateKey()
if sk != nil && *sk != "" {
skUserID, err := r.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender, sk)
response.JoinEvents = append(response.JoinEvents, clientEvent)
}
return nil
@ -447,7 +456,16 @@ func (r *Queryer) QueryMembershipsForRoom(
if err == nil && userID != nil {
sender = *userID
}
clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender)
sk := event.StateKey()
if sk != nil && *sk != "" {
skUserID, err := r.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender, sk)
response.JoinEvents = append(response.JoinEvents, clientEvent)
}

View file

@ -222,7 +222,16 @@ func Context(
if err == nil && userID != nil {
sender = *userID
}
ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, sender)
sk := requestedEvent.StateKey()
if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(ctx, requestedEvent.RoomID(), spec.SenderID(*requestedEvent.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, sender, sk)
response := ContextRespsonse{
Event: &ev,
EventsAfter: eventsAfterClient,

View file

@ -106,8 +106,17 @@ func GetEvent(
if err == nil && senderUserID != nil {
sender = *senderUserID
}
sk := events[0].StateKey()
if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(ctx, events[0].RoomID(), spec.SenderID(*events[0].StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender),
JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender, sk),
}
}

View file

@ -119,9 +119,18 @@ func Relations(
if err == nil && userID != nil {
sender = *userID
}
sk := event.StateKey()
if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
res.Chunk = append(
res.Chunk,
synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender),
synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender, sk),
)
}

View file

@ -235,6 +235,15 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
if err == nil && userID != nil {
sender = *userID
}
sk := event.StateKey()
if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
results = append(results, Result{
Context: SearchContextResponse{
Start: startToken.String(),
@ -248,7 +257,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
ProfileInfo: profileInfos,
},
Rank: eventScore[event.EventID()].Score,
Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender),
Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender, sk),
})
roomGroup := groups[event.RoomID()]
roomGroup.Results = append(roomGroup.Results, event.EventID())

View file

@ -70,11 +70,20 @@ func (p *InviteStreamProvider) IncrementalSync(
user = *sender
}
sk := inviteEvent.StateKey()
if sk != nil && *sk != "" {
skUserID, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
// skip ignored user events
if _, ok := req.IgnoredUsers.List[user.String()]; ok {
continue
}
ir := types.NewInviteResponse(inviteEvent, user)
ir := types.NewInviteResponse(inviteEvent, user, sk)
req.Response.Rooms.Invite[roomID] = ir
}

View file

@ -55,18 +55,27 @@ func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat,
if err == nil && userID != nil {
sender = *userID
}
evs = append(evs, ToClientEvent(se, format, sender))
sk := se.StateKey()
if sk != nil && *sk != "" {
skUserID, err := userIDForSender(se.RoomID(), spec.SenderID(*sk))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
evs = append(evs, ToClientEvent(se, format, sender, sk))
}
return evs
}
// ToClientEvent converts a single server event to a client event.
func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID) ClientEvent {
func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID, stateKey *string) ClientEvent {
ce := ClientEvent{
Content: spec.RawJSON(se.Content()),
Sender: sender.String(),
Type: se.Type(),
StateKey: se.StateKey(),
StateKey: stateKey,
Unsigned: spec.RawJSON(se.Unsigned()),
OriginServerTS: se.OriginServerTS(),
EventID: se.EventID(),

View file

@ -48,7 +48,8 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo
if err != nil {
t.Fatalf("failed to create userID: %s", err)
}
ce := ToClientEvent(ev, FormatAll, *userID)
sk := ""
ce := ToClientEvent(ev, FormatAll, *userID, &sk)
if ce.EventID != ev.EventID() {
t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID)
}
@ -107,7 +108,8 @@ func TestToClientFormatSync(t *testing.T) {
if err != nil {
t.Fatalf("failed to create userID: %s", err)
}
ce := ToClientEvent(ev, FormatSync, *userID)
sk := ""
ce := ToClientEvent(ev, FormatSync, *userID, &sk)
if ce.RoomID != "" {
t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID)
}

View file

@ -539,7 +539,7 @@ type InviteResponse struct {
}
// NewInviteResponse creates an empty response with initialised arrays.
func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID) *InviteResponse {
func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID, stateKey *string) *InviteResponse {
res := InviteResponse{}
res.InviteState.Events = []json.RawMessage{}
@ -552,7 +552,7 @@ func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID) *InviteRe
// Then we'll see if we can create a partial of the invite event itself.
// This is needed for clients to work out *who* sent the invite.
inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID)
inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID, stateKey)
inviteEvent.Unsigned = nil
if ev, err := json.Marshal(inviteEvent); err == nil {
res.InviteState.Events = append(res.InviteState.Events, ev)

View file

@ -65,8 +65,14 @@ func TestNewInviteResponse(t *testing.T) {
if err != nil {
t.Fatal(err)
}
skUserID, err := spec.NewUserID("@neilalexander:dendrite.neilalexander.dev", true)
if err != nil {
t.Fatal(err)
}
skString := skUserID.String()
sk := &skString
res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender)
res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender, sk)
j, err := json.Marshal(res)
if err != nil {
t.Fatal(err)

View file

@ -306,7 +306,16 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst
if queryErr == nil && userID != nil {
sender = *userID
}
cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, sender)
sk := event.StateKey()
if sk != nil && *sk != "" {
skUserID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, sender, sk)
var member *localMembership
member, err = newLocalMembership(&cevent)
if err != nil {
@ -539,12 +548,21 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype
if err == nil && userID != nil {
sender = *userID
}
sk := event.StateKey()
if sk != nil && *sk != "" {
skUserID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
n := &api.Notification{
Actions: actions,
// UNSPEC: the spec doesn't say this is a ClientEvent, but the
// fields seem to match. room_id should be missing, which
// matches the behaviour of FormatSync.
Event: synctypes.ToClientEvent(event, synctypes.FormatSync, sender),
Event: synctypes.ToClientEvent(event, synctypes.FormatSync, sender, sk),
// TODO: this is per-device, but it's not part of the primary
// key. So inserting one notification per profile tag doesn't
// make sense. What is this supposed to be? Sytests require it

View file

@ -104,8 +104,9 @@ func TestNotifyUserCountsAsync(t *testing.T) {
if err != nil {
t.Error(err)
}
sk := ""
if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{
Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, *sender),
Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, *sender, &sk),
}); err != nil {
t.Error(err)
}