Merge branch 'matthew/peeking' into matthew/fix-writer-txn-rollbacks

This commit is contained in:
Kegan Dougal 2020-09-09 17:20:31 +01:00
commit abd9559d69
13 changed files with 165 additions and 102 deletions

View file

@ -22,7 +22,10 @@ import (
"io" "io"
"os" "os"
"regexp" "regexp"
"runtime"
"strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/config"
@ -31,6 +34,7 @@ import (
) )
var tracingEnabled = os.Getenv("DENDRITE_TRACE_SQL") == "1" var tracingEnabled = os.Getenv("DENDRITE_TRACE_SQL") == "1"
var goidToWriter sync.Map
type traceInterceptor struct { type traceInterceptor struct {
sqlmw.NullInterceptor sqlmw.NullInterceptor
@ -40,6 +44,8 @@ func (in *traceInterceptor) StmtQueryContext(ctx context.Context, stmt driver.St
startedAt := time.Now() startedAt := time.Now()
rows, err := stmt.QueryContext(ctx, args) rows, err := stmt.QueryContext(ctx, args)
trackGoID(query)
logrus.WithField("duration", time.Since(startedAt)).WithField(logrus.ErrorKey, err).Debug("executed sql query ", query, " args: ", args) logrus.WithField("duration", time.Since(startedAt)).WithField(logrus.ErrorKey, err).Debug("executed sql query ", query, " args: ", args)
return rows, err return rows, err
@ -49,6 +55,8 @@ func (in *traceInterceptor) StmtExecContext(ctx context.Context, stmt driver.Stm
startedAt := time.Now() startedAt := time.Now()
result, err := stmt.ExecContext(ctx, args) result, err := stmt.ExecContext(ctx, args)
trackGoID(query)
logrus.WithField("duration", time.Since(startedAt)).WithField(logrus.ErrorKey, err).Debug("executed sql query ", query, " args: ", args) logrus.WithField("duration", time.Since(startedAt)).WithField(logrus.ErrorKey, err).Debug("executed sql query ", query, " args: ", args)
return result, err return result, err
@ -75,6 +83,19 @@ func (in *traceInterceptor) RowsNext(c context.Context, rows driver.Rows, dest [
return err return err
} }
func trackGoID(query string) {
thisGoID := goid()
if _, ok := goidToWriter.Load(thisGoID); ok {
return // we're on a writer goroutine
}
q := strings.TrimSpace(query)
if strings.HasPrefix(q, "SELECT") {
return // SELECTs can go on other goroutines
}
logrus.Warnf("unsafe goid: SQL executed not on an ExclusiveWriter: %s", q)
}
// Open opens a database specified by its database driver name and a driver-specific data source name, // Open opens a database specified by its database driver name and a driver-specific data source name,
// usually consisting of at least a database name and connection information. Includes tracing driver // usually consisting of at least a database name and connection information. Includes tracing driver
// if DENDRITE_TRACE_SQL=1 // if DENDRITE_TRACE_SQL=1
@ -119,3 +140,14 @@ func Open(dbProperties *config.DatabaseOptions) (*sql.DB, error) {
func init() { func init() {
registerDrivers() registerDrivers()
} }
func goid() int {
var buf [64]byte
n := runtime.Stack(buf[:], false)
idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0]
id, err := strconv.Atoi(idField)
if err != nil {
panic(fmt.Sprintf("cannot get goroutine id: %v", err))
}
return id
}

View file

@ -60,6 +60,12 @@ func (w *ExclusiveWriter) run() {
if !w.running.CAS(false, true) { if !w.running.CAS(false, true) {
return return
} }
if tracingEnabled {
gid := goid()
goidToWriter.Store(gid, w)
defer goidToWriter.Delete(gid)
}
defer w.running.Store(false) defer w.running.Store(false)
for task := range w.todo { for task := range w.todo {
if task.db != nil && task.txn != nil { if task.db != nil && task.txn != nil {

View file

@ -53,7 +53,7 @@ const selectDeviceKeysSQL = "" +
"SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
const selectBatchDeviceKeysSQL = "" + const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1" "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
const selectMaxStreamForUserSQL = "" + const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"

View file

@ -50,7 +50,7 @@ const selectDeviceKeysSQL = "" +
"SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
const selectBatchDeviceKeysSQL = "" + const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1" "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
const selectMaxStreamForUserSQL = "" + const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"

View file

@ -79,89 +79,103 @@ func (u *MembershipUpdater) IsLeave() bool {
// SetToInvite implements types.MembershipUpdater // SetToInvite implements types.MembershipUpdater
func (u *MembershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) { func (u *MembershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) {
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender()) var inserted bool
if err != nil { err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
return false, fmt.Errorf("u.d.AssignStateKeyNID: %w", err) senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender())
} if err != nil {
inserted, err := u.d.InvitesTable.InsertInviteEvent( return fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
)
if err != nil {
return false, fmt.Errorf("u.d.InvitesTable.InsertInviteEvent: %w", err)
}
if u.membership != tables.MembershipStateInvite {
if err = u.d.MembershipTable.UpdateMembership(
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0,
); err != nil {
return false, fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
} }
} inserted, err = u.d.InvitesTable.InsertInviteEvent(
return inserted, nil u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
)
if err != nil {
return fmt.Errorf("u.d.InvitesTable.InsertInviteEvent: %w", err)
}
if u.membership != tables.MembershipStateInvite {
if err = u.d.MembershipTable.UpdateMembership(
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0,
); err != nil {
return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
}
}
return nil
})
return inserted, err
} }
// SetToJoin implements types.MembershipUpdater // SetToJoin implements types.MembershipUpdater
func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) { func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) {
var inviteEventIDs []string var inviteEventIDs []string
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
if err != nil { senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID)
return nil, fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
}
// If this is a join event update, there is no invite to update
if !isUpdate {
inviteEventIDs, err = u.d.InvitesTable.UpdateInviteRetired(
u.ctx, u.txn, u.roomNID, u.targetUserNID,
)
if err != nil { if err != nil {
return nil, fmt.Errorf("u.d.InvitesTables.UpdateInviteRetired: %w", err) return fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
} }
}
// Look up the NID of the new join event // If this is a join event update, there is no invite to update
nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) if !isUpdate {
if err != nil { inviteEventIDs, err = u.d.InvitesTable.UpdateInviteRetired(
return nil, fmt.Errorf("u.d.EventNIDs: %w", err) u.ctx, u.txn, u.roomNID, u.targetUserNID,
} )
if err != nil {
if u.membership != tables.MembershipStateJoin || isUpdate { return fmt.Errorf("u.d.InvitesTables.UpdateInviteRetired: %w", err)
if err = u.d.MembershipTable.UpdateMembership( }
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
tables.MembershipStateJoin, nIDs[eventID],
); err != nil {
return nil, fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
} }
}
return inviteEventIDs, nil // Look up the NID of the new join event
nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
if err != nil {
return fmt.Errorf("u.d.EventNIDs: %w", err)
}
if u.membership != tables.MembershipStateJoin || isUpdate {
if err = u.d.MembershipTable.UpdateMembership(
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
tables.MembershipStateJoin, nIDs[eventID],
); err != nil {
return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
}
}
return nil
})
return inviteEventIDs, err
} }
// SetToLeave implements types.MembershipUpdater // SetToLeave implements types.MembershipUpdater
func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) { func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) {
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) var inviteEventIDs []string
if err != nil {
return nil, fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
}
inviteEventIDs, err := u.d.InvitesTable.UpdateInviteRetired(
u.ctx, u.txn, u.roomNID, u.targetUserNID,
)
if err != nil {
return nil, fmt.Errorf("u.d.InvitesTable.updateInviteRetired: %w", err)
}
// Look up the NID of the new leave event err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID)
if err != nil { if err != nil {
return nil, fmt.Errorf("u.d.EventNIDs: %w", err) return fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
}
if u.membership != tables.MembershipStateLeaveOrBan {
if err = u.d.MembershipTable.UpdateMembership(
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
tables.MembershipStateLeaveOrBan, nIDs[eventID],
); err != nil {
return nil, fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
} }
} inviteEventIDs, err = u.d.InvitesTable.UpdateInviteRetired(
return inviteEventIDs, nil u.ctx, u.txn, u.roomNID, u.targetUserNID,
)
if err != nil {
return fmt.Errorf("u.d.InvitesTable.updateInviteRetired: %w", err)
}
// Look up the NID of the new leave event
nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
if err != nil {
return fmt.Errorf("u.d.EventNIDs: %w", err)
}
if u.membership != tables.MembershipStateLeaveOrBan {
if err = u.d.MembershipTable.UpdateMembership(
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
tables.MembershipStateLeaveOrBan, nIDs[eventID],
); err != nil {
return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
}
}
return nil
})
return inviteEventIDs, err
} }

