fix GetRoomsByMembership to work with pseudo IDs

This commit is contained in:
Sam Wedgwood 2023-08-11 00:32:19 +01:00
parent 6d3bc3937b
commit 96a05270e1
16 changed files with 266 additions and 97 deletions

View file

@ -33,23 +33,36 @@ func GetJoinedRooms(
device *userapi.Device, device *userapi.Device,
rsAPI api.ClientRoomserverAPI, rsAPI api.ClientRoomserverAPI,
) util.JSONResponse { ) util.JSONResponse {
var res api.QueryRoomsForUserResponse deviceUserID, err := spec.NewUserID(device.UserID, true)
err := rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{ if err != nil {
UserID: device.UserID, util.GetLogger(req.Context()).WithError(err).Error("Invalid device user ID")
WantMembership: "join", return util.JSONResponse{
}, &res) Code: http.StatusInternalServerError,
JSON: spec.Unknown("internal server error"),
}
}
rooms, err := rsAPI.QueryRoomsForUser(req.Context(), *deviceUserID, "join")
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed") util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed")
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{}, JSON: spec.Unknown("internal server error"),
} }
} }
if res.RoomIDs == nil {
res.RoomIDs = []string{} var roomIDStrs []string
if rooms == nil {
roomIDStrs = []string{}
} else {
roomIDStrs = make([]string, len(rooms))
for i, roomID := range rooms {
roomIDStrs[i] = roomID.String()
} }
}
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: getJoinedRoomsResponse{res.RoomIDs}, JSON: getJoinedRoomsResponse{roomIDStrs},
} }
} }

View file

