diff --git a/internal/sqlutil/trace.go b/internal/sqlutil/trace.go index 376e6f494..235296388 100644 --- a/internal/sqlutil/trace.go +++ b/internal/sqlutil/trace.go @@ -31,33 +31,16 @@ import ( ) var tracingEnabled = os.Getenv("DENDRITE_TRACE_SQL") == "1" -var dbToWriter map[string]Writer -var CtxDBInstance = "db_instance" -var instCount = 0 type traceInterceptor struct { sqlmw.NullInterceptor - conn driver.Conn } func (in *traceInterceptor) StmtQueryContext(ctx context.Context, stmt driver.StmtQueryContext, query string, args []driver.NamedValue) (driver.Rows, error) { startedAt := time.Now() rows, err := stmt.QueryContext(ctx, args) - key := ctx.Value(CtxDBInstance) - var safe string - if key != nil { - w := dbToWriter[key.(string)] - if w == nil { - safe = fmt.Sprintf("no writer for key %s", key) - } else { - safe = w.Safe() - } - } - if safe != "" && !strings.HasPrefix(query, "SELECT ") { - logrus.Infof("unsafe: %s -- %s", safe, query) - } - logrus.WithField("duration", time.Since(startedAt)).WithField(logrus.ErrorKey, err).WithField("safe", safe).Debug("executed sql query ", query, " args: ", args) + logrus.WithField("duration", time.Since(startedAt)).WithField(logrus.ErrorKey, err).Debug("executed sql query ", query, " args: ", args) return rows, err } @@ -65,21 +48,8 @@ func (in *traceInterceptor) StmtQueryContext(ctx context.Context, stmt driver.St func (in *traceInterceptor) StmtExecContext(ctx context.Context, stmt driver.StmtExecContext, query string, args []driver.NamedValue) (driver.Result, error) { startedAt := time.Now() result, err := stmt.ExecContext(ctx, args) - key := ctx.Value(CtxDBInstance) - var safe string - if key != nil { - w := dbToWriter[key.(string)] - if w == nil { - safe = fmt.Sprintf("no writer for key %s", key) - } else { - safe = w.Safe() - } - } - if safe != "" && !strings.HasPrefix(query, "SELECT ") { - logrus.Infof("unsafe: %s -- %s", safe, query) - } - logrus.WithField("duration", time.Since(startedAt)).WithField(logrus.ErrorKey, err).WithField("safe", safe).Debug("executed sql query ", query, " args: ", args) + logrus.WithField("duration", time.Since(startedAt)).WithField(logrus.ErrorKey, err).Debug("executed sql query ", query, " args: ", args) return result, err } @@ -105,18 +75,6 @@ func (in *traceInterceptor) RowsNext(c context.Context, rows driver.Rows, dest [ return err } -func OpenWithWriter(dbProperties *config.DatabaseOptions, w Writer) (*sql.DB, context.Context, error) { - db, err := Open(dbProperties) - if err != nil { - return nil, nil, err - } - instCount++ - ctxVal := fmt.Sprintf("%d", instCount) - dbToWriter[ctxVal] = w - ctx := context.WithValue(context.TODO(), CtxDBInstance, ctxVal) - return db, ctx, nil -} - // Open opens a database specified by its database driver name and a driver-specific data source name, // usually consisting of at least a database name and connection information. Includes tracing driver // if DENDRITE_TRACE_SQL=1 @@ -160,5 +118,4 @@ func Open(dbProperties *config.DatabaseOptions) (*sql.DB, error) { func init() { registerDrivers() - dbToWriter = make(map[string]Writer) } diff --git a/internal/sqlutil/writer.go b/internal/sqlutil/writer.go index f966e250c..5d93fef4d 100644 --- a/internal/sqlutil/writer.go +++ b/internal/sqlutil/writer.go @@ -43,6 +43,4 @@ type Writer interface { // Queue up one or more database write operations within the // provided function to be executed when it is safe to do so. Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error - - Safe() string } diff --git a/internal/sqlutil/writer_dummy.go b/internal/sqlutil/writer_dummy.go index fbca3e773..f426c2bc3 100644 --- a/internal/sqlutil/writer_dummy.go +++ b/internal/sqlutil/writer_dummy.go @@ -26,7 +26,3 @@ func (w *DummyWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) err return f(txn) } } - -func (w *DummyWriter) Safe() string { - return "DummyWriter" -} diff --git a/internal/sqlutil/writer_exclusive.go b/internal/sqlutil/writer_exclusive.go index 933661958..002bc32cf 100644 --- a/internal/sqlutil/writer_exclusive.go +++ b/internal/sqlutil/writer_exclusive.go @@ -3,10 +3,6 @@ package sqlutil import ( "database/sql" "errors" - "fmt" - "runtime" - "strconv" - "strings" "go.uber.org/atomic" ) @@ -16,9 +12,8 @@ import ( // contend on database locks in, e.g. SQLite. Only one task will run // at a time on a given ExclusiveWriter. type ExclusiveWriter struct { - running atomic.Bool - todo chan transactionWriterTask - writerID int + running atomic.Bool + todo chan transactionWriterTask } func NewExclusiveWriter() Writer { @@ -35,15 +30,6 @@ type transactionWriterTask struct { wait chan error } -func (w *ExclusiveWriter) Safe() string { - a := goid() - b := w.writerID - if a == b { - return "" - } - return fmt.Sprintf("%v != %v", a, b) -} - // Do queues a task to be run by a TransactionWriter. The function // provided will be ran within a transaction as supplied by the // txn parameter if one is supplied, and if not, will take out a @@ -74,7 +60,6 @@ func (w *ExclusiveWriter) run() { if !w.running.CAS(false, true) { return } - w.writerID = goid() defer w.running.Store(false) for task := range w.todo { if task.db != nil && task.txn != nil { @@ -89,14 +74,3 @@ func (w *ExclusiveWriter) run() { close(task.wait) } } - -func goid() int { - var buf [64]byte - n := runtime.Stack(buf[:], false) - idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0] - id, err := strconv.Atoi(idField) - if err != nil { - panic(fmt.Sprintf("cannot get goroutine id: %v", err)) - } - return id -} diff --git a/syncapi/storage/sqlite3/peeks_table.go b/syncapi/storage/sqlite3/peeks_table.go index c013a6350..bc90e175b 100644 --- a/syncapi/storage/sqlite3/peeks_table.go +++ b/syncapi/storage/sqlite3/peeks_table.go @@ -61,14 +61,14 @@ const markPeeksAsOldSQL = "" + "UPDATE syncapi_peeks SET new=false WHERE user_id = $1 and device_id = $2" type peekStatements struct { - db *sql.DB - streamIDStatements *streamIDStatements - insertPeekStmt *sql.Stmt - deletePeekStmt *sql.Stmt - deletePeeksStmt *sql.Stmt - selectPeeksStmt *sql.Stmt - selectPeekingDevicesStmt *sql.Stmt - markPeeksAsOldStmt *sql.Stmt + db *sql.DB + streamIDStatements *streamIDStatements + insertPeekStmt *sql.Stmt + deletePeekStmt *sql.Stmt + deletePeeksStmt *sql.Stmt + selectPeeksStmt *sql.Stmt + selectPeekingDevicesStmt *sql.Stmt + markPeeksAsOldStmt *sql.Stmt } func NewSqlitePeeksTable(db *sql.DB, streamID *streamIDStatements) (tables.Peeks, error) { @@ -77,7 +77,7 @@ func NewSqlitePeeksTable(db *sql.DB, streamID *streamIDStatements) (tables.Peeks return nil, err } s := &peekStatements{ - db: db, + db: db, streamIDStatements: streamID, } if s.insertPeekStmt, err = db.Prepare(insertPeekSQL); err != nil { @@ -155,7 +155,7 @@ func (s *peekStatements) SelectPeeks( return peeks, rows.Err() } -func (s *peekStatements) MarkPeeksAsOld ( +func (s *peekStatements) MarkPeeksAsOld( ctx context.Context, txn *sql.Tx, userID, deviceID string, ) (err error) { _, err = sqlutil.TxStmt(txn, s.markPeeksAsOldStmt).ExecContext(ctx, userID, deviceID) diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index b6d485352..86d83ec98 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -16,7 +16,6 @@ package sqlite3 import ( - "context" "database/sql" // Import the sqlite3 package @@ -25,11 +24,7 @@ import ( "github.com/matrix-org/dendrite/eduserver/cache" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage/shared" - "github.com/matrix-org/dendrite/syncapi/types" - userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" ) // SyncServerDatasource represents a sync server datasource which manages @@ -40,7 +35,6 @@ type SyncServerDatasource struct { writer sqlutil.Writer sqlutil.PartitionOffsetStatements streamID streamIDStatements - dbctx context.Context } // NewDatabase creates a new sync server database @@ -48,12 +42,10 @@ type SyncServerDatasource struct { func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) { var d SyncServerDatasource var err error - d.writer = sqlutil.NewExclusiveWriter() - d.db, d.dbctx, err = sqlutil.OpenWithWriter(dbProperties, d.writer) - if err != nil { + if d.db, err = sqlutil.Open(dbProperties); err != nil { return nil, err } - + d.writer = sqlutil.NewExclusiveWriter() if err = d.prepare(); err != nil { return nil, err } @@ -107,7 +99,7 @@ func (d *SyncServerDatasource) prepare() (err error) { DB: d.db, Writer: d.writer, Invites: invites, - Peeks: peeks, + Peeks: peeks, AccountData: accountData, OutputEvents: events, BackwardExtremities: bwExtrem, @@ -119,172 +111,3 @@ func (d *SyncServerDatasource) prepare() (err error) { } return nil } - -func (d *SyncServerDatasource) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.HeaderedEvent, error) { - return d.Database.Events(d.dbctx, eventIDs) -} -func (d *SyncServerDatasource) WriteEvent(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, addStateEvents []gomatrixserverlib.HeaderedEvent, - addStateEventIDs []string, removeStateEventIDs []string, transactionID *api.TransactionID, excludeFromSync bool) (types.StreamPosition, error) { - return d.Database.WriteEvent(d.dbctx, ev, addStateEvents, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync) -} -func (d *SyncServerDatasource) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { - return d.Database.AllJoinedUsersInRooms(d.dbctx) -} - -func (d *SyncServerDatasource) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) { - return d.Database.GetStateEvent(d.dbctx, roomID, evType, stateKey) -} - -// GetStateEventsForRoom fetches the state events for a given room. -// Returns an empty slice if no state events could be found for this room. -// Returns an error if there was an issue with the retrieval. -func (d *SyncServerDatasource) GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) (stateEvents []gomatrixserverlib.HeaderedEvent, err error) { - return d.Database.GetStateEventsForRoom(d.dbctx, roomID, stateFilterPart) -} - -// SyncPosition returns the latest positions for syncing. -func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.StreamingToken, error) { - return d.Database.SyncPosition(d.dbctx) -} - -func (d *SyncServerDatasource) IncrementalSync(ctx context.Context, res *types.Response, device userapi.Device, fromPos, toPos types.StreamingToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error) { - return d.Database.IncrementalSync(d.dbctx, res, device, fromPos, toPos, numRecentEventsPerRoom, wantFullState) -} - -// CompleteSync returns a complete /sync API response for the given user. A response object -// must be provided for CompleteSync to populate - it will not create one. -func (d *SyncServerDatasource) CompleteSync(ctx context.Context, res *types.Response, device userapi.Device, numRecentEventsPerRoom int) (*types.Response, error) { - return d.Database.CompleteSync(d.dbctx, res, device, numRecentEventsPerRoom) -} - -// GetAccountDataInRange returns all account data for a given user inserted or -// updated between two given positions -// Returns a map following the format data[roomID] = []dataTypes -// If no data is retrieved, returns an empty map -// If there was an issue with the retrieval, returns an error -func (d *SyncServerDatasource) GetAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataFilterPart *gomatrixserverlib.EventFilter) (map[string][]string, error) { - return d.Database.GetAccountDataInRange(d.dbctx, userID, r, accountDataFilterPart) -} - -// UpsertAccountData keeps track of new or updated account data, by saving the type -// of the new/updated data, and the user ID and room ID the data is related to (empty) -// room ID means the data isn't specific to any room) -// If no data with the given type, user ID and room ID exists in the database, -// creates a new row, else update the existing one -// Returns an error if there was an issue with the upsert -func (d *SyncServerDatasource) UpsertAccountData(ctx context.Context, userID, roomID, dataType string) (types.StreamPosition, error) { - return d.Database.UpsertAccountData(d.dbctx, userID, roomID, dataType) -} - -// AddInviteEvent stores a new invite event for a user. -// If the invite was successfully stored this returns the stream ID it was stored at. -// Returns an error if there was a problem communicating with the database. -func (d *SyncServerDatasource) AddInviteEvent(ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent) (types.StreamPosition, error) { - return d.Database.AddInviteEvent(d.dbctx, inviteEvent) -} - -// RetireInviteEvent removes an old invite event from the database. Returns the new position of the retired invite. -// Returns an error if there was a problem communicating with the database. -func (d *SyncServerDatasource) RetireInviteEvent(ctx context.Context, inviteEventID string) (types.StreamPosition, error) { - return d.Database.RetireInviteEvent(d.dbctx, inviteEventID) -} - -// GetEventsInStreamingRange retrieves all of the events on a given ordering using the given extremities and limit. -func (d *SyncServerDatasource) GetEventsInStreamingRange(ctx context.Context, from, to *types.StreamingToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) { - return d.Database.GetEventsInStreamingRange(d.dbctx, from, to, roomID, limit, backwardOrdering) -} - -// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. -func (d *SyncServerDatasource) GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) { - return d.Database.GetEventsInTopologicalRange(d.dbctx, from, to, roomID, limit, backwardOrdering) -} - -// EventPositionInTopology returns the depth and stream position of the given event. -func (d *SyncServerDatasource) EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error) { - return d.Database.EventPositionInTopology(d.dbctx, eventID) -} - -// BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events. -func (d *SyncServerDatasource) BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities map[string][]string, err error) { - return d.Database.BackwardExtremitiesForRoom(d.dbctx, roomID) -} - -func (d *SyncServerDatasource) MaxTopologicalPosition(ctx context.Context, roomID string) (types.TopologyToken, error) { - return d.Database.MaxTopologicalPosition(d.dbctx, roomID) -} - -// SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns three lists: -// - "events": a list of send-to-device events that should be included in the sync -// - "changes": a list of send-to-device events that should be updated in the database by -// CleanSendToDeviceUpdates -// - "deletions": a list of send-to-device events which have been confirmed as sent and -// can be deleted altogether by CleanSendToDeviceUpdates -// The token supplied should be the current requested sync token, e.g. from the "since" -// parameter. -func (d *SyncServerDatasource) SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, token types.StreamingToken) (events []types.SendToDeviceEvent, changes []types.SendToDeviceNID, deletions []types.SendToDeviceNID, err error) { - return d.Database.SendToDeviceUpdatesForSync(d.dbctx, userID, deviceID, token) -} - -// StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device. -func (d *SyncServerDatasource) StoreNewSendForDeviceMessage(ctx context.Context, streamPos types.StreamPosition, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error) { - return d.Database.StoreNewSendForDeviceMessage(d.dbctx, streamPos, userID, deviceID, event) -} - -// CleanSendToDeviceUpdates will update or remove any send-to-device updates based on the -// result to a previous call to SendDeviceUpdatesForSync. This is separate as it allows -// SendToDeviceUpdatesForSync to be called multiple times if needed (e.g. before and after -// starting to wait for an incremental sync with timeout). -// The token supplied should be the current requested sync token, e.g. from the "since" -// parameter. -func (d *SyncServerDatasource) CleanSendToDeviceUpdates(ctx context.Context, toUpdate, toDelete []types.SendToDeviceNID, token types.StreamingToken) (err error) { - return d.Database.CleanSendToDeviceUpdates(d.dbctx, toUpdate, toDelete, token) -} - -// SendToDeviceUpdatesWaiting returns true if there are send-to-device updates waiting to be sent. -func (d *SyncServerDatasource) SendToDeviceUpdatesWaiting(ctx context.Context, userID, deviceID string) (bool, error) { - return d.Database.SendToDeviceUpdatesWaiting(d.dbctx, userID, deviceID) -} - -// GetFilter looks up the filter associated with a given local user and filter ID. -// Returns a filter structure. Otherwise returns an error if no such filter exists -// or if there was an error talking to the database. -func (d *SyncServerDatasource) GetFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error) { - return d.Database.GetFilter(d.dbctx, localpart, filterID) -} - -// PutFilter puts the passed filter into the database. -// Returns the filterID as a string. Otherwise returns an error if something -// goes wrong. -func (d *SyncServerDatasource) PutFilter(ctx context.Context, localpart string, filter *gomatrixserverlib.Filter) (string, error) { - return d.Database.PutFilter(d.dbctx, localpart, filter) -} - -// RedactEvent wipes an event in the database and sets the unsigned.redacted_because key to the redaction event -func (d *SyncServerDatasource) RedactEvent(ctx context.Context, redactedEventID string, redactedBecause *gomatrixserverlib.HeaderedEvent) error { - return d.Database.RedactEvent(d.dbctx, redactedEventID, redactedBecause) -} - -// AllPeekingDevicesInRooms returns a map of room ID to a list of all peeking devices. -func (d *SyncServerDatasource) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) { - return d.Database.AllPeekingDevicesInRooms(d.dbctx) -} - -// AddPeek adds a new peek to our DB for a given room by a given user's device. -// Returns an error if there was a problem communicating with the database. -func (d *SyncServerDatasource) AddPeek(ctx context.Context, roomID, userID, deviceID string) (types.StreamPosition, error) { - return d.Database.AddPeek(d.dbctx, roomID, userID, deviceID) -} - - -func (d *SyncServerDatasource) PartitionOffsets( - ctx context.Context, topic string, -) ([]sqlutil.PartitionOffset, error) { - return d.PartitionOffsetStatements.PartitionOffsets(d.dbctx, topic) -} - -// SetPartitionOffset implements PartitionStorer -func (d *SyncServerDatasource) SetPartitionOffset( - ctx context.Context, topic string, partition int32, offset int64, -) error { - return d.PartitionOffsetStatements.SetPartitionOffset(d.dbctx, topic, partition, offset) -}