Connect senderID query to user room keys table

This commit is contained in:
Devon Hudson 2023-06-13 12:55:16 +01:00
parent 80652422c3
commit 2aec9de269
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
28 changed files with 217 additions and 44 deletions

View file

@ -233,7 +233,15 @@ func RemoveLocalAlias(
} }
} }
deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomIDRes.RoomID, *userID) validRoomID, err := spec.NewRoomID(roomIDRes.RoomID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("roomID is invalid")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *userID)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,
@ -321,7 +329,15 @@ func SetVisibility(
JSON: spec.BadJSON("userID for this device is invalid"), JSON: spec.BadJSON("userID for this device is invalid"),
} }
} }
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("roomID is invalid")
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("RoomID is invalid"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,

View file

@ -64,7 +64,14 @@ func SendBan(
JSON: spec.Forbidden("You don't have permission to ban this user, bad userID"), JSON: spec.Forbidden("You don't have permission to ban this user, bad userID"),
} }
} }
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("RoomID is invalid"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
@ -155,7 +162,14 @@ func SendKick(
JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"), JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"),
} }
} }
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("RoomID is invalid"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
@ -428,7 +442,11 @@ func buildMembershipEvent(
if err != nil { if err != nil {
return nil, err return nil, err
} }
senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *userID) validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return nil, err
}
senderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -437,7 +455,7 @@ func buildMembershipEvent(
if err != nil { if err != nil {
return nil, err return nil, err
} }
targetSenderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *targetID) targetSenderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *targetID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -368,7 +368,11 @@ func buildMembershipEvents(
return nil, err return nil, err
} }
for _, roomID := range roomIDs { for _, roomID := range roomIDs {
senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID) validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return nil, err
}
senderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -54,7 +54,14 @@ func SendRedaction(
JSON: spec.Forbidden("userID doesn't have power level to redact"), JSON: spec.Forbidden("userID doesn't have power level to redact"),
} }
} }
senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("RoomID is invalid"),
}
}
senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID)
if queryErr != nil { if queryErr != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
@ -103,8 +110,8 @@ func SendRedaction(
JSON: spec.Forbidden("You don't have permission to redact this event, no power_levels event in this room."), JSON: spec.Forbidden("You don't have permission to redact this event, no power_levels event in this room."),
} }
} }
pl, err := plEvent.PowerLevels() pl, plErr := plEvent.PowerLevels()
if err != nil { if plErr != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: 403, Code: 403,
JSON: spec.Forbidden( JSON: spec.Forbidden(
@ -134,7 +141,7 @@ func SendRedaction(
Type: spec.MRoomRedaction, Type: spec.MRoomRedaction,
Redacts: eventID, Redacts: eventID,
} }
err := proto.SetContent(r) err = proto.SetContent(r)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("proto.SetContent failed") util.GetLogger(req.Context()).WithError(err).Error("proto.SetContent failed")
return util.JSONResponse{ return util.JSONResponse{

View file

@ -273,7 +273,14 @@ func generateSendEvent(
JSON: spec.BadJSON("Bad userID"), JSON: spec.BadJSON("Bad userID"),
} }
} }
senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID) validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return nil, &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("RoomID is invalid"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID)
if err != nil { if err != nil {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,

View file

@ -359,7 +359,11 @@ func emit3PIDInviteEvent(
if err != nil { if err != nil {
return err return err
} }
sender, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *userID) validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return err
}
sender, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *userID)
if err != nil { if err != nil {
return err return err
} }

View file

