Fix room summary returning wrong heroes (#2930)

This should fix #2910.
Probably makes Sytest/Complement a bit upset, since this not using
`sort.Strings` anymore.
This commit is contained in:
Till 2023-01-12 10:06:03 +01:00 committed by GitHub
parent 25dfbc6ec3
commit 0491a8e343
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 378 additions and 210 deletions

View file

@ -45,7 +45,7 @@ type DatabaseTransaction interface {
GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error)
MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error)
GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) GetRoomSummary(ctx context.Context, roomID, userID string) (summary *types.Summary, err error)
RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
GetBackwardTopologyPos(ctx context.Context, events []*gomatrixserverlib.HeaderedEvent) (types.TopologyToken, error) GetBackwardTopologyPos(ctx context.Context, events []*gomatrixserverlib.HeaderedEvent) (types.TopologyToken, error)
PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error)

View file

@ -19,6 +19,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"errors"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
@ -110,6 +111,15 @@ const selectSharedUsersSQL = "" +
" SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" + " SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" +
") AND type = 'm.room.member' AND state_key = ANY($2) AND membership IN ('join', 'invite');" ") AND type = 'm.room.member' AND state_key = ANY($2) AND membership IN ('join', 'invite');"
const selectMembershipCount = `SELECT count(*) FROM syncapi_current_room_state WHERE type = 'm.room.member' AND room_id = $1 AND membership = $2`
const selectRoomHeroes = `
SELECT state_key FROM syncapi_current_room_state
WHERE type = 'm.room.member' AND room_id = $1 AND membership = ANY($2) AND state_key != $3
ORDER BY added_at, state_key
LIMIT 5
`
type currentRoomStateStatements struct { type currentRoomStateStatements struct {
upsertRoomStateStmt *sql.Stmt upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt
@ -122,6 +132,8 @@ type currentRoomStateStatements struct {
selectEventsWithEventIDsStmt *sql.Stmt selectEventsWithEventIDsStmt *sql.Stmt
selectStateEventStmt *sql.Stmt selectStateEventStmt *sql.Stmt
selectSharedUsersStmt *sql.Stmt selectSharedUsersStmt *sql.Stmt
selectMembershipCountStmt *sql.Stmt
selectRoomHeroesStmt *sql.Stmt
} }
func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
@ -141,40 +153,21 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
return nil, err return nil, err
} }
if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil { return s, sqlutil.StatementList{
return nil, err {&s.upsertRoomStateStmt, upsertRoomStateSQL},
} {&s.deleteRoomStateByEventIDStmt, deleteRoomStateByEventIDSQL},
if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil { {&s.deleteRoomStateForRoomStmt, deleteRoomStateForRoomSQL},
return nil, err {&s.selectRoomIDsWithMembershipStmt, selectRoomIDsWithMembershipSQL},
} {&s.selectRoomIDsWithAnyMembershipStmt, selectRoomIDsWithAnyMembershipSQL},
if s.deleteRoomStateForRoomStmt, err = db.Prepare(deleteRoomStateForRoomSQL); err != nil { {&s.selectCurrentStateStmt, selectCurrentStateSQL},
return nil, err {&s.selectJoinedUsersStmt, selectJoinedUsersSQL},
} {&s.selectJoinedUsersInRoomStmt, selectJoinedUsersInRoomSQL},
if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { {&s.selectEventsWithEventIDsStmt, selectEventsWithEventIDsSQL},
return nil, err {&s.selectStateEventStmt, selectStateEventSQL},
} {&s.selectSharedUsersStmt, selectSharedUsersSQL},
if s.selectRoomIDsWithAnyMembershipStmt, err = db.Prepare(selectRoomIDsWithAnyMembershipSQL); err != nil { {&s.selectMembershipCountStmt, selectMembershipCount},
return nil, err {&s.selectRoomHeroesStmt, selectRoomHeroes},
} }.Prepare(db)
if s.selectCurrentStateStmt, err = db.Prepare(selectCurrentStateSQL); err != nil {
return nil, err
}
if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil {
return nil, err
}
if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil {
return nil, err
}
if s.selectEventsWithEventIDsStmt, err = db.Prepare(selectEventsWithEventIDsSQL); err != nil {
return nil, err
}
if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil {
return nil, err
}
if s.selectSharedUsersStmt, err = db.Prepare(selectSharedUsersSQL); err != nil {
return nil, err
}
return s, nil
} }
// SelectJoinedUsers returns a map of room ID to a list of joined user IDs. // SelectJoinedUsers returns a map of room ID to a list of joined user IDs.
@ -447,3 +440,34 @@ func (s *currentRoomStateStatements) SelectSharedUsers(
} }
return result, rows.Err() return result, rows.Err()
} }
func (s *currentRoomStateStatements) SelectRoomHeroes(ctx context.Context, txn *sql.Tx, roomID, excludeUserID string, memberships []string) ([]string, error) {
stmt := sqlutil.TxStmt(txn, s.selectRoomHeroesStmt)
rows, err := stmt.QueryContext(ctx, roomID, pq.StringArray(memberships), excludeUserID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomHeroesStmt: rows.close() failed")
var stateKey string
result := make([]string, 0, 5)
for rows.Next() {
if err = rows.Scan(&stateKey); err != nil {
return nil, err
}
result = append(result, stateKey)
}
return result, rows.Err()
}
func (s *currentRoomStateStatements) SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string) (count int, err error) {
stmt := sqlutil.TxStmt(txn, s.selectMembershipCountStmt)
err = stmt.QueryRowContext(ctx, roomID, membership).Scan(&count)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return 0, nil
}
return 0, err
}
return count, nil
}

