From cbad03fc5e15975d5908c86ea6a3bf04fe1af11a Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Tue, 17 Jan 2023 17:22:05 -0700 Subject: [PATCH] Add validation to relay_txn prev entry id --- relayapi/routing/relaytxn.go | 15 +++++++++++++-- relayapi/routing/relaytxn_test.go | 31 ++++++++++++++++++++++++++++--- 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/relayapi/routing/relaytxn.go b/relayapi/routing/relaytxn.go index 1c9e350a5..1b11b0ecd 100644 --- a/relayapi/routing/relaytxn.go +++ b/relayapi/routing/relaytxn.go @@ -18,6 +18,7 @@ import ( "encoding/json" "net/http" + "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/relayapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -41,9 +42,19 @@ func GetTransactionFromRelay( logrus.Infof("Handling relay_txn for %s", userID.Raw()) previousEntry := gomatrixserverlib.RelayEntry{} - if err := json.Unmarshal(fedReq.Content(), &previousEntry); err == nil { - logrus.Infof("Previous entry provided: %v", previousEntry.EntryID) + if err := json.Unmarshal(fedReq.Content(), &previousEntry); err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.BadJSON("invalid json provided"), + } } + if previousEntry.EntryID < 0 { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.BadJSON("Invalid entry id provided. Must be >= 0."), + } + } + logrus.Infof("Previous entry provided: %v", previousEntry.EntryID) response, err := relayAPI.QueryTransactions(httpReq.Context(), userID, previousEntry) if err != nil { diff --git a/relayapi/routing/relaytxn_test.go b/relayapi/routing/relaytxn_test.go index a0b03a398..a47fdb198 100644 --- a/relayapi/routing/relaytxn_test.go +++ b/relayapi/routing/relaytxn_test.go @@ -60,7 +60,7 @@ func TestGetEmptyDatabaseReturnsNothing(t *testing.T) { &db, nil, nil, nil, nil, false, "", ) - request := createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: -1}) + request := createQuery(*userID, gomatrixserverlib.RelayEntry{}) response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) assert.Equal(t, http.StatusOK, response.Code) @@ -73,6 +73,31 @@ func TestGetEmptyDatabaseReturnsNothing(t *testing.T) { assert.Zero(t, count) } +func TestGetInvalidPrevEntryFails(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + transaction := createTransaction() + + _, err = db.StoreTransaction(context.Background(), transaction) + assert.NoError(t, err, "Failed to store transaction") + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + request := createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: -1}) + response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusInternalServerError, response.Code) +} + func TestGetReturnsSavedTransaction(t *testing.T) { testDB := test.NewInMemoryRelayDatabase() db := shared.Database{ @@ -101,7 +126,7 @@ func TestGetReturnsSavedTransaction(t *testing.T) { &db, nil, nil, nil, nil, false, "", ) - request := createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: -1}) + request := createQuery(*userID, gomatrixserverlib.RelayEntry{}) response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) assert.Equal(t, http.StatusOK, response.Code) @@ -164,7 +189,7 @@ func TestGetReturnsMultipleSavedTransactions(t *testing.T) { &db, nil, nil, nil, nil, false, "", ) - request := createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: -1}) + request := createQuery(*userID, gomatrixserverlib.RelayEntry{}) response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) assert.Equal(t, http.StatusOK, response.Code)