Fix userapi issues

This commit is contained in:
Till Faelligen 2022-05-04 14:33:51 +02:00
parent cd7a7606a1
commit 4d5feb2544
3 changed files with 33 additions and 51 deletions

View file

@ -466,7 +466,7 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
} }
err := s.QueryPolicyVersion(req.Context(), &request, &response) err := s.QueryPolicyVersion(req.Context(), &request, &response)
if err != nil { if err != nil {
return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response} return util.ErrorResponse(err)
} }
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),

View file

@ -140,7 +140,12 @@ func (s *accountsStatements) InsertAccount(
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt) stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
_, err := stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType, policyVersion) var err error
if accountType != api.AccountTypeAppService {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
} else {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -214,10 +219,8 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
func (s *accountsStatements) SelectPrivacyPolicy( func (s *accountsStatements) SelectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, localPart string, ctx context.Context, txn *sql.Tx, localPart string,
) (policy string, err error) { ) (policy string, err error) {
stmt := s.selectPrivacyPolicyStmt stmt := sqlutil.TxStmt(txn, s.selectPrivacyPolicyStmt)
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
err = stmt.QueryRowContext(ctx, localPart).Scan(&policy) err = stmt.QueryRowContext(ctx, localPart).Scan(&policy)
return return
} }
@ -226,11 +229,11 @@ func (s *accountsStatements) SelectPrivacyPolicy(
func (s *accountsStatements) BatchSelectPrivacyPolicy( func (s *accountsStatements) BatchSelectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, policyVersion string, ctx context.Context, txn *sql.Tx, policyVersion string,
) (userIDs []string, err error) { ) (userIDs []string, err error) {
stmt := s.batchSelectPrivacyPolicyStmt stmt := sqlutil.TxStmt(txn, s.batchSelectPrivacyPolicyStmt)
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
rows, err := stmt.QueryContext(ctx, policyVersion) rows, err := stmt.QueryContext(ctx, policyVersion)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "BatchSelectPrivacyPolicy: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "BatchSelectPrivacyPolicy: rows.close() failed")
for rows.Next() { for rows.Next() {
var userID string var userID string
@ -250,9 +253,7 @@ func (s *accountsStatements) UpdatePolicyVersion(
if serverNotice { if serverNotice {
stmt = s.updatePolicyVersionServerNoticeStmt stmt = s.updatePolicyVersionServerNoticeStmt
} }
if txn != nil { stmt = sqlutil.TxStmt(txn, stmt)
stmt = sqlutil.TxStmt(txn, stmt)
}
_, err = stmt.ExecContext(ctx, policyVersion, localpart) _, err = stmt.ExecContext(ctx, policyVersion, localpart)
return err return err
} }
@ -261,31 +262,23 @@ func (s *accountsStatements) UpdatePolicyVersion(
func (s *accountsStatements) SelectServerNoticeRoomID( func (s *accountsStatements) SelectServerNoticeRoomID(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx, localpart string,
) (roomID string, err error) { ) (roomID string, err error) {
stmt := s.selectServerNoticeRoomStmt stmt := sqlutil.TxStmt(txn, s.selectServerNoticeRoomStmt)
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
roomIDNull := sql.NullString{} roomIDNull := sql.NullString{}
row := stmt.QueryRowContext(ctx, localpart) row := stmt.QueryRowContext(ctx, localpart)
err = row.Scan(&roomIDNull) err = row.Scan(&roomIDNull)
if err != nil { if err != nil && err != sql.ErrNoRows {
return "", err return "", err
} }
if roomIDNull.Valid { // roomIDNull.String is either the roomID or an empty string
return roomIDNull.String, nil return roomIDNull.String, nil
}
return "", nil
} }
// UpdateServerNoticeRoomID sets the server notice room ID. // UpdateServerNoticeRoomID sets the server notice room ID.
func (s *accountsStatements) UpdateServerNoticeRoomID( func (s *accountsStatements) UpdateServerNoticeRoomID(
ctx context.Context, txn *sql.Tx, localpart, roomID string, ctx context.Context, txn *sql.Tx, localpart, roomID string,
) (err error) { ) (err error) {
stmt := s.updateServerNoticeRoomStmt stmt := sqlutil.TxStmt(txn, s.updateServerNoticeRoomStmt)
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
_, err = stmt.ExecContext(ctx, roomID, localpart) _, err = stmt.ExecContext(ctx, roomID, localpart)
return return
} }

View file

@ -142,7 +142,12 @@ func (s *accountsStatements) InsertAccount(
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
stmt := s.insertAccountStmt stmt := s.insertAccountStmt
_, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType, policyVersion) var err error
if accountType != api.AccountTypeAppService {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
} else {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -220,10 +225,7 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
func (s *accountsStatements) SelectPrivacyPolicy( func (s *accountsStatements) SelectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, localPart string, ctx context.Context, txn *sql.Tx, localPart string,
) (policy string, err error) { ) (policy string, err error) {
stmt := s.selectPrivacyPolicyStmt stmt := sqlutil.TxStmt(txn, s.selectPrivacyPolicyStmt)
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
err = stmt.QueryRowContext(ctx, localPart).Scan(&policy) err = stmt.QueryRowContext(ctx, localPart).Scan(&policy)
return return
} }
@ -232,10 +234,7 @@ func (s *accountsStatements) SelectPrivacyPolicy(
func (s *accountsStatements) BatchSelectPrivacyPolicy( func (s *accountsStatements) BatchSelectPrivacyPolicy(
ctx context.Context, txn *sql.Tx, policyVersion string, ctx context.Context, txn *sql.Tx, policyVersion string,
) (userIDs []string, err error) { ) (userIDs []string, err error) {
stmt := s.batchSelectPrivacyPolicyStmt stmt := sqlutil.TxStmt(txn, s.batchSelectPrivacyPolicyStmt)
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
rows, err := stmt.QueryContext(ctx, policyVersion, policyVersion) rows, err := stmt.QueryContext(ctx, policyVersion, policyVersion)
if err != nil { if err != nil {
return nil, err return nil, err
@ -259,9 +258,7 @@ func (s *accountsStatements) UpdatePolicyVersion(
if serverNotice { if serverNotice {
stmt = s.updatePolicyVersionServerNoticeStmt stmt = s.updatePolicyVersionServerNoticeStmt
} }
if txn != nil { stmt = sqlutil.TxStmt(txn, stmt)
stmt = sqlutil.TxStmt(txn, stmt)
}
_, err = stmt.ExecContext(ctx, policyVersion, localpart) _, err = stmt.ExecContext(ctx, policyVersion, localpart)
return err return err
} }
@ -270,31 +267,23 @@ func (s *accountsStatements) UpdatePolicyVersion(
func (s *accountsStatements) SelectServerNoticeRoomID( func (s *accountsStatements) SelectServerNoticeRoomID(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx, localpart string,
) (roomID string, err error) { ) (roomID string, err error) {
stmt := s.selectServerNoticeRoomStmt stmt := sqlutil.TxStmt(txn, s.selectServerNoticeRoomStmt)
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
roomIDNull := sql.NullString{} roomIDNull := sql.NullString{}
row := stmt.QueryRowContext(ctx, localpart) row := stmt.QueryRowContext(ctx, localpart)
err = row.Scan(&roomIDNull) err = row.Scan(&roomIDNull)
if err != nil { if err != nil && err != sql.ErrNoRows {
return "", err return "", err
} }
if roomIDNull.Valid { // roomIDNull.String is either the roomID or an empty string
return roomIDNull.String, nil return roomIDNull.String, nil
}
return "", nil
} }
// UpdateServerNoticeRoomID sets the server notice room ID. // UpdateServerNoticeRoomID sets the server notice room ID.
func (s *accountsStatements) UpdateServerNoticeRoomID( func (s *accountsStatements) UpdateServerNoticeRoomID(
ctx context.Context, txn *sql.Tx, localpart, roomID string, ctx context.Context, txn *sql.Tx, localpart, roomID string,
) (err error) { ) (err error) {
stmt := s.updateServerNoticeRoomStmt stmt := sqlutil.TxStmt(txn, s.updateServerNoticeRoomStmt)
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
_, err = stmt.ExecContext(ctx, roomID, localpart) _, err = stmt.ExecContext(ctx, roomID, localpart)
return return
} }