@ -251,11 +251,15 @@ func updateProfile(
profile *authtypes.Profile, profile *authtypes.Profile,
userID string, evTime time.Time, userID string, evTime time.Time,
) (util.JSONResponse, error) { ) (util.JSONResponse, error) {
var res api.QueryRoomsForUserResponse deviceUserID, err := spec.NewUserID(device.UserID, true)
err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ if err != nil {
UserID: device.UserID, return util.JSONResponse{
WantMembership: "join", Code: http.StatusInternalServerError,
}, &res) JSON: spec.Unknown("internal server error"),
}, err
}
rooms, err := rsAPI.QueryRoomsForUser(ctx, *deviceUserID, "join")
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("QueryRoomsForUser failed") util.GetLogger(ctx).WithError(err).Error("QueryRoomsForUser failed")
return util.JSONResponse{ return util.JSONResponse{
@ -264,6 +268,11 @@ func updateProfile(
}, err }, err
} }
roomIDStrs := make([]string, len(rooms))
for i, room := range rooms {
roomIDStrs[i] = room.String()
}
_, domain, err := gomatrixserverlib.SplitID('@', userID) _, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
@ -274,7 +283,7 @@ func updateProfile(
} }
events, err := buildMembershipEvents( events, err := buildMembershipEvents(
ctx, res.RoomIDs, *profile, userID, evTime, rsAPI, ctx, roomIDStrs, *profile, userID, evTime, rsAPI,
) )
switch e := err.(type) { switch e := err.(type) {
case nil: case nil:

View file

@ -94,34 +94,42 @@ func SendServerNotice(
} }
} }
userID, err := spec.NewUserID(r.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam("invalid user ID"),
}
}
// get rooms for specified user // get rooms for specified user
allUserRooms := []string{} allUserRooms := []spec.RoomID{}
userRooms := api.QueryRoomsForUserResponse{}
// Get rooms the user is either joined, invited or has left. // Get rooms the user is either joined, invited or has left.
for _, membership := range []string{"join", "invite", "leave"} { for _, membership := range []string{"join", "invite", "leave"} {
if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ userRooms, err := rsAPI.QueryRoomsForUser(ctx, *userID, membership)
UserID: r.UserID, if err != nil {
WantMembership: membership,
}, &userRooms); err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
allUserRooms = append(allUserRooms, userRooms.RoomIDs...) allUserRooms = append(allUserRooms, userRooms...)
} }
// get rooms of the sender // get rooms of the sender
senderUserID := fmt.Sprintf("@%s:%s", cfgNotices.LocalPart, cfgClient.Matrix.ServerName) senderUserID, err := spec.NewUserID(fmt.Sprintf("@%s:%s", cfgNotices.LocalPart, cfgClient.Matrix.ServerName), true)
senderRooms := api.QueryRoomsForUserResponse{} if err != nil {
if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ return util.JSONResponse{
UserID: senderUserID, Code: http.StatusInternalServerError,
WantMembership: "join", JSON: spec.Unknown("internal server error"),
}, &senderRooms); err != nil { }
}
senderRooms, err := rsAPI.QueryRoomsForUser(ctx, *senderUserID, "join")
if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
// check if we have rooms in common // check if we have rooms in common
commonRooms := []string{} commonRooms := []spec.RoomID{}
for _, userRoomID := range allUserRooms { for _, userRoomID := range allUserRooms {
for _, senderRoomID := range senderRooms.RoomIDs { for _, senderRoomID := range senderRooms {
if userRoomID == senderRoomID { if userRoomID == senderRoomID {
commonRooms = append(commonRooms, senderRoomID) commonRooms = append(commonRooms, senderRoomID)
} }
@ -139,7 +147,7 @@ func SendServerNotice(
// create a new room for the user // create a new room for the user
if len(commonRooms) == 0 { if len(commonRooms) == 0 {
powerLevelContent := eventutil.InitialPowerLevelsContent(senderUserID) powerLevelContent := eventutil.InitialPowerLevelsContent(senderUserID.String())
powerLevelContent.Users[r.UserID] = -10 // taken from Synapse powerLevelContent.Users[r.UserID] = -10 // taken from Synapse
pl, err := json.Marshal(powerLevelContent) pl, err := json.Marshal(powerLevelContent)
if err != nil { if err != nil {
@ -195,7 +203,7 @@ func SendServerNotice(
} }
} }
roomID = commonRooms[0] roomID = commonRooms[0].String()
membershipRes := api.QueryMembershipForUserResponse{} membershipRes := api.QueryMembershipForUserResponse{}
err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: *deviceUserID, RoomID: roomID}, &membershipRes) err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: *deviceUserID, RoomID: roomID}, &membershipRes)
if err != nil { if err != nil {

View file

@ -117,19 +117,27 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool {
return true return true
} }
var queryRes roomserverAPI.QueryRoomsForUserResponse userID, err := spec.NewUserID(m.UserID, true)
err = t.rsAPI.QueryRoomsForUser(t.ctx, &roomserverAPI.QueryRoomsForUserRequest{ if err != nil {
UserID: m.UserID, sentry.CaptureException(err)
WantMembership: "join", logger.WithError(err).Error("invalid user ID")
}, &queryRes) return true
}
roomIDs, err := t.rsAPI.QueryRoomsForUser(t.ctx, *userID, "join")
if err != nil { if err != nil {
sentry.CaptureException(err) sentry.CaptureException(err)
logger.WithError(err).Error("failed to calculate joined rooms for user") logger.WithError(err).Error("failed to calculate joined rooms for user")
return true return true
} }
roomIDStrs := make([]string, len(roomIDs))
for i, room := range roomIDs {
roomIDStrs[i] = room.String()
}
// send this key change to all servers who share rooms with this user. // send this key change to all servers who share rooms with this user.
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true) destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, roomIDStrs, true, true)
if err != nil { if err != nil {
sentry.CaptureException(err) sentry.CaptureException(err)
logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in") logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in")
@ -179,18 +187,27 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool {
} }
logger := logrus.WithField("user_id", output.UserID) logger := logrus.WithField("user_id", output.UserID)
var queryRes roomserverAPI.QueryRoomsForUserResponse outputUserID, err := spec.NewUserID(output.UserID, true)
err = t.rsAPI.QueryRoomsForUser(t.ctx, &roomserverAPI.QueryRoomsForUserRequest{ if err != nil {
UserID: output.UserID, sentry.CaptureException(err)
WantMembership: "join", logrus.WithError(err).Errorf("invalid user ID")
}, &queryRes) return true
}
rooms, err := t.rsAPI.QueryRoomsForUser(t.ctx, *outputUserID, "join")
if err != nil { if err != nil {
sentry.CaptureException(err) sentry.CaptureException(err)
logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined rooms for user") logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined rooms for user")
return true return true
} }
roomIDStrs := make([]string, len(rooms))
for i, room := range rooms {
roomIDStrs[i] = room.String()
}
// send this key change to all servers who share rooms with this user. // send this key change to all servers who share rooms with this user.
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true) destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, roomIDStrs, true, true)
if err != nil { if err != nil {
sentry.CaptureException(err) sentry.CaptureException(err)
logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in") logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in")

View file

@ -29,6 +29,7 @@ import (
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -94,16 +95,23 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg
return true return true
} }
var queryRes roomserverAPI.QueryRoomsForUserResponse parsedUserID, err := spec.NewUserID(userID, true)
err = t.rsAPI.QueryRoomsForUser(t.ctx, &roomserverAPI.QueryRoomsForUserRequest{ if err != nil {
UserID: userID, util.GetLogger(ctx).WithError(err).WithField("user_id", userID).Error("failed to extract domain from receipt sender")
WantMembership: "join", return true
}, &queryRes) }
roomIDs, err := t.rsAPI.QueryRoomsForUser(t.ctx, *parsedUserID, "join")
if err != nil { if err != nil {
log.WithError(err).Error("failed to calculate joined rooms for user") log.WithError(err).Error("failed to calculate joined rooms for user")
return true return true
} }
roomIDStrs := make([]string, len(roomIDs))
for i, roomID := range roomIDs {
roomIDStrs[i] = roomID.String()
}
presence := msg.Header.Get("presence") presence := msg.Header.Get("presence")
ts, err := strconv.Atoi(msg.Header.Get("last_active_ts")) ts, err := strconv.Atoi(msg.Header.Get("last_active_ts"))
@ -112,7 +120,7 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg
} }
// send this presence to all servers who share rooms with this user. // send this presence to all servers who share rooms with this user.
joined, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true) joined, err := t.db.GetJoinedHostsForRooms(t.ctx, roomIDStrs, true, true)
if err != nil { if err != nil {
log.WithError(err).Error("failed to get joined hosts") log.WithError(err).Error("failed to get joined hosts")
return true return true

View file

@ -33,7 +33,7 @@ import (
type fedRoomserverAPI struct { type fedRoomserverAPI struct {
rsapi.FederationRoomserverAPI rsapi.FederationRoomserverAPI
inputRoomEvents func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) inputRoomEvents func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse)
queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error queryRoomsForUser func(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error)
} }
func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
@ -54,11 +54,11 @@ func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.Input
} }
// keychange consumer calls this // keychange consumer calls this
func (f *fedRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error { func (f *fedRoomserverAPI) QueryRoomsForUser(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error) {
if f.queryRoomsForUser == nil { if f.queryRoomsForUser == nil {
return nil return nil, nil
} }
return f.queryRoomsForUser(ctx, req, res) return f.queryRoomsForUser(ctx, userID, desiredMembership)
} }
// TODO: This struct isn't generic, only works for TestFederationAPIJoinThenKeyUpdate // TODO: This struct isn't generic, only works for TestFederationAPIJoinThenKeyUpdate
@ -199,18 +199,22 @@ func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) {
fmt.Printf("creator: %v joining user: %v\n", creator.ID, joiningUser.ID) fmt.Printf("creator: %v joining user: %v\n", creator.ID, joiningUser.ID)
room := test.NewRoom(t, creator) room := test.NewRoom(t, creator)
roomID, err := spec.NewRoomID(room.ID)
if err != nil {
t.Fatalf("Invalid room ID: %q", roomID)
}
rsapi := &fedRoomserverAPI{ rsapi := &fedRoomserverAPI{
inputRoomEvents: func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) { inputRoomEvents: func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) {
if req.Asynchronous { if req.Asynchronous {
t.Errorf("InputRoomEvents from PerformJoin MUST be synchronous") t.Errorf("InputRoomEvents from PerformJoin MUST be synchronous")
} }
}, },
queryRoomsForUser: func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error { queryRoomsForUser: func(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error) {
if req.UserID == joiningUser.ID && req.WantMembership == "join" { if userID.String() == joiningUser.ID && desiredMembership == "join" {
res.RoomIDs = []string{room.ID} return []spec.RoomID{*roomID}, nil
return nil
} }
return fmt.Errorf("unexpected queryRoomsForUser: %+v", *req) return nil, fmt.Errorf("unexpected queryRoomsForUser: %v, %v", userID, desiredMembership)
}, },
} }
fc := &fedClient{ fc := &fedClient{

View file

@ -222,7 +222,7 @@ type ClientRoomserverAPI interface {
DefaultRoomVersionAPI DefaultRoomVersionAPI
QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error QueryRoomsForUser(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error)
QueryStateAfterEvents(ctx context.Context, req *QueryStateAfterEventsRequest, res *QueryStateAfterEventsResponse) error QueryStateAfterEvents(ctx context.Context, req *QueryStateAfterEventsRequest, res *QueryStateAfterEventsResponse) error
// QueryKnownUsers returns a list of users that we know about from our joined rooms. // QueryKnownUsers returns a list of users that we know about from our joined rooms.
QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error
@ -300,7 +300,7 @@ type FederationRoomserverAPI interface {
QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error
// Query whether a server is allowed to see an event // Query whether a server is allowed to see an event
QueryServerAllowedToSeeEvent(ctx context.Context, serverName spec.ServerName, eventID string, roomID string) (allowed bool, err error) QueryServerAllowedToSeeEvent(ctx context.Context, serverName spec.ServerName, eventID string, roomID string) (allowed bool, err error)
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error QueryRoomsForUser(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error)
QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error)
PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error
HandleInvite(ctx context.Context, event *types.HeaderedEvent) error HandleInvite(ctx context.Context, event *types.HeaderedEvent) error

View file

@ -291,16 +291,6 @@ type QuerySharedUsersResponse struct {
UserIDsToCount map[string]int UserIDsToCount map[string]int
} }
type QueryRoomsForUserRequest struct {
UserID string
// The desired membership of the user. If this is the empty string then no rooms are returned.
WantMembership string
}
type QueryRoomsForUserResponse struct {
RoomIDs []string
}
type QueryBulkStateContentRequest struct { type QueryBulkStateContentRequest struct {
// Returns state events in these rooms // Returns state events in these rooms
RoomIDs []string RoomIDs []string

View file

@ -161,12 +161,12 @@ func (r *Admin) PerformAdminEvacuateUser(
return nil, fmt.Errorf("can only evacuate local users using this endpoint") return nil, fmt.Errorf("can only evacuate local users using this endpoint")
} }
roomIDs, err := r.DB.GetRoomsByMembership(ctx, userID, spec.Join) roomIDs, err := r.DB.GetRoomsByMembership(ctx, *fullUserID, spec.Join)
if err != nil { if err != nil {
return nil, err return nil, err
} }
inviteRoomIDs, err := r.DB.GetRoomsByMembership(ctx, userID, spec.Invite) inviteRoomIDs, err := r.DB.GetRoomsByMembership(ctx, *fullUserID, spec.Invite)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return nil, err return nil, err
} }

View file

@ -846,13 +846,20 @@ func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentSt
return nil return nil
} }
func (r *Queryer) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error { func (r *Queryer) QueryRoomsForUser(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error) {
roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, req.WantMembership) roomIDStrs, err := r.DB.GetRoomsByMembership(ctx, userID, desiredMembership)
if err != nil { if err != nil {
return err return nil, err
} }
res.RoomIDs = roomIDs roomIDs := make([]spec.RoomID, len(roomIDStrs))
return nil for i, roomIDStr := range roomIDStrs {
roomID, err := spec.NewRoomID(roomIDStr)
if err != nil {
return nil, nil
}
roomIDs[i] = *roomID
}
return roomIDs, nil
} }
func (r *Queryer) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error { func (r *Queryer) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error {
@ -895,7 +902,12 @@ func (r *Queryer) QueryLeftUsers(ctx context.Context, req *api.QueryLeftUsersReq
} }
func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error { func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error {
roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, "join") parsedUserID, err := spec.NewUserID(req.UserID, true)
if err != nil {
return err
}
roomIDs, err := r.DB.GetRoomsByMembership(ctx, *parsedUserID, "join")
if err != nil { if err != nil {
return err return err
} }

View file

@ -158,7 +158,7 @@ type Database interface {
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error)
GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*types.HeaderedEvent, error) GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*types.HeaderedEvent, error)
// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). // GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) GetRoomsByMembership(ctx context.Context, userID spec.UserID, membership string) ([]string, error)
// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match.
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. // If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)

