Factor out account data

This commit is contained in:
Kegan Dougal 2020-05-13 17:34:34 +01:00
parent a25d477cdb
commit 37e81e3c06
6 changed files with 91 additions and 116 deletions

View file

@ -21,6 +21,7 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -70,32 +71,33 @@ type accountDataStatements struct {
selectMaxAccountDataIDStmt *sql.Stmt selectMaxAccountDataIDStmt *sql.Stmt
} }
func (s *accountDataStatements) prepare(db *sql.DB) (err error) { func NewPostgresAccountDataTable(db *sql.DB) (tables.AccountData, error) {
_, err = db.Exec(accountDataSchema) s := &accountDataStatements{}
_, err := db.Exec(accountDataSchema)
if err != nil { if err != nil {
return return nil, err
} }
if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil { if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil {
return return nil, err
} }
if s.selectAccountDataInRangeStmt, err = db.Prepare(selectAccountDataInRangeSQL); err != nil { if s.selectAccountDataInRangeStmt, err = db.Prepare(selectAccountDataInRangeSQL); err != nil {
return return nil, err
} }
if s.selectMaxAccountDataIDStmt, err = db.Prepare(selectMaxAccountDataIDSQL); err != nil { if s.selectMaxAccountDataIDStmt, err = db.Prepare(selectMaxAccountDataIDSQL); err != nil {
return return nil, err
} }
return return s, nil
} }
func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) InsertAccountData(
ctx context.Context, ctx context.Context, txn *sql.Tx,
userID, roomID, dataType string, userID, roomID, dataType string,
) (pos types.StreamPosition, err error) { ) (pos types.StreamPosition, err error) {
err = s.insertAccountDataStmt.QueryRowContext(ctx, userID, roomID, dataType).Scan(&pos) err = s.insertAccountDataStmt.QueryRowContext(ctx, userID, roomID, dataType).Scan(&pos)
return return
} }
func (s *accountDataStatements) selectAccountDataInRange( func (s *accountDataStatements) SelectAccountDataInRange(
ctx context.Context, ctx context.Context,
userID string, userID string,
oldPos, newPos types.StreamPosition, oldPos, newPos types.StreamPosition,
@ -137,7 +139,7 @@ func (s *accountDataStatements) selectAccountDataInRange(
return data, rows.Err() return data, rows.Err()
} }
func (s *accountDataStatements) selectMaxAccountDataID( func (s *accountDataStatements) SelectMaxAccountDataID(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
) (id int64, err error) { ) (id int64, err error) {
var nullableID sql.NullInt64 var nullableID sql.NullInt64

View file

@ -50,15 +50,14 @@ type stateDelta struct {
// SyncServerDatasource represents a sync server datasource which manages // SyncServerDatasource represents a sync server datasource which manages
// both the database for PDUs and caches for EDUs. // both the database for PDUs and caches for EDUs.
type SyncServerDatasource struct { type SyncServerDatasource struct {
shared.Database
db *sql.DB db *sql.DB
common.PartitionOffsetStatements common.PartitionOffsetStatements
accountData accountDataStatements
events outputRoomEventsStatements events outputRoomEventsStatements
roomstate currentRoomStateStatements roomstate currentRoomStateStatements
eduCache *cache.EDUCache eduCache *cache.EDUCache
topology outputRoomEventsTopologyStatements topology outputRoomEventsTopologyStatements
backwardExtremities tables.BackwardsExtremities backwardExtremities tables.BackwardsExtremities
shared *shared.Database
} }
// NewSyncServerDatasource creates a new sync server database // NewSyncServerDatasource creates a new sync server database
@ -71,7 +70,8 @@ func NewSyncServerDatasource(dbDataSourceName string, dbProperties common.DbProp
if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil { if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil {
return nil, err return nil, err
} }
if err = d.accountData.prepare(d.db); err != nil { accountData, err := NewPostgresAccountDataTable(d.db)
if err != nil {
return nil, err return nil, err
} }
if err = d.events.prepare(d.db); err != nil { if err = d.events.prepare(d.db); err != nil {
@ -92,9 +92,10 @@ func NewSyncServerDatasource(dbDataSourceName string, dbProperties common.DbProp
return nil, err return nil, err
} }
d.eduCache = cache.New() d.eduCache = cache.New()
d.shared = &shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Invites: invites, Invites: invites,
AccountData: accountData,
} }
return &d, nil return &d, nil
} }
@ -339,14 +340,14 @@ func (d *SyncServerDatasource) syncStreamPositionTx(
if err != nil { if err != nil {
return 0, err return 0, err
} }
maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn) maxAccountDataID, err := d.Database.AccountData.SelectMaxAccountDataID(ctx, txn)
if err != nil { if err != nil {
return 0, err return 0, err
} }
if maxAccountDataID > maxID { if maxAccountDataID > maxID {
maxID = maxAccountDataID maxID = maxAccountDataID
} }
maxInviteID, err := d.shared.Invites.SelectMaxInviteID(ctx, txn) maxInviteID, err := d.Database.Invites.SelectMaxInviteID(ctx, txn)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -364,14 +365,14 @@ func (d *SyncServerDatasource) syncPositionTx(
if err != nil { if err != nil {
return sp, err return sp, err
} }
maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn) maxAccountDataID, err := d.Database.AccountData.SelectMaxAccountDataID(ctx, txn)
if err != nil { if err != nil {
return sp, err return sp, err
} }
if maxAccountDataID > maxEventID { if maxAccountDataID > maxEventID {
maxEventID = maxAccountDataID maxEventID = maxAccountDataID
} }
maxInviteID, err := d.shared.Invites.SelectMaxInviteID(ctx, txn) maxInviteID, err := d.Database.Invites.SelectMaxInviteID(ctx, txn)
if err != nil { if err != nil {
return sp, err return sp, err
} }
@ -653,31 +654,6 @@ var txReadOnlySnapshot = sql.TxOptions{
ReadOnly: true, ReadOnly: true,
} }
func (d *SyncServerDatasource) GetAccountDataInRange(
ctx context.Context, userID string, oldPos, newPos types.StreamPosition,
accountDataFilterPart *gomatrixserverlib.EventFilter,
) (map[string][]string, error) {
return d.accountData.selectAccountDataInRange(ctx, userID, oldPos, newPos, accountDataFilterPart)
}
func (d *SyncServerDatasource) UpsertAccountData(
ctx context.Context, userID, roomID, dataType string,
) (types.StreamPosition, error) {
return d.accountData.insertAccountData(ctx, userID, roomID, dataType)
}
func (d *SyncServerDatasource) AddInviteEvent(
ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent,
) (sp types.StreamPosition, err error) {
return d.shared.AddInviteEvent(ctx, inviteEvent)
}
func (d *SyncServerDatasource) RetireInviteEvent(
ctx context.Context, inviteEventID string,
) error {
return d.shared.RetireInviteEvent(ctx, inviteEventID)
}
func (d *SyncServerDatasource) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) { func (d *SyncServerDatasource) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) {
d.eduCache.SetTimeoutCallback(fn) d.eduCache.SetTimeoutCallback(fn)
} }
@ -700,7 +676,7 @@ func (d *SyncServerDatasource) addInvitesToResponse(
fromPos, toPos types.StreamPosition, fromPos, toPos types.StreamPosition,
res *types.Response, res *types.Response,
) error { ) error {
invites, err := d.shared.Invites.SelectInviteEventsInRange( invites, err := d.Database.Invites.SelectInviteEventsInRange(
ctx, txn, userID, fromPos, toPos, ctx, txn, userID, fromPos, toPos,
) )
if err != nil { if err != nil {

View file

@ -13,10 +13,14 @@ import (
// Database is a temporary struct until we have made syncserver.go the same for both pq/sqlite // Database is a temporary struct until we have made syncserver.go the same for both pq/sqlite
// For now this contains the shared functions // For now this contains the shared functions
type Database struct { type Database struct {
DB *sql.DB DB *sql.DB
Invites tables.Invites Invites tables.Invites
AccountData tables.AccountData
} }
// AddInviteEvent stores a new invite event for a user.
// If the invite was successfully stored this returns the stream ID it was stored at.
// Returns an error if there was a problem communicating with the database.
func (d *Database) AddInviteEvent( func (d *Database) AddInviteEvent(
ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent, ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent,
) (sp types.StreamPosition, err error) { ) (sp types.StreamPosition, err error) {
@ -27,6 +31,8 @@ func (d *Database) AddInviteEvent(
return return
} }
// RetireInviteEvent removes an old invite event from the database.
// Returns an error if there was a problem communicating with the database.
func (d *Database) RetireInviteEvent( func (d *Database) RetireInviteEvent(
ctx context.Context, inviteEventID string, ctx context.Context, inviteEventID string,
) error { ) error {
@ -35,3 +41,31 @@ func (d *Database) RetireInviteEvent(
err := d.Invites.DeleteInviteEvent(ctx, inviteEventID) err := d.Invites.DeleteInviteEvent(ctx, inviteEventID)
return err return err
} }
// GetAccountDataInRange returns all account data for a given user inserted or
// updated between two given positions
// Returns a map following the format data[roomID] = []dataTypes
// If no data is retrieved, returns an empty map
// If there was an issue with the retrieval, returns an error
func (d *Database) GetAccountDataInRange(
ctx context.Context, userID string, oldPos, newPos types.StreamPosition,
accountDataFilterPart *gomatrixserverlib.EventFilter,
) (map[string][]string, error) {
return d.AccountData.SelectAccountDataInRange(ctx, userID, oldPos, newPos, accountDataFilterPart)
}
// UpsertAccountData keeps track of new or updated account data, by saving the type
// of the new/updated data, and the user ID and room ID the data is related to (empty)
// room ID means the data isn't specific to any room)
// If no data with the given type, user ID and room ID exists in the database,
// creates a new row, else update the existing one
// Returns an error if there was an issue with the upsert
func (d *Database) UpsertAccountData(
ctx context.Context, userID, roomID, dataType string,
) (sp types.StreamPosition, err error) {
err = common.WithTransaction(d.DB, func(txn *sql.Tx) error {
sp, err = d.AccountData.InsertAccountData(ctx, txn, userID, roomID, dataType)
return err
})
return
}

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -55,25 +56,25 @@ type accountDataStatements struct {
selectAccountDataInRangeStmt *sql.Stmt selectAccountDataInRangeStmt *sql.Stmt
} }
func (s *accountDataStatements) prepare(db *sql.DB, streamID *streamIDStatements) (err error) { func NewSqliteAccountDataTable(db *sql.DB) (tables.AccountData, error) {
s.streamIDStatements = streamID s := &accountDataStatements{}
_, err = db.Exec(accountDataSchema) _, err := db.Exec(accountDataSchema)
if err != nil { if err != nil {
return return nil, err
} }
if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil { if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil {
return return nil, err
} }
if s.selectMaxAccountDataIDStmt, err = db.Prepare(selectMaxAccountDataIDSQL); err != nil { if s.selectMaxAccountDataIDStmt, err = db.Prepare(selectMaxAccountDataIDSQL); err != nil {
return return nil, err
} }
if s.selectAccountDataInRangeStmt, err = db.Prepare(selectAccountDataInRangeSQL); err != nil { if s.selectAccountDataInRangeStmt, err = db.Prepare(selectAccountDataInRangeSQL); err != nil {
return return nil, err
} }
return return s, nil
} }
func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) InsertAccountData(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
userID, roomID, dataType string, userID, roomID, dataType string,
) (pos types.StreamPosition, err error) { ) (pos types.StreamPosition, err error) {
@ -85,7 +86,7 @@ func (s *accountDataStatements) insertAccountData(
return return
} }
func (s *accountDataStatements) selectAccountDataInRange( func (s *accountDataStatements) SelectAccountDataInRange(
ctx context.Context, ctx context.Context,
userID string, userID string,
oldPos, newPos types.StreamPosition, oldPos, newPos types.StreamPosition,
@ -146,7 +147,7 @@ func (s *accountDataStatements) selectAccountDataInRange(
return data, nil return data, nil
} }
func (s *accountDataStatements) selectMaxAccountDataID( func (s *accountDataStatements) SelectMaxAccountDataID(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
) (id int64, err error) { ) (id int64, err error) {
var nullableID sql.NullInt64 var nullableID sql.NullInt64

View file

@ -53,16 +53,15 @@ type stateDelta struct {
// SyncServerDatasource represents a sync server datasource which manages // SyncServerDatasource represents a sync server datasource which manages
// both the database for PDUs and caches for EDUs. // both the database for PDUs and caches for EDUs.
type SyncServerDatasource struct { type SyncServerDatasource struct {
shared.Database
db *sql.DB db *sql.DB
common.PartitionOffsetStatements common.PartitionOffsetStatements
streamID streamIDStatements streamID streamIDStatements
accountData accountDataStatements
events outputRoomEventsStatements events outputRoomEventsStatements
roomstate currentRoomStateStatements roomstate currentRoomStateStatements
eduCache *cache.EDUCache eduCache *cache.EDUCache
topology outputRoomEventsTopologyStatements topology outputRoomEventsTopologyStatements
backwardExtremities tables.BackwardsExtremities backwardExtremities tables.BackwardsExtremities
shared *shared.Database
} }
// NewSyncServerDatasource creates a new sync server database // NewSyncServerDatasource creates a new sync server database
@ -98,7 +97,8 @@ func (d *SyncServerDatasource) prepare() (err error) {
if err = d.streamID.prepare(d.db); err != nil { if err = d.streamID.prepare(d.db); err != nil {
return err return err
} }
if err = d.accountData.prepare(d.db, &d.streamID); err != nil { accountData, err := NewSqliteAccountDataTable(d.db)
if err != nil {
return err return err
} }
if err = d.events.prepare(d.db, &d.streamID); err != nil { if err = d.events.prepare(d.db, &d.streamID); err != nil {
@ -118,9 +118,10 @@ func (d *SyncServerDatasource) prepare() (err error) {
if err != nil { if err != nil {
return err return err
} }
d.shared = &shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Invites: invites, Invites: invites,
AccountData: accountData,
} }
return nil return nil
} }
@ -403,14 +404,14 @@ func (d *SyncServerDatasource) syncStreamPositionTx(
if err != nil { if err != nil {
return 0, err return 0, err
} }
maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn) maxAccountDataID, err := d.Database.AccountData.SelectMaxAccountDataID(ctx, txn)
if err != nil { if err != nil {
return 0, err return 0, err
} }
if maxAccountDataID > maxID { if maxAccountDataID > maxID {
maxID = maxAccountDataID maxID = maxAccountDataID
} }
maxInviteID, err := d.shared.Invites.SelectMaxInviteID(ctx, txn) maxInviteID, err := d.Database.Invites.SelectMaxInviteID(ctx, txn)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -428,14 +429,14 @@ func (d *SyncServerDatasource) syncPositionTx(
if err != nil { if err != nil {
return nil, err return nil, err
} }
maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn) maxAccountDataID, err := d.Database.AccountData.SelectMaxAccountDataID(ctx, txn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if maxAccountDataID > maxEventID { if maxAccountDataID > maxEventID {
maxEventID = maxAccountDataID maxEventID = maxAccountDataID
} }
maxInviteID, err := d.shared.Invites.SelectMaxInviteID(ctx, txn) maxInviteID, err := d.Database.Invites.SelectMaxInviteID(ctx, txn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -729,51 +730,6 @@ var txReadOnlySnapshot = sql.TxOptions{
ReadOnly: true, ReadOnly: true,
} }
// GetAccountDataInRange returns all account data for a given user inserted or
// updated between two given positions
// Returns a map following the format data[roomID] = []dataTypes
// If no data is retrieved, returns an empty map
// If there was an issue with the retrieval, returns an error
func (d *SyncServerDatasource) GetAccountDataInRange(
ctx context.Context, userID string, oldPos, newPos types.StreamPosition,
accountDataFilterPart *gomatrixserverlib.EventFilter,
) (map[string][]string, error) {
return d.accountData.selectAccountDataInRange(ctx, userID, oldPos, newPos, accountDataFilterPart)
}
// UpsertAccountData keeps track of new or updated account data, by saving the type
// of the new/updated data, and the user ID and room ID the data is related to (empty)
// room ID means the data isn't specific to any room)
// If no data with the given type, user ID and room ID exists in the database,
// creates a new row, else update the existing one
// Returns an error if there was an issue with the upsert
func (d *SyncServerDatasource) UpsertAccountData(
ctx context.Context, userID, roomID, dataType string,
) (sp types.StreamPosition, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
sp, err = d.accountData.insertAccountData(ctx, txn, userID, roomID, dataType)
return err
})
return
}
// AddInviteEvent stores a new invite event for a user.
// If the invite was successfully stored this returns the stream ID it was stored at.
// Returns an error if there was a problem communicating with the database.
func (d *SyncServerDatasource) AddInviteEvent(
ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent,
) (sp types.StreamPosition, err error) {
return d.shared.AddInviteEvent(ctx, inviteEvent)
}
// RetireInviteEvent removes an old invite event from the database.
// Returns an error if there was a problem communicating with the database.
func (d *SyncServerDatasource) RetireInviteEvent(
ctx context.Context, inviteEventID string,
) error {
return d.shared.RetireInviteEvent(ctx, inviteEventID)
}
func (d *SyncServerDatasource) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) { func (d *SyncServerDatasource) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) {
d.eduCache.SetTimeoutCallback(fn) d.eduCache.SetTimeoutCallback(fn)
} }
@ -800,7 +756,7 @@ func (d *SyncServerDatasource) addInvitesToResponse(
fromPos, toPos types.StreamPosition, fromPos, toPos types.StreamPosition,
res *types.Response, res *types.Response,
) error { ) error {
invites, err := d.shared.Invites.SelectInviteEventsInRange( invites, err := d.Database.Invites.SelectInviteEventsInRange(
ctx, txn, userID, fromPos, toPos, ctx, txn, userID, fromPos, toPos,
) )
if err != nil { if err != nil {

View file

@ -8,6 +8,12 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
type AccountData interface {
InsertAccountData(ctx context.Context, txn *sql.Tx, userID, roomID, dataType string) (pos types.StreamPosition, err error)
SelectAccountDataInRange(ctx context.Context, userID string, oldPos, newPos types.StreamPosition, accountDataEventFilter *gomatrixserverlib.EventFilter) (data map[string][]string, err error)
SelectMaxAccountDataID(ctx context.Context, txn *sql.Tx) (id int64, err error)
}
type Invites interface { type Invites interface {
InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent) (streamPos types.StreamPosition, err error) InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent) (streamPos types.StreamPosition, err error)
DeleteInviteEvent(ctx context.Context, inviteEventID string) error DeleteInviteEvent(ctx context.Context, inviteEventID string) error