Inject UserID into state_key field of ClientEvent

This commit is contained in:
Devon Hudson 2023-06-07 16:24:58 -06:00
parent 8ea1a11105
commit cd308012df
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
15 changed files with 150 additions and 24 deletions

View file

@ -145,9 +145,18 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
if err == nil && userID != nil { if err == nil && userID != nil {
sender = *userID sender = *userID
} }
sk := ev.StateKey()
if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
stateEvents = append( stateEvents = append(
stateEvents, stateEvents,
synctypes.ToClientEvent(ev, synctypes.FormatAll, sender), synctypes.ToClientEvent(ev, synctypes.FormatAll, sender, sk),
) )
} }
} else { } else {
@ -172,9 +181,18 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
if err == nil && userID != nil { if err == nil && userID != nil {
sender = *userID sender = *userID
} }
sk := ev.StateKey()
if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
stateEvents = append( stateEvents = append(
stateEvents, stateEvents,
synctypes.ToClientEvent(ev, synctypes.FormatAll, sender), synctypes.ToClientEvent(ev, synctypes.FormatAll, sender, sk),
) )
} }
} }
@ -349,8 +367,17 @@ func OnIncomingStateTypeRequest(
if err == nil && userID != nil { if err == nil && userID != nil {
sender = *userID sender = *userID
} }
sk := event.StateKey()
if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
stateEvent := stateEventInStateResp{ stateEvent := stateEventInStateResp{
ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll, sender), ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll, sender, sk),
} }
var res interface{} var res interface{}

2
go.mod
View file

@ -22,7 +22,7 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d github.com/matrix-org/gomatrixserverlib v0.0.0-20230607195007-f30c42b17b85
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
github.com/mattn/go-sqlite3 v1.14.16 github.com/mattn/go-sqlite3 v1.14.16

4
go.sum
View file

@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d h1:MjL8SXRzhO61aXDFL+gA3Bx1SicqLGL9gCWXDv8jkD8= github.com/matrix-org/gomatrixserverlib v0.0.0-20230607195007-f30c42b17b85 h1:WkmHgdQvpPBxGo/huuCQVR9FWmdkKOcAuQHZBmcHK3g=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/gomatrixserverlib v0.0.0-20230607195007-f30c42b17b85/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ=
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=

View file

@ -393,7 +393,16 @@ func (r *Queryer) QueryMembershipsForRoom(
if queryErr == nil && userID != nil { if queryErr == nil && userID != nil {
sender = *userID sender = *userID
} }
clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender)
sk := event.StateKey()
if sk != nil && *sk != "" {
skUserID, err := r.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender, sk)
response.JoinEvents = append(response.JoinEvents, clientEvent) response.JoinEvents = append(response.JoinEvents, clientEvent)
} }
return nil return nil
@ -447,7 +456,16 @@ func (r *Queryer) QueryMembershipsForRoom(
if err == nil && userID != nil { if err == nil && userID != nil {
sender = *userID sender = *userID
} }
clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender)
sk := event.StateKey()
if sk != nil && *sk != "" {
skUserID, err := r.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender, sk)
response.JoinEvents = append(response.JoinEvents, clientEvent) response.JoinEvents = append(response.JoinEvents, clientEvent)
} }

View file

@ -222,7 +222,16 @@ func Context(
if err == nil && userID != nil { if err == nil && userID != nil {
sender = *userID sender = *userID
} }
ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, sender)
sk := requestedEvent.StateKey()
if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(ctx, requestedEvent.RoomID(), spec.SenderID(*requestedEvent.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, sender, sk)
response := ContextRespsonse{ response := ContextRespsonse{
Event: &ev, Event: &ev,
EventsAfter: eventsAfterClient, EventsAfter: eventsAfterClient,

View file

@ -106,8 +106,17 @@ func GetEvent(
if err == nil && senderUserID != nil { if err == nil && senderUserID != nil {
sender = *senderUserID sender = *senderUserID
} }
sk := events[0].StateKey()
if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(ctx, events[0].RoomID(), spec.SenderID(*events[0].StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender), JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender, sk),
} }
} }

View file

@ -119,9 +119,18 @@ func Relations(
if err == nil && userID != nil { if err == nil && userID != nil {
sender = *userID sender = *userID
} }
sk := event.StateKey()
if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
res.Chunk = append( res.Chunk = append(
res.Chunk, res.Chunk,
synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender), synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender, sk),
) )
} }

View file

@ -235,6 +235,15 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
if err == nil && userID != nil { if err == nil && userID != nil {
sender = *userID sender = *userID
} }
sk := event.StateKey()
if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
results = append(results, Result{ results = append(results, Result{
Context: SearchContextResponse{ Context: SearchContextResponse{
Start: startToken.String(), Start: startToken.String(),
@ -248,7 +257,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
ProfileInfo: profileInfos, ProfileInfo: profileInfos,
}, },
Rank: eventScore[event.EventID()].Score, Rank: eventScore[event.EventID()].Score,
Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender), Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender, sk),
}) })
roomGroup := groups[event.RoomID()] roomGroup := groups[event.RoomID()]
roomGroup.Results = append(roomGroup.Results, event.EventID()) roomGroup.Results = append(roomGroup.Results, event.EventID())

View file

@ -70,11 +70,20 @@ func (p *InviteStreamProvider) IncrementalSync(
user = *sender user = *sender
} }
sk := inviteEvent.StateKey()
if sk != nil && *sk != "" {
skUserID, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
// skip ignored user events // skip ignored user events
if _, ok := req.IgnoredUsers.List[user.String()]; ok { if _, ok := req.IgnoredUsers.List[user.String()]; ok {
continue continue
} }
ir := types.NewInviteResponse(inviteEvent, user) ir := types.NewInviteResponse(inviteEvent, user, sk)
req.Response.Rooms.Invite[roomID] = ir req.Response.Rooms.Invite[roomID] = ir
} }