@ -40,7 +40,7 @@ func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID stri
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
} }
func (f *fedRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { func (f *fedRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
return spec.SenderID(userID.String()), nil return spec.SenderID(userID.String()), nil
} }

View file

@ -456,7 +456,11 @@ func (r *FederationInternalAPI) PerformLeave(
// Set all the fields to be what they should be, this should be a no-op // Set all the fields to be what they should be, this should be a no-op
// but it's possible that the remote server returned us something "odd" // but it's possible that the remote server returned us something "odd"
senderID, err := r.rsAPI.QuerySenderIDForUser(ctx, request.RoomID, *userID) roomID, err := spec.NewRoomID(request.RoomID)
if err != nil {
return err
}
senderID, err := r.rsAPI.QuerySenderIDForUser(ctx, *roomID, *userID)
if err != nil { if err != nil {
return err return err
} }

View file

@ -98,7 +98,7 @@ func MakeJoin(
Roomserver: rsAPI, Roomserver: rsAPI,
} }
senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID.String(), userID) senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID, userID)
if err != nil { if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QuerySenderIDForUser failed") util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QuerySenderIDForUser failed")
return util.JSONResponse{ return util.JSONResponse{

View file

@ -87,7 +87,7 @@ func MakeLeave(
return event, stateEvents, nil return event, stateEvents, nil
} }
senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID.String(), userID) senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID, userID)
if err != nil { if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QuerySenderIDForUser failed") util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QuerySenderIDForUser failed")
return util.JSONResponse{ return util.JSONResponse{

View file

@ -77,7 +77,7 @@ type InputRoomEventsAPI interface {
} }
type QuerySenderIDAPI interface { type QuerySenderIDAPI interface {
QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error)
QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error)
} }

View file

@ -293,7 +293,11 @@ func (r *Admin) PerformAdminDownloadState(
stateIDs = append(stateIDs, stateEvent.EventID()) stateIDs = append(stateIDs, stateEvent.EventID())
} }
senderID, err := r.Queryer.QuerySenderIDForUser(ctx, roomID, *fullUserID) validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return err
}
senderID, err := r.Queryer.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID)
if err != nil { if err != nil {
return err return err
} }

View file

@ -431,7 +431,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
JSON: spec.InternalServerError{}, JSON: spec.InternalServerError{},
} }
} }
inviteeSenderID, queryErr := c.RSAPI.QuerySenderIDForUser(ctx, roomID.String(), *inviteeUserID) inviteeSenderID, queryErr := c.RSAPI.QuerySenderIDForUser(ctx, roomID, *inviteeUserID)
if queryErr != nil { if queryErr != nil {
util.GetLogger(ctx).WithError(queryErr).Error("rsapi.QuerySenderIDForUser failed") util.GetLogger(ctx).WithError(queryErr).Error("rsapi.QuerySenderIDForUser failed")
return "", &util.JSONResponse{ return "", &util.JSONResponse{

View file

@ -149,7 +149,7 @@ func (r *Inviter) PerformInvite(
return err return err
} }
invitedSenderID, err := r.RSAPI.QuerySenderIDForUser(ctx, event.RoomID(), *invitedUser) invitedSenderID, err := r.RSAPI.QuerySenderIDForUser(ctx, *validRoomID, *invitedUser)
if err != nil { if err != nil {
return fmt.Errorf("failed looking up senderID for invited user") return fmt.Errorf("failed looking up senderID for invited user")
} }

View file

@ -78,7 +78,11 @@ func (r *Leaver) performLeaveRoomByID(
req *api.PerformLeaveRequest, req *api.PerformLeaveRequest,
res *api.PerformLeaveResponse, // nolint:unparam res *api.PerformLeaveResponse, // nolint:unparam
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
leaver, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomID, req.Leaver) roomID, err := spec.NewRoomID(req.RoomID)
if err != nil {
return nil, err
}
leaver, err := r.RSAPI.QuerySenderIDForUser(ctx, *roomID, req.Leaver)
if err != nil { if err != nil {
return nil, fmt.Errorf("leaver %s has no matching senderID in this room", req.Leaver.String()) return nil, fmt.Errorf("leaver %s has no matching senderID in this room", req.Leaver.String())
} }

View file

@ -54,7 +54,11 @@ func (r *Upgrader) performRoomUpgrade(
return "", err return "", err
} }
senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, userID) fullRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return "", err
}
senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, *fullRoomID, userID)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user") util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user")
return "", err return "", err