View file

@ -19,10 +19,8 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/lib/pq"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
@ -64,9 +62,6 @@ const selectMembershipCountSQL = "" +
" SELECT DISTINCT ON (room_id, user_id) room_id, user_id, membership FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC" + " SELECT DISTINCT ON (room_id, user_id) room_id, user_id, membership FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC" +
") t WHERE t.membership = $3" ") t WHERE t.membership = $3"
const selectHeroesSQL = "" +
"SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership = ANY($3) LIMIT 5"
const selectMembershipBeforeSQL = "" + const selectMembershipBeforeSQL = "" +
"SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1" "SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1"
@ -81,7 +76,6 @@ WHERE ($3::text IS NULL OR t.membership = $3)
type membershipsStatements struct { type membershipsStatements struct {
upsertMembershipStmt *sql.Stmt upsertMembershipStmt *sql.Stmt
selectMembershipCountStmt *sql.Stmt selectMembershipCountStmt *sql.Stmt
selectHeroesStmt *sql.Stmt
selectMembershipForUserStmt *sql.Stmt selectMembershipForUserStmt *sql.Stmt
selectMembersStmt *sql.Stmt selectMembersStmt *sql.Stmt
} }
@ -95,7 +89,6 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) {
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.upsertMembershipStmt, upsertMembershipSQL}, {&s.upsertMembershipStmt, upsertMembershipSQL},
{&s.selectMembershipCountStmt, selectMembershipCountSQL}, {&s.selectMembershipCountStmt, selectMembershipCountSQL},
{&s.selectHeroesStmt, selectHeroesSQL},
{&s.selectMembershipForUserStmt, selectMembershipBeforeSQL}, {&s.selectMembershipForUserStmt, selectMembershipBeforeSQL},
{&s.selectMembersStmt, selectMembersSQL}, {&s.selectMembersStmt, selectMembersSQL},
}.Prepare(db) }.Prepare(db)
@ -129,26 +122,6 @@ func (s *membershipsStatements) SelectMembershipCount(
return return
} }
func (s *membershipsStatements) SelectHeroes(
ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string,
) (heroes []string, err error) {
stmt := sqlutil.TxStmt(txn, s.selectHeroesStmt)
var rows *sql.Rows
rows, err = stmt.QueryContext(ctx, roomID, userID, pq.StringArray(memberships))
if err != nil {
return
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectHeroes: rows.close() failed")
var hero string
for rows.Next() {
if err = rows.Scan(&hero); err != nil {
return
}
heroes = append(heroes, hero)
}
return heroes, rows.Err()
}
// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found // SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found
// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty // returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty
// string as the membership. // string as the membership.

View file

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
@ -92,8 +93,61 @@ func (d *DatabaseTransaction) MembershipCount(ctx context.Context, roomID, membe
return d.Memberships.SelectMembershipCount(ctx, d.txn, roomID, membership, pos) return d.Memberships.SelectMembershipCount(ctx, d.txn, roomID, membership, pos)
} }
func (d *DatabaseTransaction) GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) { func (d *DatabaseTransaction) GetRoomSummary(ctx context.Context, roomID, userID string) (*types.Summary, error) {
return d.Memberships.SelectHeroes(ctx, d.txn, roomID, userID, memberships) summary := &types.Summary{Heroes: []string{}}
joinCount, err := d.CurrentRoomState.SelectMembershipCount(ctx, d.txn, roomID, gomatrixserverlib.Join)
if err != nil {
return summary, err
}
inviteCount, err := d.CurrentRoomState.SelectMembershipCount(ctx, d.txn, roomID, gomatrixserverlib.Invite)
if err != nil {
return summary, err
}
summary.InvitedMemberCount = &inviteCount
summary.JoinedMemberCount = &joinCount
// Get the room name and canonical alias, if any
filter := gomatrixserverlib.DefaultStateFilter()
filterTypes := []string{gomatrixserverlib.MRoomName, gomatrixserverlib.MRoomCanonicalAlias}
filterRooms := []string{roomID}
filter.Types = &filterTypes
filter.Rooms = &filterRooms
evs, err := d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, &filter, nil)
if err != nil {
return summary, err
}
for _, ev := range evs {
switch ev.Type() {
case gomatrixserverlib.MRoomName:
if gjson.GetBytes(ev.Content(), "name").Str != "" {
return summary, nil
}
case gomatrixserverlib.MRoomCanonicalAlias:
if gjson.GetBytes(ev.Content(), "alias").Str != "" {
return summary, nil
}
}
}
// If there's no room name or canonical alias, get the room heroes, excluding the user
heroes, err := d.CurrentRoomState.SelectRoomHeroes(ctx, d.txn, roomID, userID, []string{gomatrixserverlib.Join, gomatrixserverlib.Invite})
if err != nil {
return summary, err
}
// "When no joined or invited members are available, this should consist of the banned and left users"
if len(heroes) == 0 {
heroes, err = d.CurrentRoomState.SelectRoomHeroes(ctx, d.txn, roomID, userID, []string{gomatrixserverlib.Leave, gomatrixserverlib.Ban})
if err != nil {
return summary, err
}
}
summary.Heroes = heroes
return summary, nil
} }
func (d *DatabaseTransaction) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) { func (d *DatabaseTransaction) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) {

View file

@ -19,6 +19,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"strings" "strings"
@ -95,6 +96,15 @@ const selectSharedUsersSQL = "" +
" SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" + " SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" +
") AND type = 'm.room.member' AND state_key IN ($2) AND membership IN ('join', 'invite');" ") AND type = 'm.room.member' AND state_key IN ($2) AND membership IN ('join', 'invite');"
const selectMembershipCount = `SELECT count(*) FROM syncapi_current_room_state WHERE type = 'm.room.member' AND room_id = $1 AND membership = $2`
const selectRoomHeroes = `
SELECT state_key FROM syncapi_current_room_state
WHERE type = 'm.room.member' AND room_id = $1 AND state_key != $2 AND membership IN ($3)
ORDER BY added_at, state_key
LIMIT 5
`
type currentRoomStateStatements struct { type currentRoomStateStatements struct {
db *sql.DB db *sql.DB
streamIDStatements *StreamIDStatements streamIDStatements *StreamIDStatements
@ -107,6 +117,8 @@ type currentRoomStateStatements struct {
//selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic //selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic
selectStateEventStmt *sql.Stmt selectStateEventStmt *sql.Stmt
//selectSharedUsersSQL *sql.Stmt - prepared at runtime due to variadic //selectSharedUsersSQL *sql.Stmt - prepared at runtime due to variadic
selectMembershipCountStmt *sql.Stmt
//selectRoomHeroes *sql.Stmt - prepared at runtime due to variadic
} }
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) { func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) {
@ -129,31 +141,16 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (t
return nil, err return nil, err
} }
if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil { return s, sqlutil.StatementList{
return nil, err {&s.upsertRoomStateStmt, upsertRoomStateSQL},
} {&s.deleteRoomStateByEventIDStmt, deleteRoomStateByEventIDSQL},
if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil { {&s.deleteRoomStateForRoomStmt, deleteRoomStateForRoomSQL},
return nil, err {&s.selectRoomIDsWithMembershipStmt, selectRoomIDsWithMembershipSQL},
} {&s.selectRoomIDsWithAnyMembershipStmt, selectRoomIDsWithAnyMembershipSQL},
if s.deleteRoomStateForRoomStmt, err = db.Prepare(deleteRoomStateForRoomSQL); err != nil { {&s.selectJoinedUsersStmt, selectJoinedUsersSQL},
return nil, err {&s.selectStateEventStmt, selectStateEventSQL},
} {&s.selectMembershipCountStmt, selectMembershipCount},
if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { }.Prepare(db)
return nil, err
}
if s.selectRoomIDsWithAnyMembershipStmt, err = db.Prepare(selectRoomIDsWithAnyMembershipSQL); err != nil {
return nil, err
}
if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil {
return nil, err
}
//if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil {
// return nil, err
//}
if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil {
return nil, err
}
return s, nil
} }
// SelectJoinedUsers returns a map of room ID to a list of joined user IDs. // SelectJoinedUsers returns a map of room ID to a list of joined user IDs.
@ -485,3 +482,53 @@ func (s *currentRoomStateStatements) SelectSharedUsers(
return result, err return result, err
} }
func (s *currentRoomStateStatements) SelectRoomHeroes(ctx context.Context, txn *sql.Tx, roomID, excludeUserID string, memberships []string) ([]string, error) {
params := make([]interface{}, len(memberships)+2)
params[0] = roomID
params[1] = excludeUserID
for k, v := range memberships {
params[k+2] = v
}
query := strings.Replace(selectRoomHeroes, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1)
var stmt *sql.Stmt
var err error
if txn != nil {
stmt, err = txn.Prepare(query)
} else {
stmt, err = s.db.Prepare(query)
}
if err != nil {
return []string{}, err
}
defer internal.CloseAndLogIfError(ctx, stmt, "selectRoomHeroes: stmt.close() failed")
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomHeroes: rows.close() failed")
var stateKey string
result := make([]string, 0, 5)
for rows.Next() {
if err = rows.Scan(&stateKey); err != nil {
return nil, err
}
result = append(result, stateKey)
}
return result, rows.Err()
}
func (s *currentRoomStateStatements) SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string) (count int, err error) {
stmt := sqlutil.TxStmt(txn, s.selectMembershipCountStmt)
err = stmt.QueryRowContext(ctx, roomID, membership).Scan(&count)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return 0, nil
}
return 0, err
}
return count, nil
}