View file

@ -56,12 +56,15 @@ const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_use
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid = ANY($1) AND pseudo_id_pub_key = ANY($2)` const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid = ANY($1) AND pseudo_id_pub_key = ANY($2)`
const selectAllUserRoomPublicKeyForUserSQL = `SELECT room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE user_nid = $1`
type userRoomKeysStatements struct { type userRoomKeysStatements struct {
insertUserRoomPrivateKeyStmt *sql.Stmt insertUserRoomPrivateKeyStmt *sql.Stmt
insertUserRoomPublicKeyStmt *sql.Stmt insertUserRoomPublicKeyStmt *sql.Stmt
selectUserRoomKeyStmt *sql.Stmt selectUserRoomKeyStmt *sql.Stmt
selectUserRoomPublicKeyStmt *sql.Stmt selectUserRoomPublicKeyStmt *sql.Stmt
selectUserNIDsStmt *sql.Stmt selectUserNIDsStmt *sql.Stmt
selectAllUserRoomPublicKeysForUser *sql.Stmt
} }
func CreateUserRoomKeysTable(db *sql.DB) error { func CreateUserRoomKeysTable(db *sql.DB) error {
@ -77,6 +80,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) {
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
{&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL}, {&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL},
{&s.selectUserNIDsStmt, selectUserNIDsSQL}, {&s.selectUserNIDsStmt, selectUserNIDsSQL},
{&s.selectAllUserRoomPublicKeysForUser, selectAllUserRoomPublicKeyForUserSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -150,3 +154,24 @@ func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sq
} }
return result, rows.Err() return result, rows.Err()
} }
func (s *userRoomKeysStatements) SelectAllPublicKeysForUser(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID) (map[types.RoomNID]ed25519.PublicKey, error) {
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectAllUserRoomPublicKeysForUser)
rows, err := stmt.QueryContext(ctx, userNID)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
resultMap := make(map[types.RoomNID]ed25519.PublicKey)
var roomNID types.RoomNID
var pubkey ed25519.PublicKey
for rows.Next() {
if err = rows.Scan(&roomNID, &pubkey); err != nil {
return nil, err
}
resultMap[roomNID] = pubkey
}
return resultMap, err
}