View file

@ -271,15 +271,15 @@ func (r *Queryer) QueryMembershipForUser(
request *api.QueryMembershipForUserRequest, request *api.QueryMembershipForUserRequest,
response *api.QueryMembershipForUserResponse, response *api.QueryMembershipForUserResponse,
) error { ) error {
senderID, err := r.QuerySenderIDForUser(ctx, request.RoomID, request.UserID)
if err != nil {
return err
}
roomID, err := spec.NewRoomID(request.RoomID) roomID, err := spec.NewRoomID(request.RoomID)
if err != nil { if err != nil {
return err return err
} }
senderID, err := r.QuerySenderIDForUser(ctx, *roomID, request.UserID)
if err != nil {
return err
}
return r.QueryMembershipForSenderID(ctx, *roomID, senderID, response) return r.QueryMembershipForSenderID(ctx, *roomID, senderID, response)
} }
@ -989,15 +989,19 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.Ro
return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, senderID) return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, senderID)
} }
func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
version, err := r.DB.GetRoomVersion(ctx, roomID) version, err := r.DB.GetRoomVersion(ctx, roomID.String())
if err != nil { if err != nil {
return "", err return "", err
} }
switch version { switch version {
case gomatrixserverlib.RoomVersionPseudoIDs: case gomatrixserverlib.RoomVersionPseudoIDs:
return r.DB.GetSenderIDForUser(ctx, roomID, userID) key, err := r.DB.SelectUserRoomPublicKey(ctx, userID, roomID)
if err != nil {
return "", err
}
return spec.SenderID(spec.Base64Bytes(key).Encode()), nil
default: default:
return spec.SenderID(userID.String()), nil return spec.SenderID(userID.String()), nil
} }

View file

@ -170,8 +170,6 @@ type Database interface {
GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error)
// GetKnownUsers tries to obtain the current mxid for a given user. // GetKnownUsers tries to obtain the current mxid for a given user.
GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error)
// GetKnownUsers tries to obtain the current senderID for a given user.
GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error)
// GetKnownRooms returns a list of all rooms we know about. // GetKnownRooms returns a list of all rooms we know about.
GetKnownRooms(ctx context.Context) ([]string, error) GetKnownRooms(ctx context.Context) ([]string, error)
// ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room // ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room
@ -205,6 +203,8 @@ type UserRoomKeys interface {
InsertUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) InsertUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PublicKey) (result ed25519.PublicKey, err error)
// SelectUserRoomPrivateKey selects the private key for the given user and room combination // SelectUserRoomPrivateKey selects the private key for the given user and room combination
SelectUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PrivateKey, err error) SelectUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PrivateKey, err error)
// SelectUserRoomPublicKey selects the public key for the given user and room combination
SelectUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PublicKey, err error)
// SelectUserIDsForPublicKeys selects all userIDs for the requested senderKeys. Returns a map from roomID -> map from publicKey to userID. // SelectUserIDsForPublicKeys selects all userIDs for the requested senderKeys. Returns a map from roomID -> map from publicKey to userID.
// If a senderKey can't be found, it is omitted in the result. // If a senderKey can't be found, it is omitted in the result.
SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (map[spec.RoomID]map[string]string, error) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (map[spec.RoomID]map[string]string, error)

View file

