mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-20 12:33:09 -06:00
implement storage_sync query for membership
This commit is contained in:
parent
ac074e2f79
commit
2f0415e0df
|
|
@ -109,6 +109,7 @@ type DatabaseTransaction interface {
|
||||||
GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error)
|
GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error)
|
||||||
PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error)
|
PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error)
|
||||||
RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) (events []types.StreamEvent, prevBatch, nextBatch string, err error)
|
RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) (events []types.StreamEvent, prevBatch, nextBatch string, err error)
|
||||||
|
IsMemberOfRoom(ctx context.Context, roomID string, userID string) (bool, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Database interface {
|
type Database interface {
|
||||||
|
|
|
||||||
|
|
@ -100,6 +100,9 @@ const selectJoinedUsersSQL = "" +
|
||||||
const selectJoinedUsersInRoomSQL = "" +
|
const selectJoinedUsersInRoomSQL = "" +
|
||||||
"SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join' AND room_id = ANY($1)"
|
"SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join' AND room_id = ANY($1)"
|
||||||
|
|
||||||
|
const selectRoomMembershipOfUserSQL = "" +
|
||||||
|
"SELECT membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND room_id = $1 AND state_key = $2"
|
||||||
|
|
||||||
const selectStateEventSQL = "" +
|
const selectStateEventSQL = "" +
|
||||||
"SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3"
|
"SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3"
|
||||||
|
|
||||||
|
|
@ -123,6 +126,7 @@ type currentRoomStateStatements struct {
|
||||||
selectEventsWithEventIDsStmt *sql.Stmt
|
selectEventsWithEventIDsStmt *sql.Stmt
|
||||||
selectStateEventStmt *sql.Stmt
|
selectStateEventStmt *sql.Stmt
|
||||||
selectSharedUsersStmt *sql.Stmt
|
selectSharedUsersStmt *sql.Stmt
|
||||||
|
selectRoomMembershipOfUserStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
|
func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
|
||||||
|
|
@ -166,6 +170,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
|
||||||
if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil {
|
if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if s.selectRoomMembershipOfUserStmt, err = db.Prepare(selectRoomMembershipOfUserSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if s.selectEventsWithEventIDsStmt, err = db.Prepare(selectEventsWithEventIDsSQL); err != nil {
|
if s.selectEventsWithEventIDsStmt, err = db.Prepare(selectEventsWithEventIDsSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -449,3 +456,27 @@ func (s *currentRoomStateStatements) SelectSharedUsers(
|
||||||
}
|
}
|
||||||
return result, rows.Err()
|
return result, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *currentRoomStateStatements) SelectRoomMembershipOfUser(
|
||||||
|
ctx context.Context,
|
||||||
|
txn *sql.Tx,
|
||||||
|
roomID string,
|
||||||
|
userID string,
|
||||||
|
) (string, error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.selectRoomMembershipOfUserStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, roomID, userID)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomMembershipOfUserStmt: rows.close() failed")
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
var membership string
|
||||||
|
if err := rows.Scan(&membership); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return membership, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -688,3 +688,16 @@ func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID,
|
||||||
|
|
||||||
return events, prevBatch, nextBatch, nil
|
return events, prevBatch, nextBatch, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) IsMemberOfRoom(
|
||||||
|
ctx context.Context,
|
||||||
|
roomID string,
|
||||||
|
userID string,
|
||||||
|
) (bool, error) {
|
||||||
|
membership, err := d.CurrentRoomState.SelectRoomMembershipOfUser(ctx, d.txn, roomID, userID)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return membership == gomatrixserverlib.Join, nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -84,6 +84,9 @@ const selectJoinedUsersSQL = "" +
|
||||||
const selectJoinedUsersInRoomSQL = "" +
|
const selectJoinedUsersInRoomSQL = "" +
|
||||||
"SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join' AND room_id IN ($1)"
|
"SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join' AND room_id IN ($1)"
|
||||||
|
|
||||||
|
const selectRoomMembershipOfUserSQL = "" +
|
||||||
|
"SELECT membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND room_id = $1 AND state_key = $2"
|
||||||
|
|
||||||
const selectStateEventSQL = "" +
|
const selectStateEventSQL = "" +
|
||||||
"SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3"
|
"SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3"
|
||||||
|
|
||||||
|
|
@ -107,6 +110,7 @@ type currentRoomStateStatements struct {
|
||||||
//selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic
|
//selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic
|
||||||
selectStateEventStmt *sql.Stmt
|
selectStateEventStmt *sql.Stmt
|
||||||
//selectSharedUsersSQL *sql.Stmt - prepared at runtime due to variadic
|
//selectSharedUsersSQL *sql.Stmt - prepared at runtime due to variadic
|
||||||
|
selectRoomMembershipOfUserStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) {
|
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) {
|
||||||
|
|
@ -147,6 +151,9 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (t
|
||||||
if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil {
|
if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if s.selectRoomMembershipOfUserStmt, err = db.Prepare(selectRoomMembershipOfUserSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
//if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil {
|
//if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil {
|
||||||
// return nil, err
|
// return nil, err
|
||||||
//}
|
//}
|
||||||
|
|
@ -484,3 +491,26 @@ func (s *currentRoomStateStatements) SelectSharedUsers(
|
||||||
|
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *currentRoomStateStatements) SelectRoomMembershipOfUser(
|
||||||
|
ctx context.Context, txn *sql.Tx, roomID string, userID string,
|
||||||
|
) (string, error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.selectRoomMembershipOfUserStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, roomID, userID)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomMembershipOfUserStmt: rows.close() failed")
|
||||||
|
|
||||||
|
membership := ""
|
||||||
|
for rows.Next() {
|
||||||
|
if err := rows.Scan(&membership); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
// Found the membership info
|
||||||
|
if membership != "" {
|
||||||
|
return membership, rows.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return membership, rows.Err()
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -115,6 +115,8 @@ type CurrentRoomState interface {
|
||||||
SelectJoinedUsersInRoom(ctx context.Context, txn *sql.Tx, roomIDs []string) (map[string][]string, error)
|
SelectJoinedUsersInRoom(ctx context.Context, txn *sql.Tx, roomIDs []string) (map[string][]string, error)
|
||||||
// SelectSharedUsers returns a subset of otherUserIDs that share a room with userID.
|
// SelectSharedUsers returns a subset of otherUserIDs that share a room with userID.
|
||||||
SelectSharedUsers(ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string) ([]string, error)
|
SelectSharedUsers(ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string) ([]string, error)
|
||||||
|
// SelectRoomMembershipOfUser returns the membership of the user in the room.
|
||||||
|
SelectRoomMembershipOfUser(ctx context.Context, txn *sql.Tx, roomID string, userID string) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BackwardsExtremities keeps track of backwards extremities for a room.
|
// BackwardsExtremities keeps track of backwards extremities for a room.
|
||||||
|
|
|
||||||
|
|
@ -81,12 +81,12 @@ func (p *InviteStreamProvider) IncrementalSync(
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
joinedUsers, err := snapshot.AllJoinedUsersInRoom(ctx, []string{roomID})
|
isMember, err := snapshot.IsMemberOfRoom(ctx, roomID, req.Device.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if !contains(joinedUsers[roomID], req.Device.UserID) {
|
if !isMember {
|
||||||
lr := types.NewLeaveResponse()
|
lr := types.NewLeaveResponse()
|
||||||
h := sha256.Sum256(append([]byte(roomID), []byte(strconv.FormatInt(int64(to), 10))...))
|
h := sha256.Sum256(append([]byte(roomID), []byte(strconv.FormatInt(int64(to), 10))...))
|
||||||
lr.Timeline.Events = append(lr.Timeline.Events, gomatrixserverlib.ClientEvent{
|
lr.Timeline.Events = append(lr.Timeline.Events, gomatrixserverlib.ClientEvent{
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue