diff --git a/appservice/appservice.go b/appservice/appservice.go index bd292767b..c5ae9ceb2 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -32,7 +32,6 @@ import ( roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/jetstream" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" ) @@ -55,7 +54,7 @@ func NewInternalAPI( gomatrixserverlib.WithSkipVerify(base.Cfg.AppServiceAPI.DisableTLSValidation), ) - js, _ := jetstream.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) + js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) // Create a connection to the appservice postgres DB appserviceDB, err := storage.NewDatabase(base, &base.Cfg.AppServiceAPI.Database) diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index c1e86114b..f550c29bb 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -44,7 +44,7 @@ func AddPublicRoutes( ) { cfg := &base.Cfg.ClientAPI mscCfg := &base.Cfg.MSCs - js, natsClient := jetstream.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) + js, natsClient := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) syncProducer := &producers.SyncAPIProducer{ JetStream: js, diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index e52377c94..bec9ac777 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -56,7 +56,7 @@ func AddPublicRoutes( ) { cfg := &base.Cfg.FederationAPI mscCfg := &base.Cfg.MSCs - js, _ := jetstream.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) + js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) producer := &producers.SyncAPIProducer{ JetStream: js, TopicReceiptEvent: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReceiptEvent), @@ -115,7 +115,7 @@ func NewInternalAPI( FailuresUntilBlacklist: cfg.FederationMaxRetries, } - js, _ := jetstream.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) + js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) queues := queue.NewOutgoingQueues( federationDB, base.ProcessContext, diff --git a/go.mod b/go.mod index 38d28e11e..d14ced5b7 100644 --- a/go.mod +++ b/go.mod @@ -30,7 +30,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220506114534-f3951a683bad + github.com/matrix-org/gomatrixserverlib v0.0.0-20220509120958-8d818048c34c github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.10 @@ -48,17 +48,17 @@ require ( github.com/prometheus/client_golang v1.12.1 github.com/sirupsen/logrus v1.8.1 github.com/stretchr/testify v1.7.0 - github.com/tidwall/gjson v1.14.0 + github.com/tidwall/gjson v1.14.1 github.com/tidwall/sjson v1.2.4 github.com/uber/jaeger-client-go v2.30.0+incompatible github.com/uber/jaeger-lib v2.4.1+incompatible github.com/yggdrasil-network/yggdrasil-go v0.4.3 go.uber.org/atomic v1.9.0 - golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29 + golang.org/x/crypto v0.0.0-20220507011949-2cf3adece122 golang.org/x/image v0.0.0-20220321031419-a8550c1d254a golang.org/x/mobile v0.0.0-20220407111146-e579adbbc4a2 golang.org/x/net v0.0.0-20220407224826-aac1ed45d8e3 - golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12 // indirect + golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6 // indirect golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 gopkg.in/h2non/bimg.v1 v1.1.9 gopkg.in/yaml.v2 v2.4.0 diff --git a/go.sum b/go.sum index 96ac72637..8b518935c 100644 --- a/go.sum +++ b/go.sum @@ -795,8 +795,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220506114534-f3951a683bad h1:QhPE4f2fijOSW3QvT/Ucwqz9KJwafKaiKwfhwgIpx3M= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220506114534-f3951a683bad/go.mod h1:V5eO8rn/C3rcxig37A/BCeKerLFS+9Avg/77FIeTZ48= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220509120958-8d818048c34c h1:KqzqFWxvs90pcDaW9QEveW+Q5JcEYuNnKyaqXc+ohno= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220509120958-8d818048c34c/go.mod h1:V5eO8rn/C3rcxig37A/BCeKerLFS+9Avg/77FIeTZ48= github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48 h1:W0sjjC6yjskHX4mb0nk3p0fXAlbU5bAFUFeEtlrPASE= github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48/go.mod h1:ulJzsVOTssIVp1j/m5eI//4VpAGDkMt5NrRuAVX7wpc= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= @@ -1130,8 +1130,8 @@ github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635/go.mod h1:hkRG github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/tchap/go-patricia v2.2.6+incompatible/go.mod h1:bmLyhP68RS6kStMGxByiQ23RP/odRBOTVjwp2cDyi6I= github.com/tidwall/gjson v1.12.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.14.0 h1:6aeJ0bzojgWLa82gDQHcx3S0Lr/O51I9bJ5nv6JFx5w= -github.com/tidwall/gjson v1.14.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.1 h1:iymTbGkQBhveq21bEvAQ81I0LEBork8BFe1CUZXdyuo= +github.com/tidwall/gjson v1.14.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= @@ -1281,8 +1281,8 @@ golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220112180741-5e0467b6c7ce/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29 h1:tkVvjkPTB7pnW3jnid7kNyAMPVWllTNOf/qKDze4p9o= -golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220507011949-2cf3adece122 h1:NvGWuYG8dkDHFSKksI1P9faiVJ9rayE6l0+ouWVIDs8= +golang.org/x/crypto v0.0.0-20220507011949-2cf3adece122/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -1543,8 +1543,8 @@ golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12 h1:QyVthZKMsyaQwBTJE04jdNN0Pp5Fn9Qga0mrgxyERQM= -golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6 h1:nonptSpoQ4vQjyraW20DXPAglgQfVnM9ZC6MmNLMR60= +golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= diff --git a/internal/version.go b/internal/version.go index 2477bc9ac..e74548831 100644 --- a/internal/version.go +++ b/internal/version.go @@ -17,8 +17,8 @@ var build string const ( VersionMajor = 0 VersionMinor = 8 - VersionPatch = 2 - VersionTag = "" // example: "rc1" + VersionPatch = 3 + VersionTag = "rc1" // example: "rc1" ) func VersionString() string { diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 007a48a55..47d7f57f9 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -39,7 +39,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) { func NewInternalAPI( base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.FederationClient, ) api.KeyInternalAPI { - js, _ := jetstream.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) + js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) db, err := storage.NewDatabase(base, &cfg.Database) if err != nil { diff --git a/keyserver/storage/postgres/one_time_keys_table.go b/keyserver/storage/postgres/one_time_keys_table.go index 0b143a1aa..d8c76b49b 100644 --- a/keyserver/storage/postgres/one_time_keys_table.go +++ b/keyserver/storage/postgres/one_time_keys_table.go @@ -39,6 +39,8 @@ CREATE TABLE IF NOT EXISTS keyserver_one_time_keys ( -- Clobber based on 4-uple of user/device/key/algorithm. CONSTRAINT keyserver_one_time_keys_unique UNIQUE (user_id, device_id, key_id, algorithm) ); + +CREATE INDEX IF NOT EXISTS keyserver_one_time_keys_idx ON keyserver_one_time_keys (user_id, device_id); ` const upsertKeysSQL = "" + diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/keyserver/storage/sqlite3/one_time_keys_table.go index 897839aca..d2c0b7b20 100644 --- a/keyserver/storage/sqlite3/one_time_keys_table.go +++ b/keyserver/storage/sqlite3/one_time_keys_table.go @@ -38,6 +38,8 @@ CREATE TABLE IF NOT EXISTS keyserver_one_time_keys ( -- Clobber based on 4-uple of user/device/key/algorithm. UNIQUE (user_id, device_id, key_id, algorithm) ); + +CREATE INDEX IF NOT EXISTS keyserver_one_time_keys_idx ON keyserver_one_time_keys (user_id, device_id); ` const upsertKeysSQL = "" + diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index 7e58ef9d0..9ad8b0422 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -316,40 +316,30 @@ func (u *latestEventsUpdater) calculateLatest( // Then let's see if any of the existing forward extremities now // have entries in the previous events table. If they do then we // will no longer include them as forward extremities. - existingPrevs := make(map[string]struct{}) - for _, l := range existingRefs { + for k, l := range existingRefs { referenced, err := u.updater.IsReferenced(l.EventReference) if err != nil { return false, fmt.Errorf("u.updater.IsReferenced: %w", err) } else if referenced { - existingPrevs[l.EventID] = struct{}{} + delete(existingRefs, k) } } - // Include our new event in the extremities. - newLatest := []types.StateAtEventAndReference{newStateAndRef} + // Start off with our new unreferenced event. We're reusing the backing + // array here rather than allocating a new one. + u.latest = append(u.latest[:0], newStateAndRef) - // Then run through and see if the other extremities are still valid. - // If our new event references them then they are no longer good - // candidates. + // If our new event references any of the existing forward extremities + // then they are no longer forward extremities, so remove them. for _, prevEventID := range newEvent.PrevEventIDs() { delete(existingRefs, prevEventID) } - // Ensure that we don't add any candidate forward extremities from - // the old set that are, themselves, referenced by the old set of - // forward extremities. This shouldn't happen but guards against - // the possibility anyway. - for prevEventID := range existingPrevs { - delete(existingRefs, prevEventID) - } - // Then re-add any old extremities that are still valid after all. for _, old := range existingRefs { - newLatest = append(newLatest, *old) + u.latest = append(u.latest, *old) } - u.latest = newLatest return true, nil } diff --git a/roomserver/internal/input/input_test.go b/roomserver/internal/input/input_test.go index 5d34842bf..a95c13550 100644 --- a/roomserver/internal/input/input_test.go +++ b/roomserver/internal/input/input_test.go @@ -10,9 +10,9 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/input" "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" ) @@ -21,11 +21,11 @@ var js nats.JetStreamContext var jc *nats.Conn func TestMain(m *testing.M) { - var pc *process.ProcessContext - pc, js, jc = jetstream.PrepareForTests() + var b *base.BaseDendrite + b, js, jc = test.Base(nil) code := m.Run() - pc.ShutdownDendrite() - pc.WaitForComponentsToFinish() + b.ShutdownDendrite() + b.WaitForComponentsToFinish() os.Exit(code) } diff --git a/roomserver/roomserver.go b/roomserver/roomserver.go index 46261eb3e..1480e8942 100644 --- a/roomserver/roomserver.go +++ b/roomserver/roomserver.go @@ -50,7 +50,7 @@ func NewInternalAPI( logrus.WithError(err).Panicf("failed to connect to room server db") } - js, nc := jetstream.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) + js, nc := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) return internal.NewRoomserverAPI( base.ProcessContext, cfg, roomserverDB, js, nc, diff --git a/roomserver/storage/postgres/event_json_table.go b/roomserver/storage/postgres/event_json_table.go index b3220effd..5f069ca10 100644 --- a/roomserver/storage/postgres/event_json_table.go +++ b/roomserver/storage/postgres/event_json_table.go @@ -59,12 +59,12 @@ type eventJSONStatements struct { bulkSelectEventJSONStmt *sql.Stmt } -func createEventJSONTable(db *sql.DB) error { +func CreateEventJSONTable(db *sql.DB) error { _, err := db.Exec(eventJSONSchema) return err } -func prepareEventJSONTable(db *sql.DB) (tables.EventJSON, error) { +func PrepareEventJSONTable(db *sql.DB) (tables.EventJSON, error) { s := &eventJSONStatements{} return s, sqlutil.StatementList{ @@ -97,9 +97,9 @@ func (s *eventJSONStatements) BulkSelectEventJSON( // We might get fewer results than NIDs so we adjust the length of the slice before returning it. results := make([]tables.EventJSONPair, len(eventNIDs)) i := 0 + var eventNID int64 for ; rows.Next(); i++ { result := &results[i] - var eventNID int64 if err := rows.Scan(&eventNID, &result.EventJSON); err != nil { return nil, err } diff --git a/roomserver/storage/postgres/event_state_keys_table.go b/roomserver/storage/postgres/event_state_keys_table.go index 762b3a1fc..338e11b82 100644 --- a/roomserver/storage/postgres/event_state_keys_table.go +++ b/roomserver/storage/postgres/event_state_keys_table.go @@ -76,12 +76,12 @@ type eventStateKeyStatements struct { bulkSelectEventStateKeyStmt *sql.Stmt } -func createEventStateKeysTable(db *sql.DB) error { +func CreateEventStateKeysTable(db *sql.DB) error { _, err := db.Exec(eventStateKeysSchema) return err } -func prepareEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { +func PrepareEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { s := &eventStateKeyStatements{} return s, sqlutil.StatementList{ @@ -123,9 +123,9 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventStateKeyNID: rows.close() failed") result := make(map[string]types.EventStateKeyNID, len(eventStateKeys)) + var stateKey string + var stateKeyNID int64 for rows.Next() { - var stateKey string - var stateKeyNID int64 if err := rows.Scan(&stateKey, &stateKeyNID); err != nil { return nil, err } @@ -149,9 +149,9 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKey( defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventStateKey: rows.close() failed") result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs)) + var stateKey string + var stateKeyNID int64 for rows.Next() { - var stateKey string - var stateKeyNID int64 if err := rows.Scan(&stateKey, &stateKeyNID); err != nil { return nil, err } diff --git a/roomserver/storage/postgres/event_types_table.go b/roomserver/storage/postgres/event_types_table.go index 1d5de5822..15ab7fd8e 100644 --- a/roomserver/storage/postgres/event_types_table.go +++ b/roomserver/storage/postgres/event_types_table.go @@ -99,12 +99,12 @@ type eventTypeStatements struct { bulkSelectEventTypeNIDStmt *sql.Stmt } -func createEventTypesTable(db *sql.DB) error { +func CreateEventTypesTable(db *sql.DB) error { _, err := db.Exec(eventTypesSchema) return err } -func prepareEventTypesTable(db *sql.DB) (tables.EventTypes, error) { +func PrepareEventTypesTable(db *sql.DB) (tables.EventTypes, error) { s := &eventTypeStatements{} return s, sqlutil.StatementList{ @@ -143,9 +143,9 @@ func (s *eventTypeStatements) BulkSelectEventTypeNID( defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventTypeNID: rows.close() failed") result := make(map[string]types.EventTypeNID, len(eventTypes)) + var eventType string + var eventTypeNID int64 for rows.Next() { - var eventType string - var eventTypeNID int64 if err := rows.Scan(&eventType, &eventTypeNID); err != nil { return nil, err } diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index 8012174a0..86d226ce7 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -155,12 +155,12 @@ type eventStatements struct { selectRoomNIDsForEventNIDsStmt *sql.Stmt } -func createEventsTable(db *sql.DB) error { +func CreateEventsTable(db *sql.DB) error { _, err := db.Exec(eventsSchema) return err } -func prepareEventsTable(db *sql.DB) (tables.Events, error) { +func PrepareEventsTable(db *sql.DB) (tables.Events, error) { s := &eventStatements{} return s, sqlutil.StatementList{ @@ -380,15 +380,15 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference( defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateAtEventAndReference: rows.close() failed") results := make([]types.StateAtEventAndReference, len(eventNIDs)) i := 0 + var ( + eventTypeNID int64 + eventStateKeyNID int64 + eventNID int64 + stateSnapshotNID int64 + eventID string + eventSHA256 []byte + ) for ; rows.Next(); i++ { - var ( - eventTypeNID int64 - eventStateKeyNID int64 - eventNID int64 - stateSnapshotNID int64 - eventID string - eventSHA256 []byte - ) if err = rows.Scan( &eventTypeNID, &eventStateKeyNID, &eventNID, &stateSnapshotNID, &eventID, &eventSHA256, ); err != nil { @@ -446,9 +446,9 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventID: rows.close() failed") results := make(map[types.EventNID]string, len(eventNIDs)) i := 0 + var eventNID int64 + var eventID string for ; rows.Next(); i++ { - var eventNID int64 - var eventID string if err = rows.Scan(&eventNID, &eventID); err != nil { return nil, err } @@ -491,9 +491,9 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e } defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventNID: rows.close() failed") results := make(map[string]types.EventNID, len(eventIDs)) + var eventID string + var eventNID int64 for rows.Next() { - var eventID string - var eventNID int64 if err = rows.Scan(&eventID, &eventNID); err != nil { return nil, err } @@ -522,9 +522,9 @@ func (s *eventStatements) SelectRoomNIDsForEventNIDs( } defer internal.CloseAndLogIfError(ctx, rows, "selectRoomNIDsForEventNIDsStmt: rows.close() failed") result := make(map[types.EventNID]types.RoomNID) + var eventNID types.EventNID + var roomNID types.RoomNID for rows.Next() { - var eventNID types.EventNID - var roomNID types.RoomNID if err = rows.Scan(&eventNID, &roomNID); err != nil { return nil, err } diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index da8d25848..34e891490 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -68,16 +68,16 @@ func Open(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache c } func (d *Database) create(db *sql.DB) error { - if err := createEventStateKeysTable(db); err != nil { + if err := CreateEventStateKeysTable(db); err != nil { return err } - if err := createEventTypesTable(db); err != nil { + if err := CreateEventTypesTable(db); err != nil { return err } - if err := createEventJSONTable(db); err != nil { + if err := CreateEventJSONTable(db); err != nil { return err } - if err := createEventsTable(db); err != nil { + if err := CreateEventsTable(db); err != nil { return err } if err := createRoomsTable(db); err != nil { @@ -112,19 +112,19 @@ func (d *Database) create(db *sql.DB) error { } func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.RoomServerCaches) error { - eventStateKeys, err := prepareEventStateKeysTable(db) + eventStateKeys, err := PrepareEventStateKeysTable(db) if err != nil { return err } - eventTypes, err := prepareEventTypesTable(db) + eventTypes, err := PrepareEventTypesTable(db) if err != nil { return err } - eventJSON, err := prepareEventJSONTable(db) + eventJSON, err := PrepareEventJSONTable(db) if err != nil { return err } - events, err := prepareEventsTable(db) + events, err := PrepareEventsTable(db) if err != nil { return err } diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go index f470ea326..dc26885bb 100644 --- a/roomserver/storage/sqlite3/event_json_table.go +++ b/roomserver/storage/sqlite3/event_json_table.go @@ -52,12 +52,12 @@ type eventJSONStatements struct { bulkSelectEventJSONStmt *sql.Stmt } -func createEventJSONTable(db *sql.DB) error { +func CreateEventJSONTable(db *sql.DB) error { _, err := db.Exec(eventJSONSchema) return err } -func prepareEventJSONTable(db *sql.DB) (tables.EventJSON, error) { +func PrepareEventJSONTable(db *sql.DB) (tables.EventJSON, error) { s := &eventJSONStatements{ db: db, } @@ -101,9 +101,9 @@ func (s *eventJSONStatements) BulkSelectEventJSON( // We might get fewer results than NIDs so we adjust the length of the slice before returning it. results := make([]tables.EventJSONPair, len(eventNIDs)) i := 0 + var eventNID int64 for ; rows.Next(); i++ { result := &results[i] - var eventNID int64 if err := rows.Scan(&eventNID, &result.EventJSON); err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go index f97541f4a..347524a81 100644 --- a/roomserver/storage/sqlite3/event_state_keys_table.go +++ b/roomserver/storage/sqlite3/event_state_keys_table.go @@ -71,12 +71,12 @@ type eventStateKeyStatements struct { bulkSelectEventStateKeyStmt *sql.Stmt } -func createEventStateKeysTable(db *sql.DB) error { +func CreateEventStateKeysTable(db *sql.DB) error { _, err := db.Exec(eventStateKeysSchema) return err } -func prepareEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { +func PrepareEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { s := &eventStateKeyStatements{ db: db, } @@ -128,9 +128,9 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( } defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventStateKeyNID: rows.close() failed") result := make(map[string]types.EventStateKeyNID, len(eventStateKeys)) + var stateKey string + var stateKeyNID int64 for rows.Next() { - var stateKey string - var stateKeyNID int64 if err := rows.Scan(&stateKey, &stateKeyNID); err != nil { return nil, err } @@ -159,9 +159,9 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKey( } defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventStateKey: rows.close() failed") result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs)) + var stateKey string + var stateKeyNID int64 for rows.Next() { - var stateKey string - var stateKeyNID int64 if err := rows.Scan(&stateKey, &stateKeyNID); err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go index c49cc509a..0581ec194 100644 --- a/roomserver/storage/sqlite3/event_types_table.go +++ b/roomserver/storage/sqlite3/event_types_table.go @@ -79,12 +79,12 @@ type eventTypeStatements struct { bulkSelectEventTypeNIDStmt *sql.Stmt } -func createEventTypesTable(db *sql.DB) error { +func CreateEventTypesTable(db *sql.DB) error { _, err := db.Exec(eventTypesSchema) return err } -func prepareEventTypesTable(db *sql.DB) (tables.EventTypes, error) { +func PrepareEventTypesTable(db *sql.DB) (tables.EventTypes, error) { s := &eventTypeStatements{ db: db, } @@ -139,9 +139,9 @@ func (s *eventTypeStatements) BulkSelectEventTypeNID( defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventTypeNID: rows.close() failed") result := make(map[string]types.EventTypeNID, len(eventTypes)) + var eventType string + var eventTypeNID int64 for rows.Next() { - var eventType string - var eventTypeNID int64 if err := rows.Scan(&eventType, &eventTypeNID); err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 45b49e5cb..feb06150a 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -68,7 +68,8 @@ const bulkSelectStateEventByIDSQL = "" + const bulkSelectStateEventByNIDSQL = "" + "SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" + " WHERE event_nid IN ($1)" - // Rest of query is built by BulkSelectStateEventByNID + +// Rest of query is built by BulkSelectStateEventByNID const bulkSelectStateAtEventByIDSQL = "" + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" + @@ -126,12 +127,12 @@ type eventStatements struct { //selectRoomNIDsForEventNIDsStmt *sql.Stmt } -func createEventsTable(db *sql.DB) error { +func CreateEventsTable(db *sql.DB) error { _, err := db.Exec(eventsSchema) return err } -func prepareEventsTable(db *sql.DB) (tables.Events, error) { +func PrepareEventsTable(db *sql.DB) (tables.Events, error) { s := &eventStatements{ db: db, } @@ -404,15 +405,15 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference( defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateAtEventAndReference: rows.close() failed") results := make([]types.StateAtEventAndReference, len(eventNIDs)) i := 0 + var ( + eventTypeNID int64 + eventStateKeyNID int64 + eventNID int64 + stateSnapshotNID int64 + eventID string + eventSHA256 []byte + ) for ; rows.Next(); i++ { - var ( - eventTypeNID int64 - eventStateKeyNID int64 - eventNID int64 - stateSnapshotNID int64 - eventID string - eventSHA256 []byte - ) if err = rows.Scan( &eventTypeNID, &eventStateKeyNID, &eventNID, &stateSnapshotNID, &eventID, &eventSHA256, ); err != nil { @@ -491,9 +492,9 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventID: rows.close() failed") results := make(map[types.EventNID]string, len(eventNIDs)) i := 0 + var eventNID int64 + var eventID string for ; rows.Next(); i++ { - var eventNID int64 - var eventID string if err = rows.Scan(&eventNID, &eventID); err != nil { return nil, err } @@ -545,9 +546,9 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e } defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventNID: rows.close() failed") results := make(map[string]types.EventNID, len(eventIDs)) + var eventID string + var eventNID int64 for rows.Next() { - var eventID string - var eventNID int64 if err = rows.Scan(&eventID, &eventNID); err != nil { return nil, err } @@ -595,9 +596,9 @@ func (s *eventStatements) SelectRoomNIDsForEventNIDs( } defer internal.CloseAndLogIfError(ctx, rows, "selectRoomNIDsForEventNIDsStmt: rows.close() failed") result := make(map[types.EventNID]types.RoomNID) + var eventNID types.EventNID + var roomNID types.RoomNID for rows.Next() { - var eventNID types.EventNID - var roomNID types.RoomNID if err = rows.Scan(&eventNID, &roomNID); err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index e6cf1a53f..9522d3058 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -77,16 +77,16 @@ func Open(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache c } func (d *Database) create(db *sql.DB) error { - if err := createEventStateKeysTable(db); err != nil { + if err := CreateEventStateKeysTable(db); err != nil { return err } - if err := createEventTypesTable(db); err != nil { + if err := CreateEventTypesTable(db); err != nil { return err } - if err := createEventJSONTable(db); err != nil { + if err := CreateEventJSONTable(db); err != nil { return err } - if err := createEventsTable(db); err != nil { + if err := CreateEventsTable(db); err != nil { return err } if err := createRoomsTable(db); err != nil { @@ -121,19 +121,19 @@ func (d *Database) create(db *sql.DB) error { } func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.RoomServerCaches) error { - eventStateKeys, err := prepareEventStateKeysTable(db) + eventStateKeys, err := PrepareEventStateKeysTable(db) if err != nil { return err } - eventTypes, err := prepareEventTypesTable(db) + eventTypes, err := PrepareEventTypesTable(db) if err != nil { return err } - eventJSON, err := prepareEventJSONTable(db) + eventJSON, err := PrepareEventJSONTable(db) if err != nil { return err } - events, err := prepareEventsTable(db) + events, err := PrepareEventsTable(db) if err != nil { return err } diff --git a/roomserver/storage/tables/event_json_table_test.go b/roomserver/storage/tables/event_json_table_test.go new file mode 100644 index 000000000..b490d0fe8 --- /dev/null +++ b/roomserver/storage/tables/event_json_table_test.go @@ -0,0 +1,95 @@ +package tables_test + +import ( + "context" + "fmt" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/postgres" + "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/stretchr/testify/assert" +) + +func mustCreateEventJSONTable(t *testing.T, dbType test.DBType) (tables.EventJSON, func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + var tab tables.EventJSON + switch dbType { + case test.DBTypePostgres: + err = postgres.CreateEventJSONTable(db) + assert.NoError(t, err) + tab, err = postgres.PrepareEventJSONTable(db) + case test.DBTypeSQLite: + err = sqlite3.CreateEventJSONTable(db) + assert.NoError(t, err) + tab, err = sqlite3.PrepareEventJSONTable(db) + } + assert.NoError(t, err) + + return tab, close +} + +func Test_EventJSONTable(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, close := mustCreateEventJSONTable(t, dbType) + defer close() + + // create some dummy data + for i := 0; i < 10; i++ { + err := tab.InsertEventJSON( + context.Background(), nil, types.EventNID(i), + []byte(fmt.Sprintf(`{"value":%d"}`, i)), + ) + assert.NoError(t, err) + } + + tests := []struct { + name string + args []types.EventNID + wantCount int + }{ + { + name: "select subset of existing NIDs", + args: []types.EventNID{1, 2, 3, 4, 5}, + wantCount: 5, + }, + { + name: "select subset of existing/non-existing NIDs", + args: []types.EventNID{1, 2, 12, 50}, + wantCount: 2, + }, + { + name: "select single existing NID", + args: []types.EventNID{1}, + wantCount: 1, + }, + { + name: "select single non-existing NID", + args: []types.EventNID{13}, + wantCount: 0, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // select a subset of the data + values, err := tab.BulkSelectEventJSON(context.Background(), nil, tc.args) + assert.NoError(t, err) + assert.Equal(t, tc.wantCount, len(values)) + for i, v := range values { + assert.Equal(t, v.EventNID, types.EventNID(i+1)) + assert.Equal(t, []byte(fmt.Sprintf(`{"value":%d"}`, i+1)), v.EventJSON) + } + }) + } + }) +} diff --git a/roomserver/storage/tables/event_state_keys_table_test.go b/roomserver/storage/tables/event_state_keys_table_test.go new file mode 100644 index 000000000..a856fe551 --- /dev/null +++ b/roomserver/storage/tables/event_state_keys_table_test.go @@ -0,0 +1,79 @@ +package tables_test + +import ( + "context" + "fmt" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/postgres" + "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/stretchr/testify/assert" +) + +func mustCreateEventStateKeysTable(t *testing.T, dbType test.DBType) (tables.EventStateKeys, func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + var tab tables.EventStateKeys + switch dbType { + case test.DBTypePostgres: + err = postgres.CreateEventStateKeysTable(db) + assert.NoError(t, err) + tab, err = postgres.PrepareEventStateKeysTable(db) + case test.DBTypeSQLite: + err = sqlite3.CreateEventStateKeysTable(db) + assert.NoError(t, err) + tab, err = sqlite3.PrepareEventStateKeysTable(db) + } + assert.NoError(t, err) + + return tab, close +} + +func Test_EventStateKeysTable(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, close := mustCreateEventStateKeysTable(t, dbType) + defer close() + ctx := context.Background() + var stateKeyNID, gotEventStateKey types.EventStateKeyNID + var err error + // create some dummy data + for i := 0; i < 10; i++ { + stateKey := fmt.Sprintf("@user%d:localhost", i) + stateKeyNID, err = tab.InsertEventStateKeyNID(ctx, nil, stateKey) + assert.NoError(t, err) + gotEventStateKey, err = tab.SelectEventStateKeyNID(ctx, nil, stateKey) + assert.NoError(t, err) + assert.Equal(t, stateKeyNID, gotEventStateKey) + } + // This should fail, since @user0:localhost already exists + stateKey := fmt.Sprintf("@user%d:localhost", 0) + _, err = tab.InsertEventStateKeyNID(ctx, nil, stateKey) + assert.Error(t, err) + + stateKeyNIDsMap, err := tab.BulkSelectEventStateKeyNID(ctx, nil, []string{"@user0:localhost", "@user1:localhost"}) + assert.NoError(t, err) + wantStateKeyNIDs := make([]types.EventStateKeyNID, 0, len(stateKeyNIDsMap)) + for _, nid := range stateKeyNIDsMap { + wantStateKeyNIDs = append(wantStateKeyNIDs, nid) + } + stateKeyNIDs, err := tab.BulkSelectEventStateKey(ctx, nil, wantStateKeyNIDs) + assert.NoError(t, err) + // verify that BulkSelectEventStateKeyNID and BulkSelectEventStateKey return the same values + for userID, nid := range stateKeyNIDsMap { + if v, ok := stateKeyNIDs[nid]; ok { + assert.Equal(t, v, userID) + } else { + t.Fatalf("unable to find %d in result set", nid) + } + } + }) +} diff --git a/roomserver/storage/tables/event_types_table_test.go b/roomserver/storage/tables/event_types_table_test.go new file mode 100644 index 000000000..92c57a917 --- /dev/null +++ b/roomserver/storage/tables/event_types_table_test.go @@ -0,0 +1,79 @@ +package tables_test + +import ( + "context" + "fmt" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/postgres" + "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/stretchr/testify/assert" +) + +func mustCreateEventTypesTable(t *testing.T, dbType test.DBType) (tables.EventTypes, func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + var tab tables.EventTypes + switch dbType { + case test.DBTypePostgres: + err = postgres.CreateEventTypesTable(db) + assert.NoError(t, err) + tab, err = postgres.PrepareEventTypesTable(db) + case test.DBTypeSQLite: + err = sqlite3.CreateEventTypesTable(db) + assert.NoError(t, err) + tab, err = sqlite3.PrepareEventTypesTable(db) + } + assert.NoError(t, err) + + return tab, close +} + +func Test_EventTypesTable(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, close := mustCreateEventTypesTable(t, dbType) + defer close() + ctx := context.Background() + var eventTypeNID, gotEventTypeNID types.EventTypeNID + var err error + // create some dummy data + eventTypeMap := make(map[string]types.EventTypeNID) + for i := 0; i < 10; i++ { + eventType := fmt.Sprintf("dummyEventType%d", i) + eventTypeNID, err = tab.InsertEventTypeNID(ctx, nil, eventType) + assert.NoError(t, err) + eventTypeMap[eventType] = eventTypeNID + gotEventTypeNID, err = tab.SelectEventTypeNID(ctx, nil, eventType) + assert.NoError(t, err) + assert.Equal(t, eventTypeNID, gotEventTypeNID) + } + // This should fail, since the dummyEventType0 already exists + eventType := fmt.Sprintf("dummyEventType%d", 0) + _, err = tab.InsertEventTypeNID(ctx, nil, eventType) + assert.Error(t, err) + + // This should return an error, as this eventType does not exist + _, err = tab.SelectEventTypeNID(ctx, nil, "dummyEventType13") + assert.Error(t, err) + + eventTypeNIDs, err := tab.BulkSelectEventTypeNID(ctx, nil, []string{"dummyEventType0", "dummyEventType3"}) + assert.NoError(t, err) + // verify that BulkSelectEventTypeNID and InsertEventTypeNID return the same values + for eventType, nid := range eventTypeNIDs { + if v, ok := eventTypeMap[eventType]; ok { + assert.Equal(t, v, nid) + } else { + t.Fatalf("unable to find %d in result set", nid) + } + } + }) +} diff --git a/roomserver/storage/tables/events_table_test.go b/roomserver/storage/tables/events_table_test.go new file mode 100644 index 000000000..d5d699c4c --- /dev/null +++ b/roomserver/storage/tables/events_table_test.go @@ -0,0 +1,157 @@ +package tables_test + +import ( + "context" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/postgres" + "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +func mustCreateEventsTable(t *testing.T, dbType test.DBType) (tables.Events, func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + var tab tables.Events + switch dbType { + case test.DBTypePostgres: + err = postgres.CreateEventsTable(db) + assert.NoError(t, err) + tab, err = postgres.PrepareEventsTable(db) + case test.DBTypeSQLite: + err = sqlite3.CreateEventsTable(db) + assert.NoError(t, err) + tab, err = sqlite3.PrepareEventsTable(db) + } + assert.NoError(t, err) + + return tab, close +} + +func Test_EventsTable(t *testing.T) { + alice := test.NewUser() + room := test.NewRoom(t, alice) + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, close := mustCreateEventsTable(t, dbType) + defer close() + // create some dummy data + eventIDs := make([]string, 0, len(room.Events())) + wantStateAtEvent := make([]types.StateAtEvent, 0, len(room.Events())) + wantEventReferences := make([]gomatrixserverlib.EventReference, 0, len(room.Events())) + wantStateAtEventAndRefs := make([]types.StateAtEventAndReference, 0, len(room.Events())) + for _, ev := range room.Events() { + eventNID, snapNID, err := tab.InsertEvent(ctx, nil, 1, 1, 1, ev.EventID(), ev.EventReference().EventSHA256, nil, ev.Depth(), false) + assert.NoError(t, err) + gotEventNID, gotSnapNID, err := tab.SelectEvent(ctx, nil, ev.EventID()) + assert.NoError(t, err) + assert.Equal(t, eventNID, gotEventNID) + assert.Equal(t, snapNID, gotSnapNID) + eventID, err := tab.SelectEventID(ctx, nil, eventNID) + assert.NoError(t, err) + assert.Equal(t, eventID, ev.EventID()) + + // The events shouldn't be sent to output yet + sentToOutput, err := tab.SelectEventSentToOutput(ctx, nil, gotEventNID) + assert.NoError(t, err) + assert.False(t, sentToOutput) + + err = tab.UpdateEventSentToOutput(ctx, nil, gotEventNID) + assert.NoError(t, err) + + // Now they should be sent to output + sentToOutput, err = tab.SelectEventSentToOutput(ctx, nil, gotEventNID) + assert.NoError(t, err) + assert.True(t, sentToOutput) + + eventIDs = append(eventIDs, ev.EventID()) + wantEventReferences = append(wantEventReferences, ev.EventReference()) + + // Set the stateSnapshot to 2 for some events to verify they are returned later + stateSnapshot := 0 + if eventNID < 3 { + stateSnapshot = 2 + err = tab.UpdateEventState(ctx, nil, eventNID, 2) + assert.NoError(t, err) + } + stateAtEvent := types.StateAtEvent{ + Overwrite: false, + BeforeStateSnapshotNID: types.StateSnapshotNID(stateSnapshot), + IsRejected: false, + StateEntry: types.StateEntry{ + EventNID: eventNID, + StateKeyTuple: types.StateKeyTuple{ + EventTypeNID: 1, + EventStateKeyNID: 1, + }, + }, + } + wantStateAtEvent = append(wantStateAtEvent, stateAtEvent) + wantStateAtEventAndRefs = append(wantStateAtEventAndRefs, types.StateAtEventAndReference{ + StateAtEvent: stateAtEvent, + EventReference: ev.EventReference(), + }) + } + + stateEvents, err := tab.BulkSelectStateEventByID(ctx, nil, eventIDs) + assert.NoError(t, err) + assert.Equal(t, len(stateEvents), len(eventIDs)) + nids := make([]types.EventNID, 0, len(stateEvents)) + for _, ev := range stateEvents { + nids = append(nids, ev.EventNID) + } + stateEvents2, err := tab.BulkSelectStateEventByNID(ctx, nil, nids, nil) + assert.NoError(t, err) + // somehow SQLite doesn't return the values ordered as requested by the query + assert.ElementsMatch(t, stateEvents, stateEvents2) + + roomNIDs, err := tab.SelectRoomNIDsForEventNIDs(ctx, nil, nids) + assert.NoError(t, err) + // We only inserted one room, so the RoomNID should be the same for all evendNIDs + for _, roomNID := range roomNIDs { + assert.Equal(t, types.RoomNID(1), roomNID) + } + + stateAtEvent, err := tab.BulkSelectStateAtEventByID(ctx, nil, eventIDs) + assert.NoError(t, err) + assert.Equal(t, len(eventIDs), len(stateAtEvent)) + + assert.ElementsMatch(t, wantStateAtEvent, stateAtEvent) + + evendNIDMap, err := tab.BulkSelectEventID(ctx, nil, nids) + assert.NoError(t, err) + t.Logf("%+v", evendNIDMap) + assert.Equal(t, len(evendNIDMap), len(nids)) + + nidMap, err := tab.BulkSelectEventNID(ctx, nil, eventIDs) + assert.NoError(t, err) + // check that we got all expected eventNIDs + for _, eventID := range eventIDs { + _, ok := nidMap[eventID] + assert.True(t, ok) + } + + references, err := tab.BulkSelectEventReference(ctx, nil, nids) + assert.NoError(t, err) + assert.Equal(t, wantEventReferences, references) + + stateAndRefs, err := tab.BulkSelectStateAtEventAndReference(ctx, nil, nids) + assert.NoError(t, err) + assert.Equal(t, wantStateAtEventAndRefs, stateAndRefs) + + // check we get the expected event depth + maxDepth, err := tab.SelectMaxEventDepth(ctx, nil, nids) + assert.NoError(t, err) + assert.Equal(t, int64(len(room.Events())+1), maxDepth) + }) +} diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 97e4afcff..95609787a 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -10,9 +10,8 @@ import ( ) type EventJSONPair struct { - EventNID types.EventNID - RoomVersion gomatrixserverlib.RoomVersion - EventJSON []byte + EventNID types.EventNID + EventJSON []byte } type EventJSON interface { @@ -36,7 +35,8 @@ type EventStateKeys interface { type Events interface { InsertEvent( - ctx context.Context, txn *sql.Tx, i types.RoomNID, j types.EventTypeNID, k types.EventStateKeyNID, eventID string, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventTypeNID types.EventTypeNID, + eventStateKeyNID types.EventStateKeyNID, eventID string, referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64, isRejected bool, ) (types.EventNID, types.StateSnapshotNID, error) SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error) diff --git a/setup/base/base.go b/setup/base/base.go index da122bf8f..5cbd7da9c 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -41,6 +41,7 @@ import ( "golang.org/x/net/http2/h2c" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/gorilla/mux" @@ -77,6 +78,7 @@ type BaseDendrite struct { InternalAPIMux *mux.Router DendriteAdminMux *mux.Router SynapseAdminMux *mux.Router + NATS *jetstream.NATSInstance UseHTTPAPIs bool apiHttpClient *http.Client Cfg *config.Dendrite @@ -241,6 +243,7 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, options ...Base InternalAPIMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.InternalPathPrefix).Subrouter().UseEncodedPath(), DendriteAdminMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.DendriteAdminPathPrefix).Subrouter().UseEncodedPath(), SynapseAdminMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.SynapseAdminPathPrefix).Subrouter().UseEncodedPath(), + NATS: &jetstream.NATSInstance{}, apiHttpClient: &apiClient, Database: db, // set if monolith with global connection pool only DatabaseWriter: writer, // set if monolith with global connection pool only diff --git a/setup/jetstream/nats.go b/setup/jetstream/nats.go index 90f162f82..dabc0f7ec 100644 --- a/setup/jetstream/nats.go +++ b/setup/jetstream/nats.go @@ -17,16 +17,9 @@ import ( natsclient "github.com/nats-io/nats.go" ) -var natsServer *natsserver.Server -var natsServerMutex sync.Mutex - -func PrepareForTests() (*process.ProcessContext, nats.JetStreamContext, *nats.Conn) { - cfg := &config.Dendrite{} - cfg.Defaults(true) - cfg.Global.JetStream.InMemory = true - pc := process.NewProcessContext() - js, jc := Prepare(pc, &cfg.Global.JetStream) - return pc, js, jc +type NATSInstance struct { + *natsserver.Server + sync.Mutex } func DeleteAllStreams(js nats.JetStreamContext, cfg *config.JetStream) { @@ -36,15 +29,15 @@ func DeleteAllStreams(js nats.JetStreamContext, cfg *config.JetStream) { } } -func Prepare(process *process.ProcessContext, cfg *config.JetStream) (natsclient.JetStreamContext, *natsclient.Conn) { +func (s *NATSInstance) Prepare(process *process.ProcessContext, cfg *config.JetStream) (natsclient.JetStreamContext, *natsclient.Conn) { // check if we need an in-process NATS Server if len(cfg.Addresses) != 0 { return setupNATS(process, cfg, nil) } - natsServerMutex.Lock() - if natsServer == nil { + s.Lock() + if s.Server == nil { var err error - natsServer, err = natsserver.NewServer(&natsserver.Options{ + s.Server, err = natsserver.NewServer(&natsserver.Options{ ServerName: "monolith", DontListen: true, JetStream: true, @@ -56,23 +49,23 @@ func Prepare(process *process.ProcessContext, cfg *config.JetStream) (natsclient if err != nil { panic(err) } - natsServer.ConfigureLogger() + s.ConfigureLogger() go func() { process.ComponentStarted() - natsServer.Start() + s.Start() }() go func() { <-process.WaitForShutdown() - natsServer.Shutdown() - natsServer.WaitForShutdown() + s.Shutdown() + s.WaitForShutdown() process.ComponentFinished() }() } - natsServerMutex.Unlock() - if !natsServer.ReadyForConnections(time.Second * 10) { + s.Unlock() + if !s.ReadyForConnections(time.Second * 10) { logrus.Fatalln("NATS did not start in time") } - nc, err := natsclient.Connect("", natsclient.InProcessServer(natsServer)) + nc, err := natsclient.Connect("", natsclient.InProcessServer(s)) if err != nil { logrus.Fatalln("Failed to create NATS client") } diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index fc43e0bf0..d8bacb2da 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -45,7 +45,7 @@ func AddPublicRoutes( ) { cfg := &base.Cfg.SyncAPI - js, natsClient := jetstream.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) + js, natsClient := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) syncDB, err := storage.NewSyncServerDatasource(base, &cfg.Database) if err != nil { diff --git a/test/base.go b/test/base.go index 0223f182c..664442c03 100644 --- a/test/base.go +++ b/test/base.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" + "github.com/nats-io/nats.go" ) func CreateBaseDendrite(t *testing.T, dbType DBType) (*base.BaseDendrite, func()) { @@ -76,3 +77,14 @@ func CreateBaseDendrite(t *testing.T, dbType DBType) (*base.BaseDendrite, func() } return nil, nil } + +func Base(cfg *config.Dendrite) (*base.BaseDendrite, nats.JetStreamContext, *nats.Conn) { + if cfg == nil { + cfg = &config.Dendrite{} + cfg.Defaults(true) + } + cfg.Global.JetStream.InMemory = true + base := base.NewBaseDendrite(cfg, "Tests") + js, jc := base.NATS.Prepare(base.ProcessContext, &cfg.Global.JetStream) + return base, js, jc +} diff --git a/userapi/userapi.go b/userapi/userapi.go index 03a46807f..603b416bf 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -47,7 +47,7 @@ func NewInternalAPI( appServices []config.ApplicationService, keyAPI keyapi.UserKeyAPI, rsAPI rsapi.UserRoomserverAPI, pgClient pushgateway.Client, ) api.UserInternalAPI { - js, _ := jetstream.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) + js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) db, err := storage.NewUserAPIDatabase( base,