diff --git a/src/github.com/matrix-org/dendrite/clientapi/readers/logout.go b/src/github.com/matrix-org/dendrite/clientapi/readers/logout.go index 62aaee1c3..585527fc9 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/readers/logout.go +++ b/src/github.com/matrix-org/dendrite/clientapi/readers/logout.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -35,7 +36,11 @@ func Logout( } } - localpart := getLocalPart(device.UserID) + localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + return httputil.LogThenError(req, err) + } + if err := deviceDB.RemoveDevice(device.ID, localpart); err != nil { return httputil.LogThenError(req, err) } diff --git a/src/github.com/matrix-org/dendrite/clientapi/readers/profile.go b/src/github.com/matrix-org/dendrite/clientapi/readers/profile.go index dcdb14b44..6a3b4b377 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/readers/profile.go +++ b/src/github.com/matrix-org/dendrite/clientapi/readers/profile.go @@ -15,14 +15,13 @@ package readers import ( - "fmt" "net/http" - "strings" "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -50,7 +49,11 @@ func GetProfile( JSON: jsonerror.NotFound("Bad method"), } } - localpart := getLocalPart(userID) + localpart, _, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return httputil.LogThenError(req, err) + } + profile, err := accountDB.GetProfileByLocalpart(localpart) if err != nil { return httputil.LogThenError(req, err) @@ -69,7 +72,11 @@ func GetProfile( func GetAvatarURL( req *http.Request, accountDB *accounts.Database, userID string, ) util.JSONResponse { - localpart := getLocalPart(userID) + localpart, _, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return httputil.LogThenError(req, err) + } + profile, err := accountDB.GetProfileByLocalpart(localpart) if err != nil { return httputil.LogThenError(req, err) @@ -99,7 +106,10 @@ func SetAvatarURL( } } - localpart := getLocalPart(userID) + localpart, _, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return httputil.LogThenError(req, err) + } oldProfile, err := accountDB.GetProfileByLocalpart(localpart) if err != nil { @@ -124,7 +134,11 @@ func SetAvatarURL( func GetDisplayName( req *http.Request, accountDB *accounts.Database, userID string, ) util.JSONResponse { - localpart := getLocalPart(userID) + localpart, _, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return httputil.LogThenError(req, err) + } + profile, err := accountDB.GetProfileByLocalpart(localpart) if err != nil { return httputil.LogThenError(req, err) @@ -154,7 +168,10 @@ func SetDisplayName( } } - localpart := getLocalPart(userID) + localpart, _, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return httputil.LogThenError(req, err) + } oldProfile, err := accountDB.GetProfileByLocalpart(localpart) if err != nil { @@ -174,14 +191,3 @@ func SetDisplayName( JSON: struct{}{}, } } - -func getLocalPart(userID string) string { - if !strings.HasPrefix(userID, "@") { - panic(fmt.Errorf("Invalid user ID")) - } - - // Get the part before ":" - username := strings.Split(userID, ":")[0] - // Return the part after the "@" - return strings.Split(username, "@")[1] -}