Review comments

This commit is contained in:
Neil Alexander 2022-11-11 16:07:50 +00:00
parent 3b76b701e6
commit 2970bfd8ff
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
3 changed files with 15 additions and 4 deletions

View file

@ -102,6 +102,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
serverName := cfg.Matrix.ServerName
localpart, ok := vars["localpart"] localpart, ok := vars["localpart"]
if !ok { if !ok {
return util.JSONResponse{ return util.JSONResponse{
@ -109,6 +110,9 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
JSON: jsonerror.MissingArgument("Expecting user localpart."), JSON: jsonerror.MissingArgument("Expecting user localpart."),
} }
} }
if l, s, err := gomatrixserverlib.SplitID('@', localpart); err == nil {
localpart, serverName = l, s
}
request := struct { request := struct {
Password string `json:"password"` Password string `json:"password"`
}{} }{}
@ -126,6 +130,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
} }
updateReq := &userapi.PerformPasswordUpdateRequest{ updateReq := &userapi.PerformPasswordUpdateRequest{
Localpart: localpart, Localpart: localpart,
ServerName: serverName,
Password: request.Password, Password: request.Password,
LogoutDevices: true, LogoutDevices: true,
} }

View file

@ -157,7 +157,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
dendriteAdminRouter.Handle("/admin/resetPassword/{localpart}", dendriteAdminRouter.Handle("/admin/resetPassword/{userID}",
httputil.MakeAdminAPI("admin_reset_password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAdminAPI("admin_reset_password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminResetPassword(req, cfg, device, userAPI) return AdminResetPassword(req, cfg, device, userAPI)
}), }),

View file

@ -175,6 +175,9 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
if serverName == "" { if serverName == "" {
serverName = a.Config.Matrix.ServerName serverName = a.Config.Matrix.ServerName
} }
if !a.Config.Matrix.IsLocalServerName(serverName) {
return fmt.Errorf("server name %s is not local", serverName)
}
acc, err := a.DB.CreateAccount(ctx, req.Localpart, serverName, req.Password, req.AppServiceID, req.AccountType) acc, err := a.DB.CreateAccount(ctx, req.Localpart, serverName, req.Password, req.AppServiceID, req.AccountType)
if err != nil { if err != nil {
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
@ -226,6 +229,9 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
} }
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error { func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
if !a.Config.Matrix.IsLocalServerName(req.ServerName) {
return fmt.Errorf("server name %s is not local", req.ServerName)
}
if err := a.DB.SetPassword(ctx, req.Localpart, req.ServerName, req.Password); err != nil { if err := a.DB.SetPassword(ctx, req.Localpart, req.ServerName, req.Password); err != nil {
return err return err
} }
@ -354,6 +360,9 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
return err return err
} }
if !a.Config.Matrix.IsLocalServerName(domain) {
return fmt.Errorf("server name %s is not local", domain)
}
dev, err := a.DB.GetDeviceByID(ctx, localpart, domain, req.DeviceID) dev, err := a.DB.GetDeviceByID(ctx, localpart, domain, req.DeviceID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
res.DeviceExists = false res.DeviceExists = false
@ -362,9 +371,6 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
util.GetLogger(ctx).WithError(err).Error("deviceDB.GetDeviceByID failed") util.GetLogger(ctx).WithError(err).Error("deviceDB.GetDeviceByID failed")
return err return err
} }
if !a.Config.Matrix.IsLocalServerName(domain) {
return fmt.Errorf("server name %s is not local", domain)
}
res.DeviceExists = true res.DeviceExists = true
if dev.UserID != req.RequestingUserID { if dev.UserID != req.RequestingUserID {