From 6907e421674e9c06f43da96c0819ff778ffcae60 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 8 Apr 2022 13:55:46 +0100 Subject: [PATCH 01/21] Use connection manager in Pinecone demos --- build/gobind-pinecone/monolith.go | 57 +++--------------------------- cmd/dendrite-demo-pinecone/main.go | 40 +++++---------------- cmd/dendritejs-pinecone/main.go | 20 ++--------- 3 files changed, 16 insertions(+), 101 deletions(-) diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index 608599498..9cc94d650 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -52,6 +52,7 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" + pineconeConnections "github.com/matrix-org/pinecone/connections" pineconeMulticast "github.com/matrix-org/pinecone/multicast" pineconeRouter "github.com/matrix-org/pinecone/router" pineconeSessions "github.com/matrix-org/pinecone/sessions" @@ -71,11 +72,9 @@ type DendriteMonolith struct { PineconeRouter *pineconeRouter.Router PineconeMulticast *pineconeMulticast.Multicast PineconeQUIC *pineconeSessions.Sessions + PineconeManager *pineconeConnections.ConnectionManager StorageDirectory string CacheDirectory string - staticPeerURI string - staticPeerMutex sync.RWMutex - staticPeerAttempt chan struct{} listener net.Listener httpServer *http.Server processContext *process.ProcessContext @@ -104,15 +103,8 @@ func (m *DendriteMonolith) SetMulticastEnabled(enabled bool) { } func (m *DendriteMonolith) SetStaticPeer(uri string) { - m.staticPeerMutex.Lock() - m.staticPeerURI = strings.TrimSpace(uri) - m.staticPeerMutex.Unlock() - m.DisconnectType(int(pineconeRouter.PeerTypeRemote)) - if uri != "" { - go func() { - m.staticPeerAttempt <- struct{}{} - }() - } + m.PineconeManager.RemovePeers() + m.PineconeManager.AddPeer(strings.TrimSpace(uri)) } func (m *DendriteMonolith) DisconnectType(peertype int) { @@ -210,43 +202,6 @@ func (m *DendriteMonolith) RegisterDevice(localpart, deviceID string) (string, e return loginRes.Device.AccessToken, nil } -func (m *DendriteMonolith) staticPeerConnect() { - connected := map[string]bool{} // URI -> connected? - attempt := func() { - m.staticPeerMutex.RLock() - uri := m.staticPeerURI - m.staticPeerMutex.RUnlock() - if uri == "" { - return - } - for k := range connected { - delete(connected, k) - } - for _, uri := range strings.Split(uri, ",") { - connected[strings.TrimSpace(uri)] = false - } - for _, info := range m.PineconeRouter.Peers() { - connected[info.URI] = true - } - for k, online := range connected { - if !online { - if err := conn.ConnectToPeer(m.PineconeRouter, k); err != nil { - logrus.WithError(err).Error("Failed to connect to static peer") - } - } - } - } - for { - select { - case <-m.processContext.Context().Done(): - case <-m.staticPeerAttempt: - attempt() - case <-time.After(time.Second * 5): - attempt() - } - } -} - // nolint:gocyclo func (m *DendriteMonolith) Start() { var err error @@ -284,6 +239,7 @@ func (m *DendriteMonolith) Start() { m.PineconeRouter = pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false) m.PineconeQUIC = pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), m.PineconeRouter, []string{"matrix"}) m.PineconeMulticast = pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), m.PineconeRouter) + m.PineconeManager = pineconeConnections.NewConnectionManager(m.PineconeRouter) prefix := hex.EncodeToString(pk) cfg := &config.Dendrite{} @@ -392,9 +348,6 @@ func (m *DendriteMonolith) Start() { m.processContext = base.ProcessContext - m.staticPeerAttempt = make(chan struct{}, 1) - go m.staticPeerConnect() - go func() { m.logger.Info("Listening on ", cfg.Global.ServerName) m.logger.Fatal(m.httpServer.Serve(m.PineconeQUIC.Protocol("matrix"))) diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index a3d3ed175..dd1ab3697 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -25,7 +25,6 @@ import ( "net" "net/http" "os" - "strings" "time" "github.com/gorilla/mux" @@ -47,6 +46,7 @@ import ( "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/gomatrixserverlib" + pineconeConnections "github.com/matrix-org/pinecone/connections" pineconeMulticast "github.com/matrix-org/pinecone/multicast" pineconeRouter "github.com/matrix-org/pinecone/router" pineconeSessions "github.com/matrix-org/pinecone/sessions" @@ -90,6 +90,13 @@ func main() { } pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false) + pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"}) + pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter) + pManager := pineconeConnections.NewConnectionManager(pRouter) + pMulticast.Start() + if instancePeer != nil && *instancePeer != "" { + pManager.AddPeer(*instancePeer) + } go func() { listener, err := net.Listen("tcp", *instanceListen) @@ -119,36 +126,6 @@ func main() { } }() - pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"}) - pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter) - pMulticast.Start() - - connectToStaticPeer := func() { - connected := map[string]bool{} // URI -> connected? - for _, uri := range strings.Split(*instancePeer, ",") { - connected[strings.TrimSpace(uri)] = false - } - attempt := func() { - for k := range connected { - connected[k] = false - } - for _, info := range pRouter.Peers() { - connected[info.URI] = true - } - for k, online := range connected { - if !online { - if err := conn.ConnectToPeer(pRouter, k); err != nil { - logrus.WithError(err).Error("Failed to connect to static peer") - } - } - } - } - for { - attempt() - time.Sleep(time.Second * 5) - } - } - cfg := &config.Dendrite{} cfg.Defaults(true) cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) @@ -268,7 +245,6 @@ func main() { Handler: pMux, } - go connectToStaticPeer() go func() { pubkey := pRouter.PublicKey() logrus.Info("Listening on ", hex.EncodeToString(pubkey[:])) diff --git a/cmd/dendritejs-pinecone/main.go b/cmd/dendritejs-pinecone/main.go index ba9edf230..211b3e131 100644 --- a/cmd/dendritejs-pinecone/main.go +++ b/cmd/dendritejs-pinecone/main.go @@ -22,7 +22,6 @@ import ( "encoding/hex" "fmt" "syscall/js" - "time" "github.com/gorilla/mux" "github.com/matrix-org/dendrite/appservice" @@ -44,6 +43,7 @@ import ( _ "github.com/matrix-org/go-sqlite3-js" + pineconeConnections "github.com/matrix-org/pinecone/connections" pineconeRouter "github.com/matrix-org/pinecone/router" pineconeSessions "github.com/matrix-org/pinecone/sessions" ) @@ -154,6 +154,8 @@ func startup() { pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false) pSessions := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"}) + pManager := pineconeConnections.NewConnectionManager(pRouter) + pManager.AddPeer("wss://pinecone.matrix.org/public") cfg := &config.Dendrite{} cfg.Defaults(true) @@ -237,20 +239,4 @@ func startup() { } s.ListenAndServe("fetch") }() - - // Connect to the static peer - go func() { - for { - if pRouter.PeerCount(pineconeRouter.PeerTypeRemote) == 0 { - if err := conn.ConnectToPeer(pRouter, publicPeer); err != nil { - logrus.WithError(err).Error("Failed to connect to static peer") - } - } - select { - case <-base.ProcessContext.Context().Done(): - return - case <-time.After(time.Second * 5): - } - } - }() } From 9bd9f2beba9df36fcd473d4d3c3902307b1bb0ee Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 8 Apr 2022 16:23:50 +0100 Subject: [PATCH 02/21] Update to matrix-org/pinecone@9b3963248c9bdc22cf0789bc3ca58e8f274371e6 --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index a5b2cffc6..a5f17d06d 100644 --- a/go.mod +++ b/go.mod @@ -39,7 +39,7 @@ require ( github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 github.com/matrix-org/gomatrixserverlib v0.0.0-20220405134050-301e340659d5 - github.com/matrix-org/pinecone v0.0.0-20220404141326-e526fa82f79d + github.com/matrix-org/pinecone v0.0.0-20220408150209-9b3963248c9b github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.10 github.com/nats-io/nats-server/v2 v2.7.4-0.20220309205833-773636c1c5bb diff --git a/go.sum b/go.sum index 6708573a0..df09baf2c 100644 --- a/go.sum +++ b/go.sum @@ -1114,8 +1114,8 @@ github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrixserverlib v0.0.0-20220405134050-301e340659d5 h1:Fkennny7+Z/5pygrhjFMZbz1j++P2hhhLoT7NO3p8DQ= github.com/matrix-org/gomatrixserverlib v0.0.0-20220405134050-301e340659d5/go.mod h1:V5eO8rn/C3rcxig37A/BCeKerLFS+9Avg/77FIeTZ48= -github.com/matrix-org/pinecone v0.0.0-20220404141326-e526fa82f79d h1:1+T4eOPRsf6cr0lMPW4oO2k8TTHm4mqIh65kpEID5Rk= -github.com/matrix-org/pinecone v0.0.0-20220404141326-e526fa82f79d/go.mod h1:ulJzsVOTssIVp1j/m5eI//4VpAGDkMt5NrRuAVX7wpc= +github.com/matrix-org/pinecone v0.0.0-20220408150209-9b3963248c9b h1:08XczmTY01e63MYR1fme2S1HfZAVyuAb0GbHzsZJY6k= +github.com/matrix-org/pinecone v0.0.0-20220408150209-9b3963248c9b/go.mod h1:ulJzsVOTssIVp1j/m5eI//4VpAGDkMt5NrRuAVX7wpc= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= From 986d27a1287e9d86fe16a6152f2457657513a6dd Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 8 Apr 2022 16:39:09 +0100 Subject: [PATCH 03/21] Update to matrix-org/pinecone@2999ea2 --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index a5f17d06d..f4ac8d123 100644 --- a/go.mod +++ b/go.mod @@ -39,7 +39,7 @@ require ( github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 github.com/matrix-org/gomatrixserverlib v0.0.0-20220405134050-301e340659d5 - github.com/matrix-org/pinecone v0.0.0-20220408150209-9b3963248c9b + github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.10 github.com/nats-io/nats-server/v2 v2.7.4-0.20220309205833-773636c1c5bb diff --git a/go.sum b/go.sum index df09baf2c..063365168 100644 --- a/go.sum +++ b/go.sum @@ -1114,8 +1114,8 @@ github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrixserverlib v0.0.0-20220405134050-301e340659d5 h1:Fkennny7+Z/5pygrhjFMZbz1j++P2hhhLoT7NO3p8DQ= github.com/matrix-org/gomatrixserverlib v0.0.0-20220405134050-301e340659d5/go.mod h1:V5eO8rn/C3rcxig37A/BCeKerLFS+9Avg/77FIeTZ48= -github.com/matrix-org/pinecone v0.0.0-20220408150209-9b3963248c9b h1:08XczmTY01e63MYR1fme2S1HfZAVyuAb0GbHzsZJY6k= -github.com/matrix-org/pinecone v0.0.0-20220408150209-9b3963248c9b/go.mod h1:ulJzsVOTssIVp1j/m5eI//4VpAGDkMt5NrRuAVX7wpc= +github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48 h1:W0sjjC6yjskHX4mb0nk3p0fXAlbU5bAFUFeEtlrPASE= +github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48/go.mod h1:ulJzsVOTssIVp1j/m5eI//4VpAGDkMt5NrRuAVX7wpc= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= From 6d25bd6ca57f518404000c47d69bcbfadb4fd2ef Mon Sep 17 00:00:00 2001 From: kegsay Date: Fri, 8 Apr 2022 17:53:24 +0100 Subject: [PATCH 04/21] syncapi: add more tests; fix more bugs (#2338) * syncapi: add more tests; fix more bugs bugfixes: - The postgres impl of TopologyTable.SelectEventIDsInRange did not use the provided txn - The postgres impl of EventsTable.SelectEvents did not preserve the ordering of the input event IDs in the output events slice - The sqlite impl of EventsTable.SelectEvents did not use a bulk `IN ($1)` query. Added tests: - `TestGetEventsInRangeWithTopologyToken` - `TestOutputRoomEventsTable` - `TestTopologyTable` * -p 1 for now --- .github/workflows/dendrite.yml | 2 +- syncapi/storage/interface.go | 2 +- .../postgres/output_room_events_table.go | 22 +- .../output_room_events_topology_table.go | 4 +- syncapi/storage/shared/syncserver.go | 8 +- syncapi/storage/sqlite3/account_data_table.go | 4 +- .../sqlite3/current_room_state_table.go | 4 +- syncapi/storage/sqlite3/invites_table.go | 4 +- .../sqlite3/output_room_events_table.go | 56 +++-- syncapi/storage/sqlite3/peeks_table.go | 4 +- syncapi/storage/sqlite3/presence_table.go | 4 +- syncapi/storage/sqlite3/receipt_table.go | 4 +- syncapi/storage/sqlite3/stream_id_table.go | 14 +- syncapi/storage/sqlite3/syncserver.go | 4 +- syncapi/storage/storage_test.go | 234 +++++++----------- syncapi/storage/tables/interface.go | 2 +- .../storage/tables/output_room_events_test.go | 82 ++++++ syncapi/storage/tables/topology_test.go | 91 +++++++ test/db.go | 1 + test/event.go | 39 +++ 20 files changed, 388 insertions(+), 197 deletions(-) create mode 100644 syncapi/storage/tables/output_room_events_test.go create mode 100644 syncapi/storage/tables/topology_test.go diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml index 4f337a866..8221bff96 100644 --- a/.github/workflows/dendrite.yml +++ b/.github/workflows/dendrite.yml @@ -111,7 +111,7 @@ jobs: key: ${{ runner.os }}-go${{ matrix.go }}-test-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go${{ matrix.go }}-test- - - run: go test ./... + - run: go test -p 1 ./... env: POSTGRES_HOST: localhost POSTGRES_USER: postgres diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 841f67261..cf3fd5532 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -104,7 +104,7 @@ type Database interface { // DeletePeek deletes all peeks for a given room by a given user // Returns an error if there was a problem communicating with the database. DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error) - // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. + // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last. GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) // EventPositionInTopology returns the depth and stream position of the given event. EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error) diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index 14af6a949..a30e220ba 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -427,7 +427,7 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents( // selectEvents returns the events for the given event IDs. If an event is // missing from the database, it will be omitted. func (s *outputRoomEventsStatements) SelectEvents( - ctx context.Context, txn *sql.Tx, eventIDs []string, + ctx context.Context, txn *sql.Tx, eventIDs []string, preserveOrder bool, ) ([]types.StreamEvent, error) { stmt := sqlutil.TxStmt(txn, s.selectEventsStmt) rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) @@ -435,7 +435,25 @@ func (s *outputRoomEventsStatements) SelectEvents( return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed") - return rowsToStreamEvents(rows) + streamEvents, err := rowsToStreamEvents(rows) + if err != nil { + return nil, err + } + if preserveOrder { + eventMap := make(map[string]types.StreamEvent) + for _, ev := range streamEvents { + eventMap[ev.EventID()] = ev + } + var returnEvents []types.StreamEvent + for _, eventID := range eventIDs { + ev, ok := eventMap[eventID] + if ok { + returnEvents = append(returnEvents, ev) + } + } + return returnEvents, nil + } + return streamEvents, nil } func (s *outputRoomEventsStatements) DeleteEventsForRoom( diff --git a/syncapi/storage/postgres/output_room_events_topology_table.go b/syncapi/storage/postgres/output_room_events_topology_table.go index 626386ba0..90b3b0083 100644 --- a/syncapi/storage/postgres/output_room_events_topology_table.go +++ b/syncapi/storage/postgres/output_room_events_topology_table.go @@ -148,9 +148,9 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( // is requested or not. var stmt *sql.Stmt if chronologicalOrder { - stmt = s.selectEventIDsInRangeASCStmt + stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeASCStmt) } else { - stmt = s.selectEventIDsInRangeDESCStmt + stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeDESCStmt) } // Query the event IDs. diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 1c45d5d9a..14db5795c 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -150,7 +150,7 @@ func (d *Database) RoomReceiptsAfter(ctx context.Context, roomIDs []string, stre // Returns an error if there was a problem talking with the database. // Does not include any transaction IDs in the returned events. func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { - streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs) + streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, false) if err != nil { return nil, err } @@ -312,7 +312,7 @@ func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, e // Check if we have all of the event's previous events. If an event is // missing, add it to the room's backward extremities. - prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs()) + prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs(), false) if err != nil { return err } @@ -457,7 +457,7 @@ func (d *Database) GetEventsInTopologicalRange( } // Retrieve the events' contents using their IDs. - events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs) + events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs, true) return } @@ -619,7 +619,7 @@ func (d *Database) fetchMissingStateEvents( ) ([]types.StreamEvent, error) { // Fetch from the events table first so we pick up the stream ID for the // event. - events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs) + events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs, false) if err != nil { return nil, err } diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index 24c442240..5b2287e6d 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -51,13 +51,13 @@ const selectMaxAccountDataIDSQL = "" + type accountDataStatements struct { db *sql.DB - streamIDStatements *streamIDStatements + streamIDStatements *StreamIDStatements insertAccountDataStmt *sql.Stmt selectMaxAccountDataIDStmt *sql.Stmt selectAccountDataInRangeStmt *sql.Stmt } -func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) { +func NewSqliteAccountDataTable(db *sql.DB, streamID *StreamIDStatements) (tables.AccountData, error) { s := &accountDataStatements{ db: db, streamIDStatements: streamID, diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 473aa49b0..464f32e04 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -90,7 +90,7 @@ const selectEventsWithEventIDsSQL = "" + type currentRoomStateStatements struct { db *sql.DB - streamIDStatements *streamIDStatements + streamIDStatements *StreamIDStatements upsertRoomStateStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt deleteRoomStateForRoomStmt *sql.Stmt @@ -100,7 +100,7 @@ type currentRoomStateStatements struct { selectStateEventStmt *sql.Stmt } -func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) { +func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) { s := ¤tRoomStateStatements{ db: db, streamIDStatements: streamID, diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index 0a6823cc0..58ab8461e 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -59,14 +59,14 @@ const selectMaxInviteIDSQL = "" + type inviteEventsStatements struct { db *sql.DB - streamIDStatements *streamIDStatements + streamIDStatements *StreamIDStatements insertInviteEventStmt *sql.Stmt selectInviteEventsInRangeStmt *sql.Stmt deleteInviteEventStmt *sql.Stmt selectMaxInviteIDStmt *sql.Stmt } -func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) { +func NewSqliteInvitesTable(db *sql.DB, streamID *StreamIDStatements) (tables.Invites, error) { s := &inviteEventsStatements{ db: db, streamIDStatements: streamID, diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index acd959696..9da9d776e 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -58,7 +58,7 @@ const insertEventSQL = "" + "ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $13)" const selectEventsSQL = "" + - "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id IN ($1)" const selectRecentEventsSQL = "" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + @@ -111,9 +111,8 @@ const selectContextAfterEventSQL = "" + type outputRoomEventsStatements struct { db *sql.DB - streamIDStatements *streamIDStatements + streamIDStatements *StreamIDStatements insertEventStmt *sql.Stmt - selectEventsStmt *sql.Stmt selectMaxEventIDStmt *sql.Stmt updateEventJSONStmt *sql.Stmt deleteEventsForRoomStmt *sql.Stmt @@ -122,7 +121,7 @@ type outputRoomEventsStatements struct { selectContextAfterEventStmt *sql.Stmt } -func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) { +func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Events, error) { s := &outputRoomEventsStatements{ db: db, streamIDStatements: streamID, @@ -133,7 +132,6 @@ func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Even } return s, sqlutil.StatementList{ {&s.insertEventStmt, insertEventSQL}, - {&s.selectEventsStmt, selectEventsSQL}, {&s.selectMaxEventIDStmt, selectMaxEventIDSQL}, {&s.updateEventJSONStmt, updateEventJSONSQL}, {&s.deleteEventsForRoomStmt, deleteEventsForRoomSQL}, @@ -421,21 +419,43 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents( // selectEvents returns the events for the given event IDs. If an event is // missing from the database, it will be omitted. func (s *outputRoomEventsStatements) SelectEvents( - ctx context.Context, txn *sql.Tx, eventIDs []string, + ctx context.Context, txn *sql.Tx, eventIDs []string, preserveOrder bool, ) ([]types.StreamEvent, error) { - var returnEvents []types.StreamEvent - stmt := sqlutil.TxStmt(txn, s.selectEventsStmt) - for _, eventID := range eventIDs { - rows, err := stmt.QueryContext(ctx, eventID) - if err != nil { - return nil, err - } - if streamEvents, err := rowsToStreamEvents(rows); err == nil { - returnEvents = append(returnEvents, streamEvents...) - } - internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed") + iEventIDs := make([]interface{}, len(eventIDs)) + for i := range eventIDs { + iEventIDs[i] = eventIDs[i] } - return returnEvents, nil + selectSQL := strings.Replace(selectEventsSQL, "($1)", sqlutil.QueryVariadic(len(eventIDs)), 1) + var rows *sql.Rows + var err error + if txn != nil { + rows, err = txn.QueryContext(ctx, selectSQL, iEventIDs...) + } else { + rows, err = s.db.QueryContext(ctx, selectSQL, iEventIDs...) + } + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed") + streamEvents, err := rowsToStreamEvents(rows) + if err != nil { + return nil, err + } + if preserveOrder { + var returnEvents []types.StreamEvent + eventMap := make(map[string]types.StreamEvent) + for _, ev := range streamEvents { + eventMap[ev.EventID()] = ev + } + for _, eventID := range eventIDs { + ev, ok := eventMap[eventID] + if ok { + returnEvents = append(returnEvents, ev) + } + } + return returnEvents, nil + } + return streamEvents, nil } func (s *outputRoomEventsStatements) DeleteEventsForRoom( diff --git a/syncapi/storage/sqlite3/peeks_table.go b/syncapi/storage/sqlite3/peeks_table.go index c93c82051..5ee86448c 100644 --- a/syncapi/storage/sqlite3/peeks_table.go +++ b/syncapi/storage/sqlite3/peeks_table.go @@ -66,7 +66,7 @@ const selectMaxPeekIDSQL = "" + type peekStatements struct { db *sql.DB - streamIDStatements *streamIDStatements + streamIDStatements *StreamIDStatements insertPeekStmt *sql.Stmt deletePeekStmt *sql.Stmt deletePeeksStmt *sql.Stmt @@ -75,7 +75,7 @@ type peekStatements struct { selectMaxPeekIDStmt *sql.Stmt } -func NewSqlitePeeksTable(db *sql.DB, streamID *streamIDStatements) (tables.Peeks, error) { +func NewSqlitePeeksTable(db *sql.DB, streamID *StreamIDStatements) (tables.Peeks, error) { _, err := db.Exec(peeksSchema) if err != nil { return nil, err diff --git a/syncapi/storage/sqlite3/presence_table.go b/syncapi/storage/sqlite3/presence_table.go index e7b78a705..00b16458d 100644 --- a/syncapi/storage/sqlite3/presence_table.go +++ b/syncapi/storage/sqlite3/presence_table.go @@ -75,7 +75,7 @@ const selectPresenceAfter = "" + type presenceStatements struct { db *sql.DB - streamIDStatements *streamIDStatements + streamIDStatements *StreamIDStatements upsertPresenceStmt *sql.Stmt upsertPresenceFromSyncStmt *sql.Stmt selectPresenceForUsersStmt *sql.Stmt @@ -83,7 +83,7 @@ type presenceStatements struct { selectPresenceAfterStmt *sql.Stmt } -func NewSqlitePresenceTable(db *sql.DB, streamID *streamIDStatements) (*presenceStatements, error) { +func NewSqlitePresenceTable(db *sql.DB, streamID *StreamIDStatements) (*presenceStatements, error) { _, err := db.Exec(presenceSchema) if err != nil { return nil, err diff --git a/syncapi/storage/sqlite3/receipt_table.go b/syncapi/storage/sqlite3/receipt_table.go index dea057719..bd778bf3c 100644 --- a/syncapi/storage/sqlite3/receipt_table.go +++ b/syncapi/storage/sqlite3/receipt_table.go @@ -59,13 +59,13 @@ const selectMaxReceiptIDSQL = "" + type receiptStatements struct { db *sql.DB - streamIDStatements *streamIDStatements + streamIDStatements *StreamIDStatements upsertReceipt *sql.Stmt selectRoomReceipts *sql.Stmt selectMaxReceiptID *sql.Stmt } -func NewSqliteReceiptsTable(db *sql.DB, streamID *streamIDStatements) (tables.Receipts, error) { +func NewSqliteReceiptsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Receipts, error) { _, err := db.Exec(receiptsSchema) if err != nil { return nil, err diff --git a/syncapi/storage/sqlite3/stream_id_table.go b/syncapi/storage/sqlite3/stream_id_table.go index faa2c41fe..71980b806 100644 --- a/syncapi/storage/sqlite3/stream_id_table.go +++ b/syncapi/storage/sqlite3/stream_id_table.go @@ -32,12 +32,12 @@ const increaseStreamIDStmt = "" + "UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1" + " RETURNING stream_id" -type streamIDStatements struct { +type StreamIDStatements struct { db *sql.DB increaseStreamIDStmt *sql.Stmt } -func (s *streamIDStatements) prepare(db *sql.DB) (err error) { +func (s *StreamIDStatements) Prepare(db *sql.DB) (err error) { s.db = db _, err = db.Exec(streamIDTableSchema) if err != nil { @@ -49,31 +49,31 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) { return } -func (s *streamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { +func (s *StreamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) err = increaseStmt.QueryRowContext(ctx, "global").Scan(&pos) return } -func (s *streamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { +func (s *StreamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) err = increaseStmt.QueryRowContext(ctx, "receipt").Scan(&pos) return } -func (s *streamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { +func (s *StreamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) err = increaseStmt.QueryRowContext(ctx, "invite").Scan(&pos) return } -func (s *streamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { +func (s *StreamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) err = increaseStmt.QueryRowContext(ctx, "accountdata").Scan(&pos) return } -func (s *streamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { +func (s *StreamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) err = increaseStmt.QueryRowContext(ctx, "presence").Scan(&pos) return diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index 9d9d35988..dfc289482 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -30,7 +30,7 @@ type SyncServerDatasource struct { shared.Database db *sql.DB writer sqlutil.Writer - streamID streamIDStatements + streamID StreamIDStatements } // NewDatabase creates a new sync server database @@ -49,7 +49,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e } func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (err error) { - if err = d.streamID.prepare(d.db); err != nil { + if err = d.streamID.Prepare(d.db); err != nil { return err } accountData, err := NewSqliteAccountDataTable(d.db, &d.streamID) diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index 403b50eaa..4e1634ece 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -3,6 +3,7 @@ package storage_test import ( "context" "fmt" + "reflect" "testing" "github.com/matrix-org/dendrite/setup/config" @@ -38,7 +39,7 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver if err != nil { t.Fatalf("WriteEvent failed: %s", err) } - fmt.Println("Event ID", ev.EventID(), " spos=", pos, "depth=", ev.Depth()) + t.Logf("Event ID %s spos=%v depth=%v", ev.EventID(), pos, ev.Depth()) positions = append(positions, pos) } return @@ -46,7 +47,6 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver func TestWriteEvents(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - t.Parallel() alice := test.NewUser() r := test.NewRoom(t, alice) db, close := MustCreateDatabase(t, dbType) @@ -61,84 +61,84 @@ func TestRecentEventsPDU(t *testing.T) { db, close := MustCreateDatabase(t, dbType) defer close() alice := test.NewUser() - var filter gomatrixserverlib.RoomEventFilter - filter.Limit = 100 + // dummy room to make sure SQL queries are filtering on room ID + MustWriteEvents(t, db, test.NewRoom(t, alice).Events()) + + // actual test room r := test.NewRoom(t, alice) r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hi"}) events := r.Events() positions := MustWriteEvents(t, db, events) + + // dummy room to make sure SQL queries are filtering on room ID + MustWriteEvents(t, db, test.NewRoom(t, alice).Events()) + latest, err := db.MaxStreamPositionForPDUs(ctx) if err != nil { t.Fatalf("failed to get MaxStreamPositionForPDUs: %s", err) } testCases := []struct { - Name string - From types.StreamPosition - To types.StreamPosition - WantEvents []*gomatrixserverlib.HeaderedEvent - WantLimited bool + Name string + From types.StreamPosition + To types.StreamPosition + Limit int + ReverseOrder bool + WantEvents []*gomatrixserverlib.HeaderedEvent + WantLimited bool }{ // The purpose of this test is to make sure that incremental syncs are including up to the latest events. - // It's a basic sanity test that sync works. It creates a `since` token that is on the penultimate event. + // It's a basic sanity test that sync works. It creates a streaming position that is on the penultimate event. // It makes sure the response includes the final event. { - Name: "IncrementalSync penultimate", + Name: "penultimate", From: positions[len(positions)-2], // pretend we are at the penultimate event To: latest, + Limit: 100, WantEvents: events[len(events)-1:], WantLimited: false, }, - /* - // The purpose of this test is to check that passing a `numRecentEventsPerRoom` correctly limits the - // number of returned events. This is critical for big rooms hence the test here. - { - Name: "IncrementalSync limited", - DoSync: func() (*types.Response, error) { - from := types.StreamingToken{ // pretend we are 10 events behind - PDUPosition: positions[len(positions)-11], - } - res := types.NewResponse() - // limit is set to 5 - return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false) - }, - // want the last 5 events, NOT the last 10. - WantTimeline: events[len(events)-5:], - }, - // The purpose of this test is to check that CompleteSync returns all the current state as well as - // honouring the `numRecentEventsPerRoom` value - { - Name: "CompleteSync limited", - DoSync: func() (*types.Response, error) { - res := types.NewResponse() - // limit set to 5 - return db.CompleteSync(ctx, res, testUserDeviceA, 5) - }, - // want the last 5 events - WantTimeline: events[len(events)-5:], - // want all state for the room - WantState: state, - }, - // The purpose of this test is to check that CompleteSync can return everything with a high enough - // `numRecentEventsPerRoom`. - { - Name: "CompleteSync", - DoSync: func() (*types.Response, error) { - res := types.NewResponse() - return db.CompleteSync(ctx, res, testUserDeviceA, len(events)+1) - }, - WantTimeline: events, - // We want no state at all as that field in /sync is the delta between the token (beginning of time) - // and the START of the timeline. - }, */ + // The purpose of this test is to check that limits can be applied and work. + // This is critical for big rooms hence the test here. + { + Name: "limited", + From: 0, + To: latest, + Limit: 1, + WantEvents: events[len(events)-1:], + WantLimited: true, + }, + // The purpose of this test is to check that we can return every event with a high + // enough limit + { + Name: "large limited", + From: 0, + To: latest, + Limit: 100, + WantEvents: events, + WantLimited: false, + }, + // The purpose of this test is to check that we can return events in reverse order + { + Name: "reverse", + From: positions[len(positions)-3], // 2 events back + To: latest, + Limit: 100, + ReverseOrder: true, + WantEvents: test.Reversed(events[len(events)-2:]), + WantLimited: false, + }, } - for _, tc := range testCases { + for i := range testCases { + tc := testCases[i] t.Run(tc.Name, func(st *testing.T) { + var filter gomatrixserverlib.RoomEventFilter + filter.Limit = tc.Limit gotEvents, limited, err := db.RecentEvents(ctx, r.ID, types.Range{ From: tc.From, To: tc.To, - }, &filter, true, true) + }, &filter, !tc.ReverseOrder, true) if err != nil { st.Fatalf("failed to do sync: %s", err) } @@ -148,100 +148,48 @@ func TestRecentEventsPDU(t *testing.T) { if len(gotEvents) != len(tc.WantEvents) { st.Errorf("got %d events, want %d", len(gotEvents), len(tc.WantEvents)) } + for j := range gotEvents { + if !reflect.DeepEqual(gotEvents[j].JSON(), tc.WantEvents[j].JSON()) { + st.Errorf("event %d got %s want %s", j, string(gotEvents[j].JSON()), string(tc.WantEvents[j].JSON())) + } + } }) } }) } -/* -func TestGetEventsInRangeWithPrevBatch(t *testing.T) { - t.Parallel() - db := MustCreateDatabase(t) - events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB) - positions := MustWriteEvents(t, db, events) - latest, err := db.SyncPosition(ctx) - if err != nil { - t.Fatalf("failed to get SyncPosition: %s", err) - } - from := types.StreamingToken{ - PDUPosition: positions[len(positions)-2], - } - - res := types.NewResponse() - res, err = db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false) - if err != nil { - t.Fatalf("failed to IncrementalSync with latest token") - } - roomRes, ok := res.Rooms.Join[testRoomID] - if !ok { - t.Fatalf("IncrementalSync response missing room %s - response: %+v", testRoomID, res) - } - // returns the last event "Message 10" - assertEventsEqual(t, "IncrementalSync Timeline", false, roomRes.Timeline.Events, reversed(events[len(events)-1:])) - - prev := roomRes.Timeline.PrevBatch.String() - if prev == "" { - t.Fatalf("IncrementalSync expected prev_batch token") - } - prevBatchToken, err := types.NewTopologyTokenFromString(prev) - if err != nil { - t.Fatalf("failed to NewTopologyTokenFromString : %s", err) - } - // backpaginate 5 messages starting at the latest position. - // head towards the beginning of time - to := types.TopologyToken{} - paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &prevBatchToken, &to, testRoomID, 5, true) - if err != nil { - t.Fatalf("GetEventsInRange returned an error: %s", err) - } - gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll) - assertEventsEqual(t, "", true, gots, reversed(events[len(events)-6:len(events)-1])) -} - -// The purpose of this test is to ensure that backfill does indeed go backwards, using a stream token. -func TestGetEventsInRangeWithStreamToken(t *testing.T) { - t.Parallel() - db := MustCreateDatabase(t) - events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB) - MustWriteEvents(t, db, events) - latest, err := db.SyncPosition(ctx) - if err != nil { - t.Fatalf("failed to get SyncPosition: %s", err) - } - // head towards the beginning of time - to := types.StreamingToken{} - - // backpaginate 5 messages starting at the latest position. - paginatedEvents, err := db.GetEventsInStreamingRange(ctx, &latest, &to, testRoomID, 5, true) - if err != nil { - t.Fatalf("GetEventsInRange returned an error: %s", err) - } - gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll) - assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:])) -} - // The purpose of this test is to ensure that backfill does indeed go backwards, using a topology token func TestGetEventsInRangeWithTopologyToken(t *testing.T) { - t.Parallel() - db := MustCreateDatabase(t) - events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB) - MustWriteEvents(t, db, events) - from, err := db.MaxTopologicalPosition(ctx, testRoomID) - if err != nil { - t.Fatalf("failed to get MaxTopologicalPosition: %s", err) - } - // head towards the beginning of time - to := types.TopologyToken{} + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := MustCreateDatabase(t, dbType) + defer close() + alice := test.NewUser() + r := test.NewRoom(t, alice) + for i := 0; i < 10; i++ { + r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("hi %d", i)}) + } + events := r.Events() + _ = MustWriteEvents(t, db, events) - // backpaginate 5 messages starting at the latest position. - paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, testRoomID, 5, true) - if err != nil { - t.Fatalf("GetEventsInRange returned an error: %s", err) - } - gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll) - assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:])) + from, err := db.MaxTopologicalPosition(ctx, r.ID) + if err != nil { + t.Fatalf("failed to get MaxTopologicalPosition: %s", err) + } + t.Logf("max topo pos = %+v", from) + // head towards the beginning of time + to := types.TopologyToken{} + + // backpaginate 5 messages starting at the latest position. + paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, 5, true) + if err != nil { + t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err) + } + gots := db.StreamEventsToEvents(nil, paginatedEvents) + test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:])) + }) } +/* // The purpose of this test is to make sure that backpagination returns all events, even if some events have the same depth. // For cases where events have the same depth, the streaming token should be used to tie break so events written via WriteEvent // will appear FIRST when going backwards. This test creates a DAG like: @@ -651,12 +599,4 @@ func topologyTokenBefore(t *testing.T, db storage.Database, eventID string) *typ tok.Decrement() return &tok } - -func reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent { - out := make([]*gomatrixserverlib.HeaderedEvent, len(in)) - for i := 0; i < len(in); i++ { - out[i] = in[len(in)-i-1] - } - return out -} */ diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 8d368eec1..3cbeb0462 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -59,7 +59,7 @@ type Events interface { SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) // SelectEarlyEvents returns the earliest events in the given room. SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error) - SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error) + SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, preserveOrder bool) ([]types.StreamEvent, error) UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error // DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely. DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error) diff --git a/syncapi/storage/tables/output_room_events_test.go b/syncapi/storage/tables/output_room_events_test.go new file mode 100644 index 000000000..7a81ffcd2 --- /dev/null +++ b/syncapi/storage/tables/output_room_events_test.go @@ -0,0 +1,82 @@ +package tables_test + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi/storage/postgres" + "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/test" +) + +func newOutputRoomEventsTable(t *testing.T, dbType test.DBType) (tables.Events, *sql.DB, func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }) + if err != nil { + t.Fatalf("failed to open db: %s", err) + } + + var tab tables.Events + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresEventsTable(db) + case test.DBTypeSQLite: + var stream sqlite3.StreamIDStatements + if err = stream.Prepare(db); err != nil { + t.Fatalf("failed to prepare stream stmts: %s", err) + } + tab, err = sqlite3.NewSqliteEventsTable(db, &stream) + } + if err != nil { + t.Fatalf("failed to make new table: %s", err) + } + return tab, db, close +} + +func TestOutputRoomEventsTable(t *testing.T) { + ctx := context.Background() + alice := test.NewUser() + room := test.NewRoom(t, alice) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, db, close := newOutputRoomEventsTable(t, dbType) + defer close() + events := room.Events() + err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error { + for _, ev := range events { + _, err := tab.InsertEvent(ctx, txn, ev, nil, nil, nil, false) + if err != nil { + return fmt.Errorf("failed to InsertEvent: %s", err) + } + } + // order = 2,0,3,1 + wantEventIDs := []string{ + events[2].EventID(), events[0].EventID(), events[3].EventID(), events[1].EventID(), + } + gotEvents, err := tab.SelectEvents(ctx, txn, wantEventIDs, true) + if err != nil { + return fmt.Errorf("failed to SelectEvents: %s", err) + } + gotEventIDs := make([]string, len(gotEvents)) + for i := range gotEvents { + gotEventIDs[i] = gotEvents[i].EventID() + } + if !reflect.DeepEqual(gotEventIDs, wantEventIDs) { + return fmt.Errorf("SelectEvents\ngot %v\n want %v", gotEventIDs, wantEventIDs) + } + + return nil + }) + if err != nil { + t.Fatalf("err: %s", err) + } + }) +} diff --git a/syncapi/storage/tables/topology_test.go b/syncapi/storage/tables/topology_test.go new file mode 100644 index 000000000..b6ece0b0d --- /dev/null +++ b/syncapi/storage/tables/topology_test.go @@ -0,0 +1,91 @@ +package tables_test + +import ( + "context" + "database/sql" + "fmt" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi/storage/postgres" + "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/test" +) + +func newTopologyTable(t *testing.T, dbType test.DBType) (tables.Topology, *sql.DB, func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }) + if err != nil { + t.Fatalf("failed to open db: %s", err) + } + + var tab tables.Topology + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresTopologyTable(db) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSqliteTopologyTable(db) + } + if err != nil { + t.Fatalf("failed to make new table: %s", err) + } + return tab, db, close +} + +func TestTopologyTable(t *testing.T) { + ctx := context.Background() + alice := test.NewUser() + room := test.NewRoom(t, alice) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, db, close := newTopologyTable(t, dbType) + defer close() + events := room.Events() + err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error { + var highestPos types.StreamPosition + for i, ev := range events { + topoPos, err := tab.InsertEventInTopology(ctx, txn, ev, types.StreamPosition(i)) + if err != nil { + return fmt.Errorf("failed to InsertEventInTopology: %s", err) + } + // topo pos = depth, depth starts at 1, hence 1+i + if topoPos != types.StreamPosition(1+i) { + return fmt.Errorf("got topo pos %d want %d", topoPos, 1+i) + } + highestPos = topoPos + 1 + } + // check ordering works without limit + eventIDs, err := tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, true) + if err != nil { + return fmt.Errorf("failed to SelectEventIDsInRange: %s", err) + } + test.AssertEventIDsEqual(t, eventIDs, events[:]) + eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, false) + if err != nil { + return fmt.Errorf("failed to SelectEventIDsInRange: %s", err) + } + test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[:])) + // check ordering works with limit + eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, true) + if err != nil { + return fmt.Errorf("failed to SelectEventIDsInRange: %s", err) + } + test.AssertEventIDsEqual(t, eventIDs, events[:3]) + eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, false) + if err != nil { + return fmt.Errorf("failed to SelectEventIDsInRange: %s", err) + } + test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[len(events)-3:])) + + return nil + }) + if err != nil { + t.Fatalf("err: %s", err) + } + }) +} diff --git a/test/db.go b/test/db.go index 9deec0a89..674fdf5c3 100644 --- a/test/db.go +++ b/test/db.go @@ -121,6 +121,7 @@ func WithAllDatabases(t *testing.T, testFn func(t *testing.T, db DBType)) { for dbName, dbType := range dbs { dbt := dbType t.Run(dbName, func(tt *testing.T) { + tt.Parallel() testFn(tt, dbt) }) } diff --git a/test/event.go b/test/event.go index 487b09364..b2e2805ba 100644 --- a/test/event.go +++ b/test/event.go @@ -15,7 +15,9 @@ package test import ( + "bytes" "crypto/ed25519" + "testing" "time" "github.com/matrix-org/gomatrixserverlib" @@ -49,3 +51,40 @@ func WithUnsigned(unsigned interface{}) eventModifier { e.unsigned = unsigned } } + +// Reverse a list of events +func Reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent { + out := make([]*gomatrixserverlib.HeaderedEvent, len(in)) + for i := 0; i < len(in); i++ { + out[i] = in[len(in)-i-1] + } + return out +} + +func AssertEventIDsEqual(t *testing.T, gotEventIDs []string, wants []*gomatrixserverlib.HeaderedEvent) { + t.Helper() + if len(gotEventIDs) != len(wants) { + t.Fatalf("length mismatch: got %d events, want %d", len(gotEventIDs), len(wants)) + } + for i := range wants { + w := wants[i].EventID() + g := gotEventIDs[i] + if w != g { + t.Errorf("event at index %d mismatch:\ngot %s\n\nwant %s", i, string(g), string(w)) + } + } +} + +func AssertEventsEqual(t *testing.T, gots, wants []*gomatrixserverlib.HeaderedEvent) { + t.Helper() + if len(gots) != len(wants) { + t.Fatalf("length mismatch: got %d events, want %d", len(gots), len(wants)) + } + for i := range wants { + w := wants[i].JSON() + g := gots[i].JSON() + if !bytes.Equal(w, g) { + t.Errorf("event at index %d mismatch:\ngot %s\n\nwant %s", i, string(g), string(w)) + } + } +} From b4b2fbc36b1eb1b46640feadbe7e1729c864a898 Mon Sep 17 00:00:00 2001 From: kegsay Date: Sat, 9 Apr 2022 00:37:50 +0100 Subject: [PATCH 05/21] Remove dead code in the sync api (#2341) --- .../postgres/backwards_extremities_table.go | 14 ----------- syncapi/storage/postgres/memberships_table.go | 18 --------------- .../output_room_events_topology_table.go | 14 ----------- .../sqlite3/backwards_extremities_table.go | 14 ----------- syncapi/storage/sqlite3/memberships_table.go | 23 ------------------- .../output_room_events_topology_table.go | 8 ------- syncapi/storage/tables/interface.go | 5 ---- 7 files changed, 96 deletions(-) diff --git a/syncapi/storage/postgres/backwards_extremities_table.go b/syncapi/storage/postgres/backwards_extremities_table.go index d5cf563a6..d4515735c 100644 --- a/syncapi/storage/postgres/backwards_extremities_table.go +++ b/syncapi/storage/postgres/backwards_extremities_table.go @@ -47,14 +47,10 @@ const selectBackwardExtremitiesForRoomSQL = "" + const deleteBackwardExtremitySQL = "" + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" -const deleteBackwardExtremitiesForRoomSQL = "" + - "DELETE FROM syncapi_backward_extremities WHERE room_id = $1" - type backwardExtremitiesStatements struct { insertBackwardExtremityStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt - deleteBackwardExtremitiesForRoomStmt *sql.Stmt } func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { @@ -72,9 +68,6 @@ func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremiti if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil { return nil, err } - if s.deleteBackwardExtremitiesForRoomStmt, err = db.Prepare(deleteBackwardExtremitiesForRoomSQL); err != nil { - return nil, err - } return s, nil } @@ -113,10 +106,3 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) return } - -func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom( - ctx context.Context, txn *sql.Tx, roomID string, -) (err error) { - _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID) - return err -} diff --git a/syncapi/storage/postgres/memberships_table.go b/syncapi/storage/postgres/memberships_table.go index 1242a3221..39fa656cb 100644 --- a/syncapi/storage/postgres/memberships_table.go +++ b/syncapi/storage/postgres/memberships_table.go @@ -56,12 +56,6 @@ const upsertMembershipSQL = "" + " ON CONFLICT ON CONSTRAINT syncapi_memberships_unique" + " DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6" -const selectMembershipSQL = "" + - "SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" + - " WHERE room_id = $1 AND user_id = $2 AND membership = ANY($3)" + - " ORDER BY stream_pos DESC" + - " LIMIT 1" - const selectMembershipCountSQL = "" + "SELECT COUNT(*) FROM (" + " SELECT DISTINCT ON (room_id, user_id) room_id, user_id, membership FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC" + @@ -69,7 +63,6 @@ const selectMembershipCountSQL = "" + type membershipsStatements struct { upsertMembershipStmt *sql.Stmt - selectMembershipStmt *sql.Stmt selectMembershipCountStmt *sql.Stmt } @@ -82,9 +75,6 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) { if s.upsertMembershipStmt, err = db.Prepare(upsertMembershipSQL); err != nil { return nil, err } - if s.selectMembershipStmt, err = db.Prepare(selectMembershipSQL); err != nil { - return nil, err - } if s.selectMembershipCountStmt, err = db.Prepare(selectMembershipCountSQL); err != nil { return nil, err } @@ -111,14 +101,6 @@ func (s *membershipsStatements) UpsertMembership( return err } -func (s *membershipsStatements) SelectMembership( - ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string, -) (eventID string, streamPos, topologyPos types.StreamPosition, err error) { - stmt := sqlutil.TxStmt(txn, s.selectMembershipStmt) - err = stmt.QueryRowContext(ctx, roomID, userID, memberships).Scan(&eventID, &streamPos, &topologyPos) - return -} - func (s *membershipsStatements) SelectMembershipCount( ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition, ) (count int, err error) { diff --git a/syncapi/storage/postgres/output_room_events_topology_table.go b/syncapi/storage/postgres/output_room_events_topology_table.go index 90b3b0083..a1fc9b2a3 100644 --- a/syncapi/storage/postgres/output_room_events_topology_table.go +++ b/syncapi/storage/postgres/output_room_events_topology_table.go @@ -73,9 +73,6 @@ const selectMaxPositionInTopologySQL = "" + "SELECT MAX(topological_position) FROM syncapi_output_room_events_topology WHERE room_id=$1" + ") ORDER BY stream_position DESC LIMIT 1" -const deleteTopologyForRoomSQL = "" + - "DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1" - const selectStreamToTopologicalPositionAscSQL = "" + "SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position >= $2 ORDER BY topological_position ASC LIMIT 1;" @@ -88,7 +85,6 @@ type outputRoomEventsTopologyStatements struct { selectEventIDsInRangeDESCStmt *sql.Stmt selectPositionInTopologyStmt *sql.Stmt selectMaxPositionInTopologyStmt *sql.Stmt - deleteTopologyForRoomStmt *sql.Stmt selectStreamToTopologicalPositionAscStmt *sql.Stmt selectStreamToTopologicalPositionDescStmt *sql.Stmt } @@ -114,9 +110,6 @@ func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) { if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil { return nil, err } - if s.deleteTopologyForRoomStmt, err = db.Prepare(deleteTopologyForRoomSQL); err != nil { - return nil, err - } if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil { return nil, err } @@ -203,10 +196,3 @@ func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos) return } - -func (s *outputRoomEventsTopologyStatements) DeleteTopologyForRoom( - ctx context.Context, txn *sql.Tx, roomID string, -) (err error) { - _, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID) - return err -} diff --git a/syncapi/storage/sqlite3/backwards_extremities_table.go b/syncapi/storage/sqlite3/backwards_extremities_table.go index 662cb0252..c5674dded 100644 --- a/syncapi/storage/sqlite3/backwards_extremities_table.go +++ b/syncapi/storage/sqlite3/backwards_extremities_table.go @@ -47,15 +47,11 @@ const selectBackwardExtremitiesForRoomSQL = "" + const deleteBackwardExtremitySQL = "" + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" -const deleteBackwardExtremitiesForRoomSQL = "" + - "DELETE FROM syncapi_backward_extremities WHERE room_id = $1" - type backwardExtremitiesStatements struct { db *sql.DB insertBackwardExtremityStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt - deleteBackwardExtremitiesForRoomStmt *sql.Stmt } func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { @@ -75,9 +71,6 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil { return nil, err } - if s.deleteBackwardExtremitiesForRoomStmt, err = db.Prepare(deleteBackwardExtremitiesForRoomSQL); err != nil { - return nil, err - } return s, nil } @@ -116,10 +109,3 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) return err } - -func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom( - ctx context.Context, txn *sql.Tx, roomID string, -) (err error) { - _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID) - return err -} diff --git a/syncapi/storage/sqlite3/memberships_table.go b/syncapi/storage/sqlite3/memberships_table.go index 776bf3da3..9f3530ccd 100644 --- a/syncapi/storage/sqlite3/memberships_table.go +++ b/syncapi/storage/sqlite3/memberships_table.go @@ -18,7 +18,6 @@ import ( "context" "database/sql" "fmt" - "strings" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" @@ -57,12 +56,6 @@ const upsertMembershipSQL = "" + " ON CONFLICT (room_id, user_id, membership)" + " DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6" -const selectMembershipSQL = "" + - "SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" + - " WHERE room_id = $1 AND user_id = $2 AND membership IN ($3)" + - " ORDER BY stream_pos DESC" + - " LIMIT 1" - const selectMembershipCountSQL = "" + "SELECT COUNT(*) FROM (" + " SELECT * FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 GROUP BY user_id HAVING(max(stream_pos))" + @@ -111,22 +104,6 @@ func (s *membershipsStatements) UpsertMembership( return err } -func (s *membershipsStatements) SelectMembership( - ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string, -) (eventID string, streamPos, topologyPos types.StreamPosition, err error) { - params := []interface{}{roomID, userID} - for _, membership := range memberships { - params = append(params, membership) - } - orig := strings.Replace(selectMembershipSQL, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1) - stmt, err := s.db.Prepare(orig) - if err != nil { - return "", 0, 0, err - } - err = sqlutil.TxStmt(txn, stmt).QueryRowContext(ctx, params...).Scan(&eventID, &streamPos, &topologyPos) - return -} - func (s *membershipsStatements) SelectMembershipCount( ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition, ) (count int, err error) { diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go index b972ae285..b2fb77417 100644 --- a/syncapi/storage/sqlite3/output_room_events_topology_table.go +++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go @@ -78,7 +78,6 @@ type outputRoomEventsTopologyStatements struct { selectEventIDsInRangeDESCStmt *sql.Stmt selectPositionInTopologyStmt *sql.Stmt selectMaxPositionInTopologyStmt *sql.Stmt - deleteTopologyForRoomStmt *sql.Stmt selectStreamToTopologicalPositionAscStmt *sql.Stmt selectStreamToTopologicalPositionDescStmt *sql.Stmt } @@ -191,10 +190,3 @@ func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos) return } - -func (s *outputRoomEventsTopologyStatements) DeleteTopologyForRoom( - ctx context.Context, txn *sql.Tx, roomID string, -) (err error) { - _, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID) - return err -} diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 3cbeb0462..a7df70248 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -84,8 +84,6 @@ type Topology interface { SelectPositionInTopology(ctx context.Context, txn *sql.Tx, eventID string) (depth, spos types.StreamPosition, err error) // SelectMaxPositionInTopology returns the event which has the highest depth, and if there are multiple, the event with the highest stream position. SelectMaxPositionInTopology(ctx context.Context, txn *sql.Tx, roomID string) (depth types.StreamPosition, spos types.StreamPosition, err error) - // DeleteTopologyForRoom removes all topological information for a room. This should only be done when removing the room entirely. - DeleteTopologyForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error) // SelectStreamToTopologicalPosition converts a stream position to a topological position by finding the nearest topological position in the room. SelectStreamToTopologicalPosition(ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, forward bool) (topoPos types.StreamPosition, err error) } @@ -132,8 +130,6 @@ type BackwardsExtremities interface { SelectBackwardExtremitiesForRoom(ctx context.Context, roomID string) (bwExtrems map[string][]string, err error) // DeleteBackwardExtremity removes a backwards extremity for a room, if one existed. DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error) - // DeleteBackwardExtremitiesFoorRoomID removes all backward extremities for a room. This should only be done when removing the room entirely. - DeleteBackwardExtremitiesForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error) } // SendToDevice tracks send-to-device messages which are sent to individual @@ -173,7 +169,6 @@ type Receipts interface { type Memberships interface { UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error - SelectMembership(ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string) (eventID string, streamPos, topologyPos types.StreamPosition, err error) SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error) } From 69f2ff7c82abe0731a05febde88098f4cd34ab8d Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Mon, 11 Apr 2022 09:05:23 +0200 Subject: [PATCH 06/21] Correctly use provided filters (#2339) * Apply filters correctly * Fix issues; Use prepareWithFilters * Update gmsl & tests * go.mod.. * PR comments --- go.mod | 3 +- go.sum | 7 +-- syncapi/routing/context.go | 4 +- .../postgres/current_room_state_table.go | 5 +- syncapi/storage/postgres/filtering.go | 32 ++++++++++-- .../postgres/output_room_events_table.go | 26 ++++++---- syncapi/storage/sqlite3/account_data_table.go | 39 +++++---------- syncapi/storage/sqlite3/filtering.go | 50 ++++++++++++------- syncapi/streams/stream_pdu.go | 6 ++- sytest-whitelist | 4 +- 10 files changed, 109 insertions(+), 67 deletions(-) diff --git a/go.mod b/go.mod index f4ac8d123..070f2b5b9 100644 --- a/go.mod +++ b/go.mod @@ -38,7 +38,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220405134050-301e340659d5 + github.com/matrix-org/gomatrixserverlib v0.0.0-20220408160933-cf558306b56f github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.10 @@ -64,6 +64,7 @@ require ( golang.org/x/image v0.0.0-20220321031419-a8550c1d254a golang.org/x/mobile v0.0.0-20220407111146-e579adbbc4a2 golang.org/x/net v0.0.0-20220407224826-aac1ed45d8e3 + golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12 // indirect golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 gopkg.in/h2non/bimg.v1 v1.1.9 gopkg.in/yaml.v2 v2.4.0 diff --git a/go.sum b/go.sum index 063365168..464f774be 100644 --- a/go.sum +++ b/go.sum @@ -1112,8 +1112,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220405134050-301e340659d5 h1:Fkennny7+Z/5pygrhjFMZbz1j++P2hhhLoT7NO3p8DQ= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220405134050-301e340659d5/go.mod h1:V5eO8rn/C3rcxig37A/BCeKerLFS+9Avg/77FIeTZ48= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220408160933-cf558306b56f h1:MZrl4TgTnlaOn2Cu9gJCoJ3oyW5mT4/3QIZGgZXzKl4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220408160933-cf558306b56f/go.mod h1:V5eO8rn/C3rcxig37A/BCeKerLFS+9Avg/77FIeTZ48= github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48 h1:W0sjjC6yjskHX4mb0nk3p0fXAlbU5bAFUFeEtlrPASE= github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48/go.mod h1:ulJzsVOTssIVp1j/m5eI//4VpAGDkMt5NrRuAVX7wpc= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= @@ -1977,8 +1977,9 @@ golang.org/x/sys v0.0.0-20211116061358-0a5406a5449c/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64 h1:D1v9ucDTYBtbz5vNuBbAhIMAGhQhJ6Ym5ah3maMVNX4= golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12 h1:QyVthZKMsyaQwBTJE04jdNN0Pp5Fn9Qga0mrgxyERQM= +golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index 2412bc2ae..aaa0c61bf 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -60,7 +60,9 @@ func Context( Headers: nil, } } - filter.Rooms = append(filter.Rooms, roomID) + if filter.Rooms != nil { + *filter.Rooms = append(*filter.Rooms, roomID) + } ctx := req.Context() membershipRes := roomserver.QueryMembershipForUserResponse{} diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index 69e6e30ec..fe68788d1 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -233,9 +233,10 @@ func (s *currentRoomStateStatements) SelectCurrentState( excludeEventIDs []string, ) ([]*gomatrixserverlib.HeaderedEvent, error) { stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt) + senders, notSenders := getSendersStateFilterFilter(stateFilter) rows, err := stmt.QueryContext(ctx, roomID, - pq.StringArray(stateFilter.Senders), - pq.StringArray(stateFilter.NotSenders), + pq.StringArray(senders), + pq.StringArray(notSenders), pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)), stateFilter.ContainsURL, diff --git a/syncapi/storage/postgres/filtering.go b/syncapi/storage/postgres/filtering.go index dcc421362..a2ca42156 100644 --- a/syncapi/storage/postgres/filtering.go +++ b/syncapi/storage/postgres/filtering.go @@ -16,21 +16,45 @@ package postgres import ( "strings" + + "github.com/matrix-org/gomatrixserverlib" ) // filterConvertWildcardToSQL converts wildcards as defined in // https://matrix.org/docs/spec/client_server/r0.3.0.html#post-matrix-client-r0-user-userid-filter // to SQL wildcards that can be used with LIKE() -func filterConvertTypeWildcardToSQL(values []string) []string { +func filterConvertTypeWildcardToSQL(values *[]string) []string { if values == nil { // Return nil instead of []string{} so IS NULL can work correctly when // the return value is passed into SQL queries return nil } - ret := make([]string, len(values)) - for i := range values { - ret[i] = strings.Replace(values[i], "*", "%", -1) + v := *values + ret := make([]string, len(v)) + for i := range v { + ret[i] = strings.Replace(v[i], "*", "%", -1) } return ret } + +// TODO: Replace when Dendrite uses Go 1.18 +func getSendersRoomEventFilter(filter *gomatrixserverlib.RoomEventFilter) (senders []string, notSenders []string) { + if filter.Senders != nil { + senders = *filter.Senders + } + if filter.NotSenders != nil { + notSenders = *filter.NotSenders + } + return senders, notSenders +} + +func getSendersStateFilterFilter(filter *gomatrixserverlib.StateFilter) (senders []string, notSenders []string) { + if filter.Senders != nil { + senders = *filter.Senders + } + if filter.NotSenders != nil { + notSenders = *filter.NotSenders + } + return senders, notSenders +} diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index a30e220ba..269cd4494 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -204,11 +204,11 @@ func (s *outputRoomEventsStatements) SelectStateInRange( stateFilter *gomatrixserverlib.StateFilter, roomIDs []string, ) (map[string]map[string]bool, map[string]types.StreamEvent, error) { stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt) - + senders, notSenders := getSendersStateFilterFilter(stateFilter) rows, err := stmt.QueryContext( ctx, r.Low(), r.High(), pq.StringArray(roomIDs), - pq.StringArray(stateFilter.Senders), - pq.StringArray(stateFilter.NotSenders), + pq.StringArray(senders), + pq.StringArray(notSenders), pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)), stateFilter.ContainsURL, @@ -353,10 +353,11 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( } else { stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt) } + senders, notSenders := getSendersRoomEventFilter(eventFilter) rows, err := stmt.QueryContext( ctx, roomID, r.Low(), r.High(), - pq.StringArray(eventFilter.Senders), - pq.StringArray(eventFilter.NotSenders), + pq.StringArray(senders), + pq.StringArray(notSenders), pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)), eventFilter.Limit+1, @@ -398,11 +399,12 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents( ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, ) ([]types.StreamEvent, error) { + senders, notSenders := getSendersRoomEventFilter(eventFilter) stmt := sqlutil.TxStmt(txn, s.selectEarlyEventsStmt) rows, err := stmt.QueryContext( ctx, roomID, r.Low(), r.High(), - pq.StringArray(eventFilter.Senders), - pq.StringArray(eventFilter.NotSenders), + pq.StringArray(senders), + pq.StringArray(notSenders), pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)), eventFilter.Limit, @@ -480,10 +482,11 @@ func (s *outputRoomEventsStatements) SelectContextEvent(ctx context.Context, txn func (s *outputRoomEventsStatements) SelectContextBeforeEvent( ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter, ) (evts []*gomatrixserverlib.HeaderedEvent, err error) { + senders, notSenders := getSendersRoomEventFilter(filter) rows, err := sqlutil.TxStmt(txn, s.selectContextBeforeEventStmt).QueryContext( ctx, roomID, id, filter.Limit, - pq.StringArray(filter.Senders), - pq.StringArray(filter.NotSenders), + pq.StringArray(senders), + pq.StringArray(notSenders), pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)), ) @@ -512,10 +515,11 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent( func (s *outputRoomEventsStatements) SelectContextAfterEvent( ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter, ) (lastID int, evts []*gomatrixserverlib.HeaderedEvent, err error) { + senders, notSenders := getSendersRoomEventFilter(filter) rows, err := sqlutil.TxStmt(txn, s.selectContextAfterEventStmt).QueryContext( ctx, roomID, id, filter.Limit, - pq.StringArray(filter.Senders), - pq.StringArray(filter.NotSenders), + pq.StringArray(senders), + pq.StringArray(notSenders), pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)), ) diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index 5b2287e6d..b0aeb70f2 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -41,10 +41,10 @@ const insertAccountDataSQL = "" + " ON CONFLICT (user_id, room_id, type) DO UPDATE" + " SET id = $5" +// further parameters are added by prepareWithFilters const selectAccountDataInRangeSQL = "" + "SELECT room_id, type FROM syncapi_account_data_type" + - " WHERE user_id = $1 AND id > $2 AND id <= $3" + - " ORDER BY id ASC" + " WHERE user_id = $1 AND id > $2 AND id <= $3" const selectMaxAccountDataIDSQL = "" + "SELECT MAX(id) FROM syncapi_account_data_type" @@ -94,18 +94,25 @@ func (s *accountDataStatements) SelectAccountDataInRange( ctx context.Context, userID string, r types.Range, - accountDataFilterPart *gomatrixserverlib.EventFilter, + filter *gomatrixserverlib.EventFilter, ) (data map[string][]string, err error) { data = make(map[string][]string) + stmt, params, err := prepareWithFilters( + s.db, nil, selectAccountDataInRangeSQL, + []interface{}{ + userID, r.Low(), r.High(), + }, + filter.Senders, filter.NotSenders, + filter.Types, filter.NotTypes, + []string{}, filter.Limit, FilterOrderAsc, + ) - rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High()) + rows, err := stmt.QueryContext(ctx, params...) if err != nil { return } defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed") - var entries int - for rows.Next() { var dataType string var roomID string @@ -114,31 +121,11 @@ func (s *accountDataStatements) SelectAccountDataInRange( return } - // check if we should add this by looking at the filter. - // It would be nice if we could do this in SQL-land, but the mix of variadic - // and positional parameters makes the query annoyingly hard to do, it's easier - // and clearer to do it in Go-land. If there are no filters for [not]types then - // this gets skipped. - for _, includeType := range accountDataFilterPart.Types { - if includeType != dataType { // TODO: wildcard support - continue - } - } - for _, excludeType := range accountDataFilterPart.NotTypes { - if excludeType == dataType { // TODO: wildcard support - continue - } - } - if len(data[roomID]) > 0 { data[roomID] = append(data[roomID], dataType) } else { data[roomID] = []string{dataType} } - entries++ - if entries >= accountDataFilterPart.Limit { - break - } } return data, nil diff --git a/syncapi/storage/sqlite3/filtering.go b/syncapi/storage/sqlite3/filtering.go index 11f3e647b..54b12ddf8 100644 --- a/syncapi/storage/sqlite3/filtering.go +++ b/syncapi/storage/sqlite3/filtering.go @@ -25,32 +25,48 @@ const ( // parts. func prepareWithFilters( db *sql.DB, txn *sql.Tx, query string, params []interface{}, - senders, notsenders, types, nottypes []string, excludeEventIDs []string, + senders, notsenders, types, nottypes *[]string, excludeEventIDs []string, limit int, order FilterOrder, ) (*sql.Stmt, []interface{}, error) { offset := len(params) - if count := len(senders); count > 0 { - query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset) - for _, v := range senders { - params, offset = append(params, v), offset+1 + if senders != nil { + if count := len(*senders); count > 0 { + query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset) + for _, v := range *senders { + params, offset = append(params, v), offset+1 + } + } else { + query += ` AND sender = ""` } } - if count := len(notsenders); count > 0 { - query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset) - for _, v := range notsenders { - params, offset = append(params, v), offset+1 + if notsenders != nil { + if count := len(*notsenders); count > 0 { + query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset) + for _, v := range *notsenders { + params, offset = append(params, v), offset+1 + } + } else { + query += ` AND sender NOT = ""` } } - if count := len(types); count > 0 { - query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset) - for _, v := range types { - params, offset = append(params, v), offset+1 + if types != nil { + if count := len(*types); count > 0 { + query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset) + for _, v := range *types { + params, offset = append(params, v), offset+1 + } + } else { + query += ` AND type = ""` } } - if count := len(nottypes); count > 0 { - query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset) - for _, v := range nottypes { - params, offset = append(params, v), offset+1 + if nottypes != nil { + if count := len(*nottypes); count > 0 { + query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset) + for _, v := range *nottypes { + params, offset = append(params, v), offset+1 + } + } else { + query += ` AND type NOT = ""` } } if count := len(excludeEventIDs); count > 0 { diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index ab200e007..bcaf6ca31 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -423,8 +423,12 @@ func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, req *ty return err } req.IgnoredUsers = *ignores + userList := make([]string, 0, len(ignores.List)) for userID := range ignores.List { - eventFilter.NotSenders = append(eventFilter.NotSenders, userID) + userList = append(userList, userID) + } + if len(userList) > 0 { + eventFilter.NotSenders = &userList } return nil } diff --git a/sytest-whitelist b/sytest-whitelist index dc67c9935..a7aea05ed 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -696,4 +696,6 @@ Room state after a rejected message event is the same as before Room state after a rejected state event is the same as before Ignore user in existing room Ignore invite in full sync -Ignore invite in incremental sync \ No newline at end of file +Ignore invite in incremental sync +A filtered timeline reaches its limit +A change to displayname should not result in a full state sync \ No newline at end of file From ea92f80c128bcff0fd4e02df40b272f7d90a97da Mon Sep 17 00:00:00 2001 From: kegsay Date: Mon, 11 Apr 2022 10:23:01 +0100 Subject: [PATCH 07/21] Add database namespacing for unit tests (#2340) * Add database namespacing for unit tests Background: Running `go test ./...` will run tests in different packages concurrently. This can be stopped or limited by using `-p 1` (no concurrency). We want concurrency, but this causes problems when running Postgres DBs in CI. The problem is that, in CI, we have 1x postgres server exposing 1x postgres DB, which we wipe clean at the end of each test via `defer close()`. When tests run concurrently, calls to `close()` will delete data/tables which other tests are currently using, causing havoc. Fix this by: - Creating a database per package. - Namespacing the database name by a hash of the current working directory (the directory containing those `_test.go` files) This is exactly what SQLite does, quite unintentionally, via the use of `file:dendrite_test.db`, which dumps the file into the current working directory which is the package running the tests, hence deleting the file is safe when running concurrently. * Linting * Don't create the database in a txn * dupe db is not an error --- .github/workflows/dendrite.yml | 2 +- test/db.go | 58 +++++++++++++++++++++++++++++----- 2 files changed, 51 insertions(+), 9 deletions(-) diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml index 8221bff96..4f337a866 100644 --- a/.github/workflows/dendrite.yml +++ b/.github/workflows/dendrite.yml @@ -111,7 +111,7 @@ jobs: key: ${{ runner.os }}-go${{ matrix.go }}-test-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go${{ matrix.go }}-test- - - run: go test -p 1 ./... + - run: go test ./... env: POSTGRES_HOST: localhost POSTGRES_USER: postgres diff --git a/test/db.go b/test/db.go index 674fdf5c3..6412feaa6 100644 --- a/test/db.go +++ b/test/db.go @@ -15,12 +15,16 @@ package test import ( + "crypto/sha256" "database/sql" + "encoding/hex" "fmt" "os" "os/exec" "os/user" "testing" + + "github.com/lib/pq" ) type DBType int @@ -30,7 +34,7 @@ var DBTypePostgres DBType = 2 var Quiet = false -func createLocalDB(dbName string) string { +func createLocalDB(dbName string) { if !Quiet { fmt.Println("Note: tests require a postgres install accessible to the current user") } @@ -43,7 +47,29 @@ func createLocalDB(dbName string) string { if err != nil && !Quiet { fmt.Println("createLocalDB returned error:", err) } - return dbName +} + +func createRemoteDB(t *testing.T, dbName, user, connStr string) { + db, err := sql.Open("postgres", connStr+" dbname=postgres") + if err != nil { + t.Fatalf("failed to open postgres conn with connstr=%s : %s", connStr, err) + } + _, err = db.Exec(fmt.Sprintf(`CREATE DATABASE %s;`, dbName)) + if err != nil { + pqErr, ok := err.(*pq.Error) + if !ok { + t.Fatalf("failed to CREATE DATABASE: %s", err) + } + // we ignore duplicate database error as we expect this + if pqErr.Code != "42P04" { + t.Fatalf("failed to CREATE DATABASE with code=%s msg=%s", pqErr.Code, pqErr.Message) + } + } + _, err = db.Exec(fmt.Sprintf(`GRANT ALL PRIVILEGES ON DATABASE %s TO %s`, dbName, user)) + if err != nil { + t.Fatalf("failed to GRANT: %s", err) + } + _ = db.Close() } func currentUser() string { @@ -64,6 +90,7 @@ func currentUser() string { // TODO: namespace for concurrent package tests func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, close func()) { if dbType == DBTypeSQLite { + // this will be made in the current working directory which namespaces concurrent package runs correctly dbname := "dendrite_test.db" return fmt.Sprintf("file:%s", dbname), func() { err := os.Remove(dbname) @@ -79,13 +106,9 @@ func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, clo if user == "" { user = currentUser() } - dbName := os.Getenv("POSTGRES_DB") - if dbName == "" { - dbName = createLocalDB("dendrite_test") - } connStr = fmt.Sprintf( - "user=%s dbname=%s sslmode=disable", - user, dbName, + "user=%s sslmode=disable", + user, ) // optional vars, used in CI password := os.Getenv("POSTGRES_PASSWORD") @@ -97,6 +120,25 @@ func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, clo connStr += fmt.Sprintf(" host=%s", host) } + // superuser database + postgresDB := os.Getenv("POSTGRES_DB") + // we cannot use 'dendrite_test' here else 2x concurrently running packages will try to use the same db. + // instead, hash the current working directory, snaffle the first 16 bytes and append that to dendrite_test + // and use that as the unique db name. We do this because packages are per-directory hence by hashing the + // working (test) directory we ensure we get a consistent hash and don't hash against concurrent packages. + wd, err := os.Getwd() + if err != nil { + t.Fatalf("cannot get working directory: %s", err) + } + hash := sha256.Sum256([]byte(wd)) + dbName := fmt.Sprintf("dendrite_test_%s", hex.EncodeToString(hash[:16])) + if postgresDB == "" { // local server, use createdb + createLocalDB(dbName) + } else { // remote server, shell into the postgres user and CREATE DATABASE + createRemoteDB(t, dbName, user, connStr) + } + connStr += fmt.Sprintf(" dbname=%s", dbName) + return connStr, func() { // Drop all tables on the database to get a fresh instance db, err := sql.Open("postgres", connStr) From 29f216878994dccc68e34c90e5a0240c7698589f Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Wed, 13 Apr 2022 13:16:02 +0200 Subject: [PATCH 08/21] Make `/messages` filterable (#2347) * Make /messages filterable Fix bug when determining if an event contains an URL * Add newly passing test * Fix test --- syncapi/routing/messages.go | 6 +-- syncapi/storage/interface.go | 2 +- .../postgres/output_room_events_table.go | 38 +++++++++++++++++-- syncapi/storage/shared/syncserver.go | 13 ++++--- syncapi/storage/sqlite3/account_data_table.go | 3 +- .../sqlite3/current_room_state_table.go | 2 +- syncapi/storage/sqlite3/filtering.go | 5 ++- .../sqlite3/output_room_events_table.go | 33 +++++++++------- syncapi/storage/storage_test.go | 3 +- syncapi/storage/tables/interface.go | 2 +- .../storage/tables/output_room_events_test.go | 25 +++++++++++- sytest-whitelist | 3 +- 12 files changed, 98 insertions(+), 37 deletions(-) diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 36ba3a3e6..519aeff68 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -262,12 +262,8 @@ func (r *messagesReq) retrieveEvents() ( clientEvents []gomatrixserverlib.ClientEvent, start, end types.TopologyToken, err error, ) { - eventFilter := r.filter - // Retrieve the events from the local database. - streamEvents, err := r.db.GetEventsInTopologicalRange( - r.ctx, r.from, r.to, r.roomID, eventFilter.Limit, r.backwardOrdering, - ) + streamEvents, err := r.db.GetEventsInTopologicalRange(r.ctx, r.from, r.to, r.roomID, r.filter, r.backwardOrdering) if err != nil { err = fmt.Errorf("GetEventsInRange: %w", err) return diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index cf3fd5532..14cb08a52 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -105,7 +105,7 @@ type Database interface { // Returns an error if there was a problem communicating with the database. DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error) // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last. - GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) + GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, filter *gomatrixserverlib.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error) // EventPositionInTopology returns the depth and stream position of the given event. EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error) // BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events. diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index 269cd4494..17e2feab6 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -81,6 +81,15 @@ const insertEventSQL = "" + const selectEventsSQL = "" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)" +const selectEventsWithFilterSQL = "" + + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)" + + " AND ( $2::text[] IS NULL OR sender = ANY($2) )" + + " AND ( $3::text[] IS NULL OR NOT(sender = ANY($3)) )" + + " AND ( $4::text[] IS NULL OR type LIKE ANY($4) )" + + " AND ( $5::text[] IS NULL OR NOT(type LIKE ANY($5)) )" + + " AND ( $6::bool IS NULL OR contains_url = $6 )" + + " LIMIT $7" + const selectRecentEventsSQL = "" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + " WHERE room_id = $1 AND id > $2 AND id <= $3" + @@ -153,6 +162,7 @@ const selectContextAfterEventSQL = "" + type outputRoomEventsStatements struct { insertEventStmt *sql.Stmt selectEventsStmt *sql.Stmt + selectEventsWitFilterStmt *sql.Stmt selectMaxEventIDStmt *sql.Stmt selectRecentEventsStmt *sql.Stmt selectRecentEventsForSyncStmt *sql.Stmt @@ -174,6 +184,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { return s, sqlutil.StatementList{ {&s.insertEventStmt, insertEventSQL}, {&s.selectEventsStmt, selectEventsSQL}, + {&s.selectEventsWitFilterStmt, selectEventsWithFilterSQL}, {&s.selectMaxEventIDStmt, selectMaxEventIDSQL}, {&s.selectRecentEventsStmt, selectRecentEventsSQL}, {&s.selectRecentEventsForSyncStmt, selectRecentEventsForSyncSQL}, @@ -310,7 +321,7 @@ func (s *outputRoomEventsStatements) InsertEvent( // Parse content as JSON and search for an "url" key containsURL := false var content map[string]interface{} - if json.Unmarshal(event.Content(), &content) != nil { + if json.Unmarshal(event.Content(), &content) == nil { // Set containsURL to true if url is present _, containsURL = content["url"] } @@ -429,10 +440,29 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents( // selectEvents returns the events for the given event IDs. If an event is // missing from the database, it will be omitted. func (s *outputRoomEventsStatements) SelectEvents( - ctx context.Context, txn *sql.Tx, eventIDs []string, preserveOrder bool, + ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool, ) ([]types.StreamEvent, error) { - stmt := sqlutil.TxStmt(txn, s.selectEventsStmt) - rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) + var ( + stmt *sql.Stmt + rows *sql.Rows + err error + ) + if filter == nil { + stmt = sqlutil.TxStmt(txn, s.selectEventsStmt) + rows, err = stmt.QueryContext(ctx, pq.StringArray(eventIDs)) + } else { + senders, notSenders := getSendersRoomEventFilter(filter) + stmt = sqlutil.TxStmt(txn, s.selectEventsWitFilterStmt) + rows, err = stmt.QueryContext(ctx, + pq.StringArray(eventIDs), + pq.StringArray(senders), + pq.StringArray(notSenders), + pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)), + pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)), + filter.ContainsURL, + filter.Limit, + ) + } if err != nil { return nil, err } diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 14db5795c..91eba44e1 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -150,7 +150,7 @@ func (d *Database) RoomReceiptsAfter(ctx context.Context, roomIDs []string, stre // Returns an error if there was a problem talking with the database. // Does not include any transaction IDs in the returned events. func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { - streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, false) + streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, nil, false) if err != nil { return nil, err } @@ -312,7 +312,7 @@ func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, e // Check if we have all of the event's previous events. If an event is // missing, add it to the room's backward extremities. - prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs(), false) + prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs(), nil, false) if err != nil { return err } @@ -429,7 +429,8 @@ func (d *Database) updateRoomState( func (d *Database) GetEventsInTopologicalRange( ctx context.Context, from, to *types.TopologyToken, - roomID string, limit int, + roomID string, + filter *gomatrixserverlib.RoomEventFilter, backwardOrdering bool, ) (events []types.StreamEvent, err error) { var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition @@ -450,14 +451,14 @@ func (d *Database) GetEventsInTopologicalRange( // Select the event IDs from the defined range. var eIDs []string eIDs, err = d.Topology.SelectEventIDsInRange( - ctx, nil, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, limit, !backwardOrdering, + ctx, nil, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, filter.Limit, !backwardOrdering, ) if err != nil { return } // Retrieve the events' contents using their IDs. - events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs, true) + events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs, filter, true) return } @@ -619,7 +620,7 @@ func (d *Database) fetchMissingStateEvents( ) ([]types.StreamEvent, error) { // Fetch from the events table first so we pick up the stream ID for the // event. - events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs, false) + events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs, nil, false) if err != nil { return nil, err } diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index b0aeb70f2..71a098177 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -104,8 +104,7 @@ func (s *accountDataStatements) SelectAccountDataInRange( }, filter.Senders, filter.NotSenders, filter.Types, filter.NotTypes, - []string{}, filter.Limit, FilterOrderAsc, - ) + []string{}, nil, filter.Limit, FilterOrderAsc) rows, err := stmt.QueryContext(ctx, params...) if err != nil { diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 464f32e04..ccda005c1 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -220,7 +220,7 @@ func (s *currentRoomStateStatements) SelectCurrentState( }, stateFilter.Senders, stateFilter.NotSenders, stateFilter.Types, stateFilter.NotTypes, - excludeEventIDs, stateFilter.Limit, FilterOrderNone, + excludeEventIDs, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderNone, ) if err != nil { return nil, fmt.Errorf("s.prepareWithFilters: %w", err) diff --git a/syncapi/storage/sqlite3/filtering.go b/syncapi/storage/sqlite3/filtering.go index 54b12ddf8..05edb7b8c 100644 --- a/syncapi/storage/sqlite3/filtering.go +++ b/syncapi/storage/sqlite3/filtering.go @@ -26,7 +26,7 @@ const ( func prepareWithFilters( db *sql.DB, txn *sql.Tx, query string, params []interface{}, senders, notsenders, types, nottypes *[]string, excludeEventIDs []string, - limit int, order FilterOrder, + containsURL *bool, limit int, order FilterOrder, ) (*sql.Stmt, []interface{}, error) { offset := len(params) if senders != nil { @@ -69,6 +69,9 @@ func prepareWithFilters( query += ` AND type NOT = ""` } } + if containsURL != nil { + query += fmt.Sprintf(" AND contains_url = %v", *containsURL) + } if count := len(excludeEventIDs); count > 0 { query += " AND event_id NOT IN " + sqlutil.QueryVariadicOffset(count, offset) for _, v := range excludeEventIDs { diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 9da9d776e..188f7582b 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -168,7 +168,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange( s.db, txn, stmtSQL, inputParams, stateFilter.Senders, stateFilter.NotSenders, stateFilter.Types, stateFilter.NotTypes, - nil, stateFilter.Limit, FilterOrderAsc, + nil, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderAsc, ) if err != nil { return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err) @@ -277,7 +277,7 @@ func (s *outputRoomEventsStatements) InsertEvent( // Parse content as JSON and search for an "url" key containsURL := false var content map[string]interface{} - if json.Unmarshal(event.Content(), &content) != nil { + if json.Unmarshal(event.Content(), &content) == nil { // Set containsURL to true if url is present _, containsURL = content["url"] } @@ -345,7 +345,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( }, eventFilter.Senders, eventFilter.NotSenders, eventFilter.Types, eventFilter.NotTypes, - nil, eventFilter.Limit+1, FilterOrderDesc, + nil, eventFilter.ContainsURL, eventFilter.Limit+1, FilterOrderDesc, ) if err != nil { return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err) @@ -393,7 +393,7 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents( }, eventFilter.Senders, eventFilter.NotSenders, eventFilter.Types, eventFilter.NotTypes, - nil, eventFilter.Limit, FilterOrderAsc, + nil, eventFilter.ContainsURL, eventFilter.Limit, FilterOrderAsc, ) if err != nil { return nil, fmt.Errorf("s.prepareWithFilters: %w", err) @@ -419,20 +419,27 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents( // selectEvents returns the events for the given event IDs. If an event is // missing from the database, it will be omitted. func (s *outputRoomEventsStatements) SelectEvents( - ctx context.Context, txn *sql.Tx, eventIDs []string, preserveOrder bool, + ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool, ) ([]types.StreamEvent, error) { iEventIDs := make([]interface{}, len(eventIDs)) for i := range eventIDs { iEventIDs[i] = eventIDs[i] } selectSQL := strings.Replace(selectEventsSQL, "($1)", sqlutil.QueryVariadic(len(eventIDs)), 1) - var rows *sql.Rows - var err error - if txn != nil { - rows, err = txn.QueryContext(ctx, selectSQL, iEventIDs...) - } else { - rows, err = s.db.QueryContext(ctx, selectSQL, iEventIDs...) + + if filter == nil { + filter = &gomatrixserverlib.RoomEventFilter{Limit: 20} } + stmt, params, err := prepareWithFilters( + s.db, txn, selectSQL, iEventIDs, + filter.Senders, filter.NotSenders, + filter.Types, filter.NotTypes, + nil, filter.ContainsURL, filter.Limit, FilterOrderAsc, + ) + if err != nil { + return nil, err + } + rows, err := stmt.QueryContext(ctx, params...) if err != nil { return nil, err } @@ -527,7 +534,7 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent( }, filter.Senders, filter.NotSenders, filter.Types, filter.NotTypes, - nil, filter.Limit, FilterOrderDesc, + nil, filter.ContainsURL, filter.Limit, FilterOrderDesc, ) rows, err := stmt.QueryContext(ctx, params...) @@ -563,7 +570,7 @@ func (s *outputRoomEventsStatements) SelectContextAfterEvent( }, filter.Senders, filter.NotSenders, filter.Types, filter.NotTypes, - nil, filter.Limit, FilterOrderAsc, + nil, filter.ContainsURL, filter.Limit, FilterOrderAsc, ) rows, err := stmt.QueryContext(ctx, params...) diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index 4e1634ece..15bb769a2 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -180,7 +180,8 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) { to := types.TopologyToken{} // backpaginate 5 messages starting at the latest position. - paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, 5, true) + filter := &gomatrixserverlib.RoomEventFilter{Limit: 5} + paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true) if err != nil { t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err) } diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index a7df70248..993e2022b 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -59,7 +59,7 @@ type Events interface { SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) // SelectEarlyEvents returns the earliest events in the given room. SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error) - SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, preserveOrder bool) ([]types.StreamEvent, error) + SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool) ([]types.StreamEvent, error) UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error // DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely. DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error) diff --git a/syncapi/storage/tables/output_room_events_test.go b/syncapi/storage/tables/output_room_events_test.go index 7a81ffcd2..a143e5ecd 100644 --- a/syncapi/storage/tables/output_room_events_test.go +++ b/syncapi/storage/tables/output_room_events_test.go @@ -13,6 +13,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" ) func newOutputRoomEventsTable(t *testing.T, dbType test.DBType) (tables.Events, *sql.DB, func()) { @@ -61,7 +62,7 @@ func TestOutputRoomEventsTable(t *testing.T) { wantEventIDs := []string{ events[2].EventID(), events[0].EventID(), events[3].EventID(), events[1].EventID(), } - gotEvents, err := tab.SelectEvents(ctx, txn, wantEventIDs, true) + gotEvents, err := tab.SelectEvents(ctx, txn, wantEventIDs, nil, true) if err != nil { return fmt.Errorf("failed to SelectEvents: %s", err) } @@ -73,6 +74,28 @@ func TestOutputRoomEventsTable(t *testing.T) { return fmt.Errorf("SelectEvents\ngot %v\n want %v", gotEventIDs, wantEventIDs) } + // Test that contains_url is correctly populated + urlEv := room.CreateEvent(t, alice, "m.text", map[string]interface{}{ + "body": "test.txt", + "url": "mxc://test.txt", + }) + if _, err = tab.InsertEvent(ctx, txn, urlEv, nil, nil, nil, false); err != nil { + return fmt.Errorf("failed to InsertEvent: %s", err) + } + wantEventID := []string{urlEv.EventID()} + t := true + gotEvents, err = tab.SelectEvents(ctx, txn, wantEventID, &gomatrixserverlib.RoomEventFilter{Limit: 1, ContainsURL: &t}, true) + if err != nil { + return fmt.Errorf("failed to SelectEvents: %s", err) + } + gotEventIDs = make([]string, len(gotEvents)) + for i := range gotEvents { + gotEventIDs[i] = gotEvents[i].EventID() + } + if !reflect.DeepEqual(gotEventIDs, wantEventID) { + return fmt.Errorf("SelectEvents\ngot %v\n want %v", gotEventIDs, wantEventID) + } + return nil }) if err != nil { diff --git a/sytest-whitelist b/sytest-whitelist index a7aea05ed..f63b96f52 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -698,4 +698,5 @@ Ignore user in existing room Ignore invite in full sync Ignore invite in incremental sync A filtered timeline reaches its limit -A change to displayname should not result in a full state sync \ No newline at end of file +A change to displayname should not result in a full state sync +Can fetch images in room \ No newline at end of file From 1140f39993f1d4fb80952bf853bb05df0b42ca20 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 13 Apr 2022 12:35:30 +0100 Subject: [PATCH 09/21] Precompute values for `userIDSet` in sync notifier (#2348) * Precompute values for `userIDSet` in sync notifier * Mutexes * Fixes * Sensible initial value * Update syncapi/notifier/notifier.go Co-authored-by: Till <2353100+S7evinK@users.noreply.github.com> * Placate the almighty linter Co-authored-by: Till <2353100+S7evinK@users.noreply.github.com> --- syncapi/notifier/notifier.go | 73 +++++++++++++++++++++++++++--------- 1 file changed, 56 insertions(+), 17 deletions(-) diff --git a/syncapi/notifier/notifier.go b/syncapi/notifier/notifier.go index 443744b6f..82834239b 100644 --- a/syncapi/notifier/notifier.go +++ b/syncapi/notifier/notifier.go @@ -36,7 +36,7 @@ import ( type Notifier struct { lock *sync.RWMutex // A map of RoomID => Set : Must only be accessed by the OnNewEvent goroutine - roomIDToJoinedUsers map[string]userIDSet + roomIDToJoinedUsers map[string]*userIDSet // A map of RoomID => Set : Must only be accessed by the OnNewEvent goroutine roomIDToPeekingDevices map[string]peekingDeviceSet // The latest sync position @@ -54,7 +54,7 @@ type Notifier struct { // the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase). func NewNotifier() *Notifier { return &Notifier{ - roomIDToJoinedUsers: make(map[string]userIDSet), + roomIDToJoinedUsers: make(map[string]*userIDSet), roomIDToPeekingDevices: make(map[string]peekingDeviceSet), userDeviceStreams: make(map[string]map[string]*UserDeviceStream), lock: &sync.RWMutex{}, @@ -262,7 +262,7 @@ func (n *Notifier) SharedUsers(userID string) []string { func (n *Notifier) _sharedUsers(userID string) []string { n._sharedUserMap[userID] = struct{}{} for roomID, users := range n.roomIDToJoinedUsers { - if _, ok := users[userID]; !ok { + if ok := users.isIn(userID); !ok { continue } for _, userID := range n._joinedUsers(roomID) { @@ -282,8 +282,11 @@ func (n *Notifier) IsSharedUser(userA, userB string) bool { defer n.lock.RUnlock() var okA, okB bool for _, users := range n.roomIDToJoinedUsers { - _, okA = users[userA] - _, okB = users[userB] + okA = users.isIn(userA) + if !okA { + continue + } + okB = users.isIn(userB) if okA && okB { return true } @@ -345,11 +348,12 @@ func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) { // This is just the bulk form of addJoinedUser for roomID, userIDs := range roomIDToUserIDs { if _, ok := n.roomIDToJoinedUsers[roomID]; !ok { - n.roomIDToJoinedUsers[roomID] = make(userIDSet, len(userIDs)) + n.roomIDToJoinedUsers[roomID] = newUserIDSet(len(userIDs)) } for _, userID := range userIDs { n.roomIDToJoinedUsers[roomID].add(userID) } + n.roomIDToJoinedUsers[roomID].precompute() } } @@ -440,16 +444,18 @@ func (n *Notifier) _fetchUserStreams(userID string) []*UserDeviceStream { func (n *Notifier) _addJoinedUser(roomID, userID string) { if _, ok := n.roomIDToJoinedUsers[roomID]; !ok { - n.roomIDToJoinedUsers[roomID] = make(userIDSet) + n.roomIDToJoinedUsers[roomID] = newUserIDSet(8) } n.roomIDToJoinedUsers[roomID].add(userID) + n.roomIDToJoinedUsers[roomID].precompute() } func (n *Notifier) _removeJoinedUser(roomID, userID string) { if _, ok := n.roomIDToJoinedUsers[roomID]; !ok { - n.roomIDToJoinedUsers[roomID] = make(userIDSet) + n.roomIDToJoinedUsers[roomID] = newUserIDSet(8) } n.roomIDToJoinedUsers[roomID].remove(userID) + n.roomIDToJoinedUsers[roomID].precompute() } func (n *Notifier) JoinedUsers(roomID string) (userIDs []string) { @@ -521,19 +527,52 @@ func (n *Notifier) _removeEmptyUserStreams() { } // A string set, mainly existing for improving clarity of structs in this file. -type userIDSet map[string]struct{} - -func (s userIDSet) add(str string) { - s[str] = struct{}{} +type userIDSet struct { + sync.Mutex + set map[string]struct{} + precomputed []string } -func (s userIDSet) remove(str string) { - delete(s, str) +func newUserIDSet(cap int) *userIDSet { + return &userIDSet{ + set: make(map[string]struct{}, cap), + precomputed: nil, + } } -func (s userIDSet) values() (vals []string) { - vals = make([]string, 0, len(s)) - for str := range s { +func (s *userIDSet) add(str string) { + s.Lock() + defer s.Unlock() + s.set[str] = struct{}{} + s.precomputed = s.precomputed[:0] // invalidate cache +} + +func (s *userIDSet) remove(str string) { + s.Lock() + defer s.Unlock() + delete(s.set, str) + s.precomputed = s.precomputed[:0] // invalidate cache +} + +func (s *userIDSet) precompute() { + s.Lock() + defer s.Unlock() + s.precomputed = s.values() +} + +func (s *userIDSet) isIn(str string) bool { + s.Lock() + defer s.Unlock() + _, ok := s.set[str] + return ok +} + +func (s *userIDSet) values() (vals []string) { + if len(s.precomputed) > 0 { + return s.precomputed // only return if not invalidated + } + vals = make([]string, 0, len(s.set)) + for str := range s.set { vals = append(vals, str) } return From 3a5e9a0f284eef0fcb94a22035f9ffd2eb65eedf Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 13 Apr 2022 16:41:22 +0100 Subject: [PATCH 10/21] Use default sync filter if specified filter is not found (should fix #2350) (#2351) --- syncapi/sync/request.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/syncapi/sync/request.go b/syncapi/sync/request.go index 09a62e3dd..f04f172d3 100644 --- a/syncapi/sync/request.go +++ b/syncapi/sync/request.go @@ -15,6 +15,7 @@ package sync import ( + "database/sql" "encoding/json" "fmt" "net/http" @@ -60,10 +61,10 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") return nil, fmt.Errorf("gomatrixserverlib.SplitID: %w", err) } - if f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery); err != nil { + if f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery); err != nil && err != sql.ErrNoRows { util.GetLogger(req.Context()).WithError(err).Error("syncDB.GetFilter failed") return nil, fmt.Errorf("syncDB.GetFilter: %w", err) - } else { + } else if f != nil { filter = *f } } From 3ddbffd59ece5f74d951d6209882d9d954db4bc3 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Thu, 14 Apr 2022 14:32:48 +0200 Subject: [PATCH 11/21] Refactor media storage layer, add tests (#2352) * Refactor mediaapi storage layer * Verify filetype before trying to create thumbnails * Add media api storage tests * Fix returned values --- mediaapi/mediaapi.go | 2 +- mediaapi/routing/upload.go | 21 +++ mediaapi/routing/upload_test.go | 2 +- mediaapi/storage/interface.go | 8 ++ .../postgres/media_repository_table.go | 33 +++-- mediaapi/storage/postgres/mediaapi.go | 46 ++++++ mediaapi/storage/postgres/prepare.go | 38 ----- mediaapi/storage/postgres/sql.go | 36 ----- mediaapi/storage/postgres/thumbnail_table.go | 34 +++-- .../storage.go => shared/mediaapi.go} | 93 +++++------- .../storage/sqlite3/media_repository_table.go | 63 ++++---- .../storage/sqlite3/{sql.go => mediaapi.go} | 35 +++-- mediaapi/storage/sqlite3/prepare.go | 38 ----- mediaapi/storage/sqlite3/storage.go | 123 ---------------- mediaapi/storage/sqlite3/thumbnail_table.go | 63 ++++---- mediaapi/storage/storage.go | 8 +- mediaapi/storage/storage_test.go | 135 ++++++++++++++++++ mediaapi/storage/storage_wasm.go | 4 +- mediaapi/storage/tables/interface.go | 46 ++++++ mediaapi/types/types.go | 5 +- 20 files changed, 417 insertions(+), 416 deletions(-) create mode 100644 mediaapi/storage/postgres/mediaapi.go delete mode 100644 mediaapi/storage/postgres/prepare.go delete mode 100644 mediaapi/storage/postgres/sql.go rename mediaapi/storage/{postgres/storage.go => shared/mediaapi.go} (52%) rename mediaapi/storage/sqlite3/{sql.go => mediaapi.go} (51%) delete mode 100644 mediaapi/storage/sqlite3/prepare.go delete mode 100644 mediaapi/storage/sqlite3/storage.go create mode 100644 mediaapi/storage/storage_test.go create mode 100644 mediaapi/storage/tables/interface.go diff --git a/mediaapi/mediaapi.go b/mediaapi/mediaapi.go index c010981c0..e5daf480d 100644 --- a/mediaapi/mediaapi.go +++ b/mediaapi/mediaapi.go @@ -32,7 +32,7 @@ func AddPublicRoutes( userAPI userapi.UserInternalAPI, client *gomatrixserverlib.Client, ) { - mediaDB, err := storage.Open(&cfg.Database) + mediaDB, err := storage.NewMediaAPIDatasource(&cfg.Database) if err != nil { logrus.WithError(err).Panicf("failed to connect to media db") } diff --git a/mediaapi/routing/upload.go b/mediaapi/routing/upload.go index f762b2ff5..972c52af0 100644 --- a/mediaapi/routing/upload.go +++ b/mediaapi/routing/upload.go @@ -22,6 +22,7 @@ import ( "io" "net/http" "net/url" + "os" "path" "strings" @@ -311,6 +312,26 @@ func (r *uploadRequest) storeFileAndMetadata( } go func() { + file, err := os.Open(string(finalPath)) + if err != nil { + r.Logger.WithError(err).Error("unable to open file") + return + } + defer file.Close() // nolint: errcheck + // http.DetectContentType only needs 512 bytes + buf := make([]byte, 512) + _, err = file.Read(buf) + if err != nil { + r.Logger.WithError(err).Error("unable to read file") + return + } + // Check if we need to generate thumbnails + fileType := http.DetectContentType(buf) + if !strings.HasPrefix(fileType, "image") { + r.Logger.WithField("contentType", fileType).Debugf("uploaded file is not an image or can not be thumbnailed, not generating thumbnails") + return + } + busy, err := thumbnailer.GenerateThumbnails( context.Background(), finalPath, thumbnailSizes, r.MediaMetadata, activeThumbnailGeneration, maxThumbnailGenerators, db, r.Logger, diff --git a/mediaapi/routing/upload_test.go b/mediaapi/routing/upload_test.go index e81254f35..b2c2f5a44 100644 --- a/mediaapi/routing/upload_test.go +++ b/mediaapi/routing/upload_test.go @@ -51,7 +51,7 @@ func Test_uploadRequest_doUpload(t *testing.T) { _ = os.Mkdir(testdataPath, os.ModePerm) defer fileutils.RemoveDir(types.Path(testdataPath), nil) - db, err := storage.Open(&config.DatabaseOptions{ + db, err := storage.NewMediaAPIDatasource(&config.DatabaseOptions{ ConnectionString: "file::memory:?cache=shared", MaxOpenConnections: 100, MaxIdleConnections: 2, diff --git a/mediaapi/storage/interface.go b/mediaapi/storage/interface.go index 843199719..d083be1eb 100644 --- a/mediaapi/storage/interface.go +++ b/mediaapi/storage/interface.go @@ -22,9 +22,17 @@ import ( ) type Database interface { + MediaRepository + Thumbnails +} + +type MediaRepository interface { StoreMediaMetadata(ctx context.Context, mediaMetadata *types.MediaMetadata) error GetMediaMetadata(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) GetMediaMetadataByHash(ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) +} + +type Thumbnails interface { StoreThumbnail(ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata) error GetThumbnail(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, width, height int, resizeMethod string) (*types.ThumbnailMetadata, error) GetThumbnails(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) ([]*types.ThumbnailMetadata, error) diff --git a/mediaapi/storage/postgres/media_repository_table.go b/mediaapi/storage/postgres/media_repository_table.go index 1d3264ca9..41cee4878 100644 --- a/mediaapi/storage/postgres/media_repository_table.go +++ b/mediaapi/storage/postgres/media_repository_table.go @@ -20,6 +20,8 @@ import ( "database/sql" "time" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/mediaapi/storage/tables" "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -69,24 +71,25 @@ type mediaStatements struct { selectMediaByHashStmt *sql.Stmt } -func (s *mediaStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(mediaSchema) +func NewPostgresMediaRepositoryTable(db *sql.DB) (tables.MediaRepository, error) { + s := &mediaStatements{} + _, err := db.Exec(mediaSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, sqlutil.StatementList{ {&s.insertMediaStmt, insertMediaSQL}, {&s.selectMediaStmt, selectMediaSQL}, {&s.selectMediaByHashStmt, selectMediaByHashSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *mediaStatements) insertMedia( - ctx context.Context, mediaMetadata *types.MediaMetadata, +func (s *mediaStatements) InsertMedia( + ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata, ) error { - mediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000) - _, err := s.insertMediaStmt.ExecContext( + mediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now()) + _, err := sqlutil.TxStmtContext(ctx, txn, s.insertMediaStmt).ExecContext( ctx, mediaMetadata.MediaID, mediaMetadata.Origin, @@ -100,14 +103,14 @@ func (s *mediaStatements) insertMedia( return err } -func (s *mediaStatements) selectMedia( - ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, +func (s *mediaStatements) SelectMedia( + ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, ) (*types.MediaMetadata, error) { mediaMetadata := types.MediaMetadata{ MediaID: mediaID, Origin: mediaOrigin, } - err := s.selectMediaStmt.QueryRowContext( + err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaStmt).QueryRowContext( ctx, mediaMetadata.MediaID, mediaMetadata.Origin, ).Scan( &mediaMetadata.ContentType, @@ -120,14 +123,14 @@ func (s *mediaStatements) selectMedia( return &mediaMetadata, err } -func (s *mediaStatements) selectMediaByHash( - ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName, +func (s *mediaStatements) SelectMediaByHash( + ctx context.Context, txn *sql.Tx, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName, ) (*types.MediaMetadata, error) { mediaMetadata := types.MediaMetadata{ Base64Hash: mediaHash, Origin: mediaOrigin, } - err := s.selectMediaStmt.QueryRowContext( + err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaByHashStmt).QueryRowContext( ctx, mediaMetadata.Base64Hash, mediaMetadata.Origin, ).Scan( &mediaMetadata.ContentType, diff --git a/mediaapi/storage/postgres/mediaapi.go b/mediaapi/storage/postgres/mediaapi.go new file mode 100644 index 000000000..ea70e575b --- /dev/null +++ b/mediaapi/storage/postgres/mediaapi.go @@ -0,0 +1,46 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + // Import the postgres database driver. + _ "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/mediaapi/storage/shared" + "github.com/matrix-org/dendrite/setup/config" +) + +// NewDatabase opens a postgres database. +func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) { + db, err := sqlutil.Open(dbProperties) + if err != nil { + return nil, err + } + mediaRepo, err := NewPostgresMediaRepositoryTable(db) + if err != nil { + return nil, err + } + thumbnails, err := NewPostgresThumbnailsTable(db) + if err != nil { + return nil, err + } + return &shared.Database{ + MediaRepository: mediaRepo, + Thumbnails: thumbnails, + DB: db, + Writer: sqlutil.NewExclusiveWriter(), + }, nil +} diff --git a/mediaapi/storage/postgres/prepare.go b/mediaapi/storage/postgres/prepare.go deleted file mode 100644 index a2e01884e..000000000 --- a/mediaapi/storage/postgres/prepare.go +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// FIXME: This should be made internal! - -package postgres - -import ( - "database/sql" -) - -// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement. -type statementList []struct { - statement **sql.Stmt - sql string -} - -// prepare the SQL for each statement in the list and assign the result to the prepared statement. -func (s statementList) prepare(db *sql.DB) (err error) { - for _, statement := range s { - if *statement.statement, err = db.Prepare(statement.sql); err != nil { - return - } - } - return -} diff --git a/mediaapi/storage/postgres/sql.go b/mediaapi/storage/postgres/sql.go deleted file mode 100644 index 181cd15ff..000000000 --- a/mediaapi/storage/postgres/sql.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "database/sql" -) - -type statements struct { - media mediaStatements - thumbnail thumbnailStatements -} - -func (s *statements) prepare(db *sql.DB) (err error) { - if err = s.media.prepare(db); err != nil { - return - } - if err = s.thumbnail.prepare(db); err != nil { - return - } - - return -} diff --git a/mediaapi/storage/postgres/thumbnail_table.go b/mediaapi/storage/postgres/thumbnail_table.go index 3f28cdbbf..7e07b476e 100644 --- a/mediaapi/storage/postgres/thumbnail_table.go +++ b/mediaapi/storage/postgres/thumbnail_table.go @@ -21,6 +21,8 @@ import ( "time" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/mediaapi/storage/tables" "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -63,7 +65,7 @@ SELECT content_type, file_size_bytes, creation_ts FROM mediaapi_thumbnail WHERE // Note: this selects all thumbnails for a media_origin and media_id const selectThumbnailsSQL = ` -SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2 +SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2 ORDER BY creation_ts ASC ` type thumbnailStatements struct { @@ -72,24 +74,25 @@ type thumbnailStatements struct { selectThumbnailsStmt *sql.Stmt } -func (s *thumbnailStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(thumbnailSchema) +func NewPostgresThumbnailsTable(db *sql.DB) (tables.Thumbnails, error) { + s := &thumbnailStatements{} + _, err := db.Exec(thumbnailSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, sqlutil.StatementList{ {&s.insertThumbnailStmt, insertThumbnailSQL}, {&s.selectThumbnailStmt, selectThumbnailSQL}, {&s.selectThumbnailsStmt, selectThumbnailsSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *thumbnailStatements) insertThumbnail( - ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata, +func (s *thumbnailStatements) InsertThumbnail( + ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata, ) error { - thumbnailMetadata.MediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000) - _, err := s.insertThumbnailStmt.ExecContext( + thumbnailMetadata.MediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now()) + _, err := sqlutil.TxStmtContext(ctx, txn, s.insertThumbnailStmt).ExecContext( ctx, thumbnailMetadata.MediaMetadata.MediaID, thumbnailMetadata.MediaMetadata.Origin, @@ -103,8 +106,9 @@ func (s *thumbnailStatements) insertThumbnail( return err } -func (s *thumbnailStatements) selectThumbnail( +func (s *thumbnailStatements) SelectThumbnail( ctx context.Context, + txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, width, height int, @@ -121,7 +125,7 @@ func (s *thumbnailStatements) selectThumbnail( ResizeMethod: resizeMethod, }, } - err := s.selectThumbnailStmt.QueryRowContext( + err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailStmt).QueryRowContext( ctx, thumbnailMetadata.MediaMetadata.MediaID, thumbnailMetadata.MediaMetadata.Origin, @@ -136,10 +140,10 @@ func (s *thumbnailStatements) selectThumbnail( return &thumbnailMetadata, err } -func (s *thumbnailStatements) selectThumbnails( - ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, +func (s *thumbnailStatements) SelectThumbnails( + ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, ) ([]*types.ThumbnailMetadata, error) { - rows, err := s.selectThumbnailsStmt.QueryContext( + rows, err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailsStmt).QueryContext( ctx, mediaID, mediaOrigin, ) if err != nil { diff --git a/mediaapi/storage/postgres/storage.go b/mediaapi/storage/shared/mediaapi.go similarity index 52% rename from mediaapi/storage/postgres/storage.go rename to mediaapi/storage/shared/mediaapi.go index 61ad468fe..c8d9ad6ab 100644 --- a/mediaapi/storage/postgres/storage.go +++ b/mediaapi/storage/shared/mediaapi.go @@ -1,5 +1,4 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// Copyright 2022 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,54 +12,38 @@ // See the License for the specific language governing permissions and // limitations under the License. -package postgres +package shared import ( "context" "database/sql" - // Import the postgres database driver. - _ "github.com/lib/pq" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/mediaapi/storage/tables" "github.com/matrix-org/dendrite/mediaapi/types" - "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" ) -// Database is used to store metadata about a repository of media files. type Database struct { - statements statements - db *sql.DB -} - -// Open opens a postgres database. -func Open(dbProperties *config.DatabaseOptions) (*Database, error) { - var d Database - var err error - if d.db, err = sqlutil.Open(dbProperties); err != nil { - return nil, err - } - if err = d.statements.prepare(d.db); err != nil { - return nil, err - } - return &d, nil + DB *sql.DB + Writer sqlutil.Writer + MediaRepository tables.MediaRepository + Thumbnails tables.Thumbnails } // StoreMediaMetadata inserts the metadata about the uploaded media into the database. // Returns an error if the combination of MediaID and Origin are not unique in the table. -func (d *Database) StoreMediaMetadata( - ctx context.Context, mediaMetadata *types.MediaMetadata, -) error { - return d.statements.media.insertMedia(ctx, mediaMetadata) +func (d Database) StoreMediaMetadata(ctx context.Context, mediaMetadata *types.MediaMetadata) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.MediaRepository.InsertMedia(ctx, txn, mediaMetadata) + }) } // GetMediaMetadata returns metadata about media stored on this server. // The media could have been uploaded to this server or fetched from another server and cached here. // Returns nil metadata if there is no metadata associated with this media. -func (d *Database) GetMediaMetadata( - ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, -) (*types.MediaMetadata, error) { - mediaMetadata, err := d.statements.media.selectMedia(ctx, mediaID, mediaOrigin) +func (d Database) GetMediaMetadata(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) { + mediaMetadata, err := d.MediaRepository.SelectMedia(ctx, nil, mediaID, mediaOrigin) if err != nil && err == sql.ErrNoRows { return nil, nil } @@ -70,10 +53,8 @@ func (d *Database) GetMediaMetadata( // GetMediaMetadataByHash returns metadata about media stored on this server. // The media could have been uploaded to this server or fetched from another server and cached here. // Returns nil metadata if there is no metadata associated with this media. -func (d *Database) GetMediaMetadataByHash( - ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName, -) (*types.MediaMetadata, error) { - mediaMetadata, err := d.statements.media.selectMediaByHash(ctx, mediaHash, mediaOrigin) +func (d Database) GetMediaMetadataByHash(ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) { + mediaMetadata, err := d.MediaRepository.SelectMediaByHash(ctx, nil, mediaHash, mediaOrigin) if err != nil && err == sql.ErrNoRows { return nil, nil } @@ -82,40 +63,36 @@ func (d *Database) GetMediaMetadataByHash( // StoreThumbnail inserts the metadata about the thumbnail into the database. // Returns an error if the combination of MediaID and Origin are not unique in the table. -func (d *Database) StoreThumbnail( - ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata, -) error { - return d.statements.thumbnail.insertThumbnail(ctx, thumbnailMetadata) +func (d Database) StoreThumbnail(ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Thumbnails.InsertThumbnail(ctx, txn, thumbnailMetadata) + }) } // GetThumbnail returns metadata about a specific thumbnail. // The media could have been uploaded to this server or fetched from another server and cached here. // Returns nil metadata if there is no metadata associated with this thumbnail. -func (d *Database) GetThumbnail( - ctx context.Context, - mediaID types.MediaID, - mediaOrigin gomatrixserverlib.ServerName, - width, height int, - resizeMethod string, -) (*types.ThumbnailMetadata, error) { - thumbnailMetadata, err := d.statements.thumbnail.selectThumbnail( - ctx, mediaID, mediaOrigin, width, height, resizeMethod, - ) - if err != nil && err == sql.ErrNoRows { - return nil, nil +func (d Database) GetThumbnail(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, width, height int, resizeMethod string) (*types.ThumbnailMetadata, error) { + metadata, err := d.Thumbnails.SelectThumbnail(ctx, nil, mediaID, mediaOrigin, width, height, resizeMethod) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err } - return thumbnailMetadata, err + return metadata, err } // GetThumbnails returns metadata about all thumbnails for a specific media stored on this server. // The media could have been uploaded to this server or fetched from another server and cached here. // Returns nil metadata if there are no thumbnails associated with this media. -func (d *Database) GetThumbnails( - ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, -) ([]*types.ThumbnailMetadata, error) { - thumbnails, err := d.statements.thumbnail.selectThumbnails(ctx, mediaID, mediaOrigin) - if err != nil && err == sql.ErrNoRows { - return nil, nil +func (d Database) GetThumbnails(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) ([]*types.ThumbnailMetadata, error) { + metadatas, err := d.Thumbnails.SelectThumbnails(ctx, nil, mediaID, mediaOrigin) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err } - return thumbnails, err + return metadatas, err } diff --git a/mediaapi/storage/sqlite3/media_repository_table.go b/mediaapi/storage/sqlite3/media_repository_table.go index bcef609d8..78431967f 100644 --- a/mediaapi/storage/sqlite3/media_repository_table.go +++ b/mediaapi/storage/sqlite3/media_repository_table.go @@ -21,6 +21,7 @@ import ( "time" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/mediaapi/storage/tables" "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -66,57 +67,53 @@ SELECT content_type, file_size_bytes, creation_ts, upload_name, media_id, user_i type mediaStatements struct { db *sql.DB - writer sqlutil.Writer insertMediaStmt *sql.Stmt selectMediaStmt *sql.Stmt selectMediaByHashStmt *sql.Stmt } -func (s *mediaStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { - s.db = db - s.writer = writer - - _, err = db.Exec(mediaSchema) +func NewSQLiteMediaRepositoryTable(db *sql.DB) (tables.MediaRepository, error) { + s := &mediaStatements{ + db: db, + } + _, err := db.Exec(mediaSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, sqlutil.StatementList{ {&s.insertMediaStmt, insertMediaSQL}, {&s.selectMediaStmt, selectMediaSQL}, {&s.selectMediaByHashStmt, selectMediaByHashSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *mediaStatements) insertMedia( - ctx context.Context, mediaMetadata *types.MediaMetadata, +func (s *mediaStatements) InsertMedia( + ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata, ) error { - mediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000) - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertMediaStmt) - _, err := stmt.ExecContext( - ctx, - mediaMetadata.MediaID, - mediaMetadata.Origin, - mediaMetadata.ContentType, - mediaMetadata.FileSizeBytes, - mediaMetadata.CreationTimestamp, - mediaMetadata.UploadName, - mediaMetadata.Base64Hash, - mediaMetadata.UserID, - ) - return err - }) + mediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now()) + _, err := sqlutil.TxStmtContext(ctx, txn, s.insertMediaStmt).ExecContext( + ctx, + mediaMetadata.MediaID, + mediaMetadata.Origin, + mediaMetadata.ContentType, + mediaMetadata.FileSizeBytes, + mediaMetadata.CreationTimestamp, + mediaMetadata.UploadName, + mediaMetadata.Base64Hash, + mediaMetadata.UserID, + ) + return err } -func (s *mediaStatements) selectMedia( - ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, +func (s *mediaStatements) SelectMedia( + ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, ) (*types.MediaMetadata, error) { mediaMetadata := types.MediaMetadata{ MediaID: mediaID, Origin: mediaOrigin, } - err := s.selectMediaStmt.QueryRowContext( + err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaStmt).QueryRowContext( ctx, mediaMetadata.MediaID, mediaMetadata.Origin, ).Scan( &mediaMetadata.ContentType, @@ -129,14 +126,14 @@ func (s *mediaStatements) selectMedia( return &mediaMetadata, err } -func (s *mediaStatements) selectMediaByHash( - ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName, +func (s *mediaStatements) SelectMediaByHash( + ctx context.Context, txn *sql.Tx, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName, ) (*types.MediaMetadata, error) { mediaMetadata := types.MediaMetadata{ Base64Hash: mediaHash, Origin: mediaOrigin, } - err := s.selectMediaStmt.QueryRowContext( + err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaByHashStmt).QueryRowContext( ctx, mediaMetadata.Base64Hash, mediaMetadata.Origin, ).Scan( &mediaMetadata.ContentType, diff --git a/mediaapi/storage/sqlite3/sql.go b/mediaapi/storage/sqlite3/mediaapi.go similarity index 51% rename from mediaapi/storage/sqlite3/sql.go rename to mediaapi/storage/sqlite3/mediaapi.go index 245bd40cc..abf329367 100644 --- a/mediaapi/storage/sqlite3/sql.go +++ b/mediaapi/storage/sqlite3/mediaapi.go @@ -16,23 +16,30 @@ package sqlite3 import ( - "database/sql" - + // Import the postgres database driver. "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/mediaapi/storage/shared" + "github.com/matrix-org/dendrite/setup/config" ) -type statements struct { - media mediaStatements - thumbnail thumbnailStatements -} - -func (s *statements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { - if err = s.media.prepare(db, writer); err != nil { - return +// NewDatabase opens a SQLIte database. +func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) { + db, err := sqlutil.Open(dbProperties) + if err != nil { + return nil, err } - if err = s.thumbnail.prepare(db, writer); err != nil { - return + mediaRepo, err := NewSQLiteMediaRepositoryTable(db) + if err != nil { + return nil, err } - - return + thumbnails, err := NewSQLiteThumbnailsTable(db) + if err != nil { + return nil, err + } + return &shared.Database{ + MediaRepository: mediaRepo, + Thumbnails: thumbnails, + DB: db, + Writer: sqlutil.NewExclusiveWriter(), + }, nil } diff --git a/mediaapi/storage/sqlite3/prepare.go b/mediaapi/storage/sqlite3/prepare.go deleted file mode 100644 index 8fb3b56f3..000000000 --- a/mediaapi/storage/sqlite3/prepare.go +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// FIXME: This should be made internal! - -package sqlite3 - -import ( - "database/sql" -) - -// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement. -type statementList []struct { - statement **sql.Stmt - sql string -} - -// prepare the SQL for each statement in the list and assign the result to the prepared statement. -func (s statementList) prepare(db *sql.DB) (err error) { - for _, statement := range s { - if *statement.statement, err = db.Prepare(statement.sql); err != nil { - return - } - } - return -} diff --git a/mediaapi/storage/sqlite3/storage.go b/mediaapi/storage/sqlite3/storage.go deleted file mode 100644 index fa442173b..000000000 --- a/mediaapi/storage/sqlite3/storage.go +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "context" - "database/sql" - - // Import the postgres database driver. - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/mediaapi/types" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" -) - -// Database is used to store metadata about a repository of media files. -type Database struct { - statements statements - db *sql.DB - writer sqlutil.Writer -} - -// Open opens a postgres database. -func Open(dbProperties *config.DatabaseOptions) (*Database, error) { - d := Database{ - writer: sqlutil.NewExclusiveWriter(), - } - var err error - if d.db, err = sqlutil.Open(dbProperties); err != nil { - return nil, err - } - if err = d.statements.prepare(d.db, d.writer); err != nil { - return nil, err - } - return &d, nil -} - -// StoreMediaMetadata inserts the metadata about the uploaded media into the database. -// Returns an error if the combination of MediaID and Origin are not unique in the table. -func (d *Database) StoreMediaMetadata( - ctx context.Context, mediaMetadata *types.MediaMetadata, -) error { - return d.statements.media.insertMedia(ctx, mediaMetadata) -} - -// GetMediaMetadata returns metadata about media stored on this server. -// The media could have been uploaded to this server or fetched from another server and cached here. -// Returns nil metadata if there is no metadata associated with this media. -func (d *Database) GetMediaMetadata( - ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, -) (*types.MediaMetadata, error) { - mediaMetadata, err := d.statements.media.selectMedia(ctx, mediaID, mediaOrigin) - if err != nil && err == sql.ErrNoRows { - return nil, nil - } - return mediaMetadata, err -} - -// GetMediaMetadataByHash returns metadata about media stored on this server. -// The media could have been uploaded to this server or fetched from another server and cached here. -// Returns nil metadata if there is no metadata associated with this media. -func (d *Database) GetMediaMetadataByHash( - ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName, -) (*types.MediaMetadata, error) { - mediaMetadata, err := d.statements.media.selectMediaByHash(ctx, mediaHash, mediaOrigin) - if err != nil && err == sql.ErrNoRows { - return nil, nil - } - return mediaMetadata, err -} - -// StoreThumbnail inserts the metadata about the thumbnail into the database. -// Returns an error if the combination of MediaID and Origin are not unique in the table. -func (d *Database) StoreThumbnail( - ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata, -) error { - return d.statements.thumbnail.insertThumbnail(ctx, thumbnailMetadata) -} - -// GetThumbnail returns metadata about a specific thumbnail. -// The media could have been uploaded to this server or fetched from another server and cached here. -// Returns nil metadata if there is no metadata associated with this thumbnail. -func (d *Database) GetThumbnail( - ctx context.Context, - mediaID types.MediaID, - mediaOrigin gomatrixserverlib.ServerName, - width, height int, - resizeMethod string, -) (*types.ThumbnailMetadata, error) { - thumbnailMetadata, err := d.statements.thumbnail.selectThumbnail( - ctx, mediaID, mediaOrigin, width, height, resizeMethod, - ) - if err != nil && err == sql.ErrNoRows { - return nil, nil - } - return thumbnailMetadata, err -} - -// GetThumbnails returns metadata about all thumbnails for a specific media stored on this server. -// The media could have been uploaded to this server or fetched from another server and cached here. -// Returns nil metadata if there are no thumbnails associated with this media. -func (d *Database) GetThumbnails( - ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, -) ([]*types.ThumbnailMetadata, error) { - thumbnails, err := d.statements.thumbnail.selectThumbnails(ctx, mediaID, mediaOrigin) - if err != nil && err == sql.ErrNoRows { - return nil, nil - } - return thumbnails, err -} diff --git a/mediaapi/storage/sqlite3/thumbnail_table.go b/mediaapi/storage/sqlite3/thumbnail_table.go index 06b056b6e..5ff2fece0 100644 --- a/mediaapi/storage/sqlite3/thumbnail_table.go +++ b/mediaapi/storage/sqlite3/thumbnail_table.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/mediaapi/storage/tables" "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -54,55 +55,48 @@ SELECT content_type, file_size_bytes, creation_ts FROM mediaapi_thumbnail WHERE // Note: this selects all thumbnails for a media_origin and media_id const selectThumbnailsSQL = ` -SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2 +SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2 ORDER BY creation_ts ASC ` type thumbnailStatements struct { - db *sql.DB - writer sqlutil.Writer insertThumbnailStmt *sql.Stmt selectThumbnailStmt *sql.Stmt selectThumbnailsStmt *sql.Stmt } -func (s *thumbnailStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { - _, err = db.Exec(thumbnailSchema) +func NewSQLiteThumbnailsTable(db *sql.DB) (tables.Thumbnails, error) { + s := &thumbnailStatements{} + _, err := db.Exec(thumbnailSchema) if err != nil { - return + return nil, err } - s.db = db - s.writer = writer - return statementList{ + return s, sqlutil.StatementList{ {&s.insertThumbnailStmt, insertThumbnailSQL}, {&s.selectThumbnailStmt, selectThumbnailSQL}, {&s.selectThumbnailsStmt, selectThumbnailsSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *thumbnailStatements) insertThumbnail( - ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata, -) error { - thumbnailMetadata.MediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000) - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertThumbnailStmt) - _, err := stmt.ExecContext( - ctx, - thumbnailMetadata.MediaMetadata.MediaID, - thumbnailMetadata.MediaMetadata.Origin, - thumbnailMetadata.MediaMetadata.ContentType, - thumbnailMetadata.MediaMetadata.FileSizeBytes, - thumbnailMetadata.MediaMetadata.CreationTimestamp, - thumbnailMetadata.ThumbnailSize.Width, - thumbnailMetadata.ThumbnailSize.Height, - thumbnailMetadata.ThumbnailSize.ResizeMethod, - ) - return err - }) +func (s *thumbnailStatements) InsertThumbnail(ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata) error { + thumbnailMetadata.MediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now()) + _, err := sqlutil.TxStmtContext(ctx, txn, s.insertThumbnailStmt).ExecContext( + ctx, + thumbnailMetadata.MediaMetadata.MediaID, + thumbnailMetadata.MediaMetadata.Origin, + thumbnailMetadata.MediaMetadata.ContentType, + thumbnailMetadata.MediaMetadata.FileSizeBytes, + thumbnailMetadata.MediaMetadata.CreationTimestamp, + thumbnailMetadata.ThumbnailSize.Width, + thumbnailMetadata.ThumbnailSize.Height, + thumbnailMetadata.ThumbnailSize.ResizeMethod, + ) + return err } -func (s *thumbnailStatements) selectThumbnail( +func (s *thumbnailStatements) SelectThumbnail( ctx context.Context, + txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, width, height int, @@ -119,7 +113,7 @@ func (s *thumbnailStatements) selectThumbnail( ResizeMethod: resizeMethod, }, } - err := s.selectThumbnailStmt.QueryRowContext( + err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailStmt).QueryRowContext( ctx, thumbnailMetadata.MediaMetadata.MediaID, thumbnailMetadata.MediaMetadata.Origin, @@ -134,10 +128,11 @@ func (s *thumbnailStatements) selectThumbnail( return &thumbnailMetadata, err } -func (s *thumbnailStatements) selectThumbnails( - ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, +func (s *thumbnailStatements) SelectThumbnails( + ctx context.Context, txn *sql.Tx, mediaID types.MediaID, + mediaOrigin gomatrixserverlib.ServerName, ) ([]*types.ThumbnailMetadata, error) { - rows, err := s.selectThumbnailsStmt.QueryContext( + rows, err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailsStmt).QueryContext( ctx, mediaID, mediaOrigin, ) if err != nil { diff --git a/mediaapi/storage/storage.go b/mediaapi/storage/storage.go index 56059f1c8..baa242e57 100644 --- a/mediaapi/storage/storage.go +++ b/mediaapi/storage/storage.go @@ -25,13 +25,13 @@ import ( "github.com/matrix-org/dendrite/setup/config" ) -// Open opens a postgres database. -func Open(dbProperties *config.DatabaseOptions) (Database, error) { +// NewMediaAPIDatasource opens a database connection. +func NewMediaAPIDatasource(dbProperties *config.DatabaseOptions) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.Open(dbProperties) + return sqlite3.NewDatabase(dbProperties) case dbProperties.ConnectionString.IsPostgres(): - return postgres.Open(dbProperties) + return postgres.NewDatabase(dbProperties) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/mediaapi/storage/storage_test.go b/mediaapi/storage/storage_test.go new file mode 100644 index 000000000..8d3403045 --- /dev/null +++ b/mediaapi/storage/storage_test.go @@ -0,0 +1,135 @@ +package storage_test + +import ( + "context" + "reflect" + "testing" + + "github.com/matrix-org/dendrite/mediaapi/storage" + "github.com/matrix-org/dendrite/mediaapi/types" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" +) + +func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := storage.NewMediaAPIDatasource(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }) + if err != nil { + t.Fatalf("NewSyncServerDatasource returned %s", err) + } + return db, close +} +func TestMediaRepository(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateDatabase(t, dbType) + defer close() + ctx := context.Background() + t.Run("can insert media & query media", func(t *testing.T) { + metadata := &types.MediaMetadata{ + MediaID: "testing", + Origin: "localhost", + ContentType: "image/png", + FileSizeBytes: 10, + UploadName: "upload test", + Base64Hash: "dGVzdGluZw==", + UserID: "@alice:localhost", + } + if err := db.StoreMediaMetadata(ctx, metadata); err != nil { + t.Fatalf("unable to store media metadata: %v", err) + } + // query by media id + gotMetadata, err := db.GetMediaMetadata(ctx, metadata.MediaID, metadata.Origin) + if err != nil { + t.Fatalf("unable to query media metadata: %v", err) + } + if !reflect.DeepEqual(metadata, gotMetadata) { + t.Fatalf("expected metadata %+v, got %v", metadata, gotMetadata) + } + // query by media hash + gotMetadata, err = db.GetMediaMetadataByHash(ctx, metadata.Base64Hash, metadata.Origin) + if err != nil { + t.Fatalf("unable to query media metadata by hash: %v", err) + } + if !reflect.DeepEqual(metadata, gotMetadata) { + t.Fatalf("expected metadata %+v, got %v", metadata, gotMetadata) + } + }) + }) +} + +func TestThumbnailsStorage(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateDatabase(t, dbType) + defer close() + ctx := context.Background() + t.Run("can insert thumbnails & query media", func(t *testing.T) { + thumbnails := []*types.ThumbnailMetadata{ + { + MediaMetadata: &types.MediaMetadata{ + MediaID: "testing", + Origin: "localhost", + ContentType: "image/png", + FileSizeBytes: 6, + }, + ThumbnailSize: types.ThumbnailSize{ + Width: 5, + Height: 5, + ResizeMethod: types.Crop, + }, + }, + { + MediaMetadata: &types.MediaMetadata{ + MediaID: "testing", + Origin: "localhost", + ContentType: "image/png", + FileSizeBytes: 7, + }, + ThumbnailSize: types.ThumbnailSize{ + Width: 1, + Height: 1, + ResizeMethod: types.Scale, + }, + }, + } + for i := range thumbnails { + if err := db.StoreThumbnail(ctx, thumbnails[i]); err != nil { + t.Fatalf("unable to store thumbnail metadata: %v", err) + } + } + // query by single thumbnail + gotMetadata, err := db.GetThumbnail(ctx, + thumbnails[0].MediaMetadata.MediaID, + thumbnails[0].MediaMetadata.Origin, + thumbnails[0].ThumbnailSize.Width, thumbnails[0].ThumbnailSize.Height, + thumbnails[0].ThumbnailSize.ResizeMethod, + ) + if err != nil { + t.Fatalf("unable to query thumbnail metadata: %v", err) + } + if !reflect.DeepEqual(thumbnails[0].MediaMetadata, gotMetadata.MediaMetadata) { + t.Fatalf("expected metadata %+v, got %+v", thumbnails[0].MediaMetadata, gotMetadata.MediaMetadata) + } + if !reflect.DeepEqual(thumbnails[0].ThumbnailSize, gotMetadata.ThumbnailSize) { + t.Fatalf("expected metadata %+v, got %+v", thumbnails[0].MediaMetadata, gotMetadata.MediaMetadata) + } + // query by all thumbnails + gotMediadatas, err := db.GetThumbnails(ctx, thumbnails[0].MediaMetadata.MediaID, thumbnails[0].MediaMetadata.Origin) + if err != nil { + t.Fatalf("unable to query media metadata by hash: %v", err) + } + if len(gotMediadatas) != len(thumbnails) { + t.Fatalf("expected %d stored thumbnail metadata, got %d", len(thumbnails), len(gotMediadatas)) + } + for i := range gotMediadatas { + if !reflect.DeepEqual(thumbnails[i].MediaMetadata, gotMediadatas[i].MediaMetadata) { + t.Fatalf("expected metadata %+v, got %v", thumbnails[i].MediaMetadata, gotMediadatas[i].MediaMetadata) + } + if !reflect.DeepEqual(thumbnails[i].ThumbnailSize, gotMediadatas[i].ThumbnailSize) { + t.Fatalf("expected metadata %+v, got %v", thumbnails[i].ThumbnailSize, gotMediadatas[i].ThumbnailSize) + } + } + }) + }) +} diff --git a/mediaapi/storage/storage_wasm.go b/mediaapi/storage/storage_wasm.go index a6e997b2a..f67f9d5e1 100644 --- a/mediaapi/storage/storage_wasm.go +++ b/mediaapi/storage/storage_wasm.go @@ -22,10 +22,10 @@ import ( ) // Open opens a postgres database. -func Open(dbProperties *config.DatabaseOptions) (Database, error) { +func NewMediaAPIDatasource(dbProperties *config.DatabaseOptions) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.Open(dbProperties) + return sqlite3.NewDatabase(dbProperties) case dbProperties.ConnectionString.IsPostgres(): return nil, fmt.Errorf("can't use Postgres implementation") default: diff --git a/mediaapi/storage/tables/interface.go b/mediaapi/storage/tables/interface.go new file mode 100644 index 000000000..bf63bc6ab --- /dev/null +++ b/mediaapi/storage/tables/interface.go @@ -0,0 +1,46 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tables + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/mediaapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +type Thumbnails interface { + InsertThumbnail(ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata) error + SelectThumbnail( + ctx context.Context, txn *sql.Tx, + mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, + width, height int, + resizeMethod string, + ) (*types.ThumbnailMetadata, error) + SelectThumbnails( + ctx context.Context, txn *sql.Tx, mediaID types.MediaID, + mediaOrigin gomatrixserverlib.ServerName, + ) ([]*types.ThumbnailMetadata, error) +} + +type MediaRepository interface { + InsertMedia(ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata) error + SelectMedia(ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) + SelectMediaByHash( + ctx context.Context, txn *sql.Tx, + mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName, + ) (*types.MediaMetadata, error) +} diff --git a/mediaapi/types/types.go b/mediaapi/types/types.go index 0ba7010ad..ab28b3410 100644 --- a/mediaapi/types/types.go +++ b/mediaapi/types/types.go @@ -45,16 +45,13 @@ type RequestMethod string // MatrixUserID is a Matrix user ID string in the form @user:domain e.g. @alice:matrix.org type MatrixUserID string -// UnixMs is the milliseconds since the Unix epoch -type UnixMs int64 - // MediaMetadata is metadata associated with a media file type MediaMetadata struct { MediaID MediaID Origin gomatrixserverlib.ServerName ContentType ContentType FileSizeBytes FileSizeBytes - CreationTimestamp UnixMs + CreationTimestamp gomatrixserverlib.Timestamp UploadName Filename Base64Hash Base64Hash UserID MatrixUserID From 57e3622b85fd4d80d9826404135f09e91ed47973 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 19 Apr 2022 10:46:45 +0200 Subject: [PATCH 12/21] Implement lazy loading on `/sync` (#2346) * Initial work on lazyloading * Partially implement lazy loading on /sync * Rename methods * Make missing tests pass * Preallocate slice, even if it will end up with fewer values * Let the cache handle the user mapping * Linter * Cap cache growth --- internal/caching/cache_lazy_load_members.go | 86 +++++++++++++++++ syncapi/streams/stream_pdu.go | 100 +++++++++++++++++++- syncapi/streams/streams.go | 4 +- syncapi/syncapi.go | 6 +- sytest-whitelist | 11 ++- 5 files changed, 200 insertions(+), 7 deletions(-) create mode 100644 internal/caching/cache_lazy_load_members.go diff --git a/internal/caching/cache_lazy_load_members.go b/internal/caching/cache_lazy_load_members.go new file mode 100644 index 000000000..71a317624 --- /dev/null +++ b/internal/caching/cache_lazy_load_members.go @@ -0,0 +1,86 @@ +package caching + +import ( + "fmt" + "time" + + userapi "github.com/matrix-org/dendrite/userapi/api" +) + +const ( + LazyLoadCacheName = "lazy_load_members" + LazyLoadCacheMaxEntries = 128 + LazyLoadCacheMaxUserEntries = 128 + LazyLoadCacheMutable = true + LazyLoadCacheMaxAge = time.Minute * 30 +) + +type LazyLoadCache struct { + // InMemoryLRUCachePartition containing other InMemoryLRUCachePartitions + // with the actual cached members + userCaches *InMemoryLRUCachePartition +} + +// NewLazyLoadCache creates a new LazyLoadCache. +func NewLazyLoadCache() (*LazyLoadCache, error) { + cache, err := NewInMemoryLRUCachePartition( + LazyLoadCacheName, + LazyLoadCacheMutable, + LazyLoadCacheMaxEntries, + LazyLoadCacheMaxAge, + true, + ) + if err != nil { + return nil, err + } + go cacheCleaner(cache) + return &LazyLoadCache{ + userCaches: cache, + }, nil +} + +func (c *LazyLoadCache) lazyLoadCacheForUser(device *userapi.Device) (*InMemoryLRUCachePartition, error) { + cacheName := fmt.Sprintf("%s/%s", device.UserID, device.ID) + userCache, ok := c.userCaches.Get(cacheName) + if ok && userCache != nil { + if cache, ok := userCache.(*InMemoryLRUCachePartition); ok { + return cache, nil + } + } + cache, err := NewInMemoryLRUCachePartition( + LazyLoadCacheName, + LazyLoadCacheMutable, + LazyLoadCacheMaxUserEntries, + LazyLoadCacheMaxAge, + false, + ) + if err != nil { + return nil, err + } + c.userCaches.Set(cacheName, cache) + go cacheCleaner(cache) + return cache, nil +} + +func (c *LazyLoadCache) StoreLazyLoadedUser(device *userapi.Device, roomID, userID, eventID string) { + cache, err := c.lazyLoadCacheForUser(device) + if err != nil { + return + } + cacheKey := fmt.Sprintf("%s/%s/%s/%s", device.UserID, device.ID, roomID, userID) + cache.Set(cacheKey, eventID) +} + +func (c *LazyLoadCache) IsLazyLoadedUserCached(device *userapi.Device, roomID, userID string) (string, bool) { + cache, err := c.lazyLoadCacheForUser(device) + if err != nil { + return "", false + } + + cacheKey := fmt.Sprintf("%s/%s/%s/%s", device.UserID, device.ID, roomID, userID) + val, ok := cache.Get(cacheKey) + if !ok { + return "", ok + } + return val.(string), ok +} diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index bcaf6ca31..ddc2f55c2 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -6,6 +6,7 @@ import ( "sync" "time" + "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -26,7 +27,8 @@ type PDUStreamProvider struct { tasks chan func() workers atomic.Int32 - userAPI userapi.UserInternalAPI + // userID+deviceID -> lazy loading cache + lazyLoadCache *caching.LazyLoadCache } func (p *PDUStreamProvider) worker() { @@ -188,7 +190,7 @@ func (p *PDUStreamProvider) IncrementalSync( newPos = from for _, delta := range stateDeltas { var pos types.StreamPosition - if pos, err = p.addRoomDeltaToResponse(ctx, req.Device, r, delta, &eventFilter, req.Response); err != nil { + if pos, err = p.addRoomDeltaToResponse(ctx, req.Device, r, delta, &eventFilter, &stateFilter, req.Response); err != nil { req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed") return to } @@ -209,6 +211,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( r types.Range, delta types.StateDelta, eventFilter *gomatrixserverlib.RoomEventFilter, + stateFilter *gomatrixserverlib.StateFilter, res *types.Response, ) (types.StreamPosition, error) { if delta.MembershipPos > 0 && delta.Membership == gomatrixserverlib.Leave { @@ -247,7 +250,8 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( // room that were returned. latestPosition := r.To updateLatestPosition := func(mostRecentEventID string) { - if _, pos, err := p.DB.PositionInTopology(ctx, mostRecentEventID); err == nil { + var pos types.StreamPosition + if _, pos, err = p.DB.PositionInTopology(ctx, mostRecentEventID); err == nil { switch { case r.Backwards && pos > latestPosition: fallthrough @@ -263,6 +267,19 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID()) } + if stateFilter.LazyLoadMembers { + if err != nil { + return r.From, err + } + delta.StateEvents, err = p.lazyLoadMembers( + ctx, delta.RoomID, true, limited, stateFilter.IncludeRedundantMembers, + device, recentEvents, delta.StateEvents, + ) + if err != nil { + return r.From, err + } + } + hasMembershipChange := false for _, recentEvent := range recentStreamEvents { if recentEvent.Type() == gomatrixserverlib.MRoomMember && recentEvent.StateKey() != nil { @@ -402,6 +419,20 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( // "Can sync a room with a message with a transaction id" - which does a complete sync to check. recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents) stateEvents = removeDuplicates(stateEvents, recentEvents) + + if stateFilter.LazyLoadMembers { + if err != nil { + return nil, err + } + stateEvents, err = p.lazyLoadMembers(ctx, roomID, + false, limited, stateFilter.IncludeRedundantMembers, + device, recentEvents, stateEvents, + ) + if err != nil { + return nil, err + } + } + jr = types.NewJoinResponse() jr.Summary.JoinedMemberCount = &joinedCount jr.Summary.InvitedMemberCount = &invitedCount @@ -412,6 +443,69 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( return jr, nil } +func (p *PDUStreamProvider) lazyLoadMembers( + ctx context.Context, roomID string, + incremental, limited, includeRedundant bool, + device *userapi.Device, + timelineEvents, stateEvents []*gomatrixserverlib.HeaderedEvent, +) ([]*gomatrixserverlib.HeaderedEvent, error) { + if len(timelineEvents) == 0 { + return stateEvents, nil + } + // Work out which memberships to include + timelineUsers := make(map[string]struct{}) + if !incremental { + timelineUsers[device.UserID] = struct{}{} + } + // Add all users the client doesn't know about yet to a list + for _, event := range timelineEvents { + // Membership is not yet cached, add it to the list + if _, ok := p.lazyLoadCache.IsLazyLoadedUserCached(device, roomID, event.Sender()); !ok { + timelineUsers[event.Sender()] = struct{}{} + } + } + // Preallocate with the same amount, even if it will end up with fewer values + newStateEvents := make([]*gomatrixserverlib.HeaderedEvent, 0, len(stateEvents)) + // Remove existing membership events we don't care about, e.g. users not in the timeline.events + for _, event := range stateEvents { + if event.Type() == gomatrixserverlib.MRoomMember && event.StateKey() != nil { + // If this is a gapped incremental sync, we still want this membership + isGappedIncremental := limited && incremental + // We want this users membership event, keep it in the list + _, ok := timelineUsers[event.Sender()] + wantMembership := ok || isGappedIncremental + if wantMembership { + newStateEvents = append(newStateEvents, event) + if !includeRedundant { + p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, event.Sender(), event.EventID()) + } + delete(timelineUsers, event.Sender()) + } + } else { + newStateEvents = append(newStateEvents, event) + } + } + wantUsers := make([]string, 0, len(timelineUsers)) + for userID := range timelineUsers { + wantUsers = append(wantUsers, userID) + } + // Query missing membership events + memberships, err := p.DB.GetStateEventsForRoom(ctx, roomID, &gomatrixserverlib.StateFilter{ + Limit: 100, + Senders: &wantUsers, + Types: &[]string{gomatrixserverlib.MRoomMember}, + }) + if err != nil { + return stateEvents, err + } + // cache the membership events + for _, membership := range memberships { + p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, membership.Sender(), membership.EventID()) + } + stateEvents = append(newStateEvents, memberships...) + return stateEvents, nil +} + // addIgnoredUsersToFilter adds ignored users to the eventfilter and // the syncreq itself for further use in streams. func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, req *types.SyncRequest, eventFilter *gomatrixserverlib.RoomEventFilter) error { diff --git a/syncapi/streams/streams.go b/syncapi/streams/streams.go index c7d06a296..d3195b78f 100644 --- a/syncapi/streams/streams.go +++ b/syncapi/streams/streams.go @@ -27,12 +27,12 @@ type Streams struct { func NewSyncStreamProviders( d storage.Database, userAPI userapi.UserInternalAPI, rsAPI rsapi.RoomserverInternalAPI, keyAPI keyapi.KeyInternalAPI, - eduCache *caching.EDUCache, notifier *notifier.Notifier, + eduCache *caching.EDUCache, lazyLoadCache *caching.LazyLoadCache, notifier *notifier.Notifier, ) *Streams { streams := &Streams{ PDUStreamProvider: &PDUStreamProvider{ StreamProvider: StreamProvider{DB: d}, - userAPI: userAPI, + lazyLoadCache: lazyLoadCache, }, TypingStreamProvider: &TypingStreamProvider{ StreamProvider: StreamProvider{DB: d}, diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 384121a8a..2f9165d91 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -57,8 +57,12 @@ func AddPublicRoutes( } eduCache := caching.NewTypingCache() + lazyLoadCache, err := caching.NewLazyLoadCache() + if err != nil { + logrus.WithError(err).Panicf("failed to create lazy loading cache") + } notifier := notifier.NewNotifier() - streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, keyAPI, eduCache, notifier) + streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, keyAPI, eduCache, lazyLoadCache, notifier) notifier.SetCurrentPosition(streams.Latest(context.Background())) if err = notifier.Load(context.Background(), syncDB); err != nil { logrus.WithError(err).Panicf("failed to load notifier ") diff --git a/sytest-whitelist b/sytest-whitelist index f63b96f52..c8dedd59c 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -699,4 +699,13 @@ Ignore invite in full sync Ignore invite in incremental sync A filtered timeline reaches its limit A change to displayname should not result in a full state sync -Can fetch images in room \ No newline at end of file +Can fetch images in room +The only membership state included in an initial sync is for all the senders in the timeline +The only membership state included in an incremental sync is for senders in the timeline +Old members are included in gappy incr LL sync if they start speaking +We do send redundant membership state across incremental syncs if asked +Rejecting invite over federation doesn't break incremental /sync +Gapped incremental syncs include all state changes +Old leaves are present in gapped incremental syncs +Leaves are present in non-gapped incremental syncs +Members from the gap are included in gappy incr LL sync \ No newline at end of file From abf71649b0a27ddc564105f94db270696698ecd5 Mon Sep 17 00:00:00 2001 From: fcwoknhenuxdfiyv-nextcloud <84577563+fcwoknhenuxdfiyv@users.noreply.github.com> Date: Tue, 19 Apr 2022 10:46:54 +0200 Subject: [PATCH 13/21] Make sure resp.Username is defined before hashing. Fixes #2356 (#2357) Co-authored-by: Jason Quigley --- clientapi/routing/voip.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clientapi/routing/voip.go b/clientapi/routing/voip.go index 13dca7ac0..c7ddaabcf 100644 --- a/clientapi/routing/voip.go +++ b/clientapi/routing/voip.go @@ -52,6 +52,7 @@ func RequestTurnServer(req *http.Request, device *api.Device, cfg *config.Client if turnConfig.SharedSecret != "" { expiry := time.Now().Add(duration).Unix() + resp.Username = fmt.Sprintf("%d:%s", expiry, device.UserID) mac := hmac.New(sha1.New, []byte(turnConfig.SharedSecret)) _, err := mac.Write([]byte(resp.Username)) @@ -60,7 +61,6 @@ func RequestTurnServer(req *http.Request, device *api.Device, cfg *config.Client return jsonerror.InternalServerError() } - resp.Username = fmt.Sprintf("%d:%s", expiry, device.UserID) resp.Password = base64.StdEncoding.EncodeToString(mac.Sum(nil)) } else if turnConfig.Username != "" && turnConfig.Password != "" { resp.Username = turnConfig.Username From 7e745665a47058209f7f1fcf51c50ff91c43c7c4 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 19 Apr 2022 09:51:02 +0100 Subject: [PATCH 14/21] Change `pushkey_ts` to be seconds (fix #2354) (#2358) --- internal/pushgateway/pushgateway.go | 12 +++++------- userapi/api/api.go | 20 ++++++++++---------- userapi/internal/api.go | 2 +- userapi/storage/postgres/pusher_table.go | 3 +-- userapi/storage/sqlite3/pusher_table.go | 3 +-- userapi/storage/tables/interface.go | 3 +-- 6 files changed, 19 insertions(+), 24 deletions(-) diff --git a/internal/pushgateway/pushgateway.go b/internal/pushgateway/pushgateway.go index 88c326eb2..1817a040b 100644 --- a/internal/pushgateway/pushgateway.go +++ b/internal/pushgateway/pushgateway.go @@ -3,8 +3,6 @@ package pushgateway import ( "context" "encoding/json" - - "github.com/matrix-org/gomatrixserverlib" ) // A Client is how interactions with a Push Gateway is done. @@ -47,11 +45,11 @@ type Counts struct { } type Device struct { - AppID string `json:"app_id"` // Required - Data map[string]interface{} `json:"data"` // Required. UNSPEC: Sytests require this to allow unknown keys. - PushKey string `json:"pushkey"` // Required - PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"` - Tweaks map[string]interface{} `json:"tweaks,omitempty"` + AppID string `json:"app_id"` // Required + Data map[string]interface{} `json:"data"` // Required. UNSPEC: Sytests require this to allow unknown keys. + PushKey string `json:"pushkey"` // Required + PushKeyTS int64 `json:"pushkey_ts,omitempty"` + Tweaks map[string]interface{} `json:"tweaks,omitempty"` } type Prio string diff --git a/userapi/api/api.go b/userapi/api/api.go index b86774d14..6aa6a6842 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -492,16 +492,16 @@ type PerformPusherDeletionRequest struct { // Pusher represents a push notification subscriber type Pusher struct { - SessionID int64 `json:"session_id,omitempty"` - PushKey string `json:"pushkey"` - PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"` - Kind PusherKind `json:"kind"` - AppID string `json:"app_id"` - AppDisplayName string `json:"app_display_name"` - DeviceDisplayName string `json:"device_display_name"` - ProfileTag string `json:"profile_tag"` - Language string `json:"lang"` - Data map[string]interface{} `json:"data"` + SessionID int64 `json:"session_id,omitempty"` + PushKey string `json:"pushkey"` + PushKeyTS int64 `json:"pushkey_ts,omitempty"` + Kind PusherKind `json:"kind"` + AppID string `json:"app_id"` + AppDisplayName string `json:"app_display_name"` + DeviceDisplayName string `json:"device_display_name"` + ProfileTag string `json:"profile_tag"` + Language string `json:"lang"` + Data map[string]interface{} `json:"data"` } type PusherKind string diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 206c6f7de..6a16ea686 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -653,7 +653,7 @@ func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.Perform return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart) } if req.Pusher.PushKeyTS == 0 { - req.Pusher.PushKeyTS = gomatrixserverlib.AsTimestamp(time.Now()) + req.Pusher.PushKeyTS = int64(time.Now().Second()) } return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart) } diff --git a/userapi/storage/postgres/pusher_table.go b/userapi/storage/postgres/pusher_table.go index 670dc916f..2eb379ae4 100644 --- a/userapi/storage/postgres/pusher_table.go +++ b/userapi/storage/postgres/pusher_table.go @@ -23,7 +23,6 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" ) @@ -95,7 +94,7 @@ type pushersStatements struct { // Returns nil error success. func (s *pushersStatements) InsertPusher( ctx context.Context, txn *sql.Tx, session_id int64, - pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, + pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, ) error { _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) logrus.Debugf("Created pusher %d", session_id) diff --git a/userapi/storage/sqlite3/pusher_table.go b/userapi/storage/sqlite3/pusher_table.go index e718792e1..d5bd1617b 100644 --- a/userapi/storage/sqlite3/pusher_table.go +++ b/userapi/storage/sqlite3/pusher_table.go @@ -23,7 +23,6 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" ) @@ -95,7 +94,7 @@ type pushersStatements struct { // Returns nil error success. func (s *pushersStatements) InsertPusher( ctx context.Context, txn *sql.Tx, session_id int64, - pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, + pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, ) error { _, err := s.insertPusherStmt.ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) logrus.Debugf("Created pusher %d", session_id) diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 99c907b85..eb0cae314 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -21,7 +21,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" ) type AccountDataTable interface { @@ -96,7 +95,7 @@ type ThreePIDTable interface { } type PusherTable interface { - InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error + InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error) DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error From 711e377b9cb066c42c58f31b155aca194ae9928a Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 19 Apr 2022 10:34:33 +0100 Subject: [PATCH 15/21] Update `go-sqlite3-js` to matrix-org/go-sqlite3-js#2 (SQLite 3.36.0) --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 070f2b5b9..e3ac6a171 100644 --- a/go.mod +++ b/go.mod @@ -36,7 +36,7 @@ require ( github.com/libp2p/go-libp2p-record v0.1.3 github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 - github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d + github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 github.com/matrix-org/gomatrixserverlib v0.0.0-20220408160933-cf558306b56f github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48 diff --git a/go.sum b/go.sum index 464f774be..e8b0a216b 100644 --- a/go.sum +++ b/go.sum @@ -1107,8 +1107,8 @@ github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e h1:DP5RC0Z3XdyBE github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e/go.mod h1:NgPCr+UavRGH6n5jmdX8DuqFZ4JiCWIJoZiuhTRLSUg= github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 h1:eqE5OnGx9ZMWmrRbD3KF/3KtTunw0iQulI7YxOIdxo4= github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4/go.mod h1:3WluEZ9QXSwU30tWYqktnpC1x9mwZKx1r8uAv8Iq+a4= -github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d h1:mGhPVaTht5NViFN/UpdrIlRApmH2FWcVaKUH5MdBKiY= -github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= +github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw2QV3YD/fRrzEDPNGgTlJlvXY0EHHnT87wF3OA= +github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= From 85b1631ecfbeaa382c4bd17e7e310e668f5400dd Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 19 Apr 2022 10:48:32 +0100 Subject: [PATCH 16/21] Add newly passing test to list --- sytest-whitelist | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sytest-whitelist b/sytest-whitelist index c8dedd59c..aef18c512 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -708,4 +708,5 @@ Rejecting invite over federation doesn't break incremental /sync Gapped incremental syncs include all state changes Old leaves are present in gapped incremental syncs Leaves are present in non-gapped incremental syncs -Members from the gap are included in gappy incr LL sync \ No newline at end of file +Members from the gap are included in gappy incr LL sync +Presence can be set from sync \ No newline at end of file From 073972646fd38b6987c2b62f8db5500c2f87d2d0 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Tue, 19 Apr 2022 13:57:02 +0100 Subject: [PATCH 17/21] Use unix not second --- userapi/internal/api.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 6a16ea686..d1c12f05f 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -653,7 +653,7 @@ func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.Perform return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart) } if req.Pusher.PushKeyTS == 0 { - req.Pusher.PushKeyTS = int64(time.Now().Second()) + req.Pusher.PushKeyTS = int64(time.Now().Unix()) } return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart) } From a9f0a390c6e6efed9c09c72839960f57b5f3e5f2 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 20 Apr 2022 15:13:04 +0100 Subject: [PATCH 18/21] Update to NATS Server v2.8.0 and nats.go v1.14.0 (#2359) --- go.mod | 6 +++--- go.sum | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index e3ac6a171..a9774c2c8 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,8 @@ module github.com/matrix-org/dendrite -replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.7.5-0.20220311134712-e2e4a244f30e +replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.8.1-0.20220419100629-2278c94774f9 -replace github.com/nats-io/nats.go => github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c +replace github.com/nats-io/nats.go => github.com/neilalexander/nats.go v1.13.1-0.20220419101051-b262d9f0be1e require ( github.com/Arceliar/ironwood v0.0.0-20211125050254-8951369625d0 @@ -43,7 +43,7 @@ require ( github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.10 github.com/nats-io/nats-server/v2 v2.7.4-0.20220309205833-773636c1c5bb - github.com/nats-io/nats.go v1.13.1-0.20220308171302-2f2f6968e98d + github.com/nats-io/nats.go v1.14.0 github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 github.com/ngrok/sqlmw v0.0.0-20211220175533-9d16fdc47b31 diff --git a/go.sum b/go.sum index e8b0a216b..ee4096fdc 100644 --- a/go.sum +++ b/go.sum @@ -1266,8 +1266,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8m github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= -github.com/nats-io/jwt/v2 v2.2.1-0.20220113022732-58e87895b296 h1:vU9tpM3apjYlLLeY23zRWJ9Zktr5jp+mloR942LEOpY= -github.com/nats-io/jwt/v2 v2.2.1-0.20220113022732-58e87895b296/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= +github.com/nats-io/jwt/v2 v2.2.1-0.20220330180145-442af02fd36a h1:lem6QCvxR0Y28gth9P+wV2K/zYUUAkJ+55U8cpS0p5I= +github.com/nats-io/jwt/v2 v2.2.1-0.20220330180145-442af02fd36a/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8= github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= @@ -1276,10 +1276,10 @@ github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uY github.com/ncw/swift v1.0.47/go.mod h1:23YIA4yWVnGwv2dQlN4bB7egfYX6YLn0Yo/S6zZO/ZM= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= -github.com/neilalexander/nats-server/v2 v2.7.5-0.20220311134712-e2e4a244f30e h1:5tEHLzvDeS6IeqO2o9FFhsE3V2erYj8FlMt2J91wzsk= -github.com/neilalexander/nats-server/v2 v2.7.5-0.20220311134712-e2e4a244f30e/go.mod h1:1vZ2Nijh8tcyNe8BDVyTviCd9NYzRbubQYiEHsvOQWc= -github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c h1:G2qsv7D0rY94HAu8pXmElMluuMHQ85waxIDQBhIzV2Q= -github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w= +github.com/neilalexander/nats-server/v2 v2.8.1-0.20220419100629-2278c94774f9 h1:VGU5HYAwy8LRbSkrT+kCHvujVmwK8Aa/vc1O+eReTbM= +github.com/neilalexander/nats-server/v2 v2.8.1-0.20220419100629-2278c94774f9/go.mod h1:5vic7C58BFEVltiZhs7Kq81q2WcEPhJPsmNv1FOrdv0= +github.com/neilalexander/nats.go v1.13.1-0.20220419101051-b262d9f0be1e h1:kNIzIzj2OvnlreA+sTJ12nWJzTP3OSLNKDL/Iq9mF6Y= +github.com/neilalexander/nats.go v1.13.1-0.20220419101051-b262d9f0be1e/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w= github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 h1:lrVQzBtkeQEGGYUHwSX1XPe1E5GL6U3KYCNe2G4bncQ= github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= From bb987cd64b118044a4f3351c377516813514ee19 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 20 Apr 2022 16:06:46 +0100 Subject: [PATCH 19/21] Lazy loading fixes (#2362) * Return some more usefully wrapped errors when doing sync * Remove unnecessary error check * Couple of guards around `sql.ErrNoRows` * Nolint --- syncapi/streams/stream_pdu.go | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index ddc2f55c2..0f11d55f6 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -3,6 +3,7 @@ package streams import ( "context" "database/sql" + "fmt" "sync" "time" @@ -205,6 +206,7 @@ func (p *PDUStreamProvider) IncrementalSync( return newPos } +// nolint:gocyclo func (p *PDUStreamProvider) addRoomDeltaToResponse( ctx context.Context, device *userapi.Device, @@ -228,13 +230,16 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( eventFilter, true, true, ) if err != nil { - return r.From, err + if err == sql.ErrNoRows { + return r.To, nil + } + return r.From, fmt.Errorf("p.DB.RecentEvents: %w", err) } recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents) delta.StateEvents = removeDuplicates(delta.StateEvents, recentEvents) // roll back prevBatch, err := p.DB.GetBackwardTopologyPos(ctx, recentStreamEvents) if err != nil { - return r.From, err + return r.From, fmt.Errorf("p.DB.GetBackwardTopologyPos: %w", err) } // If we didn't return any events at all then don't bother doing anything else. @@ -268,15 +273,12 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } if stateFilter.LazyLoadMembers { - if err != nil { - return r.From, err - } delta.StateEvents, err = p.lazyLoadMembers( ctx, delta.RoomID, true, limited, stateFilter.IncludeRedundantMembers, device, recentEvents, delta.StateEvents, ) - if err != nil { - return r.From, err + if err != nil && err != sql.ErrNoRows { + return r.From, fmt.Errorf("p.lazyLoadMembers: %w", err) } } From 54e7ea41c688a0bcc89150c99045b2dcdf8d0b12 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 20 Apr 2022 16:51:37 +0100 Subject: [PATCH 20/21] Eliminate more SQL no row errors in sync API (#2363) * Handle `sql.ErrNoRows` in main `/sync` codepaths * Catch more --- syncapi/storage/shared/syncserver.go | 34 ++++++++++++++++++++++++++-- syncapi/streams/stream_pdu.go | 7 ++++-- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 91eba44e1..2143fd672 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -688,6 +688,9 @@ func (d *Database) GetStateDeltas( // user has ever interacted with — joined to, kicked/banned from, left. memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID) if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } return nil, nil, err } @@ -705,17 +708,23 @@ func (d *Database) GetStateDeltas( // get all the state events ever (i.e. for all available rooms) between these two positions stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs) if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } return nil, nil, err } state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap) if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } return nil, nil, err } // find out which rooms this user is peeking, if any. // We do this before joins so any peeks get overwritten peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r) - if err != nil { + if err != nil && err != sql.ErrNoRows { return nil, nil, err } @@ -726,6 +735,9 @@ func (d *Database) GetStateDeltas( var s []types.StreamEvent s, err = d.currentStateStreamEventsForRoom(ctx, txn, peek.RoomID, stateFilter) if err != nil { + if err == sql.ErrNoRows { + continue + } return nil, nil, err } state[peek.RoomID] = s @@ -753,6 +765,9 @@ func (d *Database) GetStateDeltas( var s []types.StreamEvent s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilter) if err != nil { + if err == sql.ErrNoRows { + continue + } return nil, nil, err } state[roomID] = s @@ -803,6 +818,9 @@ func (d *Database) GetStateDeltasForFullStateSync( // user has ever interacted with — joined to, kicked/banned from, left. memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID) if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } return nil, nil, err } @@ -819,7 +837,7 @@ func (d *Database) GetStateDeltasForFullStateSync( deltas := make(map[string]types.StateDelta) peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r) - if err != nil { + if err != nil && err != sql.ErrNoRows { return nil, nil, err } @@ -828,6 +846,9 @@ func (d *Database) GetStateDeltasForFullStateSync( if !peek.Deleted { s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, peek.RoomID, stateFilter) if stateErr != nil { + if stateErr == sql.ErrNoRows { + continue + } return nil, nil, stateErr } deltas[peek.RoomID] = types.StateDelta{ @@ -841,10 +862,16 @@ func (d *Database) GetStateDeltasForFullStateSync( // Get all the state events ever between these two positions stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs) if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } return nil, nil, err } state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap) if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } return nil, nil, err } @@ -869,6 +896,9 @@ func (d *Database) GetStateDeltasForFullStateSync( for _, joinedRoomID := range joinedRoomIDs { s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, joinedRoomID, stateFilter) if stateErr != nil { + if stateErr == sql.ErrNoRows { + continue + } return nil, nil, stateErr } deltas[joinedRoomID] = types.StateDelta{ diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 0f11d55f6..df5fb8e08 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -341,12 +341,16 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( wantFullState bool, device *userapi.Device, ) (jr *types.JoinResponse, err error) { + jr = types.NewJoinResponse() // TODO: When filters are added, we may need to call this multiple times to get enough events. // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 recentStreamEvents, limited, err := p.DB.RecentEvents( ctx, roomID, r, eventFilter, true, true, ) if err != nil { + if err == sql.ErrNoRows { + return jr, nil + } return } @@ -430,12 +434,11 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( false, limited, stateFilter.IncludeRedundantMembers, device, recentEvents, stateEvents, ) - if err != nil { + if err != nil && err != sql.ErrNoRows { return nil, err } } - jr = types.NewJoinResponse() jr.Summary.JoinedMemberCount = &joinedCount jr.Summary.InvitedMemberCount = &invitedCount jr.Timeline.PrevBatch = prevBatch From 2258387d393426922344213948c6d814de53f465 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 20 Apr 2022 16:55:24 +0100 Subject: [PATCH 21/21] Update test list --- sytest-whitelist | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sytest-whitelist b/sytest-whitelist index aef18c512..979f12bf6 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -312,10 +312,10 @@ Inbound federation can return events Inbound federation can return missing events for world_readable visibility Inbound federation can return missing events for invite visibility Inbound federation can get public room list -POST /rooms/:room_id/redact/:event_id as power user redacts message -POST /rooms/:room_id/redact/:event_id as original message sender redacts message -POST /rooms/:room_id/redact/:event_id as random user does not redact message -POST /redact disallows redaction of event in different room +PUT /rooms/:room_id/redact/:event_id/:txn_id as power user redacts message +PUT /rooms/:room_id/redact/:event_id/:txn_id as original message sender redacts message +PUT /rooms/:room_id/redact/:event_id/:txn_id as random user does not redact message +PUT /redact disallows redaction of event in different room An event which redacts itself should be ignored A pair of events which redact each other should be ignored Redaction of a redaction redacts the redaction reason