diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go index 69c145a49..3c103fd72 100644 --- a/clientapi/routing/key_crosssigning.go +++ b/clientapi/routing/key_crosssigning.go @@ -15,10 +15,11 @@ package routing import ( + "encoding/json" + "io/ioutil" "net/http" "github.com/matrix-org/dendrite/clientapi/auth" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/keyserver/api" @@ -28,51 +29,37 @@ import ( "github.com/matrix-org/util" ) -type crossSigningRequest struct { - api.PerformUploadDeviceKeysRequest - Auth newPasswordAuth `json:"auth"` -} - func UploadCrossSigningDeviceKeys( - req *http.Request, keyserverAPI api.KeyInternalAPI, device *userapi.Device, + req *http.Request, userInteractiveAuth *auth.UserInteractive, + keyserverAPI api.KeyInternalAPI, device *userapi.Device, accountDB accounts.Database, cfg *config.ClientAPI, ) util.JSONResponse { - uploadReq := &crossSigningRequest{} + uploadReq := &api.PerformUploadDeviceKeysRequest{} uploadRes := &api.PerformUploadDeviceKeysResponse{} - resErr := httputil.UnmarshalJSONRequest(req, &uploadReq) - if resErr != nil { - return *resErr - } - sessionID := uploadReq.Auth.Session - if sessionID == "" { - sessionID = util.RandomString(sessionIDLength) - } - if uploadReq.Auth.Type != authtypes.LoginTypePassword { + ctx := req.Context() + defer req.Body.Close() // nolint:errcheck + bodyBytes, err := ioutil.ReadAll(req.Body) + if err != nil { return util.JSONResponse{ - Code: http.StatusUnauthorized, - JSON: newUserInteractiveResponse( - sessionID, - []authtypes.Flow{ - { - Stages: []authtypes.LoginType{authtypes.LoginTypePassword}, - }, - }, - nil, - ), + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("The request body could not be read: " + err.Error()), } } - typePassword := auth.LoginTypePassword{ - GetAccountByPassword: accountDB.GetAccountByPassword, - Config: cfg, + + if _, err := userInteractiveAuth.Verify(ctx, bodyBytes, device); err != nil { + return *err } - if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil { - return *authErr + + if err = json.Unmarshal(bodyBytes, &uploadReq); err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("The request body could not be unmarshalled: " + err.Error()), + } } - AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword) uploadReq.UserID = device.UserID - keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes) + keyserverAPI.PerformUploadDeviceKeys(req.Context(), uploadReq, uploadRes) if err := uploadRes.Error; err != nil { switch { diff --git a/keyserver/internal/cross_signing.go b/keyserver/internal/cross_signing.go index 94f2583c7..bfcb2a46e 100644 --- a/keyserver/internal/cross_signing.go +++ b/keyserver/internal/cross_signing.go @@ -1,3 +1,17 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package internal import ( @@ -40,6 +54,7 @@ func sanityCheckKey(key gomatrixserverlib.CrossSigningKey, userID string, purpos for _, usage := range key.Usage { if usage == purpose { useful = true + break } } if !useful { @@ -50,6 +65,7 @@ func sanityCheckKey(key gomatrixserverlib.CrossSigningKey, userID string, purpos } func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) { + var masterKey gomatrixserverlib.Base64Bytes hasMasterKey := false if len(req.MasterKey.Keys) > 0 { @@ -60,6 +76,9 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P return } hasMasterKey = true + for _, keyData := range req.MasterKey.Keys { // iterates once, because sanityCheckKey requires one key + masterKey = keyData + } } if len(req.SelfSigningKey.Keys) > 0 { @@ -82,29 +101,29 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P // If the user hasn't given a new master key, then let's go and get their // existing keys from the database. - var masterKey gomatrixserverlib.Base64Bytes if !hasMasterKey { existingKeys, err := a.DB.CrossSigningKeysForUser(ctx, req.UserID) if err != nil { res.Error = &api.KeyError{ - Err: "User-signing key sanity check failed: " + err.Error(), + Err: "Retrieving cross-signing keys from database failed: " + err.Error(), } return } masterKey, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster] - if !hasMasterKey { - res.Error = &api.KeyError{ - Err: "No master key was found, either in the database or in the request!", - IsMissingParam: true, - } - return - } - } else { - for _, keyData := range req.MasterKey.Keys { // iterates once, see sanityCheckKey - masterKey = keyData - } } + + // If we still don't have a master key at this point then there's nothing else + // we can do - we've checked both the request and the database. + if !hasMasterKey { + res.Error = &api.KeyError{ + Err: "No master key was found, either in the database or in the request!", + IsMissingParam: true, + } + return + } + + // The key ID is basically the key itself. masterKeyID := gomatrixserverlib.KeyID(fmt.Sprintf("ed25519:%s", masterKey.Encode())) // Work out which things we need to verify the signatures for. @@ -116,7 +135,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P if len(req.SelfSigningKey.Keys) > 0 { toVerify[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning] = req.SelfSigningKey } - if len(req.SelfSigningKey.Keys) > 0 { + if len(req.UserSigningKey.Keys) > 0 { toVerify[gomatrixserverlib.CrossSigningKeyPurposeUserSigning] = req.UserSigningKey } for purpose, key := range toVerify { @@ -173,7 +192,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req } selfSignatures[userID][keyID] = keyOrDevice } else { - if _, ok := selfSignatures[userID]; !ok { + if _, ok := otherSignatures[userID]; !ok { otherSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} } otherSignatures[userID][keyID] = keyOrDevice @@ -186,7 +205,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req } selfSignatures[userID][keyID] = keyOrDevice } else { - if _, ok := selfSignatures[userID]; !ok { + if _, ok := otherSignatures[userID]; !ok { otherSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} } otherSignatures[userID][keyID] = keyOrDevice diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index a4944deb0..0126fa066 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -444,6 +444,9 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( continue } } + if len(devKeys) == 0 { + return + } queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, gomatrixserverlib.ServerName(serverName), devKeys) if err == nil { resultCh <- &queryKeysResp diff --git a/keyserver/storage/postgres/cross_signing_keys_table.go b/keyserver/storage/postgres/cross_signing_keys_table.go index 2440977fa..8f3f7054e 100644 --- a/keyserver/storage/postgres/cross_signing_keys_table.go +++ b/keyserver/storage/postgres/cross_signing_keys_table.go @@ -39,7 +39,7 @@ const selectCrossSigningKeysForUserSQL = "" + "SELECT key_type, key_data FROM keyserver_cross_signing_keys" + " WHERE user_id = $1" -const insertCrossSigningKeysForUserSQL = "" + +const upsertCrossSigningKeysForUserSQL = "" + "INSERT INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" + " VALUES($1, $2, $3)" + " ON CONFLICT (user_id, key_type) DO UPDATE SET key_data = $3" @@ -47,7 +47,7 @@ const insertCrossSigningKeysForUserSQL = "" + type crossSigningKeysStatements struct { db *sql.DB selectCrossSigningKeysForUserStmt *sql.Stmt - insertCrossSigningKeysForUserStmt *sql.Stmt + upsertCrossSigningKeysForUserStmt *sql.Stmt } func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) { @@ -58,13 +58,10 @@ func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, erro if err != nil { return nil, err } - if s.selectCrossSigningKeysForUserStmt, err = db.Prepare(selectCrossSigningKeysForUserSQL); err != nil { - return nil, err - } - if s.insertCrossSigningKeysForUserStmt, err = db.Prepare(insertCrossSigningKeysForUserSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL}, + {&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL}, + }.Prepare(db) } func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( @@ -87,11 +84,11 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( return } -func (s *crossSigningKeysStatements) InsertCrossSigningKeysForUser( +func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes, ) error { - if _, err := sqlutil.TxStmt(txn, s.insertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyType, keyData); err != nil { - return fmt.Errorf("s.insertCrossSigningKeysForUserStmt: %w", err) + if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyType, keyData); err != nil { + return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err) } return nil } diff --git a/keyserver/storage/postgres/cross_signing_sigs_table.go b/keyserver/storage/postgres/cross_signing_sigs_table.go index 9d69334c6..cbdd5a5d1 100644 --- a/keyserver/storage/postgres/cross_signing_sigs_table.go +++ b/keyserver/storage/postgres/cross_signing_sigs_table.go @@ -41,7 +41,7 @@ const selectCrossSigningSigsForTargetSQL = "" + "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" + " WHERE target_user_id = $1 AND target_key_id = $2" -const insertCrossSigningSigsForTargetSQL = "" + +const upsertCrossSigningSigsForTargetSQL = "" + "INSERT INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" + " VALUES($1, $2, $3, $4, $5)" + " ON CONFLICT (origin_user_id, target_user_id, target_key_id) DO UPDATE SET (origin_key_id, signature) = ($2, $5)" @@ -49,7 +49,7 @@ const insertCrossSigningSigsForTargetSQL = "" + type crossSigningSigsStatements struct { db *sql.DB selectCrossSigningSigsForTargetStmt *sql.Stmt - insertCrossSigningSigsForTargetStmt *sql.Stmt + upsertCrossSigningSigsForTargetStmt *sql.Stmt } func NewPostgresCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) { @@ -60,13 +60,10 @@ func NewPostgresCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, erro if err != nil { return nil, err } - if s.selectCrossSigningSigsForTargetStmt, err = db.Prepare(selectCrossSigningSigsForTargetSQL); err != nil { - return nil, err - } - if s.insertCrossSigningSigsForTargetStmt, err = db.Prepare(insertCrossSigningSigsForTargetSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL}, + {&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL}, + }.Prepare(db) } func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( @@ -93,14 +90,14 @@ func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( return } -func (s *crossSigningSigsStatements) InsertCrossSigningSigsForTarget( +func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget( ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes, ) error { - if _, err := sqlutil.TxStmt(txn, s.insertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { - return fmt.Errorf("s.insertCrossSigningSigsForTargetStmt: %w", err) + if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { + return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err) } return nil } diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index f191cb623..0d01689c9 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -171,7 +171,7 @@ func (d *Database) CrossSigningSigsForTarget(ctx context.Context, targetUserID s func (d *Database) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap api.CrossSigningKeyMap) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { for keyType, keyData := range keyMap { - if err := d.CrossSigningKeysTable.InsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData); err != nil { + if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData); err != nil { return fmt.Errorf("d.CrossSigningKeysTable.InsertCrossSigningKeysForUser: %w", err) } } @@ -187,7 +187,7 @@ func (d *Database) StoreCrossSigningSigsForTarget( signature gomatrixserverlib.Base64Bytes, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - if err := d.CrossSigningSigsTable.InsertCrossSigningSigsForTarget(ctx, nil, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { + if err := d.CrossSigningSigsTable.UpsertCrossSigningSigsForTarget(ctx, nil, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { return fmt.Errorf("d.CrossSigningSigsTable.InsertCrossSigningSigsForTarget: %w", err) } return nil diff --git a/keyserver/storage/sqlite3/cross_signing_keys_table.go b/keyserver/storage/sqlite3/cross_signing_keys_table.go index 34e7079eb..03b26d3c7 100644 --- a/keyserver/storage/sqlite3/cross_signing_keys_table.go +++ b/keyserver/storage/sqlite3/cross_signing_keys_table.go @@ -39,14 +39,14 @@ const selectCrossSigningKeysForUserSQL = "" + "SELECT key_type, key_data FROM keyserver_cross_signing_keys" + " WHERE user_id = $1" -const insertCrossSigningKeysForUserSQL = "" + +const upsertCrossSigningKeysForUserSQL = "" + "INSERT OR REPLACE INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" + " VALUES($1, $2, $3)" type crossSigningKeysStatements struct { db *sql.DB selectCrossSigningKeysForUserStmt *sql.Stmt - insertCrossSigningKeysForUserStmt *sql.Stmt + upsertCrossSigningKeysForUserStmt *sql.Stmt } func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) { @@ -57,13 +57,10 @@ func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) if err != nil { return nil, err } - if s.selectCrossSigningKeysForUserStmt, err = db.Prepare(selectCrossSigningKeysForUserSQL); err != nil { - return nil, err - } - if s.insertCrossSigningKeysForUserStmt, err = db.Prepare(insertCrossSigningKeysForUserSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL}, + {&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL}, + }.Prepare(db) } func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( @@ -86,11 +83,11 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( return } -func (s *crossSigningKeysStatements) InsertCrossSigningKeysForUser( +func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes, ) error { - if _, err := sqlutil.TxStmt(txn, s.insertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyType, keyData); err != nil { - return fmt.Errorf("s.insertCrossSigningKeysForUserStmt: %w", err) + if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyType, keyData); err != nil { + return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err) } return nil } diff --git a/keyserver/storage/sqlite3/cross_signing_sigs_table.go b/keyserver/storage/sqlite3/cross_signing_sigs_table.go index b839096bf..120921c42 100644 --- a/keyserver/storage/sqlite3/cross_signing_sigs_table.go +++ b/keyserver/storage/sqlite3/cross_signing_sigs_table.go @@ -41,14 +41,14 @@ const selectCrossSigningSigsForTargetSQL = "" + "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" + " WHERE target_user_id = $1 AND target_key_id = $2" -const insertCrossSigningSigsForTargetSQL = "" + +const upsertCrossSigningSigsForTargetSQL = "" + "INSERT OR REPLACE INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" + " VALUES($1, $2, $3, $4, $5)" type crossSigningSigsStatements struct { db *sql.DB selectCrossSigningSigsForTargetStmt *sql.Stmt - insertCrossSigningSigsForTargetStmt *sql.Stmt + upsertCrossSigningSigsForTargetStmt *sql.Stmt } func NewSqliteCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) { @@ -59,13 +59,10 @@ func NewSqliteCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) if err != nil { return nil, err } - if s.selectCrossSigningSigsForTargetStmt, err = db.Prepare(selectCrossSigningSigsForTargetSQL); err != nil { - return nil, err - } - if s.insertCrossSigningSigsForTargetStmt, err = db.Prepare(insertCrossSigningSigsForTargetSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL}, + {&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL}, + }.Prepare(db) } func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( @@ -92,14 +89,14 @@ func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( return } -func (s *crossSigningSigsStatements) InsertCrossSigningSigsForTarget( +func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget( ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes, ) error { - if _, err := sqlutil.TxStmt(txn, s.insertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { - return fmt.Errorf("s.insertCrossSigningSigsForTargetStmt: %w", err) + if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { + return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err) } return nil } diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index 2179215c9..e5e253877 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -55,12 +55,10 @@ type StaleDeviceLists interface { type CrossSigningKeys interface { SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r api.CrossSigningKeyMap, err error) - InsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes) error + UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes) error } type CrossSigningSigs interface { SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r api.CrossSigningSigMap, err error) - InsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error + UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error } - -type CrossSigningStreams interface{}