View file

@ -18,11 +18,9 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"strings"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
@ -64,9 +62,6 @@ const selectMembershipCountSQL = "" +
" SELECT * FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 GROUP BY user_id HAVING(max(stream_pos))" + " SELECT * FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 GROUP BY user_id HAVING(max(stream_pos))" +
") t WHERE t.membership = $3" ") t WHERE t.membership = $3"
const selectHeroesSQL = "" +
"SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership IN ($3) LIMIT 5"
const selectMembershipBeforeSQL = "" + const selectMembershipBeforeSQL = "" +
"SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1" "SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1"
@ -99,7 +94,6 @@ func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) {
{&s.selectMembershipCountStmt, selectMembershipCountSQL}, {&s.selectMembershipCountStmt, selectMembershipCountSQL},
{&s.selectMembershipForUserStmt, selectMembershipBeforeSQL}, {&s.selectMembershipForUserStmt, selectMembershipBeforeSQL},
{&s.selectMembersStmt, selectMembersSQL}, {&s.selectMembersStmt, selectMembersSQL},
// {&s.selectHeroesStmt, selectHeroesSQL}, - prepared at runtime due to variadic
}.Prepare(db) }.Prepare(db)
} }
@ -131,39 +125,6 @@ func (s *membershipsStatements) SelectMembershipCount(
return return
} }
func (s *membershipsStatements) SelectHeroes(
ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string,
) (heroes []string, err error) {
stmtSQL := strings.Replace(selectHeroesSQL, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1)
stmt, err := s.db.PrepareContext(ctx, stmtSQL)
if err != nil {
return
}
defer internal.CloseAndLogIfError(ctx, stmt, "SelectHeroes: stmt.close() failed")
params := []interface{}{
roomID, userID,
}
for _, membership := range memberships {
params = append(params, membership)
}
stmt = sqlutil.TxStmt(txn, stmt)
var rows *sql.Rows
rows, err = stmt.QueryContext(ctx, params...)
if err != nil {
return
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectHeroes: rows.close() failed")
var hero string
for rows.Next() {
if err = rows.Scan(&hero); err != nil {
return
}
heroes = append(heroes, hero)
}
return heroes, rows.Err()
}
// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found // SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found
// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty // returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty
// string as the membership. // string as the membership.

View file

@ -14,6 +14,7 @@ import (
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
) )
var ctx = context.Background() var ctx = context.Background()
@ -664,3 +665,181 @@ func topologyTokenBefore(t *testing.T, db storage.Database, eventID string) *typ
return &tok return &tok
} }
*/ */
func pointer[t any](s t) *t {
return &s
}
func TestRoomSummary(t *testing.T) {
alice := test.NewUser(t)
bob := test.NewUser(t)
charlie := test.NewUser(t)
// Create some dummy users
moreUsers := []*test.User{}
moreUserIDs := []string{}
for i := 0; i < 10; i++ {
u := test.NewUser(t)
moreUsers = append(moreUsers, u)
moreUserIDs = append(moreUserIDs, u.ID)
}
testCases := []struct {
name string
wantSummary *types.Summary
additionalEvents func(t *testing.T, room *test.Room)
}{
{
name: "after initial creation",
wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(0), Heroes: []string{}},
},
{
name: "invited user",
wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(1), Heroes: []string{bob.ID}},
additionalEvents: func(t *testing.T, room *test.Room) {
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "invite",
}, test.WithStateKey(bob.ID))
},
},
{
name: "invited user, but declined",
wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(0), Heroes: []string{bob.ID}},
additionalEvents: func(t *testing.T, room *test.Room) {
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "invite",
}, test.WithStateKey(bob.ID))
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "leave",
}, test.WithStateKey(bob.ID))
},
},
{
name: "joined user after invitation",
wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(0), Heroes: []string{bob.ID}},
additionalEvents: func(t *testing.T, room *test.Room) {
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "invite",
}, test.WithStateKey(bob.ID))
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "join",
}, test.WithStateKey(bob.ID))
},
},
{
name: "multiple joined user",
wantSummary: &types.Summary{JoinedMemberCount: pointer(3), InvitedMemberCount: pointer(0), Heroes: []string{charlie.ID, bob.ID}},
additionalEvents: func(t *testing.T, room *test.Room) {
room.CreateAndInsert(t, charlie, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "join",
}, test.WithStateKey(charlie.ID))
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "join",
}, test.WithStateKey(bob.ID))
},
},
{
name: "multiple joined/invited user",
wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(1), Heroes: []string{charlie.ID, bob.ID}},
additionalEvents: func(t *testing.T, room *test.Room) {
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "invite",
}, test.WithStateKey(charlie.ID))
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "join",
}, test.WithStateKey(bob.ID))
},
},
{
name: "multiple joined/invited/left user",
wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(1), Heroes: []string{charlie.ID}},
additionalEvents: func(t *testing.T, room *test.Room) {
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "invite",
}, test.WithStateKey(charlie.ID))
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "join",
}, test.WithStateKey(bob.ID))
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "leave",
}, test.WithStateKey(bob.ID))
},
},
{
name: "leaving user after joining",
wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(0), Heroes: []string{bob.ID}},
additionalEvents: func(t *testing.T, room *test.Room) {
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "join",
}, test.WithStateKey(bob.ID))
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "leave",
}, test.WithStateKey(bob.ID))
},
},
{
name: "many users", // heroes ordered by stream id
wantSummary: &types.Summary{JoinedMemberCount: pointer(len(moreUserIDs) + 1), InvitedMemberCount: pointer(0), Heroes: moreUserIDs[:5]},
additionalEvents: func(t *testing.T, room *test.Room) {
for _, x := range moreUsers {
room.CreateAndInsert(t, x, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "join",
}, test.WithStateKey(x.ID))
}
},
},
{
name: "canonical alias set",
wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(0), Heroes: []string{}},
additionalEvents: func(t *testing.T, room *test.Room) {
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "join",
}, test.WithStateKey(bob.ID))
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomCanonicalAlias, map[string]interface{}{
"alias": "myalias",
}, test.WithStateKey(""))
},
},
{
name: "room name set",
wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(0), Heroes: []string{}},
additionalEvents: func(t *testing.T, room *test.Room) {
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "join",
}, test.WithStateKey(bob.ID))
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomName, map[string]interface{}{
"name": "my room name",
}, test.WithStateKey(""))
},
},
}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close, closeBase := MustCreateDatabase(t, dbType)
defer close()
defer closeBase()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
r := test.NewRoom(t, alice)
if tc.additionalEvents != nil {
tc.additionalEvents(t, r)
}
// write the room before creating a transaction
MustWriteEvents(t, db, r.Events())
transaction, err := db.NewDatabaseTransaction(ctx)
assert.NoError(t, err)
defer transaction.Rollback()
summary, err := transaction.GetRoomSummary(ctx, r.ID, alice.ID)
assert.NoError(t, err)
assert.Equal(t, tc.wantSummary, summary)
})
}
})
}

