Update structs to include servername

This commit is contained in:
Till Faelligen 2022-06-09 08:23:20 +02:00
parent 093cbe483a
commit 3fea5d791f
7 changed files with 24 additions and 13 deletions

View file

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -85,6 +86,10 @@ func VerifyUserFromRequest(
JSON: jsonerror.UnknownToken("Unknown token"), JSON: jsonerror.UnknownToken("Unknown token"),
} }
} }
// At this point, we should be certain we've got an actual UserID
_, res.Device.ServerName, _ = gomatrixserverlib.SplitID('@', res.Device.UserID)
return res.Device, nil return res.Device, nil
} }

View file

@ -142,8 +142,9 @@ func SetAvatarURL(
setRes := &userapi.PerformSetAvatarURLResponse{} setRes := &userapi.PerformSetAvatarURLResponse{}
if err = profileAPI.SetAvatarURL(req.Context(), &userapi.PerformSetAvatarURLRequest{ if err = profileAPI.SetAvatarURL(req.Context(), &userapi.PerformSetAvatarURLRequest{
Localpart: localpart, Localpart: localpart,
AvatarURL: r.AvatarURL, ServerName: device.ServerName,
AvatarURL: r.AvatarURL,
}, setRes); err != nil { }, setRes); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetAvatarURL failed") util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetAvatarURL failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -271,6 +272,7 @@ func SetDisplayName(
err = profileAPI.SetDisplayName(req.Context(), &userapi.PerformUpdateDisplayNameRequest{ err = profileAPI.SetDisplayName(req.Context(), &userapi.PerformUpdateDisplayNameRequest{
Localpart: localpart, Localpart: localpart,
ServerName: device.ServerName,
DisplayName: r.DisplayName, DisplayName: r.DisplayName,
}, &struct{}{}) }, &struct{}{})
if err != nil { if err != nil {

View file

@ -293,8 +293,9 @@ func getSenderDevice(
// set the avatarurl for the user // set the avatarurl for the user
res := &userapi.PerformSetAvatarURLResponse{} res := &userapi.PerformSetAvatarURLResponse{}
if err = userAPI.SetAvatarURL(ctx, &userapi.PerformSetAvatarURLRequest{ if err = userAPI.SetAvatarURL(ctx, &userapi.PerformSetAvatarURLRequest{
Localpart: cfg.Matrix.ServerNotices.LocalPart, Localpart: cfg.Matrix.ServerNotices.LocalPart,
AvatarURL: cfg.Matrix.ServerNotices.AvatarURL, ServerName: cfg.Matrix.ServerName,
AvatarURL: cfg.Matrix.ServerNotices.AvatarURL,
}, res); err != nil { }, res); err != nil {
util.GetLogger(ctx).WithError(err).Error("userAPI.SetAvatarURL failed") util.GetLogger(ctx).WithError(err).Error("userAPI.SetAvatarURL failed")
return nil, err return nil, err

View file

@ -414,8 +414,9 @@ type QueryOpenIDTokenResponse struct {
// Device represents a client's device (mobile, web, etc) // Device represents a client's device (mobile, web, etc)
type Device struct { type Device struct {
ID string ID string
UserID string UserID string
ServerName gomatrixserverlib.ServerName
// The access_token granted to this device. // The access_token granted to this device.
// This uniquely identifies the device from all other devices and clients. // This uniquely identifies the device from all other devices and clients.
AccessToken string AccessToken string
@ -577,6 +578,7 @@ type Notification struct {
type PerformSetAvatarURLRequest struct { type PerformSetAvatarURLRequest struct {
Localpart, AvatarURL string Localpart, AvatarURL string
ServerName gomatrixserverlib.ServerName
} }
type PerformSetAvatarURLResponse struct{} type PerformSetAvatarURLResponse struct{}
@ -603,6 +605,7 @@ type QueryAccountByPasswordResponse struct {
type PerformUpdateDisplayNameRequest struct { type PerformUpdateDisplayNameRequest struct {
Localpart, DisplayName string Localpart, DisplayName string
ServerName gomatrixserverlib.ServerName
} }
type QueryLocalpartForThreePIDRequest struct { type QueryLocalpartForThreePIDRequest struct {

View file

@ -103,7 +103,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
return nil return nil
} }
if err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { if err = a.DB.SetDisplayName(ctx, req.Localpart, a.ServerName, req.Localpart); err != nil {
return err return err
} }
@ -275,7 +275,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
if domain != a.ServerName { if domain != a.ServerName {
return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName) return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName)
} }
prof, err := a.DB.GetProfileByLocalpart(ctx, local) prof, err := a.DB.GetProfileByLocalpart(ctx, local, domain)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil return nil
@ -770,7 +770,7 @@ func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPush
} }
func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error { func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error {
return a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL) return a.DB.SetAvatarURL(ctx, req.Localpart, req.ServerName, req.AvatarURL)
} }
func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error { func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error {
@ -803,7 +803,7 @@ func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.Q
} }
func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, _ *struct{}) error { func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, _ *struct{}) error {
return a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName) return a.DB.SetDisplayName(ctx, req.Localpart, req.ServerName, req.DisplayName)
} }
func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error { func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error {

View file

@ -72,7 +72,7 @@ func TestProfileTable(t *testing.T) {
t.Fatalf("failed to set avatar url: %v", err) t.Fatalf("failed to set avatar url: %v", err)
} }
// Verify dummy1 on serverName2 is as expected, just to test the function // Verify dummy1 on serverName2 is as expected
dummy1, err := tab.SelectProfileByLocalpart(ctx, "dummy1", serverName2) dummy1, err := tab.SelectProfileByLocalpart(ctx, "dummy1", serverName2)
if err != nil { if err != nil {
t.Fatalf("failed to query profile by localpart: %v", err) t.Fatalf("failed to query profile by localpart: %v", err)

View file

@ -79,10 +79,10 @@ func TestQueryProfile(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("failed to make account: %s", err) t.Fatalf("failed to make account: %s", err)
} }
if err := accountDB.SetAvatarURL(context.TODO(), "alice", aliceAvatarURL); err != nil { if err := accountDB.SetAvatarURL(context.TODO(), "alice", serverName, aliceAvatarURL); err != nil {
t.Fatalf("failed to set avatar url: %s", err) t.Fatalf("failed to set avatar url: %s", err)
} }
if err := accountDB.SetDisplayName(context.TODO(), "alice", aliceDisplayName); err != nil { if err := accountDB.SetDisplayName(context.TODO(), "alice", serverName, aliceDisplayName); err != nil {
t.Fatalf("failed to set display name: %s", err) t.Fatalf("failed to set display name: %s", err)
} }