diff --git a/clientapi/auth/auth.go b/clientapi/auth/auth.go index 93345f4b9..717ab8ae3 100644 --- a/clientapi/auth/auth.go +++ b/clientapi/auth/auth.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -85,6 +86,10 @@ func VerifyUserFromRequest( JSON: jsonerror.UnknownToken("Unknown token"), } } + + // At this point, we should be certain we've got an actual UserID + _, res.Device.ServerName, _ = gomatrixserverlib.SplitID('@', res.Device.UserID) + return res.Device, nil } diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index 0685c7352..0da75f83e 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -142,8 +142,9 @@ func SetAvatarURL( setRes := &userapi.PerformSetAvatarURLResponse{} if err = profileAPI.SetAvatarURL(req.Context(), &userapi.PerformSetAvatarURLRequest{ - Localpart: localpart, - AvatarURL: r.AvatarURL, + Localpart: localpart, + ServerName: device.ServerName, + AvatarURL: r.AvatarURL, }, setRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetAvatarURL failed") return jsonerror.InternalServerError() @@ -271,6 +272,7 @@ func SetDisplayName( err = profileAPI.SetDisplayName(req.Context(), &userapi.PerformUpdateDisplayNameRequest{ Localpart: localpart, + ServerName: device.ServerName, DisplayName: r.DisplayName, }, &struct{}{}) if err != nil { diff --git a/clientapi/routing/server_notices.go b/clientapi/routing/server_notices.go index 9edeed2f7..b62fb229a 100644 --- a/clientapi/routing/server_notices.go +++ b/clientapi/routing/server_notices.go @@ -293,8 +293,9 @@ func getSenderDevice( // set the avatarurl for the user res := &userapi.PerformSetAvatarURLResponse{} if err = userAPI.SetAvatarURL(ctx, &userapi.PerformSetAvatarURLRequest{ - Localpart: cfg.Matrix.ServerNotices.LocalPart, - AvatarURL: cfg.Matrix.ServerNotices.AvatarURL, + Localpart: cfg.Matrix.ServerNotices.LocalPart, + ServerName: cfg.Matrix.ServerName, + AvatarURL: cfg.Matrix.ServerNotices.AvatarURL, }, res); err != nil { util.GetLogger(ctx).WithError(err).Error("userAPI.SetAvatarURL failed") return nil, err diff --git a/userapi/api/api.go b/userapi/api/api.go index df9408acb..2fb98b714 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -414,8 +414,9 @@ type QueryOpenIDTokenResponse struct { // Device represents a client's device (mobile, web, etc) type Device struct { - ID string - UserID string + ID string + UserID string + ServerName gomatrixserverlib.ServerName // The access_token granted to this device. // This uniquely identifies the device from all other devices and clients. AccessToken string @@ -577,6 +578,7 @@ type Notification struct { type PerformSetAvatarURLRequest struct { Localpart, AvatarURL string + ServerName gomatrixserverlib.ServerName } type PerformSetAvatarURLResponse struct{} @@ -603,6 +605,7 @@ type QueryAccountByPasswordResponse struct { type PerformUpdateDisplayNameRequest struct { Localpart, DisplayName string + ServerName gomatrixserverlib.ServerName } type QueryLocalpartForThreePIDRequest struct { diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 9d2f63c72..582cbc740 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -103,7 +103,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P return nil } - if err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { + if err = a.DB.SetDisplayName(ctx, req.Localpart, a.ServerName, req.Localpart); err != nil { return err } @@ -275,7 +275,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil if domain != a.ServerName { return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName) } - prof, err := a.DB.GetProfileByLocalpart(ctx, local) + prof, err := a.DB.GetProfileByLocalpart(ctx, local, domain) if err != nil { if err == sql.ErrNoRows { return nil @@ -770,7 +770,7 @@ func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPush } func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error { - return a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL) + return a.DB.SetAvatarURL(ctx, req.Localpart, req.ServerName, req.AvatarURL) } func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error { @@ -803,7 +803,7 @@ func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.Q } func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, _ *struct{}) error { - return a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName) + return a.DB.SetDisplayName(ctx, req.Localpart, req.ServerName, req.DisplayName) } func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error { diff --git a/userapi/storage/tables/profile_table_test.go b/userapi/storage/tables/profile_table_test.go index d8aebd279..316d34165 100644 --- a/userapi/storage/tables/profile_table_test.go +++ b/userapi/storage/tables/profile_table_test.go @@ -72,7 +72,7 @@ func TestProfileTable(t *testing.T) { t.Fatalf("failed to set avatar url: %v", err) } - // Verify dummy1 on serverName2 is as expected, just to test the function + // Verify dummy1 on serverName2 is as expected dummy1, err := tab.SelectProfileByLocalpart(ctx, "dummy1", serverName2) if err != nil { t.Fatalf("failed to query profile by localpart: %v", err) diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 40e37c5d6..1dc78f063 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -79,10 +79,10 @@ func TestQueryProfile(t *testing.T) { if err != nil { t.Fatalf("failed to make account: %s", err) } - if err := accountDB.SetAvatarURL(context.TODO(), "alice", aliceAvatarURL); err != nil { + if err := accountDB.SetAvatarURL(context.TODO(), "alice", serverName, aliceAvatarURL); err != nil { t.Fatalf("failed to set avatar url: %s", err) } - if err := accountDB.SetDisplayName(context.TODO(), "alice", aliceDisplayName); err != nil { + if err := accountDB.SetDisplayName(context.TODO(), "alice", serverName, aliceDisplayName); err != nil { t.Fatalf("failed to set display name: %s", err) }