From 37e81e3c0652b7bb46ef144077c5be9d098d4e03 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Wed, 13 May 2020 17:34:34 +0100 Subject: [PATCH] Factor out account data --- .../storage/postgres/account_data_table.go | 24 ++++--- syncapi/storage/postgres/syncserver.go | 48 ++++--------- syncapi/storage/shared/syncserver.go | 38 ++++++++++- syncapi/storage/sqlite3/account_data_table.go | 23 ++++--- syncapi/storage/sqlite3/syncserver.go | 68 ++++--------------- syncapi/storage/tables/interface.go | 6 ++ 6 files changed, 91 insertions(+), 116 deletions(-) diff --git a/syncapi/storage/postgres/account_data_table.go b/syncapi/storage/postgres/account_data_table.go index d1e3b527f..58fb21983 100644 --- a/syncapi/storage/postgres/account_data_table.go +++ b/syncapi/storage/postgres/account_data_table.go @@ -21,6 +21,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -70,32 +71,33 @@ type accountDataStatements struct { selectMaxAccountDataIDStmt *sql.Stmt } -func (s *accountDataStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(accountDataSchema) +func NewPostgresAccountDataTable(db *sql.DB) (tables.AccountData, error) { + s := &accountDataStatements{} + _, err := db.Exec(accountDataSchema) if err != nil { - return + return nil, err } if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil { - return + return nil, err } if s.selectAccountDataInRangeStmt, err = db.Prepare(selectAccountDataInRangeSQL); err != nil { - return + return nil, err } if s.selectMaxAccountDataIDStmt, err = db.Prepare(selectMaxAccountDataIDSQL); err != nil { - return + return nil, err } - return + return s, nil } -func (s *accountDataStatements) insertAccountData( - ctx context.Context, +func (s *accountDataStatements) InsertAccountData( + ctx context.Context, txn *sql.Tx, userID, roomID, dataType string, ) (pos types.StreamPosition, err error) { err = s.insertAccountDataStmt.QueryRowContext(ctx, userID, roomID, dataType).Scan(&pos) return } -func (s *accountDataStatements) selectAccountDataInRange( +func (s *accountDataStatements) SelectAccountDataInRange( ctx context.Context, userID string, oldPos, newPos types.StreamPosition, @@ -137,7 +139,7 @@ func (s *accountDataStatements) selectAccountDataInRange( return data, rows.Err() } -func (s *accountDataStatements) selectMaxAccountDataID( +func (s *accountDataStatements) SelectMaxAccountDataID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { var nullableID sql.NullInt64 diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index 9883c3629..a8f31f55c 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -50,15 +50,14 @@ type stateDelta struct { // SyncServerDatasource represents a sync server datasource which manages // both the database for PDUs and caches for EDUs. type SyncServerDatasource struct { + shared.Database db *sql.DB common.PartitionOffsetStatements - accountData accountDataStatements events outputRoomEventsStatements roomstate currentRoomStateStatements eduCache *cache.EDUCache topology outputRoomEventsTopologyStatements backwardExtremities tables.BackwardsExtremities - shared *shared.Database } // NewSyncServerDatasource creates a new sync server database @@ -71,7 +70,8 @@ func NewSyncServerDatasource(dbDataSourceName string, dbProperties common.DbProp if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil { return nil, err } - if err = d.accountData.prepare(d.db); err != nil { + accountData, err := NewPostgresAccountDataTable(d.db) + if err != nil { return nil, err } if err = d.events.prepare(d.db); err != nil { @@ -92,9 +92,10 @@ func NewSyncServerDatasource(dbDataSourceName string, dbProperties common.DbProp return nil, err } d.eduCache = cache.New() - d.shared = &shared.Database{ - DB: d.db, - Invites: invites, + d.Database = shared.Database{ + DB: d.db, + Invites: invites, + AccountData: accountData, } return &d, nil } @@ -339,14 +340,14 @@ func (d *SyncServerDatasource) syncStreamPositionTx( if err != nil { return 0, err } - maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn) + maxAccountDataID, err := d.Database.AccountData.SelectMaxAccountDataID(ctx, txn) if err != nil { return 0, err } if maxAccountDataID > maxID { maxID = maxAccountDataID } - maxInviteID, err := d.shared.Invites.SelectMaxInviteID(ctx, txn) + maxInviteID, err := d.Database.Invites.SelectMaxInviteID(ctx, txn) if err != nil { return 0, err } @@ -364,14 +365,14 @@ func (d *SyncServerDatasource) syncPositionTx( if err != nil { return sp, err } - maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn) + maxAccountDataID, err := d.Database.AccountData.SelectMaxAccountDataID(ctx, txn) if err != nil { return sp, err } if maxAccountDataID > maxEventID { maxEventID = maxAccountDataID } - maxInviteID, err := d.shared.Invites.SelectMaxInviteID(ctx, txn) + maxInviteID, err := d.Database.Invites.SelectMaxInviteID(ctx, txn) if err != nil { return sp, err } @@ -653,31 +654,6 @@ var txReadOnlySnapshot = sql.TxOptions{ ReadOnly: true, } -func (d *SyncServerDatasource) GetAccountDataInRange( - ctx context.Context, userID string, oldPos, newPos types.StreamPosition, - accountDataFilterPart *gomatrixserverlib.EventFilter, -) (map[string][]string, error) { - return d.accountData.selectAccountDataInRange(ctx, userID, oldPos, newPos, accountDataFilterPart) -} - -func (d *SyncServerDatasource) UpsertAccountData( - ctx context.Context, userID, roomID, dataType string, -) (types.StreamPosition, error) { - return d.accountData.insertAccountData(ctx, userID, roomID, dataType) -} - -func (d *SyncServerDatasource) AddInviteEvent( - ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent, -) (sp types.StreamPosition, err error) { - return d.shared.AddInviteEvent(ctx, inviteEvent) -} - -func (d *SyncServerDatasource) RetireInviteEvent( - ctx context.Context, inviteEventID string, -) error { - return d.shared.RetireInviteEvent(ctx, inviteEventID) -} - func (d *SyncServerDatasource) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) { d.eduCache.SetTimeoutCallback(fn) } @@ -700,7 +676,7 @@ func (d *SyncServerDatasource) addInvitesToResponse( fromPos, toPos types.StreamPosition, res *types.Response, ) error { - invites, err := d.shared.Invites.SelectInviteEventsInRange( + invites, err := d.Database.Invites.SelectInviteEventsInRange( ctx, txn, userID, fromPos, toPos, ) if err != nil { diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index e89976df6..5ebce6859 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -13,10 +13,14 @@ import ( // Database is a temporary struct until we have made syncserver.go the same for both pq/sqlite // For now this contains the shared functions type Database struct { - DB *sql.DB - Invites tables.Invites + DB *sql.DB + Invites tables.Invites + AccountData tables.AccountData } +// AddInviteEvent stores a new invite event for a user. +// If the invite was successfully stored this returns the stream ID it was stored at. +// Returns an error if there was a problem communicating with the database. func (d *Database) AddInviteEvent( ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent, ) (sp types.StreamPosition, err error) { @@ -27,6 +31,8 @@ func (d *Database) AddInviteEvent( return } +// RetireInviteEvent removes an old invite event from the database. +// Returns an error if there was a problem communicating with the database. func (d *Database) RetireInviteEvent( ctx context.Context, inviteEventID string, ) error { @@ -35,3 +41,31 @@ func (d *Database) RetireInviteEvent( err := d.Invites.DeleteInviteEvent(ctx, inviteEventID) return err } + +// GetAccountDataInRange returns all account data for a given user inserted or +// updated between two given positions +// Returns a map following the format data[roomID] = []dataTypes +// If no data is retrieved, returns an empty map +// If there was an issue with the retrieval, returns an error +func (d *Database) GetAccountDataInRange( + ctx context.Context, userID string, oldPos, newPos types.StreamPosition, + accountDataFilterPart *gomatrixserverlib.EventFilter, +) (map[string][]string, error) { + return d.AccountData.SelectAccountDataInRange(ctx, userID, oldPos, newPos, accountDataFilterPart) +} + +// UpsertAccountData keeps track of new or updated account data, by saving the type +// of the new/updated data, and the user ID and room ID the data is related to (empty) +// room ID means the data isn't specific to any room) +// If no data with the given type, user ID and room ID exists in the database, +// creates a new row, else update the existing one +// Returns an error if there was an issue with the upsert +func (d *Database) UpsertAccountData( + ctx context.Context, userID, roomID, dataType string, +) (sp types.StreamPosition, err error) { + err = common.WithTransaction(d.DB, func(txn *sql.Tx) error { + sp, err = d.AccountData.InsertAccountData(ctx, txn, userID, roomID, dataType) + return err + }) + return +} diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index 3dbf961b4..6c1ccb824 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -55,25 +56,25 @@ type accountDataStatements struct { selectAccountDataInRangeStmt *sql.Stmt } -func (s *accountDataStatements) prepare(db *sql.DB, streamID *streamIDStatements) (err error) { - s.streamIDStatements = streamID - _, err = db.Exec(accountDataSchema) +func NewSqliteAccountDataTable(db *sql.DB) (tables.AccountData, error) { + s := &accountDataStatements{} + _, err := db.Exec(accountDataSchema) if err != nil { - return + return nil, err } if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil { - return + return nil, err } if s.selectMaxAccountDataIDStmt, err = db.Prepare(selectMaxAccountDataIDSQL); err != nil { - return + return nil, err } if s.selectAccountDataInRangeStmt, err = db.Prepare(selectAccountDataInRangeSQL); err != nil { - return + return nil, err } - return + return s, nil } -func (s *accountDataStatements) insertAccountData( +func (s *accountDataStatements) InsertAccountData( ctx context.Context, txn *sql.Tx, userID, roomID, dataType string, ) (pos types.StreamPosition, err error) { @@ -85,7 +86,7 @@ func (s *accountDataStatements) insertAccountData( return } -func (s *accountDataStatements) selectAccountDataInRange( +func (s *accountDataStatements) SelectAccountDataInRange( ctx context.Context, userID string, oldPos, newPos types.StreamPosition, @@ -146,7 +147,7 @@ func (s *accountDataStatements) selectAccountDataInRange( return data, nil } -func (s *accountDataStatements) selectMaxAccountDataID( +func (s *accountDataStatements) SelectMaxAccountDataID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { var nullableID sql.NullInt64 diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index a2253dcd1..93da1e56d 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -53,16 +53,15 @@ type stateDelta struct { // SyncServerDatasource represents a sync server datasource which manages // both the database for PDUs and caches for EDUs. type SyncServerDatasource struct { + shared.Database db *sql.DB common.PartitionOffsetStatements streamID streamIDStatements - accountData accountDataStatements events outputRoomEventsStatements roomstate currentRoomStateStatements eduCache *cache.EDUCache topology outputRoomEventsTopologyStatements backwardExtremities tables.BackwardsExtremities - shared *shared.Database } // NewSyncServerDatasource creates a new sync server database @@ -98,7 +97,8 @@ func (d *SyncServerDatasource) prepare() (err error) { if err = d.streamID.prepare(d.db); err != nil { return err } - if err = d.accountData.prepare(d.db, &d.streamID); err != nil { + accountData, err := NewSqliteAccountDataTable(d.db) + if err != nil { return err } if err = d.events.prepare(d.db, &d.streamID); err != nil { @@ -118,9 +118,10 @@ func (d *SyncServerDatasource) prepare() (err error) { if err != nil { return err } - d.shared = &shared.Database{ - DB: d.db, - Invites: invites, + d.Database = shared.Database{ + DB: d.db, + Invites: invites, + AccountData: accountData, } return nil } @@ -403,14 +404,14 @@ func (d *SyncServerDatasource) syncStreamPositionTx( if err != nil { return 0, err } - maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn) + maxAccountDataID, err := d.Database.AccountData.SelectMaxAccountDataID(ctx, txn) if err != nil { return 0, err } if maxAccountDataID > maxID { maxID = maxAccountDataID } - maxInviteID, err := d.shared.Invites.SelectMaxInviteID(ctx, txn) + maxInviteID, err := d.Database.Invites.SelectMaxInviteID(ctx, txn) if err != nil { return 0, err } @@ -428,14 +429,14 @@ func (d *SyncServerDatasource) syncPositionTx( if err != nil { return nil, err } - maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn) + maxAccountDataID, err := d.Database.AccountData.SelectMaxAccountDataID(ctx, txn) if err != nil { return nil, err } if maxAccountDataID > maxEventID { maxEventID = maxAccountDataID } - maxInviteID, err := d.shared.Invites.SelectMaxInviteID(ctx, txn) + maxInviteID, err := d.Database.Invites.SelectMaxInviteID(ctx, txn) if err != nil { return nil, err } @@ -729,51 +730,6 @@ var txReadOnlySnapshot = sql.TxOptions{ ReadOnly: true, } -// GetAccountDataInRange returns all account data for a given user inserted or -// updated between two given positions -// Returns a map following the format data[roomID] = []dataTypes -// If no data is retrieved, returns an empty map -// If there was an issue with the retrieval, returns an error -func (d *SyncServerDatasource) GetAccountDataInRange( - ctx context.Context, userID string, oldPos, newPos types.StreamPosition, - accountDataFilterPart *gomatrixserverlib.EventFilter, -) (map[string][]string, error) { - return d.accountData.selectAccountDataInRange(ctx, userID, oldPos, newPos, accountDataFilterPart) -} - -// UpsertAccountData keeps track of new or updated account data, by saving the type -// of the new/updated data, and the user ID and room ID the data is related to (empty) -// room ID means the data isn't specific to any room) -// If no data with the given type, user ID and room ID exists in the database, -// creates a new row, else update the existing one -// Returns an error if there was an issue with the upsert -func (d *SyncServerDatasource) UpsertAccountData( - ctx context.Context, userID, roomID, dataType string, -) (sp types.StreamPosition, err error) { - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - sp, err = d.accountData.insertAccountData(ctx, txn, userID, roomID, dataType) - return err - }) - return -} - -// AddInviteEvent stores a new invite event for a user. -// If the invite was successfully stored this returns the stream ID it was stored at. -// Returns an error if there was a problem communicating with the database. -func (d *SyncServerDatasource) AddInviteEvent( - ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent, -) (sp types.StreamPosition, err error) { - return d.shared.AddInviteEvent(ctx, inviteEvent) -} - -// RetireInviteEvent removes an old invite event from the database. -// Returns an error if there was a problem communicating with the database. -func (d *SyncServerDatasource) RetireInviteEvent( - ctx context.Context, inviteEventID string, -) error { - return d.shared.RetireInviteEvent(ctx, inviteEventID) -} - func (d *SyncServerDatasource) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) { d.eduCache.SetTimeoutCallback(fn) } @@ -800,7 +756,7 @@ func (d *SyncServerDatasource) addInvitesToResponse( fromPos, toPos types.StreamPosition, res *types.Response, ) error { - invites, err := d.shared.Invites.SelectInviteEventsInRange( + invites, err := d.Database.Invites.SelectInviteEventsInRange( ctx, txn, userID, fromPos, toPos, ) if err != nil { diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 1a9940524..d9145cc2d 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -8,6 +8,12 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) +type AccountData interface { + InsertAccountData(ctx context.Context, txn *sql.Tx, userID, roomID, dataType string) (pos types.StreamPosition, err error) + SelectAccountDataInRange(ctx context.Context, userID string, oldPos, newPos types.StreamPosition, accountDataEventFilter *gomatrixserverlib.EventFilter) (data map[string][]string, err error) + SelectMaxAccountDataID(ctx context.Context, txn *sql.Tx) (id int64, err error) +} + type Invites interface { InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent) (streamPos types.StreamPosition, err error) DeleteInviteEvent(ctx context.Context, inviteEventID string) error