diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index 5bf1b225d..805d05259 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -466,7 +466,7 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { } err := s.QueryPolicyVersion(req.Context(), &request, &response) if err != nil { - return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response} + return util.ErrorResponse(err) } return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go index 6fe3b914b..674f0cdb2 100644 --- a/userapi/storage/postgres/accounts_table.go +++ b/userapi/storage/postgres/accounts_table.go @@ -140,7 +140,12 @@ func (s *accountsStatements) InsertAccount( createdTimeMS := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.insertAccountStmt) - _, err := stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType, policyVersion) + var err error + if accountType != api.AccountTypeAppService { + _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType) + } else { + _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType) + } if err != nil { return nil, err } @@ -214,10 +219,8 @@ func (s *accountsStatements) SelectNewNumericLocalpart( func (s *accountsStatements) SelectPrivacyPolicy( ctx context.Context, txn *sql.Tx, localPart string, ) (policy string, err error) { - stmt := s.selectPrivacyPolicyStmt - if txn != nil { - stmt = sqlutil.TxStmt(txn, stmt) - } + stmt := sqlutil.TxStmt(txn, s.selectPrivacyPolicyStmt) + err = stmt.QueryRowContext(ctx, localPart).Scan(&policy) return } @@ -226,11 +229,11 @@ func (s *accountsStatements) SelectPrivacyPolicy( func (s *accountsStatements) BatchSelectPrivacyPolicy( ctx context.Context, txn *sql.Tx, policyVersion string, ) (userIDs []string, err error) { - stmt := s.batchSelectPrivacyPolicyStmt - if txn != nil { - stmt = sqlutil.TxStmt(txn, stmt) - } + stmt := sqlutil.TxStmt(txn, s.batchSelectPrivacyPolicyStmt) rows, err := stmt.QueryContext(ctx, policyVersion) + if err != nil { + return nil, err + } defer internal.CloseAndLogIfError(ctx, rows, "BatchSelectPrivacyPolicy: rows.close() failed") for rows.Next() { var userID string @@ -250,9 +253,7 @@ func (s *accountsStatements) UpdatePolicyVersion( if serverNotice { stmt = s.updatePolicyVersionServerNoticeStmt } - if txn != nil { - stmt = sqlutil.TxStmt(txn, stmt) - } + stmt = sqlutil.TxStmt(txn, stmt) _, err = stmt.ExecContext(ctx, policyVersion, localpart) return err } @@ -261,31 +262,23 @@ func (s *accountsStatements) UpdatePolicyVersion( func (s *accountsStatements) SelectServerNoticeRoomID( ctx context.Context, txn *sql.Tx, localpart string, ) (roomID string, err error) { - stmt := s.selectServerNoticeRoomStmt - if txn != nil { - stmt = sqlutil.TxStmt(txn, stmt) - } + stmt := sqlutil.TxStmt(txn, s.selectServerNoticeRoomStmt) roomIDNull := sql.NullString{} row := stmt.QueryRowContext(ctx, localpart) err = row.Scan(&roomIDNull) - if err != nil { + if err != nil && err != sql.ErrNoRows { return "", err } - if roomIDNull.Valid { - return roomIDNull.String, nil - } - return "", nil + // roomIDNull.String is either the roomID or an empty string + return roomIDNull.String, nil } // UpdateServerNoticeRoomID sets the server notice room ID. func (s *accountsStatements) UpdateServerNoticeRoomID( ctx context.Context, txn *sql.Tx, localpart, roomID string, ) (err error) { - stmt := s.updateServerNoticeRoomStmt - if txn != nil { - stmt = sqlutil.TxStmt(txn, stmt) - } + stmt := sqlutil.TxStmt(txn, s.updateServerNoticeRoomStmt) _, err = stmt.ExecContext(ctx, roomID, localpart) return } diff --git a/userapi/storage/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go index f9c2a1aea..6390cfc6a 100644 --- a/userapi/storage/sqlite3/accounts_table.go +++ b/userapi/storage/sqlite3/accounts_table.go @@ -142,7 +142,12 @@ func (s *accountsStatements) InsertAccount( createdTimeMS := time.Now().UnixNano() / 1000000 stmt := s.insertAccountStmt - _, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType, policyVersion) + var err error + if accountType != api.AccountTypeAppService { + _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType) + } else { + _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType) + } if err != nil { return nil, err } @@ -220,10 +225,7 @@ func (s *accountsStatements) SelectNewNumericLocalpart( func (s *accountsStatements) SelectPrivacyPolicy( ctx context.Context, txn *sql.Tx, localPart string, ) (policy string, err error) { - stmt := s.selectPrivacyPolicyStmt - if txn != nil { - stmt = sqlutil.TxStmt(txn, stmt) - } + stmt := sqlutil.TxStmt(txn, s.selectPrivacyPolicyStmt) err = stmt.QueryRowContext(ctx, localPart).Scan(&policy) return } @@ -232,10 +234,7 @@ func (s *accountsStatements) SelectPrivacyPolicy( func (s *accountsStatements) BatchSelectPrivacyPolicy( ctx context.Context, txn *sql.Tx, policyVersion string, ) (userIDs []string, err error) { - stmt := s.batchSelectPrivacyPolicyStmt - if txn != nil { - stmt = sqlutil.TxStmt(txn, stmt) - } + stmt := sqlutil.TxStmt(txn, s.batchSelectPrivacyPolicyStmt) rows, err := stmt.QueryContext(ctx, policyVersion, policyVersion) if err != nil { return nil, err @@ -259,9 +258,7 @@ func (s *accountsStatements) UpdatePolicyVersion( if serverNotice { stmt = s.updatePolicyVersionServerNoticeStmt } - if txn != nil { - stmt = sqlutil.TxStmt(txn, stmt) - } + stmt = sqlutil.TxStmt(txn, stmt) _, err = stmt.ExecContext(ctx, policyVersion, localpart) return err } @@ -270,31 +267,23 @@ func (s *accountsStatements) UpdatePolicyVersion( func (s *accountsStatements) SelectServerNoticeRoomID( ctx context.Context, txn *sql.Tx, localpart string, ) (roomID string, err error) { - stmt := s.selectServerNoticeRoomStmt - if txn != nil { - stmt = sqlutil.TxStmt(txn, stmt) - } + stmt := sqlutil.TxStmt(txn, s.selectServerNoticeRoomStmt) roomIDNull := sql.NullString{} row := stmt.QueryRowContext(ctx, localpart) err = row.Scan(&roomIDNull) - if err != nil { + if err != nil && err != sql.ErrNoRows { return "", err } - if roomIDNull.Valid { - return roomIDNull.String, nil - } - return "", nil + // roomIDNull.String is either the roomID or an empty string + return roomIDNull.String, nil } // UpdateServerNoticeRoomID sets the server notice room ID. func (s *accountsStatements) UpdateServerNoticeRoomID( ctx context.Context, txn *sql.Tx, localpart, roomID string, ) (err error) { - stmt := s.updateServerNoticeRoomStmt - if txn != nil { - stmt = sqlutil.TxStmt(txn, stmt) - } + stmt := sqlutil.TxStmt(txn, s.updateServerNoticeRoomStmt) _, err = stmt.ExecContext(ctx, roomID, localpart) return }