View file

@ -1347,7 +1347,7 @@ func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evTy
} }
// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). // GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) { func (d *Database) GetRoomsByMembership(ctx context.Context, userID spec.UserID, membership string) ([]string, error) {
var membershipState tables.MembershipState var membershipState tables.MembershipState
switch membership { switch membership {
case "join": case "join":
@ -1361,17 +1361,73 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership
default: default:
return nil, fmt.Errorf("GetRoomsByMembership: invalid membership %s", membership) return nil, fmt.Errorf("GetRoomsByMembership: invalid membership %s", membership)
} }
stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID)
// Convert provided user ID to NID
userNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID.String())
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} else {
return nil, fmt.Errorf("SelectEventStateKeyNID: cannot map user ID to state key NIDs: %w", err)
} }
return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err)
} }
roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, nil, stateKeyNID, membershipState)
// Use this NID to fetch all associated room keys (for pseudo ID rooms)
roomKeyMap, err := d.UserRoomKeyTable.SelectAllPublicKeysForUser(ctx, nil, userNID)
if err != nil {
if err == sql.ErrNoRows {
roomKeyMap = map[types.RoomNID]ed25519.PublicKey{}
} else {
return nil, fmt.Errorf("SelectAllPublicKeysForUser: could not select user room public keys for user: %w", err)
}
}
var eventStateKeyNIDs []types.EventStateKeyNID
// If there are room keys (i.e. this user is in pseudo ID rooms), then gather the appropriate NIDs
if len(roomKeyMap) != 0 {
// Convert keys to string representation
userRoomKeys := make([]string, len(roomKeyMap))
i := 0
for _, key := range roomKeyMap {
userRoomKeys[i] = spec.Base64Bytes(key).Encode()
i += 1
}
// Convert the string representation to its NID
pseudoIDStateKeys, sqlErr := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, userRoomKeys)
if sqlErr != nil {
if sqlErr == sql.ErrNoRows {
pseudoIDStateKeys = map[string]types.EventStateKeyNID{}
} else {
return nil, fmt.Errorf("BulkSelectEventStateKeyNID: could not select state keys for public room keys: %w", err)
}
}
// Collect all NIDs together
eventStateKeyNIDs = make([]types.EventStateKeyNID, len(pseudoIDStateKeys)+1)
eventStateKeyNIDs[0] = userNID
i = 1
for _, nid := range pseudoIDStateKeys {
eventStateKeyNIDs[i] = nid
i += 1
}
} else {
// If there are no room keys (so no pseudo ID rooms), we only need to care about the user ID NID.
eventStateKeyNIDs = []types.EventStateKeyNID{userNID}
}
// Fetch rooms that match membership for each NID
roomNIDs := []types.RoomNID{}
for _, nid := range eventStateKeyNIDs {
var roomNIDsChunk []types.RoomNID
roomNIDsChunk, err = d.MembershipTable.SelectRoomsWithMembership(ctx, nil, nid, membershipState)
if err != nil { if err != nil {
return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err) return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err)
} }
roomNIDs = append(roomNIDs, roomNIDsChunk...)
}
roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, nil, roomNIDs) roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, nil, roomNIDs)
if err != nil { if err != nil {
return nil, fmt.Errorf("GetRoomsByMembership: failed to lookup room nids: %w", err) return nil, fmt.Errorf("GetRoomsByMembership: failed to lookup room nids: %w", err)

View file

@ -56,12 +56,15 @@ const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_use
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)` const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)`
const selectAllUserRoomPublicKeyForUserSQL = `SELECT room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE user_nid = $1`
type userRoomKeysStatements struct { type userRoomKeysStatements struct {
db *sql.DB db *sql.DB
insertUserRoomPrivateKeyStmt *sql.Stmt insertUserRoomPrivateKeyStmt *sql.Stmt
insertUserRoomPublicKeyStmt *sql.Stmt insertUserRoomPublicKeyStmt *sql.Stmt
selectUserRoomKeyStmt *sql.Stmt selectUserRoomKeyStmt *sql.Stmt
selectUserRoomPublicKeyStmt *sql.Stmt selectUserRoomPublicKeyStmt *sql.Stmt
selectAllUserRoomPublicKeysForUser *sql.Stmt
//selectUserNIDsStmt *sql.Stmt //prepared at runtime //selectUserNIDsStmt *sql.Stmt //prepared at runtime
} }
@ -77,6 +80,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) {
{&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL}, {&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL},
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
{&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL}, {&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL},
{&s.selectAllUserRoomPublicKeysForUser, selectAllUserRoomPublicKeyForUserSQL},
//{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime //{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime
}.Prepare(db) }.Prepare(db)
} }
@ -165,3 +169,24 @@ func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sq
} }
return result, rows.Err() return result, rows.Err()
} }
func (s *userRoomKeysStatements) SelectAllPublicKeysForUser(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID) (map[types.RoomNID]ed25519.PublicKey, error) {
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectAllUserRoomPublicKeysForUser)
rows, err := stmt.QueryContext(ctx, userNID)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
resultMap := make(map[types.RoomNID]ed25519.PublicKey)
var roomNID types.RoomNID
var pubkey ed25519.PublicKey
for rows.Next() {
if err = rows.Scan(&roomNID, &pubkey); err != nil {
return nil, err
}
resultMap[roomNID] = pubkey
}
return resultMap, err
}

