Fix userapi issues
This commit is contained in:
parent
cd7a7606a1
commit
4d5feb2544
|
@ -466,7 +466,7 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
|||
}
|
||||
err := s.QueryPolicyVersion(req.Context(), &request, &response)
|
||||
if err != nil {
|
||||
return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response}
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
|
|
|
@ -140,7 +140,12 @@ func (s *accountsStatements) InsertAccount(
|
|||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
|
||||
|
||||
_, err := stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType, policyVersion)
|
||||
var err error
|
||||
if accountType != api.AccountTypeAppService {
|
||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
|
||||
} else {
|
||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -214,10 +219,8 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
|
|||
func (s *accountsStatements) SelectPrivacyPolicy(
|
||||
ctx context.Context, txn *sql.Tx, localPart string,
|
||||
) (policy string, err error) {
|
||||
stmt := s.selectPrivacyPolicyStmt
|
||||
if txn != nil {
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, s.selectPrivacyPolicyStmt)
|
||||
|
||||
err = stmt.QueryRowContext(ctx, localPart).Scan(&policy)
|
||||
return
|
||||
}
|
||||
|
@ -226,11 +229,11 @@ func (s *accountsStatements) SelectPrivacyPolicy(
|
|||
func (s *accountsStatements) BatchSelectPrivacyPolicy(
|
||||
ctx context.Context, txn *sql.Tx, policyVersion string,
|
||||
) (userIDs []string, err error) {
|
||||
stmt := s.batchSelectPrivacyPolicyStmt
|
||||
if txn != nil {
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, s.batchSelectPrivacyPolicyStmt)
|
||||
rows, err := stmt.QueryContext(ctx, policyVersion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "BatchSelectPrivacyPolicy: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var userID string
|
||||
|
@ -250,9 +253,7 @@ func (s *accountsStatements) UpdatePolicyVersion(
|
|||
if serverNotice {
|
||||
stmt = s.updatePolicyVersionServerNoticeStmt
|
||||
}
|
||||
if txn != nil {
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
}
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
_, err = stmt.ExecContext(ctx, policyVersion, localpart)
|
||||
return err
|
||||
}
|
||||
|
@ -261,31 +262,23 @@ func (s *accountsStatements) UpdatePolicyVersion(
|
|||
func (s *accountsStatements) SelectServerNoticeRoomID(
|
||||
ctx context.Context, txn *sql.Tx, localpart string,
|
||||
) (roomID string, err error) {
|
||||
stmt := s.selectServerNoticeRoomStmt
|
||||
if txn != nil {
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, s.selectServerNoticeRoomStmt)
|
||||
|
||||
roomIDNull := sql.NullString{}
|
||||
row := stmt.QueryRowContext(ctx, localpart)
|
||||
err = row.Scan(&roomIDNull)
|
||||
if err != nil {
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return "", err
|
||||
}
|
||||
if roomIDNull.Valid {
|
||||
return roomIDNull.String, nil
|
||||
}
|
||||
return "", nil
|
||||
// roomIDNull.String is either the roomID or an empty string
|
||||
return roomIDNull.String, nil
|
||||
}
|
||||
|
||||
// UpdateServerNoticeRoomID sets the server notice room ID.
|
||||
func (s *accountsStatements) UpdateServerNoticeRoomID(
|
||||
ctx context.Context, txn *sql.Tx, localpart, roomID string,
|
||||
) (err error) {
|
||||
stmt := s.updateServerNoticeRoomStmt
|
||||
if txn != nil {
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, s.updateServerNoticeRoomStmt)
|
||||
_, err = stmt.ExecContext(ctx, roomID, localpart)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -142,7 +142,12 @@ func (s *accountsStatements) InsertAccount(
|
|||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
stmt := s.insertAccountStmt
|
||||
|
||||
_, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType, policyVersion)
|
||||
var err error
|
||||
if accountType != api.AccountTypeAppService {
|
||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
|
||||
} else {
|
||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -220,10 +225,7 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
|
|||
func (s *accountsStatements) SelectPrivacyPolicy(
|
||||
ctx context.Context, txn *sql.Tx, localPart string,
|
||||
) (policy string, err error) {
|
||||
stmt := s.selectPrivacyPolicyStmt
|
||||
if txn != nil {
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, s.selectPrivacyPolicyStmt)
|
||||
err = stmt.QueryRowContext(ctx, localPart).Scan(&policy)
|
||||
return
|
||||
}
|
||||
|
@ -232,10 +234,7 @@ func (s *accountsStatements) SelectPrivacyPolicy(
|
|||
func (s *accountsStatements) BatchSelectPrivacyPolicy(
|
||||
ctx context.Context, txn *sql.Tx, policyVersion string,
|
||||
) (userIDs []string, err error) {
|
||||
stmt := s.batchSelectPrivacyPolicyStmt
|
||||
if txn != nil {
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, s.batchSelectPrivacyPolicyStmt)
|
||||
rows, err := stmt.QueryContext(ctx, policyVersion, policyVersion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -259,9 +258,7 @@ func (s *accountsStatements) UpdatePolicyVersion(
|
|||
if serverNotice {
|
||||
stmt = s.updatePolicyVersionServerNoticeStmt
|
||||
}
|
||||
if txn != nil {
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
}
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
_, err = stmt.ExecContext(ctx, policyVersion, localpart)
|
||||
return err
|
||||
}
|
||||
|
@ -270,31 +267,23 @@ func (s *accountsStatements) UpdatePolicyVersion(
|
|||
func (s *accountsStatements) SelectServerNoticeRoomID(
|
||||
ctx context.Context, txn *sql.Tx, localpart string,
|
||||
) (roomID string, err error) {
|
||||
stmt := s.selectServerNoticeRoomStmt
|
||||
if txn != nil {
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, s.selectServerNoticeRoomStmt)
|
||||
|
||||
roomIDNull := sql.NullString{}
|
||||
row := stmt.QueryRowContext(ctx, localpart)
|
||||
err = row.Scan(&roomIDNull)
|
||||
if err != nil {
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return "", err
|
||||
}
|
||||
if roomIDNull.Valid {
|
||||
return roomIDNull.String, nil
|
||||
}
|
||||
return "", nil
|
||||
// roomIDNull.String is either the roomID or an empty string
|
||||
return roomIDNull.String, nil
|
||||
}
|
||||
|
||||
// UpdateServerNoticeRoomID sets the server notice room ID.
|
||||
func (s *accountsStatements) UpdateServerNoticeRoomID(
|
||||
ctx context.Context, txn *sql.Tx, localpart, roomID string,
|
||||
) (err error) {
|
||||
stmt := s.updateServerNoticeRoomStmt
|
||||
if txn != nil {
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, s.updateServerNoticeRoomStmt)
|
||||
_, err = stmt.ExecContext(ctx, roomID, localpart)
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue