Fix room summary returning wrong heroes (#2930)
This should fix #2910. Probably makes Sytest/Complement a bit upset, since this not using `sort.Strings` anymore.
This commit is contained in:
parent
25dfbc6ec3
commit
0491a8e343
|
@ -45,7 +45,7 @@ type DatabaseTransaction interface {
|
||||||
GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
|
GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
|
||||||
RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error)
|
RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error)
|
||||||
MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error)
|
MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error)
|
||||||
GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error)
|
GetRoomSummary(ctx context.Context, roomID, userID string) (summary *types.Summary, err error)
|
||||||
RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
|
RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
|
||||||
GetBackwardTopologyPos(ctx context.Context, events []*gomatrixserverlib.HeaderedEvent) (types.TopologyToken, error)
|
GetBackwardTopologyPos(ctx context.Context, events []*gomatrixserverlib.HeaderedEvent) (types.TopologyToken, error)
|
||||||
PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error)
|
PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error)
|
||||||
|
|
|
@ -19,6 +19,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
@ -110,6 +111,15 @@ const selectSharedUsersSQL = "" +
|
||||||
" SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" +
|
" SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" +
|
||||||
") AND type = 'm.room.member' AND state_key = ANY($2) AND membership IN ('join', 'invite');"
|
") AND type = 'm.room.member' AND state_key = ANY($2) AND membership IN ('join', 'invite');"
|
||||||
|
|
||||||
|
const selectMembershipCount = `SELECT count(*) FROM syncapi_current_room_state WHERE type = 'm.room.member' AND room_id = $1 AND membership = $2`
|
||||||
|
|
||||||
|
const selectRoomHeroes = `
|
||||||
|
SELECT state_key FROM syncapi_current_room_state
|
||||||
|
WHERE type = 'm.room.member' AND room_id = $1 AND membership = ANY($2) AND state_key != $3
|
||||||
|
ORDER BY added_at, state_key
|
||||||
|
LIMIT 5
|
||||||
|
`
|
||||||
|
|
||||||
type currentRoomStateStatements struct {
|
type currentRoomStateStatements struct {
|
||||||
upsertRoomStateStmt *sql.Stmt
|
upsertRoomStateStmt *sql.Stmt
|
||||||
deleteRoomStateByEventIDStmt *sql.Stmt
|
deleteRoomStateByEventIDStmt *sql.Stmt
|
||||||
|
@ -122,6 +132,8 @@ type currentRoomStateStatements struct {
|
||||||
selectEventsWithEventIDsStmt *sql.Stmt
|
selectEventsWithEventIDsStmt *sql.Stmt
|
||||||
selectStateEventStmt *sql.Stmt
|
selectStateEventStmt *sql.Stmt
|
||||||
selectSharedUsersStmt *sql.Stmt
|
selectSharedUsersStmt *sql.Stmt
|
||||||
|
selectMembershipCountStmt *sql.Stmt
|
||||||
|
selectRoomHeroesStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
|
func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
|
||||||
|
@ -141,40 +153,21 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil {
|
return s, sqlutil.StatementList{
|
||||||
return nil, err
|
{&s.upsertRoomStateStmt, upsertRoomStateSQL},
|
||||||
}
|
{&s.deleteRoomStateByEventIDStmt, deleteRoomStateByEventIDSQL},
|
||||||
if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil {
|
{&s.deleteRoomStateForRoomStmt, deleteRoomStateForRoomSQL},
|
||||||
return nil, err
|
{&s.selectRoomIDsWithMembershipStmt, selectRoomIDsWithMembershipSQL},
|
||||||
}
|
{&s.selectRoomIDsWithAnyMembershipStmt, selectRoomIDsWithAnyMembershipSQL},
|
||||||
if s.deleteRoomStateForRoomStmt, err = db.Prepare(deleteRoomStateForRoomSQL); err != nil {
|
{&s.selectCurrentStateStmt, selectCurrentStateSQL},
|
||||||
return nil, err
|
{&s.selectJoinedUsersStmt, selectJoinedUsersSQL},
|
||||||
}
|
{&s.selectJoinedUsersInRoomStmt, selectJoinedUsersInRoomSQL},
|
||||||
if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil {
|
{&s.selectEventsWithEventIDsStmt, selectEventsWithEventIDsSQL},
|
||||||
return nil, err
|
{&s.selectStateEventStmt, selectStateEventSQL},
|
||||||
}
|
{&s.selectSharedUsersStmt, selectSharedUsersSQL},
|
||||||
if s.selectRoomIDsWithAnyMembershipStmt, err = db.Prepare(selectRoomIDsWithAnyMembershipSQL); err != nil {
|
{&s.selectMembershipCountStmt, selectMembershipCount},
|
||||||
return nil, err
|
{&s.selectRoomHeroesStmt, selectRoomHeroes},
|
||||||
}
|
}.Prepare(db)
|
||||||
if s.selectCurrentStateStmt, err = db.Prepare(selectCurrentStateSQL); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if s.selectEventsWithEventIDsStmt, err = db.Prepare(selectEventsWithEventIDsSQL); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if s.selectSharedUsersStmt, err = db.Prepare(selectSharedUsersSQL); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return s, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SelectJoinedUsers returns a map of room ID to a list of joined user IDs.
|
// SelectJoinedUsers returns a map of room ID to a list of joined user IDs.
|
||||||
|
@ -447,3 +440,34 @@ func (s *currentRoomStateStatements) SelectSharedUsers(
|
||||||
}
|
}
|
||||||
return result, rows.Err()
|
return result, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *currentRoomStateStatements) SelectRoomHeroes(ctx context.Context, txn *sql.Tx, roomID, excludeUserID string, memberships []string) ([]string, error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.selectRoomHeroesStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, roomID, pq.StringArray(memberships), excludeUserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomHeroesStmt: rows.close() failed")
|
||||||
|
|
||||||
|
var stateKey string
|
||||||
|
result := make([]string, 0, 5)
|
||||||
|
for rows.Next() {
|
||||||
|
if err = rows.Scan(&stateKey); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result = append(result, stateKey)
|
||||||
|
}
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *currentRoomStateStatements) SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string) (count int, err error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.selectMembershipCountStmt)
|
||||||
|
err = stmt.QueryRowContext(ctx, roomID, membership).Scan(&count)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|
|
@ -19,10 +19,8 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
|
@ -64,9 +62,6 @@ const selectMembershipCountSQL = "" +
|
||||||
" SELECT DISTINCT ON (room_id, user_id) room_id, user_id, membership FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC" +
|
" SELECT DISTINCT ON (room_id, user_id) room_id, user_id, membership FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC" +
|
||||||
") t WHERE t.membership = $3"
|
") t WHERE t.membership = $3"
|
||||||
|
|
||||||
const selectHeroesSQL = "" +
|
|
||||||
"SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership = ANY($3) LIMIT 5"
|
|
||||||
|
|
||||||
const selectMembershipBeforeSQL = "" +
|
const selectMembershipBeforeSQL = "" +
|
||||||
"SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1"
|
"SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1"
|
||||||
|
|
||||||
|
@ -81,7 +76,6 @@ WHERE ($3::text IS NULL OR t.membership = $3)
|
||||||
type membershipsStatements struct {
|
type membershipsStatements struct {
|
||||||
upsertMembershipStmt *sql.Stmt
|
upsertMembershipStmt *sql.Stmt
|
||||||
selectMembershipCountStmt *sql.Stmt
|
selectMembershipCountStmt *sql.Stmt
|
||||||
selectHeroesStmt *sql.Stmt
|
|
||||||
selectMembershipForUserStmt *sql.Stmt
|
selectMembershipForUserStmt *sql.Stmt
|
||||||
selectMembersStmt *sql.Stmt
|
selectMembersStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
@ -95,7 +89,6 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) {
|
||||||
return s, sqlutil.StatementList{
|
return s, sqlutil.StatementList{
|
||||||
{&s.upsertMembershipStmt, upsertMembershipSQL},
|
{&s.upsertMembershipStmt, upsertMembershipSQL},
|
||||||
{&s.selectMembershipCountStmt, selectMembershipCountSQL},
|
{&s.selectMembershipCountStmt, selectMembershipCountSQL},
|
||||||
{&s.selectHeroesStmt, selectHeroesSQL},
|
|
||||||
{&s.selectMembershipForUserStmt, selectMembershipBeforeSQL},
|
{&s.selectMembershipForUserStmt, selectMembershipBeforeSQL},
|
||||||
{&s.selectMembersStmt, selectMembersSQL},
|
{&s.selectMembersStmt, selectMembersSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
|
@ -129,26 +122,6 @@ func (s *membershipsStatements) SelectMembershipCount(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipsStatements) SelectHeroes(
|
|
||||||
ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string,
|
|
||||||
) (heroes []string, err error) {
|
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectHeroesStmt)
|
|
||||||
var rows *sql.Rows
|
|
||||||
rows, err = stmt.QueryContext(ctx, roomID, userID, pq.StringArray(memberships))
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectHeroes: rows.close() failed")
|
|
||||||
var hero string
|
|
||||||
for rows.Next() {
|
|
||||||
if err = rows.Scan(&hero); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
heroes = append(heroes, hero)
|
|
||||||
}
|
|
||||||
return heroes, rows.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found
|
// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found
|
||||||
// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty
|
// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty
|
||||||
// string as the membership.
|
// string as the membership.
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
|
@ -92,8 +93,61 @@ func (d *DatabaseTransaction) MembershipCount(ctx context.Context, roomID, membe
|
||||||
return d.Memberships.SelectMembershipCount(ctx, d.txn, roomID, membership, pos)
|
return d.Memberships.SelectMembershipCount(ctx, d.txn, roomID, membership, pos)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DatabaseTransaction) GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) {
|
func (d *DatabaseTransaction) GetRoomSummary(ctx context.Context, roomID, userID string) (*types.Summary, error) {
|
||||||
return d.Memberships.SelectHeroes(ctx, d.txn, roomID, userID, memberships)
|
summary := &types.Summary{Heroes: []string{}}
|
||||||
|
|
||||||
|
joinCount, err := d.CurrentRoomState.SelectMembershipCount(ctx, d.txn, roomID, gomatrixserverlib.Join)
|
||||||
|
if err != nil {
|
||||||
|
return summary, err
|
||||||
|
}
|
||||||
|
inviteCount, err := d.CurrentRoomState.SelectMembershipCount(ctx, d.txn, roomID, gomatrixserverlib.Invite)
|
||||||
|
if err != nil {
|
||||||
|
return summary, err
|
||||||
|
}
|
||||||
|
summary.InvitedMemberCount = &inviteCount
|
||||||
|
summary.JoinedMemberCount = &joinCount
|
||||||
|
|
||||||
|
// Get the room name and canonical alias, if any
|
||||||
|
filter := gomatrixserverlib.DefaultStateFilter()
|
||||||
|
filterTypes := []string{gomatrixserverlib.MRoomName, gomatrixserverlib.MRoomCanonicalAlias}
|
||||||
|
filterRooms := []string{roomID}
|
||||||
|
|
||||||
|
filter.Types = &filterTypes
|
||||||
|
filter.Rooms = &filterRooms
|
||||||
|
evs, err := d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, &filter, nil)
|
||||||
|
if err != nil {
|
||||||
|
return summary, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ev := range evs {
|
||||||
|
switch ev.Type() {
|
||||||
|
case gomatrixserverlib.MRoomName:
|
||||||
|
if gjson.GetBytes(ev.Content(), "name").Str != "" {
|
||||||
|
return summary, nil
|
||||||
|
}
|
||||||
|
case gomatrixserverlib.MRoomCanonicalAlias:
|
||||||
|
if gjson.GetBytes(ev.Content(), "alias").Str != "" {
|
||||||
|
return summary, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there's no room name or canonical alias, get the room heroes, excluding the user
|
||||||
|
heroes, err := d.CurrentRoomState.SelectRoomHeroes(ctx, d.txn, roomID, userID, []string{gomatrixserverlib.Join, gomatrixserverlib.Invite})
|
||||||
|
if err != nil {
|
||||||
|
return summary, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// "When no joined or invited members are available, this should consist of the banned and left users"
|
||||||
|
if len(heroes) == 0 {
|
||||||
|
heroes, err = d.CurrentRoomState.SelectRoomHeroes(ctx, d.txn, roomID, userID, []string{gomatrixserverlib.Leave, gomatrixserverlib.Ban})
|
||||||
|
if err != nil {
|
||||||
|
return summary, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
summary.Heroes = heroes
|
||||||
|
|
||||||
|
return summary, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DatabaseTransaction) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) {
|
func (d *DatabaseTransaction) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) {
|
||||||
|
|
|
@ -19,6 +19,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
@ -95,6 +96,15 @@ const selectSharedUsersSQL = "" +
|
||||||
" SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" +
|
" SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" +
|
||||||
") AND type = 'm.room.member' AND state_key IN ($2) AND membership IN ('join', 'invite');"
|
") AND type = 'm.room.member' AND state_key IN ($2) AND membership IN ('join', 'invite');"
|
||||||
|
|
||||||
|
const selectMembershipCount = `SELECT count(*) FROM syncapi_current_room_state WHERE type = 'm.room.member' AND room_id = $1 AND membership = $2`
|
||||||
|
|
||||||
|
const selectRoomHeroes = `
|
||||||
|
SELECT state_key FROM syncapi_current_room_state
|
||||||
|
WHERE type = 'm.room.member' AND room_id = $1 AND state_key != $2 AND membership IN ($3)
|
||||||
|
ORDER BY added_at, state_key
|
||||||
|
LIMIT 5
|
||||||
|
`
|
||||||
|
|
||||||
type currentRoomStateStatements struct {
|
type currentRoomStateStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
streamIDStatements *StreamIDStatements
|
streamIDStatements *StreamIDStatements
|
||||||
|
@ -107,6 +117,8 @@ type currentRoomStateStatements struct {
|
||||||
//selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic
|
//selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic
|
||||||
selectStateEventStmt *sql.Stmt
|
selectStateEventStmt *sql.Stmt
|
||||||
//selectSharedUsersSQL *sql.Stmt - prepared at runtime due to variadic
|
//selectSharedUsersSQL *sql.Stmt - prepared at runtime due to variadic
|
||||||
|
selectMembershipCountStmt *sql.Stmt
|
||||||
|
//selectRoomHeroes *sql.Stmt - prepared at runtime due to variadic
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) {
|
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) {
|
||||||
|
@ -129,31 +141,16 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (t
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil {
|
return s, sqlutil.StatementList{
|
||||||
return nil, err
|
{&s.upsertRoomStateStmt, upsertRoomStateSQL},
|
||||||
}
|
{&s.deleteRoomStateByEventIDStmt, deleteRoomStateByEventIDSQL},
|
||||||
if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil {
|
{&s.deleteRoomStateForRoomStmt, deleteRoomStateForRoomSQL},
|
||||||
return nil, err
|
{&s.selectRoomIDsWithMembershipStmt, selectRoomIDsWithMembershipSQL},
|
||||||
}
|
{&s.selectRoomIDsWithAnyMembershipStmt, selectRoomIDsWithAnyMembershipSQL},
|
||||||
if s.deleteRoomStateForRoomStmt, err = db.Prepare(deleteRoomStateForRoomSQL); err != nil {
|
{&s.selectJoinedUsersStmt, selectJoinedUsersSQL},
|
||||||
return nil, err
|
{&s.selectStateEventStmt, selectStateEventSQL},
|
||||||
}
|
{&s.selectMembershipCountStmt, selectMembershipCount},
|
||||||
if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil {
|
}.Prepare(db)
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if s.selectRoomIDsWithAnyMembershipStmt, err = db.Prepare(selectRoomIDsWithAnyMembershipSQL); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
//if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil {
|
|
||||||
// return nil, err
|
|
||||||
//}
|
|
||||||
if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return s, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SelectJoinedUsers returns a map of room ID to a list of joined user IDs.
|
// SelectJoinedUsers returns a map of room ID to a list of joined user IDs.
|
||||||
|
@ -485,3 +482,53 @@ func (s *currentRoomStateStatements) SelectSharedUsers(
|
||||||
|
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *currentRoomStateStatements) SelectRoomHeroes(ctx context.Context, txn *sql.Tx, roomID, excludeUserID string, memberships []string) ([]string, error) {
|
||||||
|
params := make([]interface{}, len(memberships)+2)
|
||||||
|
params[0] = roomID
|
||||||
|
params[1] = excludeUserID
|
||||||
|
for k, v := range memberships {
|
||||||
|
params[k+2] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
query := strings.Replace(selectRoomHeroes, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1)
|
||||||
|
var stmt *sql.Stmt
|
||||||
|
var err error
|
||||||
|
if txn != nil {
|
||||||
|
stmt, err = txn.Prepare(query)
|
||||||
|
} else {
|
||||||
|
stmt, err = s.db.Prepare(query)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return []string{}, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, stmt, "selectRoomHeroes: stmt.close() failed")
|
||||||
|
|
||||||
|
rows, err := stmt.QueryContext(ctx, params...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomHeroes: rows.close() failed")
|
||||||
|
|
||||||
|
var stateKey string
|
||||||
|
result := make([]string, 0, 5)
|
||||||
|
for rows.Next() {
|
||||||
|
if err = rows.Scan(&stateKey); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result = append(result, stateKey)
|
||||||
|
}
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *currentRoomStateStatements) SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string) (count int, err error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.selectMembershipCountStmt)
|
||||||
|
err = stmt.QueryRowContext(ctx, roomID, membership).Scan(&count)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|
|
@ -18,11 +18,9 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
|
@ -64,9 +62,6 @@ const selectMembershipCountSQL = "" +
|
||||||
" SELECT * FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 GROUP BY user_id HAVING(max(stream_pos))" +
|
" SELECT * FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 GROUP BY user_id HAVING(max(stream_pos))" +
|
||||||
") t WHERE t.membership = $3"
|
") t WHERE t.membership = $3"
|
||||||
|
|
||||||
const selectHeroesSQL = "" +
|
|
||||||
"SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership IN ($3) LIMIT 5"
|
|
||||||
|
|
||||||
const selectMembershipBeforeSQL = "" +
|
const selectMembershipBeforeSQL = "" +
|
||||||
"SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1"
|
"SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1"
|
||||||
|
|
||||||
|
@ -99,7 +94,6 @@ func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) {
|
||||||
{&s.selectMembershipCountStmt, selectMembershipCountSQL},
|
{&s.selectMembershipCountStmt, selectMembershipCountSQL},
|
||||||
{&s.selectMembershipForUserStmt, selectMembershipBeforeSQL},
|
{&s.selectMembershipForUserStmt, selectMembershipBeforeSQL},
|
||||||
{&s.selectMembersStmt, selectMembersSQL},
|
{&s.selectMembersStmt, selectMembersSQL},
|
||||||
// {&s.selectHeroesStmt, selectHeroesSQL}, - prepared at runtime due to variadic
|
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -131,39 +125,6 @@ func (s *membershipsStatements) SelectMembershipCount(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipsStatements) SelectHeroes(
|
|
||||||
ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string,
|
|
||||||
) (heroes []string, err error) {
|
|
||||||
stmtSQL := strings.Replace(selectHeroesSQL, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1)
|
|
||||||
stmt, err := s.db.PrepareContext(ctx, stmtSQL)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer internal.CloseAndLogIfError(ctx, stmt, "SelectHeroes: stmt.close() failed")
|
|
||||||
params := []interface{}{
|
|
||||||
roomID, userID,
|
|
||||||
}
|
|
||||||
for _, membership := range memberships {
|
|
||||||
params = append(params, membership)
|
|
||||||
}
|
|
||||||
|
|
||||||
stmt = sqlutil.TxStmt(txn, stmt)
|
|
||||||
var rows *sql.Rows
|
|
||||||
rows, err = stmt.QueryContext(ctx, params...)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectHeroes: rows.close() failed")
|
|
||||||
var hero string
|
|
||||||
for rows.Next() {
|
|
||||||
if err = rows.Scan(&hero); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
heroes = append(heroes, hero)
|
|
||||||
}
|
|
||||||
return heroes, rows.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found
|
// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found
|
||||||
// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty
|
// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty
|
||||||
// string as the membership.
|
// string as the membership.
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/test"
|
"github.com/matrix-org/dendrite/test"
|
||||||
"github.com/matrix-org/dendrite/test/testrig"
|
"github.com/matrix-org/dendrite/test/testrig"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ctx = context.Background()
|
var ctx = context.Background()
|
||||||
|
@ -664,3 +665,181 @@ func topologyTokenBefore(t *testing.T, db storage.Database, eventID string) *typ
|
||||||
return &tok
|
return &tok
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
func pointer[t any](s t) *t {
|
||||||
|
return &s
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoomSummary(t *testing.T) {
|
||||||
|
|
||||||
|
alice := test.NewUser(t)
|
||||||
|
bob := test.NewUser(t)
|
||||||
|
charlie := test.NewUser(t)
|
||||||
|
|
||||||
|
// Create some dummy users
|
||||||
|
moreUsers := []*test.User{}
|
||||||
|
moreUserIDs := []string{}
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
u := test.NewUser(t)
|
||||||
|
moreUsers = append(moreUsers, u)
|
||||||
|
moreUserIDs = append(moreUserIDs, u.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
wantSummary *types.Summary
|
||||||
|
additionalEvents func(t *testing.T, room *test.Room)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "after initial creation",
|
||||||
|
wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(0), Heroes: []string{}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invited user",
|
||||||
|
wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(1), Heroes: []string{bob.ID}},
|
||||||
|
additionalEvents: func(t *testing.T, room *test.Room) {
|
||||||
|
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||||
|
"membership": "invite",
|
||||||
|
}, test.WithStateKey(bob.ID))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invited user, but declined",
|
||||||
|
wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(0), Heroes: []string{bob.ID}},
|
||||||
|
additionalEvents: func(t *testing.T, room *test.Room) {
|
||||||
|
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||||
|
"membership": "invite",
|
||||||
|
}, test.WithStateKey(bob.ID))
|
||||||
|
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||||
|
"membership": "leave",
|
||||||
|
}, test.WithStateKey(bob.ID))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "joined user after invitation",
|
||||||
|
wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(0), Heroes: []string{bob.ID}},
|
||||||
|
additionalEvents: func(t *testing.T, room *test.Room) {
|
||||||
|
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||||
|
"membership": "invite",
|
||||||
|
}, test.WithStateKey(bob.ID))
|
||||||
|
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||||
|
"membership": "join",
|
||||||
|
}, test.WithStateKey(bob.ID))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple joined user",
|
||||||
|
wantSummary: &types.Summary{JoinedMemberCount: pointer(3), InvitedMemberCount: pointer(0), Heroes: []string{charlie.ID, bob.ID}},
|
||||||
|
additionalEvents: func(t *testing.T, room *test.Room) {
|
||||||
|
room.CreateAndInsert(t, charlie, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||||
|
"membership": "join",
|
||||||
|
}, test.WithStateKey(charlie.ID))
|
||||||
|
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||||
|
"membership": "join",
|
||||||
|
}, test.WithStateKey(bob.ID))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple joined/invited user",
|
||||||
|
wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(1), Heroes: []string{charlie.ID, bob.ID}},
|
||||||
|
additionalEvents: func(t *testing.T, room *test.Room) {
|
||||||
|
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||||
|
"membership": "invite",
|
||||||
|
}, test.WithStateKey(charlie.ID))
|
||||||
|
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||||
|
"membership": "join",
|
||||||
|
}, test.WithStateKey(bob.ID))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple joined/invited/left user",
|
||||||
|
wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(1), Heroes: []string{charlie.ID}},
|
||||||
|
additionalEvents: func(t *testing.T, room *test.Room) {
|
||||||
|
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||||
|
"membership": "invite",
|
||||||
|
}, test.WithStateKey(charlie.ID))
|
||||||
|
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||||
|
"membership": "join",
|
||||||
|
}, test.WithStateKey(bob.ID))
|
||||||
|
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||||
|
"membership": "leave",
|
||||||
|
}, test.WithStateKey(bob.ID))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "leaving user after joining",
|
||||||
|
wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(0), Heroes: []string{bob.ID}},
|
||||||
|
additionalEvents: func(t *testing.T, room *test.Room) {
|
||||||
|
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||||
|
"membership": "join",
|
||||||
|
}, test.WithStateKey(bob.ID))
|
||||||
|
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||||
|
"membership": "leave",
|
||||||
|
}, test.WithStateKey(bob.ID))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "many users", // heroes ordered by stream id
|
||||||
|
wantSummary: &types.Summary{JoinedMemberCount: pointer(len(moreUserIDs) + 1), InvitedMemberCount: pointer(0), Heroes: moreUserIDs[:5]},
|
||||||
|
additionalEvents: func(t *testing.T, room *test.Room) {
|
||||||
|
for _, x := range moreUsers {
|
||||||
|
room.CreateAndInsert(t, x, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||||
|
"membership": "join",
|
||||||
|
}, test.WithStateKey(x.ID))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "canonical alias set",
|
||||||
|
wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(0), Heroes: []string{}},
|
||||||
|
additionalEvents: func(t *testing.T, room *test.Room) {
|
||||||
|
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||||
|
"membership": "join",
|
||||||
|
}, test.WithStateKey(bob.ID))
|
||||||
|
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomCanonicalAlias, map[string]interface{}{
|
||||||
|
"alias": "myalias",
|
||||||
|
}, test.WithStateKey(""))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "room name set",
|
||||||
|
wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(0), Heroes: []string{}},
|
||||||
|
additionalEvents: func(t *testing.T, room *test.Room) {
|
||||||
|
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||||
|
"membership": "join",
|
||||||
|
}, test.WithStateKey(bob.ID))
|
||||||
|
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomName, map[string]interface{}{
|
||||||
|
"name": "my room name",
|
||||||
|
}, test.WithStateKey(""))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
db, close, closeBase := MustCreateDatabase(t, dbType)
|
||||||
|
defer close()
|
||||||
|
defer closeBase()
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
|
||||||
|
r := test.NewRoom(t, alice)
|
||||||
|
|
||||||
|
if tc.additionalEvents != nil {
|
||||||
|
tc.additionalEvents(t, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// write the room before creating a transaction
|
||||||
|
MustWriteEvents(t, db, r.Events())
|
||||||
|
|
||||||
|
transaction, err := db.NewDatabaseTransaction(ctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer transaction.Rollback()
|
||||||
|
|
||||||
|
summary, err := transaction.GetRoomSummary(ctx, r.ID, alice.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, tc.wantSummary, summary)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -115,6 +115,9 @@ type CurrentRoomState interface {
|
||||||
SelectJoinedUsersInRoom(ctx context.Context, txn *sql.Tx, roomIDs []string) (map[string][]string, error)
|
SelectJoinedUsersInRoom(ctx context.Context, txn *sql.Tx, roomIDs []string) (map[string][]string, error)
|
||||||
// SelectSharedUsers returns a subset of otherUserIDs that share a room with userID.
|
// SelectSharedUsers returns a subset of otherUserIDs that share a room with userID.
|
||||||
SelectSharedUsers(ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string) ([]string, error)
|
SelectSharedUsers(ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string) ([]string, error)
|
||||||
|
|
||||||
|
SelectRoomHeroes(ctx context.Context, txn *sql.Tx, roomID, excludeUserID string, memberships []string) ([]string, error)
|
||||||
|
SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string) (int, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BackwardsExtremities keeps track of backwards extremities for a room.
|
// BackwardsExtremities keeps track of backwards extremities for a room.
|
||||||
|
@ -185,7 +188,6 @@ type Receipts interface {
|
||||||
type Memberships interface {
|
type Memberships interface {
|
||||||
UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error
|
UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error
|
||||||
SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error)
|
SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error)
|
||||||
SelectHeroes(ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string) (heroes []string, err error)
|
|
||||||
SelectMembershipForUser(ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64) (membership string, topologicalPos int, err error)
|
SelectMembershipForUser(ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64) (membership string, topologicalPos int, err error)
|
||||||
SelectMemberships(
|
SelectMemberships(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
|
|
@ -3,8 +3,6 @@ package tables_test
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"reflect"
|
|
||||||
"sort"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -88,43 +86,9 @@ func TestMembershipsTable(t *testing.T) {
|
||||||
|
|
||||||
testUpsert(t, ctx, table, userEvents[0], alice, room)
|
testUpsert(t, ctx, table, userEvents[0], alice, room)
|
||||||
testMembershipCount(t, ctx, table, room)
|
testMembershipCount(t, ctx, table, room)
|
||||||
testHeroes(t, ctx, table, alice, room, users)
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func testHeroes(t *testing.T, ctx context.Context, table tables.Memberships, user *test.User, room *test.Room, users []string) {
|
|
||||||
|
|
||||||
// Re-slice and sort the expected users
|
|
||||||
users = users[1:]
|
|
||||||
sort.Strings(users)
|
|
||||||
type testCase struct {
|
|
||||||
name string
|
|
||||||
memberships []string
|
|
||||||
wantHeroes []string
|
|
||||||
}
|
|
||||||
|
|
||||||
testCases := []testCase{
|
|
||||||
{name: "no memberships queried", memberships: []string{}},
|
|
||||||
{name: "joined memberships queried should be limited", memberships: []string{gomatrixserverlib.Join}, wantHeroes: users[:5]},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
got, err := table.SelectHeroes(ctx, nil, room.ID, user.ID, tc.memberships)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to select heroes: %s", err)
|
|
||||||
}
|
|
||||||
if gotLen := len(got); gotLen != len(tc.wantHeroes) {
|
|
||||||
t.Fatalf("expected %d heroes, got %d", len(tc.wantHeroes), gotLen)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(got, tc.wantHeroes) {
|
|
||||||
t.Fatalf("expected heroes to be %+v, got %+v", tc.wantHeroes, got)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func testMembershipCount(t *testing.T, ctx context.Context, table tables.Memberships, room *test.Room) {
|
func testMembershipCount(t *testing.T, ctx context.Context, table tables.Memberships, room *test.Room) {
|
||||||
t.Run("membership counts are correct", func(t *testing.T) {
|
t.Run("membership counts are correct", func(t *testing.T) {
|
||||||
// After 10 events, we should have 6 users (5 create related [incl. one member event], 5 member events = 6 users)
|
// After 10 events, we should have 6 users (5 create related [incl. one member event], 5 member events = 6 users)
|
||||||
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/caching"
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
|
@ -14,11 +13,9 @@ import (
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/notifier"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/syncapi/notifier"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// The max number of per-room goroutines to have running.
|
// The max number of per-room goroutines to have running.
|
||||||
|
@ -339,7 +336,10 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
||||||
case gomatrixserverlib.Join:
|
case gomatrixserverlib.Join:
|
||||||
jr := types.NewJoinResponse()
|
jr := types.NewJoinResponse()
|
||||||
if hasMembershipChange {
|
if hasMembershipChange {
|
||||||
p.addRoomSummary(ctx, snapshot, jr, delta.RoomID, device.UserID, latestPosition)
|
jr.Summary, err = snapshot.GetRoomSummary(ctx, delta.RoomID, device.UserID)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Warn("failed to get room summary")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
jr.Timeline.PrevBatch = &prevBatch
|
jr.Timeline.PrevBatch = &prevBatch
|
||||||
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync)
|
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync)
|
||||||
|
@ -411,45 +411,6 @@ func applyHistoryVisibilityFilter(
|
||||||
return events, nil
|
return events, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, snapshot storage.DatabaseTransaction, jr *types.JoinResponse, roomID, userID string, latestPosition types.StreamPosition) {
|
|
||||||
// Work out how many members are in the room.
|
|
||||||
joinedCount, _ := snapshot.MembershipCount(ctx, roomID, gomatrixserverlib.Join, latestPosition)
|
|
||||||
invitedCount, _ := snapshot.MembershipCount(ctx, roomID, gomatrixserverlib.Invite, latestPosition)
|
|
||||||
|
|
||||||
jr.Summary.JoinedMemberCount = &joinedCount
|
|
||||||
jr.Summary.InvitedMemberCount = &invitedCount
|
|
||||||
|
|
||||||
fetchStates := []gomatrixserverlib.StateKeyTuple{
|
|
||||||
{EventType: gomatrixserverlib.MRoomName},
|
|
||||||
{EventType: gomatrixserverlib.MRoomCanonicalAlias},
|
|
||||||
}
|
|
||||||
// Check if the room has a name or a canonical alias
|
|
||||||
latestState := &roomserverAPI.QueryLatestEventsAndStateResponse{}
|
|
||||||
err := p.rsAPI.QueryLatestEventsAndState(ctx, &roomserverAPI.QueryLatestEventsAndStateRequest{StateToFetch: fetchStates, RoomID: roomID}, latestState)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Check if the room has a name or canonical alias, if so, return.
|
|
||||||
for _, ev := range latestState.StateEvents {
|
|
||||||
switch ev.Type() {
|
|
||||||
case gomatrixserverlib.MRoomName:
|
|
||||||
if gjson.GetBytes(ev.Content(), "name").Str != "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
case gomatrixserverlib.MRoomCanonicalAlias:
|
|
||||||
if gjson.GetBytes(ev.Content(), "alias").Str != "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
heroes, err := snapshot.GetRoomHeroes(ctx, roomID, userID, []string{"join", "invite"})
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
sort.Strings(heroes)
|
|
||||||
jr.Summary.Heroes = heroes
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
|
func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
snapshot storage.DatabaseTransaction,
|
snapshot storage.DatabaseTransaction,
|
||||||
|
@ -493,7 +454,10 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.addRoomSummary(ctx, snapshot, jr, roomID, device.UserID, r.From)
|
jr.Summary, err = snapshot.GetRoomSummary(ctx, roomID, device.UserID)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Warn("failed to get room summary")
|
||||||
|
}
|
||||||
|
|
||||||
// We don't include a device here as we don't need to send down
|
// We don't include a device here as we don't need to send down
|
||||||
// transaction IDs for complete syncs, but we do it anyway because Sytest demands it for:
|
// transaction IDs for complete syncs, but we do it anyway because Sytest demands it for:
|
||||||
|
|
Loading…
Reference in a new issue