From cd308012df7a8ea2723cd2f74c4c84591708a25d Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Wed, 7 Jun 2023 16:24:58 -0600 Subject: [PATCH] Inject UserID into state_key field of ClientEvent --- clientapi/routing/state.go | 33 ++++++++++++++++++++++++--- go.mod | 2 +- go.sum | 4 ++-- roomserver/internal/query/query.go | 22 ++++++++++++++++-- syncapi/routing/context.go | 11 ++++++++- syncapi/routing/getevent.go | 11 ++++++++- syncapi/routing/relations.go | 11 ++++++++- syncapi/routing/search.go | 11 ++++++++- syncapi/streams/stream_invite.go | 11 ++++++++- syncapi/synctypes/clientevent.go | 15 +++++++++--- syncapi/synctypes/clientevent_test.go | 6 +++-- syncapi/types/types.go | 4 ++-- syncapi/types/types_test.go | 8 ++++++- userapi/consumers/roomserver.go | 22 ++++++++++++++++-- userapi/util/notify_test.go | 3 ++- 15 files changed, 150 insertions(+), 24 deletions(-) diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go index 13f308998..359a3cca3 100644 --- a/clientapi/routing/state.go +++ b/clientapi/routing/state.go @@ -145,9 +145,18 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a if err == nil && userID != nil { sender = *userID } + + sk := ev.StateKey() + if sk != nil && *sk != "" { + skUserID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey())) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } stateEvents = append( stateEvents, - synctypes.ToClientEvent(ev, synctypes.FormatAll, sender), + synctypes.ToClientEvent(ev, synctypes.FormatAll, sender, sk), ) } } else { @@ -172,9 +181,18 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a if err == nil && userID != nil { sender = *userID } + + sk := ev.StateKey() + if sk != nil && *sk != "" { + skUserID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey())) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } stateEvents = append( stateEvents, - synctypes.ToClientEvent(ev, synctypes.FormatAll, sender), + synctypes.ToClientEvent(ev, synctypes.FormatAll, sender, sk), ) } } @@ -349,8 +367,17 @@ func OnIncomingStateTypeRequest( if err == nil && userID != nil { sender = *userID } + + sk := event.StateKey() + if sk != nil && *sk != "" { + skUserID, err := rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey())) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } stateEvent := stateEventInStateResp{ - ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll, sender), + ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll, sender, sk), } var res interface{} diff --git a/go.mod b/go.mod index 3621428c3..23da31ec1 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d + github.com/matrix-org/gomatrixserverlib v0.0.0-20230607195007-f30c42b17b85 github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.16 diff --git a/go.sum b/go.sum index 1ee0261f6..2e8d258e5 100644 --- a/go.sum +++ b/go.sum @@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d h1:MjL8SXRzhO61aXDFL+gA3Bx1SicqLGL9gCWXDv8jkD8= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230607195007-f30c42b17b85 h1:WkmHgdQvpPBxGo/huuCQVR9FWmdkKOcAuQHZBmcHK3g= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230607195007-f30c42b17b85/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index ae2b7cf57..a1dde341b 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -393,7 +393,16 @@ func (r *Queryer) QueryMembershipsForRoom( if queryErr == nil && userID != nil { sender = *userID } - clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender) + + sk := event.StateKey() + if sk != nil && *sk != "" { + skUserID, err := r.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey())) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } + clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender, sk) response.JoinEvents = append(response.JoinEvents, clientEvent) } return nil @@ -447,7 +456,16 @@ func (r *Queryer) QueryMembershipsForRoom( if err == nil && userID != nil { sender = *userID } - clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender) + + sk := event.StateKey() + if sk != nil && *sk != "" { + skUserID, err := r.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey())) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } + clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender, sk) response.JoinEvents = append(response.JoinEvents, clientEvent) } diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index 7fb88faaa..d402e468c 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -222,7 +222,16 @@ func Context( if err == nil && userID != nil { sender = *userID } - ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, sender) + + sk := requestedEvent.StateKey() + if sk != nil && *sk != "" { + skUserID, err := rsAPI.QueryUserIDForSender(ctx, requestedEvent.RoomID(), spec.SenderID(*requestedEvent.StateKey())) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } + ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, sender, sk) response := ContextRespsonse{ Event: &ev, EventsAfter: eventsAfterClient, diff --git a/syncapi/routing/getevent.go b/syncapi/routing/getevent.go index 63df7e837..de790e5cd 100644 --- a/syncapi/routing/getevent.go +++ b/syncapi/routing/getevent.go @@ -106,8 +106,17 @@ func GetEvent( if err == nil && senderUserID != nil { sender = *senderUserID } + + sk := events[0].StateKey() + if sk != nil && *sk != "" { + skUserID, err := rsAPI.QueryUserIDForSender(ctx, events[0].RoomID(), spec.SenderID(*events[0].StateKey())) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } return util.JSONResponse{ Code: http.StatusOK, - JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender), + JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender, sk), } } diff --git a/syncapi/routing/relations.go b/syncapi/routing/relations.go index f21c684c8..6efa065a9 100644 --- a/syncapi/routing/relations.go +++ b/syncapi/routing/relations.go @@ -119,9 +119,18 @@ func Relations( if err == nil && userID != nil { sender = *userID } + + sk := event.StateKey() + if sk != nil && *sk != "" { + skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey())) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } res.Chunk = append( res.Chunk, - synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender), + synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender, sk), ) } diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go index add50b181..7d9182f47 100644 --- a/syncapi/routing/search.go +++ b/syncapi/routing/search.go @@ -235,6 +235,15 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts if err == nil && userID != nil { sender = *userID } + + sk := event.StateKey() + if sk != nil && *sk != "" { + skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey())) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } results = append(results, Result{ Context: SearchContextResponse{ Start: startToken.String(), @@ -248,7 +257,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts ProfileInfo: profileInfos, }, Rank: eventScore[event.EventID()].Score, - Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender), + Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender, sk), }) roomGroup := groups[event.RoomID()] roomGroup.Results = append(roomGroup.Results, event.EventID()) diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go index a8b0a7b66..3a5badd92 100644 --- a/syncapi/streams/stream_invite.go +++ b/syncapi/streams/stream_invite.go @@ -70,11 +70,20 @@ func (p *InviteStreamProvider) IncrementalSync( user = *sender } + sk := inviteEvent.StateKey() + if sk != nil && *sk != "" { + skUserID, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey())) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } + // skip ignored user events if _, ok := req.IgnoredUsers.List[user.String()]; ok { continue } - ir := types.NewInviteResponse(inviteEvent, user) + ir := types.NewInviteResponse(inviteEvent, user, sk) req.Response.Rooms.Invite[roomID] = ir } diff --git a/syncapi/synctypes/clientevent.go b/syncapi/synctypes/clientevent.go index 66fb1d01f..06f52175b 100644 --- a/syncapi/synctypes/clientevent.go +++ b/syncapi/synctypes/clientevent.go @@ -55,18 +55,27 @@ func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat, if err == nil && userID != nil { sender = *userID } - evs = append(evs, ToClientEvent(se, format, sender)) + + sk := se.StateKey() + if sk != nil && *sk != "" { + skUserID, err := userIDForSender(se.RoomID(), spec.SenderID(*sk)) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } + evs = append(evs, ToClientEvent(se, format, sender, sk)) } return evs } // ToClientEvent converts a single server event to a client event. -func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID) ClientEvent { +func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID, stateKey *string) ClientEvent { ce := ClientEvent{ Content: spec.RawJSON(se.Content()), Sender: sender.String(), Type: se.Type(), - StateKey: se.StateKey(), + StateKey: stateKey, Unsigned: spec.RawJSON(se.Unsigned()), OriginServerTS: se.OriginServerTS(), EventID: se.EventID(), diff --git a/syncapi/synctypes/clientevent_test.go b/syncapi/synctypes/clientevent_test.go index 341795081..63c65b2af 100644 --- a/syncapi/synctypes/clientevent_test.go +++ b/syncapi/synctypes/clientevent_test.go @@ -48,7 +48,8 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo if err != nil { t.Fatalf("failed to create userID: %s", err) } - ce := ToClientEvent(ev, FormatAll, *userID) + sk := "" + ce := ToClientEvent(ev, FormatAll, *userID, &sk) if ce.EventID != ev.EventID() { t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID) } @@ -107,7 +108,8 @@ func TestToClientFormatSync(t *testing.T) { if err != nil { t.Fatalf("failed to create userID: %s", err) } - ce := ToClientEvent(ev, FormatSync, *userID) + sk := "" + ce := ToClientEvent(ev, FormatSync, *userID, &sk) if ce.RoomID != "" { t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID) } diff --git a/syncapi/types/types.go b/syncapi/types/types.go index a3dc7f54b..cb3c362d5 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -539,7 +539,7 @@ type InviteResponse struct { } // NewInviteResponse creates an empty response with initialised arrays. -func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID) *InviteResponse { +func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID, stateKey *string) *InviteResponse { res := InviteResponse{} res.InviteState.Events = []json.RawMessage{} @@ -552,7 +552,7 @@ func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID) *InviteRe // Then we'll see if we can create a partial of the invite event itself. // This is needed for clients to work out *who* sent the invite. - inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID) + inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID, stateKey) inviteEvent.Unsigned = nil if ev, err := json.Marshal(inviteEvent); err == nil { res.InviteState.Events = append(res.InviteState.Events, ev) diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go index a79ce5417..c1b7f70bd 100644 --- a/syncapi/types/types_test.go +++ b/syncapi/types/types_test.go @@ -65,8 +65,14 @@ func TestNewInviteResponse(t *testing.T) { if err != nil { t.Fatal(err) } + skUserID, err := spec.NewUserID("@neilalexander:dendrite.neilalexander.dev", true) + if err != nil { + t.Fatal(err) + } + skString := skUserID.String() + sk := &skString - res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender) + res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender, sk) j, err := json.Marshal(res) if err != nil { t.Fatal(err) diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index df507eb26..bf63c69dc 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -306,7 +306,16 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst if queryErr == nil && userID != nil { sender = *userID } - cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, sender) + + sk := event.StateKey() + if sk != nil && *sk != "" { + skUserID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey())) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } + cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, sender, sk) var member *localMembership member, err = newLocalMembership(&cevent) if err != nil { @@ -539,12 +548,21 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype if err == nil && userID != nil { sender = *userID } + + sk := event.StateKey() + if sk != nil && *sk != "" { + skUserID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey())) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } n := &api.Notification{ Actions: actions, // UNSPEC: the spec doesn't say this is a ClientEvent, but the // fields seem to match. room_id should be missing, which // matches the behaviour of FormatSync. - Event: synctypes.ToClientEvent(event, synctypes.FormatSync, sender), + Event: synctypes.ToClientEvent(event, synctypes.FormatSync, sender, sk), // TODO: this is per-device, but it's not part of the primary // key. So inserting one notification per profile tag doesn't // make sense. What is this supposed to be? Sytests require it diff --git a/userapi/util/notify_test.go b/userapi/util/notify_test.go index 27dd373c2..3017069bc 100644 --- a/userapi/util/notify_test.go +++ b/userapi/util/notify_test.go @@ -104,8 +104,9 @@ func TestNotifyUserCountsAsync(t *testing.T) { if err != nil { t.Error(err) } + sk := "" if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{ - Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, *sender), + Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, *sender, &sk), }); err != nil { t.Error(err) }