diff --git a/clientapi/routing/state_test.go b/clientapi/routing/state_test.go new file mode 100644 index 000000000..fef8017e3 --- /dev/null +++ b/clientapi/routing/state_test.go @@ -0,0 +1,199 @@ +package routing + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + rsapi "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/setup/config" + uapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/util" + "gotest.tools/v3/assert" +) + +var () + +type testRoomserverAPI struct { + rsapi.RoomserverInternalAPI + t *testing.T + roomState map[gomatrixserverlib.StateKeyTuple]*types.HeaderedEvent + roomIDStr string + roomVersion gomatrixserverlib.RoomVersion + userIDStr string + // userID -> senderID + senderMapping map[string]string +} + +func (s testRoomserverAPI) QueryLatestEventsAndState( + ctx context.Context, + req *rsapi.QueryLatestEventsAndStateRequest, + res *rsapi.QueryLatestEventsAndStateResponse, +) error { + res.RoomExists = req.RoomID == s.roomIDStr + if !res.RoomExists { + return nil + } + + res.StateEvents = []*types.HeaderedEvent{} + for _, stateKeyTuple := range req.StateToFetch { + val, ok := s.roomState[stateKeyTuple] + if ok && val != nil { + res.StateEvents = append(res.StateEvents, val) + } + } + + return nil +} + +func (s testRoomserverAPI) QueryMembershipForUser( + ctx context.Context, + req *rsapi.QueryMembershipForUserRequest, + res *rsapi.QueryMembershipForUserResponse, +) error { + if req.UserID.String() == s.userIDStr { + res.HasBeenInRoom = true + res.IsInRoom = true + res.RoomExists = true + res.Membership = spec.Join + } + + return nil +} + +func (s testRoomserverAPI) QuerySenderIDForUser( + ctx context.Context, + roomID spec.RoomID, + userID spec.UserID, +) (*spec.SenderID, error) { + sID, ok := s.senderMapping[userID.String()] + if ok { + sender := spec.SenderID(sID) + return &sender, nil + } else { + return nil, nil + } +} + +func (s testRoomserverAPI) QueryUserIDForSender( + ctx context.Context, + roomID spec.RoomID, + senderID spec.SenderID, +) (*spec.UserID, error) { + for uID, sID := range s.senderMapping { + if sID == string(senderID) { + parsedUserID, err := spec.NewUserID(uID, true) + if err != nil { + s.t.Fatalf("Mock QueryUserIDForSender failed: %s", err) + } + return parsedUserID, nil + } + } + return nil, nil +} + +func (s testRoomserverAPI) QueryStateAfterEvents( + ctx context.Context, + req *rsapi.QueryStateAfterEventsRequest, + res *rsapi.QueryStateAfterEventsResponse, +) error { + return nil +} + +func Test_OnIncomingStateTypeRequest(t *testing.T) { + var tempRoomServerCfg config.RoomServer + tempRoomServerCfg.Defaults(config.DefaultOpts{}) + defaultRoomVersion := tempRoomServerCfg.DefaultRoomVersion + pseudoIDRoomVersion := gomatrixserverlib.RoomVersionPseudoIDs + + userIDStr := "@testuser:domain" + eventType := "com.example.test" + stateKey := "testStateKey" + roomIDStr := "!id:domain" + + device := &uapi.Device{ + UserID: userIDStr, + } + + t.Run("request simple state key", func(t *testing.T) { + ctx := context.Background() + + rsAPI := testRoomserverAPI{ + roomVersion: defaultRoomVersion, + roomIDStr: roomIDStr, + roomState: map[gomatrixserverlib.StateKeyTuple]*types.HeaderedEvent{ + { + EventType: eventType, + StateKey: stateKey, + }: mustCreateStatePDU(t, defaultRoomVersion, roomIDStr, eventType, stateKey, map[string]interface{}{ + "foo": "bar", + }), + }, + userIDStr: userIDStr, + } + + jsonResp := OnIncomingStateTypeRequest(ctx, device, rsAPI, roomIDStr, eventType, stateKey, false) + + assert.DeepEqual(t, jsonResp, util.JSONResponse{ + Code: http.StatusOK, + JSON: spec.RawJSON(`{"foo":"bar"}`), + }) + }) + + t.Run("user ID state keys are translated to pseudo IDs in pseudo ID rooms", func(t *testing.T) { + ctx := context.Background() + + stateSenderUserID := "@sender:domain" + stateSenderRoomKey := "testsenderkey" + + rsAPI := testRoomserverAPI{ + roomVersion: pseudoIDRoomVersion, + roomIDStr: roomIDStr, + roomState: map[gomatrixserverlib.StateKeyTuple]*types.HeaderedEvent{ + { + EventType: eventType, + StateKey: stateSenderRoomKey, + }: mustCreateStatePDU(t, pseudoIDRoomVersion, roomIDStr, eventType, stateSenderRoomKey, map[string]interface{}{ + "foo": "bar", + }), + }, + userIDStr: userIDStr, + senderMapping: map[string]string{ + stateSenderUserID: stateSenderRoomKey, + }, + } + + jsonResp := OnIncomingStateTypeRequest(ctx, device, rsAPI, roomIDStr, eventType, stateSenderUserID, false) + + assert.DeepEqual(t, jsonResp, util.JSONResponse{ + Code: http.StatusOK, + JSON: spec.RawJSON(`{"foo":"bar"}`), + }) + }) +} + +func mustCreateStatePDU(t *testing.T, roomVer gomatrixserverlib.RoomVersion, roomID string, stateType string, stateKey string, stateContent map[string]interface{}) *types.HeaderedEvent { + t.Helper() + roomVerImpl := gomatrixserverlib.MustGetRoomVersion(roomVer) + + evBytes, err := json.Marshal(map[string]interface{}{ + "room_id": roomID, + "type": stateType, + "state_key": stateKey, + "content": stateContent, + }) + if err != nil { + t.Fatalf("failed to create event: %v", err) + } + + ev, err := roomVerImpl.NewEventFromTrustedJSON(evBytes, false) + if err != nil { + t.Fatalf("failed to create event: %v", err) + } + + return &types.HeaderedEvent{PDU: ev} +}