View file

@ -198,6 +198,8 @@ type UserRoomKeys interface {
// BulkSelectUserNIDs selects all userIDs for the requested senderKeys. Returns a map from publicKey -> types.UserRoomKeyPair. // BulkSelectUserNIDs selects all userIDs for the requested senderKeys. Returns a map from publicKey -> types.UserRoomKeyPair.
// If a senderKey can't be found, it is omitted in the result. // If a senderKey can't be found, it is omitted in the result.
BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error)
// SelectAllPublicKeysForUser returns all known public keys for a user. Returns a map from room NID -> public key
SelectAllPublicKeysForUser(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID) (map[types.RoomNID]ed25519.PublicKey, error)
} }
// StrippedEvent represents a stripped event for returning extracted content values. // StrippedEvent represents a stripped event for returning extracted content values.

View file

@ -69,8 +69,8 @@ func (s *mockRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spe
} }
// QueryRoomsForUser retrieves a list of room IDs matching the given query. // QueryRoomsForUser retrieves a list of room IDs matching the given query.
func (s *mockRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error { func (s *mockRoomserverAPI) QueryRoomsForUser(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error) {
return nil return nil, nil
} }
// QueryBulkStateContent does a bulk query for state event content in the given rooms. // QueryBulkStateContent does a bulk query for state event content in the given rooms.