From 96e5690b44d5531acf3e065d6f63d3ad08c40e4f Mon Sep 17 00:00:00 2001 From: Sam Wedgwood Date: Wed, 16 Aug 2023 15:45:30 +0100 Subject: [PATCH] move translation logic to synctypes + another test --- clientapi/routing/state.go | 58 ++++++++++---------------------- clientapi/routing/state_test.go | 46 ++++++++++++++++++++++++- syncapi/synctypes/clientevent.go | 43 +++++++++++++++++++++++ 3 files changed, 105 insertions(+), 42 deletions(-) diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go index b0e781e8a..04c0bb482 100644 --- a/clientapi/routing/state.go +++ b/clientapi/routing/state.go @@ -226,9 +226,23 @@ func OnIncomingStateTypeRequest( } // Handle user ID state keys appropriately - newStateKey, errorResp := translateUserStateKey(ctx, rsAPI, stateKey, evType, *parsedRoomID) - if errorResp != nil { - return *errorResp + newStateKey, invalidUserIDOrNoSender, err := synctypes.FromClientStateKey(*parsedRoomID, stateKey, func(roomID spec.RoomID, userID spec.UserID) (*spec.SenderID, error) { + return rsAPI.QuerySenderIDForUser(ctx, roomID, userID) + }) + if err != nil { + if invalidUserIDOrNoSender { + // Currently treat this as no state found - see comment for FromClientStateKey + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound(fmt.Sprintf("Cannot find state event for %q", evType)), + } + } else { + util.GetLogger(ctx).WithError(err).Error("synctypes.FromClientStateKey failed") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.Unknown("internal server error"), + } + } } stateKey = *newStateKey @@ -404,41 +418,3 @@ func OnIncomingStateTypeRequest( JSON: res, } } - -// If provided state key is a user ID (state keys beginning with @ are reserved for this purpose) -// fetch it's associated sender ID and use that instead. Otherwise returns the same state key back. -// -// This function either returns the state key that should be used, or a JSON error response that should be returned to the user. -// -// TODO: if any step of this process fails, should we fail with 404, or silently continue without using sender ID? -func translateUserStateKey(ctx context.Context, rsAPI api.ClientRoomserverAPI, stateKey string, evType string, roomID spec.RoomID) (*string, *util.JSONResponse) { - if len(stateKey) >= 1 && stateKey[0] == '@' { - parsedStateKey, err := spec.NewUserID(stateKey, true) - if err != nil { - // If invalid user ID, then there is no associated state event. - return nil, &util.JSONResponse{ - Code: http.StatusNotFound, - JSON: spec.NotFound(fmt.Sprintf("Cannot find state event for %q", evType)), - } - } - senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *parsedStateKey) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("QuerySenderIDForUser failed") - return nil, &util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.Unknown("internal server error"), - } - } - if senderID == nil { - // If no sender ID, then there is no associated state event. - return nil, &util.JSONResponse{ - Code: http.StatusNotFound, - JSON: spec.NotFound(fmt.Sprintf("Cannot find state event for %q", evType)), - } - } - newStateKey := string(*senderID) - return &newStateKey, nil - } else { - return &stateKey, nil - } -} diff --git a/clientapi/routing/state_test.go b/clientapi/routing/state_test.go index fef8017e3..cc4af9660 100644 --- a/clientapi/routing/state_test.go +++ b/clientapi/routing/state_test.go @@ -109,6 +109,7 @@ func Test_OnIncomingStateTypeRequest(t *testing.T) { tempRoomServerCfg.Defaults(config.DefaultOpts{}) defaultRoomVersion := tempRoomServerCfg.DefaultRoomVersion pseudoIDRoomVersion := gomatrixserverlib.RoomVersionPseudoIDs + nonPseudoIDRoomVersion := gomatrixserverlib.RoomVersionV10 userIDStr := "@testuser:domain" eventType := "com.example.test" @@ -144,7 +145,7 @@ func Test_OnIncomingStateTypeRequest(t *testing.T) { }) }) - t.Run("user ID state keys are translated to pseudo IDs in pseudo ID rooms", func(t *testing.T) { + t.Run("user ID key translated to room key in pseudo ID rooms", func(t *testing.T) { ctx := context.Background() stateSenderUserID := "@sender:domain" @@ -160,6 +161,12 @@ func Test_OnIncomingStateTypeRequest(t *testing.T) { }: mustCreateStatePDU(t, pseudoIDRoomVersion, roomIDStr, eventType, stateSenderRoomKey, map[string]interface{}{ "foo": "bar", }), + { + EventType: eventType, + StateKey: stateSenderUserID, + }: mustCreateStatePDU(t, pseudoIDRoomVersion, roomIDStr, eventType, stateSenderUserID, map[string]interface{}{ + "not": "thisone", + }), }, userIDStr: userIDStr, senderMapping: map[string]string{ @@ -174,6 +181,43 @@ func Test_OnIncomingStateTypeRequest(t *testing.T) { JSON: spec.RawJSON(`{"foo":"bar"}`), }) }) + + t.Run("user ID key not translated to room key in non-pseudo ID rooms", func(t *testing.T) { + ctx := context.Background() + + stateSenderUserID := "@sender:domain" + stateSenderRoomKey := "testsenderkey" + + rsAPI := testRoomserverAPI{ + roomVersion: nonPseudoIDRoomVersion, + roomIDStr: roomIDStr, + roomState: map[gomatrixserverlib.StateKeyTuple]*types.HeaderedEvent{ + { + EventType: eventType, + StateKey: stateSenderRoomKey, + }: mustCreateStatePDU(t, nonPseudoIDRoomVersion, roomIDStr, eventType, stateSenderRoomKey, map[string]interface{}{ + "not": "thisone", + }), + { + EventType: eventType, + StateKey: stateSenderUserID, + }: mustCreateStatePDU(t, nonPseudoIDRoomVersion, roomIDStr, eventType, stateSenderUserID, map[string]interface{}{ + "foo": "bar", + }), + }, + userIDStr: userIDStr, + senderMapping: map[string]string{ + stateSenderUserID: stateSenderUserID, + }, + } + + jsonResp := OnIncomingStateTypeRequest(ctx, device, rsAPI, roomIDStr, eventType, stateSenderUserID, false) + + assert.DeepEqual(t, jsonResp, util.JSONResponse{ + Code: http.StatusOK, + JSON: spec.RawJSON(`{"foo":"bar"}`), + }) + }) } func mustCreateStatePDU(t *testing.T, roomVer gomatrixserverlib.RoomVersion, roomID string, stateType string, stateKey string, stateContent map[string]interface{}) *types.HeaderedEvent { diff --git a/syncapi/synctypes/clientevent.go b/syncapi/synctypes/clientevent.go index 6f03d9ff0..e991c4ed4 100644 --- a/syncapi/synctypes/clientevent.go +++ b/syncapi/synctypes/clientevent.go @@ -16,6 +16,8 @@ package synctypes import ( + "fmt" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" ) @@ -118,3 +120,44 @@ func ToClientEventDefault(userIDQuery spec.UserIDForSender, event gomatrixserver } return ToClientEvent(event, FormatAll, sender, sk) } + +// If provided state key is a user ID (state keys beginning with @ are reserved for this purpose) +// fetch it's associated sender ID and use that instead. Otherwise returns the same state key back. +// +// This function either returns the state key that should be used, or a JSON error response that should be returned to the user. +// A boolean is also returned, which, if true, means the err is a result of either: +// - State key begins with @, but does not contain a valid user ID. +// - State key contains a valid user ID, but this user ID does not have a sender ID. +// +// TODO: it's currently unclear how translation logic should behave in the above two cases - two options for each case: +// - Starts with @ but invalid user ID: +// -- Reject request (e.g. as 404), but people may have state keys, that, against the spec, +// begin with @ and dont contain a valid user ID, which they would be unable to interact with. +// -- Silently ignore, and let them use the invalid user ID without any translation attempt. This will probably work +// but could cause issues down the line with user ID grammar. +// - Valid user ID, but does not have a sender ID +// -- Reject reuquest (e.g. as 404), but people may wish to set a state event with a key containing +// a user ID for someone who has not yet joined the room - this prevents that. +// -- Silently ignore and don't translate - could cause issues where a state event is set with a user ID, then that user joins +// and now querying that same state key returns different state event (as the user now has a pseudo ID) +func FromClientStateKey(roomID spec.RoomID, stateKey string, senderIDQuery spec.SenderIDForUser) (*string, bool, error) { + if len(stateKey) >= 1 && stateKey[0] == '@' { + parsedStateKey, err := spec.NewUserID(stateKey, true) + if err != nil { + // If invalid user ID, then there is no associated state event. + return nil, true, fmt.Errorf("Provided state key begins with @ but is not a valid user ID: %s", err.Error()) + } + senderID, err := senderIDQuery(roomID, *parsedStateKey) + if err != nil { + return nil, false, fmt.Errorf("Failed to query sender ID: %s", err.Error()) + } + if senderID == nil { + // If no sender ID, then there is no associated state event. + return nil, true, fmt.Errorf("No associated sender ID found.") + } + newStateKey := string(*senderID) + return &newStateKey, false, nil + } else { + return &stateKey, false, nil + } +}