From 2aec9de269da1634a3f65c439326a3242181d396 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Tue, 13 Jun 2023 12:55:16 +0100 Subject: [PATCH] Connect senderID query to user room keys table --- clientapi/routing/directory.go | 20 +++++++++-- clientapi/routing/membership.go | 26 +++++++++++--- clientapi/routing/profile.go | 6 +++- clientapi/routing/redaction.go | 15 +++++--- clientapi/routing/sendevent.go | 9 ++++- clientapi/threepid/invites.go | 6 +++- federationapi/federationapi_test.go | 2 +- federationapi/internal/perform.go | 6 +++- federationapi/routing/join.go | 2 +- federationapi/routing/leave.go | 2 +- roomserver/api/api.go | 2 +- roomserver/internal/perform/perform_admin.go | 6 +++- .../internal/perform/perform_create_room.go | 2 +- roomserver/internal/perform/perform_invite.go | 2 +- roomserver/internal/perform/perform_leave.go | 6 +++- .../internal/perform/perform_upgrade.go | 6 +++- roomserver/internal/query/query.go | 20 ++++++----- roomserver/storage/interface.go | 4 +-- .../storage/postgres/user_room_keys_table.go | 19 +++++++++++ roomserver/storage/shared/storage.go | 34 ++++++++++++++++--- roomserver/storage/shared/storage_test.go | 7 +++- .../storage/sqlite3/user_room_keys_table.go | 19 +++++++++++ roomserver/storage/tables/interface.go | 2 ++ .../tables/user_room_keys_table_test.go | 7 ++++ setup/mscs/msc2836/msc2836_test.go | 2 +- syncapi/internal/history_visibility.go | 6 +++- syncapi/storage/shared/storage_consumer.go | 15 ++++++-- userapi/consumers/roomserver.go | 8 ++++- 28 files changed, 217 insertions(+), 44 deletions(-) diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index 621df0cca..de374f082 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -233,7 +233,15 @@ func RemoveLocalAlias( } } - deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomIDRes.RoomID, *userID) + validRoomID, err := spec.NewRoomID(roomIDRes.RoomID) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("roomID is invalid") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *userID) if err != nil { return util.JSONResponse{ Code: http.StatusNotFound, @@ -321,7 +329,15 @@ func SetVisibility( JSON: spec.BadJSON("userID for this device is invalid"), } } - senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("roomID is invalid") + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("RoomID is invalid"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID) if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 03e85edbf..bafc37b67 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -64,7 +64,14 @@ func SendBan( JSON: spec.Forbidden("You don't have permission to ban this user, bad userID"), } } - senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("RoomID is invalid"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID) if err != nil { return util.JSONResponse{ Code: http.StatusForbidden, @@ -155,7 +162,14 @@ func SendKick( JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"), } } - senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("RoomID is invalid"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID) if err != nil { return util.JSONResponse{ Code: http.StatusForbidden, @@ -428,7 +442,11 @@ func buildMembershipEvent( if err != nil { return nil, err } - senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *userID) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return nil, err + } + senderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *userID) if err != nil { return nil, err } @@ -437,7 +455,7 @@ func buildMembershipEvent( if err != nil { return nil, err } - targetSenderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *targetID) + targetSenderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *targetID) if err != nil { return nil, err } diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index e734e2e4f..8a44834e1 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -368,7 +368,11 @@ func buildMembershipEvents( return nil, err } for _, roomID := range roomIDs { - senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return nil, err + } + senderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID) if err != nil { return nil, err } diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index da48e84de..42f029395 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -54,7 +54,14 @@ func SendRedaction( JSON: spec.Forbidden("userID doesn't have power level to redact"), } } - senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("RoomID is invalid"), + } + } + senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID) if queryErr != nil { return util.JSONResponse{ Code: http.StatusForbidden, @@ -103,8 +110,8 @@ func SendRedaction( JSON: spec.Forbidden("You don't have permission to redact this event, no power_levels event in this room."), } } - pl, err := plEvent.PowerLevels() - if err != nil { + pl, plErr := plEvent.PowerLevels() + if plErr != nil { return util.JSONResponse{ Code: 403, JSON: spec.Forbidden( @@ -134,7 +141,7 @@ func SendRedaction( Type: spec.MRoomRedaction, Redacts: eventID, } - err := proto.SetContent(r) + err = proto.SetContent(r) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("proto.SetContent failed") return util.JSONResponse{ diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 4d0a9f24a..860c972cb 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -273,7 +273,14 @@ func generateSendEvent( JSON: spec.BadJSON("Bad userID"), } } - senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return nil, &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("RoomID is invalid"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID) if err != nil { return nil, &util.JSONResponse{ Code: http.StatusNotFound, diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index e7ffbac2b..d15cc6d46 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -359,7 +359,11 @@ func emit3PIDInviteEvent( if err != nil { return err } - sender, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *userID) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return err + } + sender, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *userID) if err != nil { return err } diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index dcddc6947..5410d16fe 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -40,7 +40,7 @@ func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID stri return spec.NewUserID(string(senderID), true) } -func (f *fedRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { +func (f *fedRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) { return spec.SenderID(userID.String()), nil } diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index f230299d9..2d66ef681 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -456,7 +456,11 @@ func (r *FederationInternalAPI) PerformLeave( // Set all the fields to be what they should be, this should be a no-op // but it's possible that the remote server returned us something "odd" - senderID, err := r.rsAPI.QuerySenderIDForUser(ctx, request.RoomID, *userID) + roomID, err := spec.NewRoomID(request.RoomID) + if err != nil { + return err + } + senderID, err := r.rsAPI.QuerySenderIDForUser(ctx, *roomID, *userID) if err != nil { return err } diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index d14801921..b9d241572 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -98,7 +98,7 @@ func MakeJoin( Roomserver: rsAPI, } - senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID.String(), userID) + senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID, userID) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QuerySenderIDForUser failed") return util.JSONResponse{ diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index 716276bec..dfd426de6 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -87,7 +87,7 @@ func MakeLeave( return event, stateEvents, nil } - senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID.String(), userID) + senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID, userID) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QuerySenderIDForUser failed") return util.JSONResponse{ diff --git a/roomserver/api/api.go b/roomserver/api/api.go index bafde91c9..e9615a1b1 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -77,7 +77,7 @@ type InputRoomEventsAPI interface { } type QuerySenderIDAPI interface { - QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) + QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) } diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index dccedea14..bc3f805ab 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -293,7 +293,11 @@ func (r *Admin) PerformAdminDownloadState( stateIDs = append(stateIDs, stateEvent.EventID()) } - senderID, err := r.Queryer.QuerySenderIDForUser(ctx, roomID, *fullUserID) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return err + } + senderID, err := r.Queryer.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID) if err != nil { return err } diff --git a/roomserver/internal/perform/perform_create_room.go b/roomserver/internal/perform/perform_create_room.go index 59b1efa61..96b7f9199 100644 --- a/roomserver/internal/perform/perform_create_room.go +++ b/roomserver/internal/perform/perform_create_room.go @@ -431,7 +431,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo JSON: spec.InternalServerError{}, } } - inviteeSenderID, queryErr := c.RSAPI.QuerySenderIDForUser(ctx, roomID.String(), *inviteeUserID) + inviteeSenderID, queryErr := c.RSAPI.QuerySenderIDForUser(ctx, roomID, *inviteeUserID) if queryErr != nil { util.GetLogger(ctx).WithError(queryErr).Error("rsapi.QuerySenderIDForUser failed") return "", &util.JSONResponse{ diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 6143f9043..25deb64d6 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -149,7 +149,7 @@ func (r *Inviter) PerformInvite( return err } - invitedSenderID, err := r.RSAPI.QuerySenderIDForUser(ctx, event.RoomID(), *invitedUser) + invitedSenderID, err := r.RSAPI.QuerySenderIDForUser(ctx, *validRoomID, *invitedUser) if err != nil { return fmt.Errorf("failed looking up senderID for invited user") } diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index 3a81f49b8..17b44d25b 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -78,7 +78,11 @@ func (r *Leaver) performLeaveRoomByID( req *api.PerformLeaveRequest, res *api.PerformLeaveResponse, // nolint:unparam ) ([]api.OutputEvent, error) { - leaver, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomID, req.Leaver) + roomID, err := spec.NewRoomID(req.RoomID) + if err != nil { + return nil, err + } + leaver, err := r.RSAPI.QuerySenderIDForUser(ctx, *roomID, req.Leaver) if err != nil { return nil, fmt.Errorf("leaver %s has no matching senderID in this room", req.Leaver.String()) } diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go index 1aaa42c94..bcbd1546b 100644 --- a/roomserver/internal/perform/perform_upgrade.go +++ b/roomserver/internal/perform/perform_upgrade.go @@ -54,7 +54,11 @@ func (r *Upgrader) performRoomUpgrade( return "", err } - senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, userID) + fullRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return "", err + } + senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, *fullRoomID, userID) if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user") return "", err diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index bba080a14..d0d99c0eb 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -271,15 +271,15 @@ func (r *Queryer) QueryMembershipForUser( request *api.QueryMembershipForUserRequest, response *api.QueryMembershipForUserResponse, ) error { - senderID, err := r.QuerySenderIDForUser(ctx, request.RoomID, request.UserID) - if err != nil { - return err - } - roomID, err := spec.NewRoomID(request.RoomID) if err != nil { return err } + senderID, err := r.QuerySenderIDForUser(ctx, *roomID, request.UserID) + if err != nil { + return err + } + return r.QueryMembershipForSenderID(ctx, *roomID, senderID, response) } @@ -989,15 +989,19 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.Ro return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, senderID) } -func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { - version, err := r.DB.GetRoomVersion(ctx, roomID) +func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) { + version, err := r.DB.GetRoomVersion(ctx, roomID.String()) if err != nil { return "", err } switch version { case gomatrixserverlib.RoomVersionPseudoIDs: - return r.DB.GetSenderIDForUser(ctx, roomID, userID) + key, err := r.DB.SelectUserRoomPublicKey(ctx, userID, roomID) + if err != nil { + return "", err + } + return spec.SenderID(spec.Base64Bytes(key).Encode()), nil default: return spec.SenderID(userID.String()), nil } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index e6efdea5b..19cbdb346 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -170,8 +170,6 @@ type Database interface { GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) // GetKnownUsers tries to obtain the current mxid for a given user. GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) - // GetKnownUsers tries to obtain the current senderID for a given user. - GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) // GetKnownRooms returns a list of all rooms we know about. GetKnownRooms(ctx context.Context) ([]string, error) // ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room @@ -205,6 +203,8 @@ type UserRoomKeys interface { InsertUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) // SelectUserRoomPrivateKey selects the private key for the given user and room combination SelectUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PrivateKey, err error) + // SelectUserRoomPublicKey selects the public key for the given user and room combination + SelectUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PublicKey, err error) // SelectUserIDsForPublicKeys selects all userIDs for the requested senderKeys. Returns a map from roomID -> map from publicKey to userID. // If a senderKey can't be found, it is omitted in the result. SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (map[spec.RoomID]map[string]string, error) diff --git a/roomserver/storage/postgres/user_room_keys_table.go b/roomserver/storage/postgres/user_room_keys_table.go index 22f978bf0..dbb4af34a 100644 --- a/roomserver/storage/postgres/user_room_keys_table.go +++ b/roomserver/storage/postgres/user_room_keys_table.go @@ -51,12 +51,15 @@ const insertUserRoomPublicKeySQL = ` const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` +const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` + const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid = ANY($1) AND pseudo_id_pub_key = ANY($2)` type userRoomKeysStatements struct { insertUserRoomPrivateKeyStmt *sql.Stmt insertUserRoomPublicKeyStmt *sql.Stmt selectUserRoomKeyStmt *sql.Stmt + selectUserRoomPublicKeyStmt *sql.Stmt selectUserNIDsStmt *sql.Stmt } @@ -71,6 +74,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) { {&s.insertUserRoomPrivateKeyStmt, insertUserRoomPrivateKeySQL}, {&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL}, {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, + {&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL}, {&s.selectUserNIDsStmt, selectUserNIDsSQL}, }.Prepare(db) } @@ -102,6 +106,21 @@ func (s *userRoomKeysStatements) SelectUserRoomPrivateKey( return result, err } +func (s *userRoomKeysStatements) SelectUserRoomPublicKey( + ctx context.Context, + txn *sql.Tx, + userNID types.EventStateKeyNID, + roomNID types.RoomNID, +) (ed25519.PublicKey, error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomPublicKeyStmt) + var result ed25519.PublicKey + err := stmt.QueryRowContext(ctx, userNID, roomNID).Scan(&result) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return result, err +} + func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) { stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserNIDsStmt) diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 9691c5c73..da8732f64 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -1551,11 +1551,6 @@ func (d *Database) GetUserIDForSender(ctx context.Context, roomID string, sender return spec.NewUserID(string(senderID), true) } -func (d *Database) GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { - // TODO: Use real logic once DB for pseudoIDs is in place - return spec.SenderID(userID.String()), nil -} - // GetKnownRooms returns a list of all rooms we know about. func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil) @@ -1714,6 +1709,35 @@ func (d *Database) SelectUserRoomPrivateKey(ctx context.Context, userID spec.Use return } +// SelectUserRoomPublicKey queries the users room public key. +// If no key exists, returns no key and no error. Otherwise returns +// the key and a database error, if any. +func (d *Database) SelectUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PublicKey, err error) { + uID := userID.String() + stateKeyNIDMap, sErr := d.eventStateKeyNIDs(ctx, nil, []string{uID}) + if sErr != nil { + return nil, sErr + } + stateKeyNID := stateKeyNIDMap[uID] + + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String()) + if rErr != nil { + return rErr + } + if roomInfo == nil { + return nil + } + + key, sErr = d.UserRoomKeyTable.SelectUserRoomPublicKey(ctx, txn, stateKeyNID, roomInfo.RoomNID) + if !errors.Is(sErr, sql.ErrNoRows) { + return sErr + } + return nil + }) + return +} + // SelectUserIDsForPublicKeys returns a map from roomID -> map from senderKey -> userID func (d *Database) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (result map[spec.RoomID]map[string]string, err error) { result = make(map[spec.RoomID]map[string]string, len(publicKeys)) diff --git a/roomserver/storage/shared/storage_test.go b/roomserver/storage/shared/storage_test.go index 4fa451bcc..8f83af3bb 100644 --- a/roomserver/storage/shared/storage_test.go +++ b/roomserver/storage/shared/storage_test.go @@ -162,12 +162,17 @@ func TestUserRoomKeys(t *testing.T) { gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *roomID) assert.NoError(t, err) assert.Equal(t, key, gotKey) + pubKey, err := db.SelectUserRoomPublicKey(context.Background(), *userID, *roomID) + assert.NoError(t, err) + assert.Equal(t, key.Public(), pubKey) // Key doesn't exist, we shouldn't get anything back - assert.NoError(t, err) gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *doesNotExist) assert.NoError(t, err) assert.Nil(t, gotKey) + pubKey, err = db.SelectUserRoomPublicKey(context.Background(), *userID, *doesNotExist) + assert.NoError(t, err) + assert.Nil(t, pubKey) queryUserIDs := map[spec.RoomID][]ed25519.PublicKey{ *roomID: {key.Public().(ed25519.PublicKey)}, diff --git a/roomserver/storage/sqlite3/user_room_keys_table.go b/roomserver/storage/sqlite3/user_room_keys_table.go index 8af57ea0e..84c8b54ec 100644 --- a/roomserver/storage/sqlite3/user_room_keys_table.go +++ b/roomserver/storage/sqlite3/user_room_keys_table.go @@ -51,12 +51,15 @@ const insertUserRoomPublicKeySQL = ` const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` +const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` + const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)` type userRoomKeysStatements struct { insertUserRoomPrivateKeyStmt *sql.Stmt insertUserRoomPublicKeyStmt *sql.Stmt selectUserRoomKeyStmt *sql.Stmt + selectUserRoomPublicKeyStmt *sql.Stmt //selectUserNIDsStmt *sql.Stmt //prepared at runtime } @@ -71,6 +74,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) { {&s.insertUserRoomPrivateKeyStmt, insertUserRoomKeySQL}, {&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL}, {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, + {&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL}, //{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime }.Prepare(db) } @@ -102,6 +106,21 @@ func (s *userRoomKeysStatements) SelectUserRoomPrivateKey( return result, err } +func (s *userRoomKeysStatements) SelectUserRoomPublicKey( + ctx context.Context, + txn *sql.Tx, + userNID types.EventStateKeyNID, + roomNID types.RoomNID, +) (ed25519.PublicKey, error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomPublicKeyStmt) + var result ed25519.PublicKey + err := stmt.QueryRowContext(ctx, userNID, roomNID).Scan(&result) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return result, err +} + func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) { roomNIDs := make([]any, 0, len(senderKeys)) diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index cd0e51686..445c1223f 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -193,6 +193,8 @@ type UserRoomKeys interface { InsertUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (ed25519.PublicKey, error) // SelectUserRoomPrivateKey selects the private key for the given user and room combination SelectUserRoomPrivateKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PrivateKey, error) + // SelectUserRoomPublicKey selects the public key for the given user and room combination + SelectUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PublicKey, error) // BulkSelectUserNIDs selects all userIDs for the requested senderKeys. Returns a map from publicKey -> types.UserRoomKeyPair. // If a senderKey can't be found, it is omitted in the result. BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) diff --git a/roomserver/storage/tables/user_room_keys_table_test.go b/roomserver/storage/tables/user_room_keys_table_test.go index 284309481..8802a3c6e 100644 --- a/roomserver/storage/tables/user_room_keys_table_test.go +++ b/roomserver/storage/tables/user_room_keys_table_test.go @@ -50,6 +50,7 @@ func TestUserRoomKeysTable(t *testing.T) { err = sqlutil.WithTransaction(db, func(txn *sql.Tx) error { var gotKey, key2, key3 ed25519.PrivateKey + var pubKey ed25519.PublicKey gotKey, err = tab.InsertUserRoomPrivatePublicKey(context.Background(), txn, userNID, roomNID, key) assert.NoError(t, err) assert.Equal(t, gotKey, key) @@ -71,6 +72,9 @@ func TestUserRoomKeysTable(t *testing.T) { gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, roomNID) assert.NoError(t, err) assert.Equal(t, key, gotKey) + pubKey, err = tab.SelectUserRoomPublicKey(context.Background(), txn, userNID, roomNID) + assert.NoError(t, err) + assert.Equal(t, key.Public(), pubKey) // try to update an existing key, this should only be done for users NOT on this homeserver var gotPubKey ed25519.PublicKey @@ -82,6 +86,9 @@ func TestUserRoomKeysTable(t *testing.T) { gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, 2) assert.NoError(t, err) assert.Nil(t, gotKey) + pubKey, err = tab.SelectUserRoomPublicKey(context.Background(), txn, userNID, 2) + assert.NoError(t, err) + assert.Nil(t, pubKey) // query user NIDs for senderKeys var gotKeys map[string]types.UserRoomKeyPair diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go index e32d6a9f2..05d21b5a4 100644 --- a/setup/mscs/msc2836/msc2836_test.go +++ b/setup/mscs/msc2836/msc2836_test.go @@ -529,7 +529,7 @@ func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID str return spec.NewUserID(string(senderID), true) } -func (r *testRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { +func (r *testRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) { return spec.SenderID(userID.String()), nil } diff --git a/syncapi/internal/history_visibility.go b/syncapi/internal/history_visibility.go index ab1a7f83d..ce6846ca4 100644 --- a/syncapi/internal/history_visibility.go +++ b/syncapi/internal/history_visibility.go @@ -139,7 +139,11 @@ func ApplyHistoryVisibilityFilter( if err != nil { return nil, err } - senderID, err := rsAPI.QuerySenderIDForUser(ctx, ev.RoomID(), *user) + roomID, err := spec.NewRoomID(ev.RoomID()) + if err != nil { + return nil, err + } + senderID, err := rsAPI.QuerySenderIDForUser(ctx, *roomID, *user) if err == nil { if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(string(senderID)) { eventsFiltered = append(eventsFiltered, ev) diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index 799e3d166..1827218b6 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -114,7 +114,14 @@ func (d *Database) StreamEventsToEvents(ctx context.Context, device *userapi.Dev }).WithError(err).Warnf("Failed to add transaction ID to event") continue } - deviceSenderID, err := rsAPI.QuerySenderIDForUser(ctx, in[i].RoomID(), *userID) + roomID, err := spec.NewRoomID(in[i].RoomID()) + if err != nil { + logrus.WithFields(logrus.Fields{ + "event_id": out[i].EventID(), + }).WithError(err).Warnf("Room ID is invalid") + continue + } + deviceSenderID, err := rsAPI.QuerySenderIDForUser(ctx, *roomID, *userID) if err != nil { logrus.WithFields(logrus.Fields{ "event_id": out[i].EventID(), @@ -515,7 +522,11 @@ func getMembershipFromEvent(ctx context.Context, ev gomatrixserverlib.PDU, userI if err != nil { return "", "" } - senderID, err := rsAPI.QuerySenderIDForUser(ctx, ev.RoomID(), *fullUser) + roomID, err := spec.NewRoomID(ev.RoomID()) + if err != nil { + return "", "" + } + senderID, err := rsAPI.QuerySenderIDForUser(ctx, *roomID, *fullUser) if err != nil { return "", "" } diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index b2dc477aa..f845d9b0a 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -818,7 +818,13 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes logger.WithError(err).Errorf("Failed to convert local user to userID %s", localpart) return nil, err } - localSender, err := s.rsAPI.QuerySenderIDForUser(ctx, event.RoomID(), *userID) + roomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + logger.WithError(err).Errorf("event roomID is invalid %s", event.RoomID()) + return nil, err + } + + localSender, err := s.rsAPI.QuerySenderIDForUser(ctx, *roomID, *userID) if err != nil { logger.WithError(err).Errorf("Failed to get local user senderID for room %s: %s", userID.String(), event.RoomID()) return nil, err