From 275b21b89d040cc13187bd874902945f1c61e5eb Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 20 Jan 2021 17:59:10 +0000 Subject: [PATCH] Allow selecon multiple membership types, get latest --- syncapi/storage/postgres/memberships_table.go | 8 ++++-- syncapi/storage/sqlite3/memberships_table.go | 28 +++++++++++++------ syncapi/storage/tables/interface.go | 2 +- 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/syncapi/storage/postgres/memberships_table.go b/syncapi/storage/postgres/memberships_table.go index f679d5369..7c2c407ac 100644 --- a/syncapi/storage/postgres/memberships_table.go +++ b/syncapi/storage/postgres/memberships_table.go @@ -59,7 +59,9 @@ const upsertMembershipSQL = "" + const selectMembershipSQL = "" + "SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" + - " WHERE room_id = $1 AND user_id = $2 AND membership = $3" + " WHERE room_id = $1 AND user_id = $2 AND membership = ANY($3)" + + " ORDER BY stream_pos DESC" + + " LIMIT 1" type membershipsStatements struct { upsertMembershipStmt *sql.Stmt @@ -102,9 +104,9 @@ func (s *membershipsStatements) UpsertMembership( } func (s *membershipsStatements) SelectMembership( - ctx context.Context, txn *sql.Tx, roomID, userID, membership string, + ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string, ) (eventID string, streamPos, topologyPos types.StreamPosition, err error) { stmt := sqlutil.TxStmt(txn, s.selectMembershipStmt) - err = stmt.QueryRowContext(ctx, roomID, userID, membership).Scan(&eventID, &streamPos, &topologyPos) + err = stmt.QueryRowContext(ctx, roomID, userID, memberships).Scan(&eventID, &streamPos, &topologyPos) return } diff --git a/syncapi/storage/sqlite3/memberships_table.go b/syncapi/storage/sqlite3/memberships_table.go index 5ad98ff5e..725d8d9ff 100644 --- a/syncapi/storage/sqlite3/memberships_table.go +++ b/syncapi/storage/sqlite3/memberships_table.go @@ -19,6 +19,7 @@ import ( "context" "database/sql" "fmt" + "strings" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" @@ -59,15 +60,19 @@ const upsertMembershipSQL = "" + const selectMembershipSQL = "" + "SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" + - " WHERE room_id = $1 AND user_id = $2 AND membership = $3" + " WHERE room_id = $1 AND user_id = $2 AND membership IN ($3)" + + " ORDER BY stream_pos DESC" + + " LIMIT 1" type membershipsStatements struct { + db *sql.DB upsertMembershipStmt *sql.Stmt - selectMembershipStmt *sql.Stmt } func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { - s := &membershipsStatements{} + s := &membershipsStatements{ + db: db, + } _, err := db.Exec(membershipsSchema) if err != nil { return nil, err @@ -75,9 +80,6 @@ func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { if s.upsertMembershipStmt, err = db.Prepare(upsertMembershipSQL); err != nil { return nil, err } - if s.selectMembershipStmt, err = db.Prepare(selectMembershipSQL); err != nil { - return nil, err - } return s, nil } @@ -102,9 +104,17 @@ func (s *membershipsStatements) UpsertMembership( } func (s *membershipsStatements) SelectMembership( - ctx context.Context, txn *sql.Tx, roomID, userID, membership string, + ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string, ) (eventID string, streamPos, topologyPos types.StreamPosition, err error) { - stmt := sqlutil.TxStmt(txn, s.selectMembershipStmt) - err = stmt.QueryRowContext(ctx, roomID, userID, membership).Scan(&eventID, &streamPos, &topologyPos) + params := []interface{}{roomID, userID} + for _, membership := range memberships { + params = append(params, membership) + } + orig := strings.Replace(selectMembershipSQL, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1) + stmt, err := s.db.Prepare(orig) + if err != nil { + return "", 0, 0, err + } + err = sqlutil.TxStmt(txn, stmt).QueryRowContext(ctx, params...).Scan(&eventID, &streamPos, &topologyPos) return } diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 1f4fc674c..997486dd4 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -165,5 +165,5 @@ type Receipts interface { type Memberships interface { UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error - SelectMembership(ctx context.Context, txn *sql.Tx, roomID, userID, membership string) (eventID string, streamPos, topologyPos types.StreamPosition, err error) + SelectMembership(ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string) (eventID string, streamPos, topologyPos types.StreamPosition, err error) }