View file

@ -216,10 +216,12 @@ func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *gom
} }
// cancel any peeks for it // cancel any peeks for it
sp, err = s.db.DeletePeeks(ctx, ev.RoomID(), *ev.StateKey()) peekSP, peekErr := s.db.DeletePeeks(ctx, ev.RoomID(), *ev.StateKey())
// XXX: should we do anything with this new streampos? if peekErr != nil {
if err != nil { return sp, fmt.Errorf("s.db.DeletePeeks: %w", peekErr)
return sp, fmt.Errorf("s.db.DeletePeeks: %w", err) }
if peekSP > 0 {
sp = peekSP
} }
} }
return sp, nil return sp, nil

View file

@ -78,7 +78,7 @@ func NewPostgresPeeksTable(db *sql.DB) (tables.Peeks, error) {
return nil, err return nil, err
} }
s := &peekStatements{ s := &peekStatements{
db: db, db: db,
} }
if s.insertPeekStmt, err = db.Prepare(insertPeekSQL); err != nil { if s.insertPeekStmt, err = db.Prepare(insertPeekSQL); err != nil {
return nil, err return nil, err

View file

@ -86,7 +86,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
DB: d.db, DB: d.db,
Writer: d.writer, Writer: d.writer,
Invites: invites, Invites: invites,
Peeks: peeks, Peeks: peeks,
AccountData: accountData, AccountData: accountData,
OutputEvents: events, OutputEvents: events,
Topology: topology, Topology: topology,

View file

@ -30,7 +30,7 @@ import (
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// Database is a temporary struct until we have made syncserver.go the same for both pq/sqlite // Database is a temporary struct until we have made syncserver.go the same for both pq/sqlite
@ -169,8 +169,8 @@ func (d *Database) RetireInviteEvent(
func (d *Database) AddPeek( func (d *Database) AddPeek(
ctx context.Context, roomID, userID, deviceID string, ctx context.Context, roomID, userID, deviceID string,
) (sp types.StreamPosition, err error) { ) (sp types.StreamPosition, err error) {
_ = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
sp, err = d.Peeks.InsertPeek(ctx, nil, roomID, userID, deviceID) sp, err = d.Peeks.InsertPeek(ctx, txn, roomID, userID, deviceID)
return err return err
}) })
return return
@ -182,11 +182,12 @@ func (d *Database) AddPeek(
func (d *Database) DeletePeeks( func (d *Database) DeletePeeks(
ctx context.Context, roomID, userID string, ctx context.Context, roomID, userID string,
) (sp types.StreamPosition, err error) { ) (sp types.StreamPosition, err error) {
_ = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
sp, err = d.Peeks.DeletePeeks(ctx, nil, roomID, userID) sp, err = d.Peeks.DeletePeeks(ctx, txn, roomID, userID)
return err return err
}) })
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
sp = 0
err = nil err = nil
} }
return return
@ -230,7 +231,7 @@ func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.Strea
"transaction_id", in[i].TransactionID.TransactionID, "transaction_id", in[i].TransactionID.TransactionID,
) )
if err != nil { if err != nil {
logrus.WithFields(logrus.Fields{ log.WithFields(log.Fields{
"event_id": out[i].EventID(), "event_id": out[i].EventID(),
}).WithError(err).Warnf("Failed to add transaction ID to event") }).WithError(err).Warnf("Failed to add transaction ID to event")
} }
@ -624,7 +625,7 @@ func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, reda
return err return err
} }
if len(redactedEvents) == 0 { if len(redactedEvents) == 0 {
logrus.WithField("event_id", redactedEventID).WithField("redaction_event", redactedBecause.EventID()).Warnf("missing redacted event for redaction") log.WithField("event_id", redactedEventID).WithField("redaction_event", redactedBecause.EventID()).Warnf("missing redacted event for redaction")
return nil return nil
} }
eventToRedact := redactedEvents[0].Unwrap() eventToRedact := redactedEvents[0].Unwrap()