View file

@ -55,18 +55,27 @@ func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat,
if err == nil && userID != nil { if err == nil && userID != nil {
sender = *userID sender = *userID
} }
evs = append(evs, ToClientEvent(se, format, sender))
sk := se.StateKey()
if sk != nil && *sk != "" {
skUserID, err := userIDForSender(se.RoomID(), spec.SenderID(*sk))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
evs = append(evs, ToClientEvent(se, format, sender, sk))
} }
return evs return evs
} }
// ToClientEvent converts a single server event to a client event. // ToClientEvent converts a single server event to a client event.
func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID) ClientEvent { func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID, stateKey *string) ClientEvent {
ce := ClientEvent{ ce := ClientEvent{
Content: spec.RawJSON(se.Content()), Content: spec.RawJSON(se.Content()),
Sender: sender.String(), Sender: sender.String(),
Type: se.Type(), Type: se.Type(),
StateKey: se.StateKey(), StateKey: stateKey,
Unsigned: spec.RawJSON(se.Unsigned()), Unsigned: spec.RawJSON(se.Unsigned()),
OriginServerTS: se.OriginServerTS(), OriginServerTS: se.OriginServerTS(),
EventID: se.EventID(), EventID: se.EventID(),

View file

@ -48,7 +48,8 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo
if err != nil { if err != nil {
t.Fatalf("failed to create userID: %s", err) t.Fatalf("failed to create userID: %s", err)
} }
ce := ToClientEvent(ev, FormatAll, *userID) sk := ""
ce := ToClientEvent(ev, FormatAll, *userID, &sk)
if ce.EventID != ev.EventID() { if ce.EventID != ev.EventID() {
t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID) t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID)
} }
@ -107,7 +108,8 @@ func TestToClientFormatSync(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("failed to create userID: %s", err) t.Fatalf("failed to create userID: %s", err)
} }
ce := ToClientEvent(ev, FormatSync, *userID) sk := ""
ce := ToClientEvent(ev, FormatSync, *userID, &sk)
if ce.RoomID != "" { if ce.RoomID != "" {
t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID) t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID)
} }

View file

@ -539,7 +539,7 @@ type InviteResponse struct {
} }
// NewInviteResponse creates an empty response with initialised arrays. // NewInviteResponse creates an empty response with initialised arrays.
func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID) *InviteResponse { func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID, stateKey *string) *InviteResponse {
res := InviteResponse{} res := InviteResponse{}
res.InviteState.Events = []json.RawMessage{} res.InviteState.Events = []json.RawMessage{}
@ -552,7 +552,7 @@ func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID) *InviteRe
// Then we'll see if we can create a partial of the invite event itself. // Then we'll see if we can create a partial of the invite event itself.
// This is needed for clients to work out *who* sent the invite. // This is needed for clients to work out *who* sent the invite.
inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID) inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID, stateKey)
inviteEvent.Unsigned = nil inviteEvent.Unsigned = nil
if ev, err := json.Marshal(inviteEvent); err == nil { if ev, err := json.Marshal(inviteEvent); err == nil {
res.InviteState.Events = append(res.InviteState.Events, ev) res.InviteState.Events = append(res.InviteState.Events, ev)

View file

@ -65,8 +65,14 @@ func TestNewInviteResponse(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
skUserID, err := spec.NewUserID("@neilalexander:dendrite.neilalexander.dev", true)
if err != nil {
t.Fatal(err)
}
skString := skUserID.String()
sk := &skString
res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender) res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender, sk)
j, err := json.Marshal(res) j, err := json.Marshal(res)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View file

@ -306,7 +306,16 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst
if queryErr == nil && userID != nil { if queryErr == nil && userID != nil {
sender = *userID sender = *userID
} }
cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, sender)
sk := event.StateKey()
if sk != nil && *sk != "" {
skUserID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, sender, sk)
var member *localMembership var member *localMembership
member, err = newLocalMembership(&cevent) member, err = newLocalMembership(&cevent)
if err != nil { if err != nil {
@ -539,12 +548,21 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype
if err == nil && userID != nil { if err == nil && userID != nil {
sender = *userID sender = *userID
} }
sk := event.StateKey()
if sk != nil && *sk != "" {
skUserID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
n := &api.Notification{ n := &api.Notification{
Actions: actions, Actions: actions,
// UNSPEC: the spec doesn't say this is a ClientEvent, but the // UNSPEC: the spec doesn't say this is a ClientEvent, but the
// fields seem to match. room_id should be missing, which // fields seem to match. room_id should be missing, which
// matches the behaviour of FormatSync. // matches the behaviour of FormatSync.
Event: synctypes.ToClientEvent(event, synctypes.FormatSync, sender), Event: synctypes.ToClientEvent(event, synctypes.FormatSync, sender, sk),
// TODO: this is per-device, but it's not part of the primary // TODO: this is per-device, but it's not part of the primary
// key. So inserting one notification per profile tag doesn't // key. So inserting one notification per profile tag doesn't
// make sense. What is this supposed to be? Sytests require it // make sense. What is this supposed to be? Sytests require it

View file

@ -104,8 +104,9 @@ func TestNotifyUserCountsAsync(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
sk := ""
if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{ if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{
Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, *sender), Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, *sender, &sk),
}); err != nil { }); err != nil {
t.Error(err) t.Error(err)
} }