diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 7ae46e4ed..9dc63af84 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -679,6 +679,7 @@ func handleGuestRegistration( var devRes userapi.PerformDeviceCreationResponse err = userAPI.PerformDeviceCreation(req.Context(), &userapi.PerformDeviceCreationRequest{ Localpart: res.Account.Localpart, + ServerName: res.Account.ServerName, DeviceDisplayName: r.InitialDisplayName, AccessToken: token, IPAddr: req.RemoteAddr, diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index 31b6a41c3..36221cae9 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -492,11 +492,11 @@ func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, er func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int, roomName string, streamPos uint64) error { actions, err := s.evaluatePushRules(ctx, event, mem, roomSize) if err != nil { - return err + return fmt.Errorf("s.evaluatePushRules: %w", err) } a, tweaks, err := pushrules.ActionsToTweaks(actions) if err != nil { - return err + return fmt.Errorf("pushrules.ActionsToTweaks: %w", err) } // TODO: support coalescing. if a != pushrules.NotifyAction && a != pushrules.CoalesceAction { @@ -510,7 +510,7 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, mem.Domain, tweaks) if err != nil { - return err + return fmt.Errorf("s.localPushDevices: %w", err) } n := &api.Notification{ @@ -528,17 +528,17 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr TS: gomatrixserverlib.AsTimestamp(time.Now()), } if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), streamPos, tweaks, n); err != nil { - return err + return fmt.Errorf("s.db.InsertNotification: %w", err) } if err = s.syncProducer.GetAndSendNotificationData(ctx, mem.UserID, event.RoomID()); err != nil { - return err + return fmt.Errorf("s.syncProducer.GetAndSendNotificationData: %w", err) } // We do this after InsertNotification. Thus, this should always return >=1. userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, tables.AllNotifications) if err != nil { - return err + return fmt.Errorf("s.db.GetNotificationCount: %w", err) } log.WithFields(log.Fields{ @@ -696,7 +696,7 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) { pusherDevices, err := util.GetPushDevices(ctx, localpart, serverName, tweaks, s.db) if err != nil { - return nil, "", err + return nil, "", fmt.Errorf("util.GetPushDevices: %w", err) } var profileTag string diff --git a/userapi/internal/api.go b/userapi/internal/api.go index ec21d56b8..727cfdbf6 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -214,8 +214,8 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P return nil } - if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.ServerName, req.Localpart); err != nil { - return err + if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, serverName, req.Localpart); err != nil { + return fmt.Errorf("a.DB.SetDisplayName: %w", err) } postRegisterJoinRooms(a.Cfg, acc, a.RSAPI) diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go index db7125b22..477f22749 100644 --- a/userapi/storage/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -78,10 +78,10 @@ const selectDeviceByTokenSQL = "" + "SELECT session_id, device_id, localpart, server_name FROM userapi_devices WHERE access_token = $1" const selectDeviceByIDSQL = "" + - "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2" + "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" + "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC" const updateDeviceNameSQL = "" + "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" diff --git a/userapi/storage/postgres/pusher_table.go b/userapi/storage/postgres/pusher_table.go index 3ddd2e06f..707b3bd2b 100644 --- a/userapi/storage/postgres/pusher_table.go +++ b/userapi/storage/postgres/pusher_table.go @@ -60,7 +60,7 @@ CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON const insertPusherSQL = "" + "INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" + - "ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11" + "ON CONFLICT (app_id, pushkey, localpart, server_name) DO UPDATE SET session_id = $3, pushkey_ts_ms = $5, kind = $6, app_display_name = $8, device_display_name = $9, profile_tag = $10, lang = $11, data = $12" const selectPushersSQL = "" + "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2" @@ -100,8 +100,7 @@ func (s *pushersStatements) InsertPusher( pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, serverName gomatrixserverlib.ServerName, ) error { - _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) - logrus.Debugf("Created pusher %d", session_id) + _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) return err } @@ -110,7 +109,7 @@ func (s *pushersStatements) SelectPushers( localpart string, serverName gomatrixserverlib.ServerName, ) ([]api.Pusher, error) { pushers := []api.Pusher{} - rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart) + rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart, serverName) if err != nil { return pushers, err diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 1a15ee126..c93f0389a 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -148,7 +148,7 @@ func (d *Database) CreateAccount( var numLocalpart int64 numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn, serverName) if err != nil { - return err + return fmt.Errorf("d.Accounts.SelectNewNumericLocalpart: %w", err) } localpart = strconv.FormatInt(numLocalpart, 10) plaintextPassword = "" @@ -181,15 +181,15 @@ func (d *Database) createAccount( return nil, sqlutil.ErrUserExists } if err = d.Profiles.InsertProfile(ctx, txn, localpart, serverName); err != nil { - return nil, err + return nil, fmt.Errorf("d.Profiles.InsertProfile: %w", err) } pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName) prbs, err := json.Marshal(pushRuleSets) if err != nil { - return nil, err + return nil, fmt.Errorf("json.Marshal: %w", err) } if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, "", "m.push_rules", json.RawMessage(prbs)); err != nil { - return nil, err + return nil, fmt.Errorf("d.AccountDatas.InsertAccountData: %w", err) } return account, nil } diff --git a/userapi/storage/sqlite3/pusher_table.go b/userapi/storage/sqlite3/pusher_table.go index 4a52d4497..c9d451dc5 100644 --- a/userapi/storage/sqlite3/pusher_table.go +++ b/userapi/storage/sqlite3/pusher_table.go @@ -60,7 +60,7 @@ CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON const insertPusherSQL = "" + "INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" + - "ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11" + "ON CONFLICT (app_id, pushkey, localpart, server_name) DO UPDATE SET session_id = $3, pushkey_ts_ms = $5, kind = $6, app_display_name = $8, device_display_name = $9, profile_tag = $10, lang = $11, data = $12" const selectPushersSQL = "" + "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2" @@ -101,7 +101,6 @@ func (s *pushersStatements) InsertPusher( localpart string, serverName gomatrixserverlib.ServerName, ) error { _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) - logrus.Debugf("Created pusher %d", session_id) return err } diff --git a/userapi/util/devices.go b/userapi/util/devices.go index 9ff96b216..c55fc7999 100644 --- a/userapi/util/devices.go +++ b/userapi/util/devices.go @@ -2,6 +2,7 @@ package util import ( "context" + "fmt" "github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/userapi/api" @@ -21,7 +22,7 @@ type PusherDevice struct { func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) { pushers, err := db.GetPushers(ctx, localpart, serverName) if err != nil { - return nil, err + return nil, fmt.Errorf("db.GetPushers: %w", err) } devices := make([]*PusherDevice, 0, len(pushers))