View file

@ -53,8 +53,9 @@ const deletePeeksSQL = "" +
// we care about all the peeks which were created in this range, deleted in this range, // we care about all the peeks which were created in this range, deleted in this range,
// or were created before this range but haven't been deleted yet. // or were created before this range but haven't been deleted yet.
// BEWARE: sqlite chokes on out of order substitution strings.
const selectPeeksInRangeSQL = "" + const selectPeeksInRangeSQL = "" +
"SELECT room_id, deleted, (id > $3 AND id <= $4) AS changed FROM syncapi_peeks WHERE user_id = $1 AND device_id = $2 AND ((id <= $3 AND NOT deleted) OR (id > $3 AND id <= $4))" "SELECT id, room_id, deleted FROM syncapi_peeks WHERE user_id = $1 AND device_id = $2 AND ((id <= $3 AND NOT deleted=true) OR (id > $3 AND id <= $4))"
const selectPeekingDevicesSQL = "" + const selectPeekingDevicesSQL = "" +
"SELECT room_id, user_id, device_id FROM syncapi_peeks WHERE deleted=false" "SELECT room_id, user_id, device_id FROM syncapi_peeks WHERE deleted=false"
@ -128,13 +129,23 @@ func (s *peekStatements) DeletePeek(
func (s *peekStatements) DeletePeeks( func (s *peekStatements) DeletePeeks(
ctx context.Context, txn *sql.Tx, roomID, userID string, ctx context.Context, txn *sql.Tx, roomID, userID string,
) (streamPos types.StreamPosition, err error) { ) (types.StreamPosition, error) {
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn) streamPos, err := s.streamIDStatements.nextStreamID(ctx, txn)
if err != nil { if err != nil {
return return 0, err
} }
_, err = sqlutil.TxStmt(txn, s.deletePeeksStmt).ExecContext(ctx, streamPos, roomID, userID) result, err := sqlutil.TxStmt(txn, s.deletePeeksStmt).ExecContext(ctx, streamPos, roomID, userID)
return if err != nil {
return 0, err
}
numAffected, err := result.RowsAffected()
if err != nil {
return 0, err
}
if numAffected == 0 {
return 0, sql.ErrNoRows
}
return streamPos, nil
} }
func (s *peekStatements) SelectPeeksInRange( func (s *peekStatements) SelectPeeksInRange(
@ -148,11 +159,11 @@ func (s *peekStatements) SelectPeeksInRange(
for rows.Next() { for rows.Next() {
peek := types.Peek{} peek := types.Peek{}
var changed bool var id types.StreamPosition
if err = rows.Scan(&peek.RoomID, &peek.Deleted, &changed); err != nil { if err = rows.Scan(&id, &peek.RoomID, &peek.Deleted); err != nil {
return return
} }
peek.New = changed && !peek.Deleted peek.New = (id > r.Low() && id <= r.High()) && !peek.Deleted
peeks = append(peeks, peek) peeks = append(peeks, peek)
} }

View file

@ -339,6 +339,7 @@ func (n *Notifier) addPeekingDevice(roomID, userID, deviceID string) {
} }
// Not thread-safe: must be called on the OnNewEvent goroutine only // Not thread-safe: must be called on the OnNewEvent goroutine only
// nolint:unused
func (n *Notifier) removePeekingDevice(roomID, userID, deviceID string) { func (n *Notifier) removePeekingDevice(roomID, userID, deviceID string) {
if _, ok := n.roomIDToPeekingDevices[roomID]; !ok { if _, ok := n.roomIDToPeekingDevices[roomID]; !ok {
n.roomIDToPeekingDevices[roomID] = make(peekingDeviceSet) n.roomIDToPeekingDevices[roomID] = make(peekingDeviceSet)
@ -409,6 +410,7 @@ func (s peekingDeviceSet) add(d types.PeekingDevice) {
s[d] = true s[d] = true
} }
// nolint:unused
func (s peekingDeviceSet) remove(d types.PeekingDevice) { func (s peekingDeviceSet) remove(d types.PeekingDevice) {
delete(s, d) delete(s, d)
} }

View file

@ -1,11 +0,0 @@
package sync
import "github.com/matrix-org/dendrite/syncapi/types"
type SyncProvider interface {
WaitFor()
}
type SyncStream interface {
GetLatestPosition() types.StreamPosition
}

View file

@ -465,3 +465,9 @@ After changing password, can log in with new password
After changing password, existing session still works After changing password, existing session still works
After changing password, different sessions can optionally be kept After changing password, different sessions can optionally be kept
After changing password, a different session no longer works by default After changing password, a different session no longer works by default
Local users can peek into world_readable rooms by room ID
We can't peek into rooms with shared history_visibility
We can't peek into rooms with invited history_visibility
We can't peek into rooms with joined history_visibility
Local users can peek by room alias
Peeked rooms only turn up in the sync for the device who peeked them