Allow selecon multiple membership types, get latest

This commit is contained in:
Neil Alexander 2021-01-20 17:59:10 +00:00
parent 5113aa8433
commit 275b21b89d
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
3 changed files with 25 additions and 13 deletions

View file

@ -59,7 +59,9 @@ const upsertMembershipSQL = "" +
const selectMembershipSQL = "" + const selectMembershipSQL = "" +
"SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" + "SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" +
" WHERE room_id = $1 AND user_id = $2 AND membership = $3" " WHERE room_id = $1 AND user_id = $2 AND membership = ANY($3)" +
" ORDER BY stream_pos DESC" +
" LIMIT 1"
type membershipsStatements struct { type membershipsStatements struct {
upsertMembershipStmt *sql.Stmt upsertMembershipStmt *sql.Stmt
@ -102,9 +104,9 @@ func (s *membershipsStatements) UpsertMembership(
} }
func (s *membershipsStatements) SelectMembership( func (s *membershipsStatements) SelectMembership(
ctx context.Context, txn *sql.Tx, roomID, userID, membership string, ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string,
) (eventID string, streamPos, topologyPos types.StreamPosition, err error) { ) (eventID string, streamPos, topologyPos types.StreamPosition, err error) {
stmt := sqlutil.TxStmt(txn, s.selectMembershipStmt) stmt := sqlutil.TxStmt(txn, s.selectMembershipStmt)
err = stmt.QueryRowContext(ctx, roomID, userID, membership).Scan(&eventID, &streamPos, &topologyPos) err = stmt.QueryRowContext(ctx, roomID, userID, memberships).Scan(&eventID, &streamPos, &topologyPos)
return return
} }

View file

@ -19,6 +19,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"strings"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
@ -59,15 +60,19 @@ const upsertMembershipSQL = "" +
const selectMembershipSQL = "" + const selectMembershipSQL = "" +
"SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" + "SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" +
" WHERE room_id = $1 AND user_id = $2 AND membership = $3" " WHERE room_id = $1 AND user_id = $2 AND membership IN ($3)" +
" ORDER BY stream_pos DESC" +
" LIMIT 1"
type membershipsStatements struct { type membershipsStatements struct {
db *sql.DB
upsertMembershipStmt *sql.Stmt upsertMembershipStmt *sql.Stmt
selectMembershipStmt *sql.Stmt
} }
func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) {
s := &membershipsStatements{} s := &membershipsStatements{
db: db,
}
_, err := db.Exec(membershipsSchema) _, err := db.Exec(membershipsSchema)
if err != nil { if err != nil {
return nil, err return nil, err
@ -75,9 +80,6 @@ func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) {
if s.upsertMembershipStmt, err = db.Prepare(upsertMembershipSQL); err != nil { if s.upsertMembershipStmt, err = db.Prepare(upsertMembershipSQL); err != nil {
return nil, err return nil, err
} }
if s.selectMembershipStmt, err = db.Prepare(selectMembershipSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
@ -102,9 +104,17 @@ func (s *membershipsStatements) UpsertMembership(
} }
func (s *membershipsStatements) SelectMembership( func (s *membershipsStatements) SelectMembership(
ctx context.Context, txn *sql.Tx, roomID, userID, membership string, ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string,
) (eventID string, streamPos, topologyPos types.StreamPosition, err error) { ) (eventID string, streamPos, topologyPos types.StreamPosition, err error) {
stmt := sqlutil.TxStmt(txn, s.selectMembershipStmt) params := []interface{}{roomID, userID}
err = stmt.QueryRowContext(ctx, roomID, userID, membership).Scan(&eventID, &streamPos, &topologyPos) for _, membership := range memberships {
params = append(params, membership)
}
orig := strings.Replace(selectMembershipSQL, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1)
stmt, err := s.db.Prepare(orig)
if err != nil {
return "", 0, 0, err
}
err = sqlutil.TxStmt(txn, stmt).QueryRowContext(ctx, params...).Scan(&eventID, &streamPos, &topologyPos)
return return
} }

View file

@ -165,5 +165,5 @@ type Receipts interface {
type Memberships interface { type Memberships interface {
UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error
SelectMembership(ctx context.Context, txn *sql.Tx, roomID, userID, membership string) (eventID string, streamPos, topologyPos types.StreamPosition, err error) SelectMembership(ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string) (eventID string, streamPos, topologyPos types.StreamPosition, err error)
} }