mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-16 18:43:10 -06:00
implement storage_sync query for membership
This commit is contained in:
parent
ac074e2f79
commit
2f0415e0df
|
|
@ -109,6 +109,7 @@ type DatabaseTransaction interface {
|
|||
GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error)
|
||||
PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error)
|
||||
RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) (events []types.StreamEvent, prevBatch, nextBatch string, err error)
|
||||
IsMemberOfRoom(ctx context.Context, roomID string, userID string) (bool, error)
|
||||
}
|
||||
|
||||
type Database interface {
|
||||
|
|
|
|||
|
|
@ -100,6 +100,9 @@ const selectJoinedUsersSQL = "" +
|
|||
const selectJoinedUsersInRoomSQL = "" +
|
||||
"SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join' AND room_id = ANY($1)"
|
||||
|
||||
const selectRoomMembershipOfUserSQL = "" +
|
||||
"SELECT membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND room_id = $1 AND state_key = $2"
|
||||
|
||||
const selectStateEventSQL = "" +
|
||||
"SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3"
|
||||
|
||||
|
|
@ -123,6 +126,7 @@ type currentRoomStateStatements struct {
|
|||
selectEventsWithEventIDsStmt *sql.Stmt
|
||||
selectStateEventStmt *sql.Stmt
|
||||
selectSharedUsersStmt *sql.Stmt
|
||||
selectRoomMembershipOfUserStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
|
||||
|
|
@ -166,6 +170,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
|
|||
if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectRoomMembershipOfUserStmt, err = db.Prepare(selectRoomMembershipOfUserSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectEventsWithEventIDsStmt, err = db.Prepare(selectEventsWithEventIDsSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -449,3 +456,27 @@ func (s *currentRoomStateStatements) SelectSharedUsers(
|
|||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *currentRoomStateStatements) SelectRoomMembershipOfUser(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
roomID string,
|
||||
userID string,
|
||||
) (string, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomMembershipOfUserStmt)
|
||||
rows, err := stmt.QueryContext(ctx, roomID, userID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomMembershipOfUserStmt: rows.close() failed")
|
||||
|
||||
for rows.Next() {
|
||||
var membership string
|
||||
if err := rows.Scan(&membership); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return membership, nil
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -688,3 +688,16 @@ func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID,
|
|||
|
||||
return events, prevBatch, nextBatch, nil
|
||||
}
|
||||
|
||||
func (d *DatabaseTransaction) IsMemberOfRoom(
|
||||
ctx context.Context,
|
||||
roomID string,
|
||||
userID string,
|
||||
) (bool, error) {
|
||||
membership, err := d.CurrentRoomState.SelectRoomMembershipOfUser(ctx, d.txn, roomID, userID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return membership == gomatrixserverlib.Join, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -84,6 +84,9 @@ const selectJoinedUsersSQL = "" +
|
|||
const selectJoinedUsersInRoomSQL = "" +
|
||||
"SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join' AND room_id IN ($1)"
|
||||
|
||||
const selectRoomMembershipOfUserSQL = "" +
|
||||
"SELECT membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND room_id = $1 AND state_key = $2"
|
||||
|
||||
const selectStateEventSQL = "" +
|
||||
"SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3"
|
||||
|
||||
|
|
@ -107,6 +110,7 @@ type currentRoomStateStatements struct {
|
|||
//selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic
|
||||
selectStateEventStmt *sql.Stmt
|
||||
//selectSharedUsersSQL *sql.Stmt - prepared at runtime due to variadic
|
||||
selectRoomMembershipOfUserStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) {
|
||||
|
|
@ -147,6 +151,9 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (t
|
|||
if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectRoomMembershipOfUserStmt, err = db.Prepare(selectRoomMembershipOfUserSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
//if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil {
|
||||
// return nil, err
|
||||
//}
|
||||
|
|
@ -484,3 +491,26 @@ func (s *currentRoomStateStatements) SelectSharedUsers(
|
|||
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (s *currentRoomStateStatements) SelectRoomMembershipOfUser(
|
||||
ctx context.Context, txn *sql.Tx, roomID string, userID string,
|
||||
) (string, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomMembershipOfUserStmt)
|
||||
rows, err := stmt.QueryContext(ctx, roomID, userID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomMembershipOfUserStmt: rows.close() failed")
|
||||
|
||||
membership := ""
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&membership); err != nil {
|
||||
return "", err
|
||||
}
|
||||
// Found the membership info
|
||||
if membership != "" {
|
||||
return membership, rows.Err()
|
||||
}
|
||||
}
|
||||
return membership, rows.Err()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -115,6 +115,8 @@ type CurrentRoomState interface {
|
|||
SelectJoinedUsersInRoom(ctx context.Context, txn *sql.Tx, roomIDs []string) (map[string][]string, error)
|
||||
// SelectSharedUsers returns a subset of otherUserIDs that share a room with userID.
|
||||
SelectSharedUsers(ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string) ([]string, error)
|
||||
// SelectRoomMembershipOfUser returns the membership of the user in the room.
|
||||
SelectRoomMembershipOfUser(ctx context.Context, txn *sql.Tx, roomID string, userID string) (string, error)
|
||||
}
|
||||
|
||||
// BackwardsExtremities keeps track of backwards extremities for a room.
|
||||
|
|
|
|||
|
|
@ -81,12 +81,12 @@ func (p *InviteStreamProvider) IncrementalSync(
|
|||
continue
|
||||
}
|
||||
|
||||
joinedUsers, err := snapshot.AllJoinedUsersInRoom(ctx, []string{roomID})
|
||||
isMember, err := snapshot.IsMemberOfRoom(ctx, roomID, req.Device.UserID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if !contains(joinedUsers[roomID], req.Device.UserID) {
|
||||
if !isMember {
|
||||
lr := types.NewLeaveResponse()
|
||||
h := sha256.Sum256(append([]byte(roomID), []byte(strconv.FormatInt(int64(to), 10))...))
|
||||
lr.Timeline.Events = append(lr.Timeline.Events, gomatrixserverlib.ClientEvent{
|
||||
|
|
|
|||
Loading…
Reference in a new issue