@ -51,12 +51,15 @@ const insertUserRoomPublicKeySQL = `
const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2`
const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2`
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid = ANY($1) AND pseudo_id_pub_key = ANY($2)` const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid = ANY($1) AND pseudo_id_pub_key = ANY($2)`
type userRoomKeysStatements struct { type userRoomKeysStatements struct {
insertUserRoomPrivateKeyStmt *sql.Stmt insertUserRoomPrivateKeyStmt *sql.Stmt
insertUserRoomPublicKeyStmt *sql.Stmt insertUserRoomPublicKeyStmt *sql.Stmt
selectUserRoomKeyStmt *sql.Stmt selectUserRoomKeyStmt *sql.Stmt
selectUserRoomPublicKeyStmt *sql.Stmt
selectUserNIDsStmt *sql.Stmt selectUserNIDsStmt *sql.Stmt
} }
@ -71,6 +74,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) {
{&s.insertUserRoomPrivateKeyStmt, insertUserRoomPrivateKeySQL}, {&s.insertUserRoomPrivateKeyStmt, insertUserRoomPrivateKeySQL},
{&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL}, {&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL},
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
{&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL},
{&s.selectUserNIDsStmt, selectUserNIDsSQL}, {&s.selectUserNIDsStmt, selectUserNIDsSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -102,6 +106,21 @@ func (s *userRoomKeysStatements) SelectUserRoomPrivateKey(
return result, err return result, err
} }
func (s *userRoomKeysStatements) SelectUserRoomPublicKey(
ctx context.Context,
txn *sql.Tx,
userNID types.EventStateKeyNID,
roomNID types.RoomNID,
) (ed25519.PublicKey, error) {
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomPublicKeyStmt)
var result ed25519.PublicKey
err := stmt.QueryRowContext(ctx, userNID, roomNID).Scan(&result)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return result, err
}
func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) { func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) {
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserNIDsStmt) stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserNIDsStmt)

View file

@ -1551,11 +1551,6 @@ func (d *Database) GetUserIDForSender(ctx context.Context, roomID string, sender
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
} }
func (d *Database) GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) {
// TODO: Use real logic once DB for pseudoIDs is in place
return spec.SenderID(userID.String()), nil
}
// GetKnownRooms returns a list of all rooms we know about. // GetKnownRooms returns a list of all rooms we know about.
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil) return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil)
@ -1714,6 +1709,35 @@ func (d *Database) SelectUserRoomPrivateKey(ctx context.Context, userID spec.Use
return return
} }
// SelectUserRoomPublicKey queries the users room public key.
// If no key exists, returns no key and no error. Otherwise returns
// the key and a database error, if any.
func (d *Database) SelectUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PublicKey, err error) {
uID := userID.String()
stateKeyNIDMap, sErr := d.eventStateKeyNIDs(ctx, nil, []string{uID})
if sErr != nil {
return nil, sErr
}
stateKeyNID := stateKeyNIDMap[uID]
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String())
if rErr != nil {
return rErr
}
if roomInfo == nil {
return nil
}
key, sErr = d.UserRoomKeyTable.SelectUserRoomPublicKey(ctx, txn, stateKeyNID, roomInfo.RoomNID)
if !errors.Is(sErr, sql.ErrNoRows) {
return sErr
}
return nil
})
return
}
// SelectUserIDsForPublicKeys returns a map from roomID -> map from senderKey -> userID // SelectUserIDsForPublicKeys returns a map from roomID -> map from senderKey -> userID
func (d *Database) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (result map[spec.RoomID]map[string]string, err error) { func (d *Database) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (result map[spec.RoomID]map[string]string, err error) {
result = make(map[spec.RoomID]map[string]string, len(publicKeys)) result = make(map[spec.RoomID]map[string]string, len(publicKeys))

View file

@ -162,12 +162,17 @@ func TestUserRoomKeys(t *testing.T) {
gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *roomID) gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *roomID)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, key, gotKey) assert.Equal(t, key, gotKey)
pubKey, err := db.SelectUserRoomPublicKey(context.Background(), *userID, *roomID)
assert.NoError(t, err)
assert.Equal(t, key.Public(), pubKey)
// Key doesn't exist, we shouldn't get anything back // Key doesn't exist, we shouldn't get anything back
assert.NoError(t, err)
gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *doesNotExist) gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *doesNotExist)
assert.NoError(t, err) assert.NoError(t, err)
assert.Nil(t, gotKey) assert.Nil(t, gotKey)
pubKey, err = db.SelectUserRoomPublicKey(context.Background(), *userID, *doesNotExist)
assert.NoError(t, err)
assert.Nil(t, pubKey)
queryUserIDs := map[spec.RoomID][]ed25519.PublicKey{ queryUserIDs := map[spec.RoomID][]ed25519.PublicKey{
*roomID: {key.Public().(ed25519.PublicKey)}, *roomID: {key.Public().(ed25519.PublicKey)},

View file

@ -51,12 +51,15 @@ const insertUserRoomPublicKeySQL = `
const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2`
const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2`
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)` const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)`
type userRoomKeysStatements struct { type userRoomKeysStatements struct {
insertUserRoomPrivateKeyStmt *sql.Stmt insertUserRoomPrivateKeyStmt *sql.Stmt
insertUserRoomPublicKeyStmt *sql.Stmt insertUserRoomPublicKeyStmt *sql.Stmt
selectUserRoomKeyStmt *sql.Stmt selectUserRoomKeyStmt *sql.Stmt
selectUserRoomPublicKeyStmt *sql.Stmt
//selectUserNIDsStmt *sql.Stmt //prepared at runtime //selectUserNIDsStmt *sql.Stmt //prepared at runtime
} }
@ -71,6 +74,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) {
{&s.insertUserRoomPrivateKeyStmt, insertUserRoomKeySQL}, {&s.insertUserRoomPrivateKeyStmt, insertUserRoomKeySQL},
{&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL}, {&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL},
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
{&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL},
//{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime //{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime
}.Prepare(db) }.Prepare(db)
} }
@ -102,6 +106,21 @@ func (s *userRoomKeysStatements) SelectUserRoomPrivateKey(
return result, err return result, err
} }
func (s *userRoomKeysStatements) SelectUserRoomPublicKey(
ctx context.Context,
txn *sql.Tx,
userNID types.EventStateKeyNID,
roomNID types.RoomNID,
) (ed25519.PublicKey, error) {
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomPublicKeyStmt)
var result ed25519.PublicKey
err := stmt.QueryRowContext(ctx, userNID, roomNID).Scan(&result)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return result, err
}
func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) { func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) {
roomNIDs := make([]any, 0, len(senderKeys)) roomNIDs := make([]any, 0, len(senderKeys))

View file

@ -193,6 +193,8 @@ type UserRoomKeys interface {
InsertUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (ed25519.PublicKey, error) InsertUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (ed25519.PublicKey, error)
// SelectUserRoomPrivateKey selects the private key for the given user and room combination // SelectUserRoomPrivateKey selects the private key for the given user and room combination
SelectUserRoomPrivateKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PrivateKey, error) SelectUserRoomPrivateKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PrivateKey, error)
// SelectUserRoomPublicKey selects the public key for the given user and room combination
SelectUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PublicKey, error)
// BulkSelectUserNIDs selects all userIDs for the requested senderKeys. Returns a map from publicKey -> types.UserRoomKeyPair. // BulkSelectUserNIDs selects all userIDs for the requested senderKeys. Returns a map from publicKey -> types.UserRoomKeyPair.
// If a senderKey can't be found, it is omitted in the result. // If a senderKey can't be found, it is omitted in the result.
BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error)