View file

@ -115,6 +115,9 @@ type CurrentRoomState interface {
SelectJoinedUsersInRoom(ctx context.Context, txn *sql.Tx, roomIDs []string) (map[string][]string, error) SelectJoinedUsersInRoom(ctx context.Context, txn *sql.Tx, roomIDs []string) (map[string][]string, error)
// SelectSharedUsers returns a subset of otherUserIDs that share a room with userID. // SelectSharedUsers returns a subset of otherUserIDs that share a room with userID.
SelectSharedUsers(ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string) ([]string, error) SelectSharedUsers(ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string) ([]string, error)
SelectRoomHeroes(ctx context.Context, txn *sql.Tx, roomID, excludeUserID string, memberships []string) ([]string, error)
SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string) (int, error)
} }
// BackwardsExtremities keeps track of backwards extremities for a room. // BackwardsExtremities keeps track of backwards extremities for a room.
@ -185,7 +188,6 @@ type Receipts interface {
type Memberships interface { type Memberships interface {
UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error
SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error) SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error)
SelectHeroes(ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string) (heroes []string, err error)
SelectMembershipForUser(ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) SelectMembershipForUser(ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64) (membership string, topologicalPos int, err error)
SelectMemberships( SelectMemberships(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,

View file

@ -3,8 +3,6 @@ package tables_test
import ( import (
"context" "context"
"database/sql" "database/sql"
"reflect"
"sort"
"testing" "testing"
"time" "time"
@ -88,43 +86,9 @@ func TestMembershipsTable(t *testing.T) {
testUpsert(t, ctx, table, userEvents[0], alice, room) testUpsert(t, ctx, table, userEvents[0], alice, room)
testMembershipCount(t, ctx, table, room) testMembershipCount(t, ctx, table, room)
testHeroes(t, ctx, table, alice, room, users)
}) })
} }
func testHeroes(t *testing.T, ctx context.Context, table tables.Memberships, user *test.User, room *test.Room, users []string) {
// Re-slice and sort the expected users
users = users[1:]
sort.Strings(users)
type testCase struct {
name string
memberships []string
wantHeroes []string
}
testCases := []testCase{
{name: "no memberships queried", memberships: []string{}},
{name: "joined memberships queried should be limited", memberships: []string{gomatrixserverlib.Join}, wantHeroes: users[:5]},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got, err := table.SelectHeroes(ctx, nil, room.ID, user.ID, tc.memberships)
if err != nil {
t.Fatalf("unable to select heroes: %s", err)
}
if gotLen := len(got); gotLen != len(tc.wantHeroes) {
t.Fatalf("expected %d heroes, got %d", len(tc.wantHeroes), gotLen)
}
if !reflect.DeepEqual(got, tc.wantHeroes) {
t.Fatalf("expected heroes to be %+v, got %+v", tc.wantHeroes, got)
}
})
}
}
func testMembershipCount(t *testing.T, ctx context.Context, table tables.Memberships, room *test.Room) { func testMembershipCount(t *testing.T, ctx context.Context, table tables.Memberships, room *test.Room) {
t.Run("membership counts are correct", func(t *testing.T) { t.Run("membership counts are correct", func(t *testing.T) {
// After 10 events, we should have 6 users (5 create related [incl. one member event], 5 member events = 6 users) // After 10 events, we should have 6 users (5 create related [incl. one member event], 5 member events = 6 users)

View file

@ -4,7 +4,6 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"sort"
"time" "time"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
@ -14,11 +13,9 @@ import (
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/syncapi/notifier"
) )
// The max number of per-room goroutines to have running. // The max number of per-room goroutines to have running.
@ -339,7 +336,10 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
case gomatrixserverlib.Join: case gomatrixserverlib.Join:
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
if hasMembershipChange { if hasMembershipChange {
p.addRoomSummary(ctx, snapshot, jr, delta.RoomID, device.UserID, latestPosition) jr.Summary, err = snapshot.GetRoomSummary(ctx, delta.RoomID, device.UserID)
if err != nil {
logrus.WithError(err).Warn("failed to get room summary")
}
} }
jr.Timeline.PrevBatch = &prevBatch jr.Timeline.PrevBatch = &prevBatch
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync)
@ -411,45 +411,6 @@ func applyHistoryVisibilityFilter(
return events, nil return events, nil
} }
func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, snapshot storage.DatabaseTransaction, jr *types.JoinResponse, roomID, userID string, latestPosition types.StreamPosition) {
// Work out how many members are in the room.
joinedCount, _ := snapshot.MembershipCount(ctx, roomID, gomatrixserverlib.Join, latestPosition)
invitedCount, _ := snapshot.MembershipCount(ctx, roomID, gomatrixserverlib.Invite, latestPosition)
jr.Summary.JoinedMemberCount = &joinedCount
jr.Summary.InvitedMemberCount = &invitedCount
fetchStates := []gomatrixserverlib.StateKeyTuple{
{EventType: gomatrixserverlib.MRoomName},
{EventType: gomatrixserverlib.MRoomCanonicalAlias},
}
// Check if the room has a name or a canonical alias
latestState := &roomserverAPI.QueryLatestEventsAndStateResponse{}
err := p.rsAPI.QueryLatestEventsAndState(ctx, &roomserverAPI.QueryLatestEventsAndStateRequest{StateToFetch: fetchStates, RoomID: roomID}, latestState)
if err != nil {
return
}
// Check if the room has a name or canonical alias, if so, return.
for _, ev := range latestState.StateEvents {
switch ev.Type() {
case gomatrixserverlib.MRoomName:
if gjson.GetBytes(ev.Content(), "name").Str != "" {
return
}
case gomatrixserverlib.MRoomCanonicalAlias:
if gjson.GetBytes(ev.Content(), "alias").Str != "" {
return
}
}
}
heroes, err := snapshot.GetRoomHeroes(ctx, roomID, userID, []string{"join", "invite"})
if err != nil {
return
}
sort.Strings(heroes)
jr.Summary.Heroes = heroes
}
func (p *PDUStreamProvider) getJoinResponseForCompleteSync( func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
ctx context.Context, ctx context.Context,
snapshot storage.DatabaseTransaction, snapshot storage.DatabaseTransaction,
@ -493,7 +454,10 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
return return
} }
p.addRoomSummary(ctx, snapshot, jr, roomID, device.UserID, r.From) jr.Summary, err = snapshot.GetRoomSummary(ctx, roomID, device.UserID)
if err != nil {
logrus.WithError(err).Warn("failed to get room summary")
}
// We don't include a device here as we don't need to send down // We don't include a device here as we don't need to send down
// transaction IDs for complete syncs, but we do it anyway because Sytest demands it for: // transaction IDs for complete syncs, but we do it anyway because Sytest demands it for: