Fix userapi issues

This commit is contained in:
Till Faelligen 2022-05-04 14:33:51 +02:00
parent cd7a7606a1
commit 4d5feb2544
3 changed files with 33 additions and 51 deletions

View file

@ -466,7 +466,7 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
}
err := s.QueryPolicyVersion(req.Context(), &request, &response)
if err != nil {
return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response}
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),

View file

@ -140,7 +140,12 @@ func (s *accountsStatements) InsertAccount(
createdTimeMS := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
_, err := stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType, policyVersion)
var err error
if accountType != api.AccountTypeAppService {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
} else {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
}
if err != nil {
return nil, err
}
@ -214,10 +219,8 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
func (s *accountsStatements) SelectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, localPart string,
) (policy string, err error) {
stmt := s.selectPrivacyPolicyStmt
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
stmt := sqlutil.TxStmt(txn, s.selectPrivacyPolicyStmt)
err = stmt.QueryRowContext(ctx, localPart).Scan(&policy)
return
}
@ -226,11 +229,11 @@ func (s *accountsStatements) SelectPrivacyPolicy(
func (s *accountsStatements) BatchSelectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, policyVersion string,
) (userIDs []string, err error) {
stmt := s.batchSelectPrivacyPolicyStmt
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
stmt := sqlutil.TxStmt(txn, s.batchSelectPrivacyPolicyStmt)
rows, err := stmt.QueryContext(ctx, policyVersion)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "BatchSelectPrivacyPolicy: rows.close() failed")
for rows.Next() {
var userID string
@ -250,9 +253,7 @@ func (s *accountsStatements) UpdatePolicyVersion(
if serverNotice {
stmt = s.updatePolicyVersionServerNoticeStmt
}
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
stmt = sqlutil.TxStmt(txn, stmt)
_, err = stmt.ExecContext(ctx, policyVersion, localpart)
return err
}
@ -261,31 +262,23 @@ func (s *accountsStatements) UpdatePolicyVersion(
func (s *accountsStatements) SelectServerNoticeRoomID(
ctx context.Context, txn *sql.Tx, localpart string,
) (roomID string, err error) {
stmt := s.selectServerNoticeRoomStmt
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
stmt := sqlutil.TxStmt(txn, s.selectServerNoticeRoomStmt)
roomIDNull := sql.NullString{}
row := stmt.QueryRowContext(ctx, localpart)
err = row.Scan(&roomIDNull)
if err != nil {
if err != nil && err != sql.ErrNoRows {
return "", err
}
if roomIDNull.Valid {
return roomIDNull.String, nil
}
return "", nil
// roomIDNull.String is either the roomID or an empty string
return roomIDNull.String, nil
}
// UpdateServerNoticeRoomID sets the server notice room ID.
func (s *accountsStatements) UpdateServerNoticeRoomID(
ctx context.Context, txn *sql.Tx, localpart, roomID string,
) (err error) {
stmt := s.updateServerNoticeRoomStmt
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
stmt := sqlutil.TxStmt(txn, s.updateServerNoticeRoomStmt)
_, err = stmt.ExecContext(ctx, roomID, localpart)
return
}

View file

@ -142,7 +142,12 @@ func (s *accountsStatements) InsertAccount(
createdTimeMS := time.Now().UnixNano() / 1000000
stmt := s.insertAccountStmt
_, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType, policyVersion)
var err error
if accountType != api.AccountTypeAppService {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
} else {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
}
if err != nil {
return nil, err
}
@ -220,10 +225,7 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
func (s *accountsStatements) SelectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, localPart string,
) (policy string, err error) {
stmt := s.selectPrivacyPolicyStmt
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
stmt := sqlutil.TxStmt(txn, s.selectPrivacyPolicyStmt)
err = stmt.QueryRowContext(ctx, localPart).Scan(&policy)
return
}
@ -232,10 +234,7 @@ func (s *accountsStatements) SelectPrivacyPolicy(
func (s *accountsStatements) BatchSelectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, policyVersion string,
) (userIDs []string, err error) {
stmt := s.batchSelectPrivacyPolicyStmt
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
stmt := sqlutil.TxStmt(txn, s.batchSelectPrivacyPolicyStmt)
rows, err := stmt.QueryContext(ctx, policyVersion, policyVersion)
if err != nil {
return nil, err
@ -259,9 +258,7 @@ func (s *accountsStatements) UpdatePolicyVersion(
if serverNotice {
stmt = s.updatePolicyVersionServerNoticeStmt
}
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
stmt = sqlutil.TxStmt(txn, stmt)
_, err = stmt.ExecContext(ctx, policyVersion, localpart)
return err
}
@ -270,31 +267,23 @@ func (s *accountsStatements) UpdatePolicyVersion(
func (s *accountsStatements) SelectServerNoticeRoomID(
ctx context.Context, txn *sql.Tx, localpart string,
) (roomID string, err error) {
stmt := s.selectServerNoticeRoomStmt
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
stmt := sqlutil.TxStmt(txn, s.selectServerNoticeRoomStmt)
roomIDNull := sql.NullString{}
row := stmt.QueryRowContext(ctx, localpart)
err = row.Scan(&roomIDNull)
if err != nil {
if err != nil && err != sql.ErrNoRows {
return "", err
}
if roomIDNull.Valid {
return roomIDNull.String, nil
}
return "", nil
// roomIDNull.String is either the roomID or an empty string
return roomIDNull.String, nil
}
// UpdateServerNoticeRoomID sets the server notice room ID.
func (s *accountsStatements) UpdateServerNoticeRoomID(
ctx context.Context, txn *sql.Tx, localpart, roomID string,
) (err error) {
stmt := s.updateServerNoticeRoomStmt
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
stmt := sqlutil.TxStmt(txn, s.updateServerNoticeRoomStmt)
_, err = stmt.ExecContext(ctx, roomID, localpart)
return
}