View file

@ -50,6 +50,7 @@ func TestUserRoomKeysTable(t *testing.T) {
err = sqlutil.WithTransaction(db, func(txn *sql.Tx) error { err = sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
var gotKey, key2, key3 ed25519.PrivateKey var gotKey, key2, key3 ed25519.PrivateKey
var pubKey ed25519.PublicKey
gotKey, err = tab.InsertUserRoomPrivatePublicKey(context.Background(), txn, userNID, roomNID, key) gotKey, err = tab.InsertUserRoomPrivatePublicKey(context.Background(), txn, userNID, roomNID, key)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, gotKey, key) assert.Equal(t, gotKey, key)
@ -71,6 +72,9 @@ func TestUserRoomKeysTable(t *testing.T) {
gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, roomNID) gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, roomNID)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, key, gotKey) assert.Equal(t, key, gotKey)
pubKey, err = tab.SelectUserRoomPublicKey(context.Background(), txn, userNID, roomNID)
assert.NoError(t, err)
assert.Equal(t, key.Public(), pubKey)
// try to update an existing key, this should only be done for users NOT on this homeserver // try to update an existing key, this should only be done for users NOT on this homeserver
var gotPubKey ed25519.PublicKey var gotPubKey ed25519.PublicKey
@ -82,6 +86,9 @@ func TestUserRoomKeysTable(t *testing.T) {
gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, 2) gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, 2)
assert.NoError(t, err) assert.NoError(t, err)
assert.Nil(t, gotKey) assert.Nil(t, gotKey)
pubKey, err = tab.SelectUserRoomPublicKey(context.Background(), txn, userNID, 2)
assert.NoError(t, err)
assert.Nil(t, pubKey)
// query user NIDs for senderKeys // query user NIDs for senderKeys
var gotKeys map[string]types.UserRoomKeyPair var gotKeys map[string]types.UserRoomKeyPair

View file

@ -529,7 +529,7 @@ func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID str
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
} }
func (r *testRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { func (r *testRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
return spec.SenderID(userID.String()), nil return spec.SenderID(userID.String()), nil
} }

View file

@ -139,7 +139,11 @@ func ApplyHistoryVisibilityFilter(
if err != nil { if err != nil {
return nil, err return nil, err
} }
senderID, err := rsAPI.QuerySenderIDForUser(ctx, ev.RoomID(), *user) roomID, err := spec.NewRoomID(ev.RoomID())
if err != nil {
return nil, err
}
senderID, err := rsAPI.QuerySenderIDForUser(ctx, *roomID, *user)
if err == nil { if err == nil {
if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(string(senderID)) { if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(string(senderID)) {
eventsFiltered = append(eventsFiltered, ev) eventsFiltered = append(eventsFiltered, ev)

View file

@ -114,7 +114,14 @@ func (d *Database) StreamEventsToEvents(ctx context.Context, device *userapi.Dev
}).WithError(err).Warnf("Failed to add transaction ID to event") }).WithError(err).Warnf("Failed to add transaction ID to event")
continue continue
} }
deviceSenderID, err := rsAPI.QuerySenderIDForUser(ctx, in[i].RoomID(), *userID) roomID, err := spec.NewRoomID(in[i].RoomID())
if err != nil {
logrus.WithFields(logrus.Fields{
"event_id": out[i].EventID(),
}).WithError(err).Warnf("Room ID is invalid")
continue
}
deviceSenderID, err := rsAPI.QuerySenderIDForUser(ctx, *roomID, *userID)
if err != nil { if err != nil {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"event_id": out[i].EventID(), "event_id": out[i].EventID(),
@ -515,7 +522,11 @@ func getMembershipFromEvent(ctx context.Context, ev gomatrixserverlib.PDU, userI
if err != nil { if err != nil {
return "", "" return "", ""
} }
senderID, err := rsAPI.QuerySenderIDForUser(ctx, ev.RoomID(), *fullUser) roomID, err := spec.NewRoomID(ev.RoomID())
if err != nil {
return "", ""
}
senderID, err := rsAPI.QuerySenderIDForUser(ctx, *roomID, *fullUser)
if err != nil { if err != nil {
return "", "" return "", ""
} }

View file

@ -818,7 +818,13 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes
logger.WithError(err).Errorf("Failed to convert local user to userID %s", localpart) logger.WithError(err).Errorf("Failed to convert local user to userID %s", localpart)
return nil, err return nil, err
} }
localSender, err := s.rsAPI.QuerySenderIDForUser(ctx, event.RoomID(), *userID) roomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
logger.WithError(err).Errorf("event roomID is invalid %s", event.RoomID())
return nil, err
}
localSender, err := s.rsAPI.QuerySenderIDForUser(ctx, *roomID, *userID)
if err != nil { if err != nil {
logger.WithError(err).Errorf("Failed to get local user senderID for room %s: %s", userID.String(), event.RoomID()) logger.WithError(err).Errorf("Failed to get local user senderID for room %s: %s", userID.String(), event.RoomID())
return nil, err return nil, err