This commit is contained in:
Neil Alexander 2022-11-07 13:22:35 +00:00
parent 76ac6dbdf1
commit c72c58dab9
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
15 changed files with 87 additions and 69 deletions

View file

@ -124,6 +124,7 @@ func Password(
pushersReq := &api.PerformPusherDeletionRequest{ pushersReq := &api.PerformPusherDeletionRequest{
Localpart: localpart, Localpart: localpart,
ServerName: domain,
SessionID: device.SessionID, SessionID: device.SessionID,
} }
if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil { if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil {

View file

@ -31,13 +31,14 @@ func GetPushers(
userAPI userapi.ClientUserAPI, userAPI userapi.ClientUserAPI,
) util.JSONResponse { ) util.JSONResponse {
var queryRes userapi.QueryPushersResponse var queryRes userapi.QueryPushersResponse
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{ err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{
Localpart: localpart, Localpart: localpart,
ServerName: domain,
}, &queryRes) }, &queryRes)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed") util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed")
@ -59,7 +60,7 @@ func SetPusher(
req *http.Request, device *userapi.Device, req *http.Request, device *userapi.Device,
userAPI userapi.ClientUserAPI, userAPI userapi.ClientUserAPI,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -93,6 +94,7 @@ func SetPusher(
} }
body.Localpart = localpart body.Localpart = localpart
body.ServerName = domain
body.SessionID = device.SessionID body.SessionID = device.SessionID
err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{}) err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{})
if err != nil { if err != nil {

View file

@ -520,6 +520,7 @@ const (
type QueryPushersRequest struct { type QueryPushersRequest struct {
Localpart string Localpart string
ServerName gomatrixserverlib.ServerName
} }
type QueryPushersResponse struct { type QueryPushersResponse struct {
@ -529,11 +530,13 @@ type QueryPushersResponse struct {
type PerformPusherSetRequest struct { type PerformPusherSetRequest struct {
Pusher // Anonymous field because that's how clientapi unmarshals it. Pusher // Anonymous field because that's how clientapi unmarshals it.
Localpart string Localpart string
ServerName gomatrixserverlib.ServerName
Append bool `json:"append"` Append bool `json:"append"`
} }
type PerformPusherDeletionRequest struct { type PerformPusherDeletionRequest struct {
Localpart string Localpart string
ServerName gomatrixserverlib.ServerName
SessionID int64 SessionID int64
} }

View file

@ -118,7 +118,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats
if !updated { if !updated {
return true return true
} }
if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil { if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, domain, s.db); err != nil {
log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed") log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed")
return false return false
} }

View file

@ -508,7 +508,7 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr
return nil return nil
} }
devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, tweaks) devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, mem.Domain, tweaks)
if err != nil { if err != nil {
return err return err
} }
@ -589,7 +589,7 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr
} }
if len(rejected) > 0 { if len(rejected) > 0 {
s.deleteRejectedPushers(ctx, rejected, mem.Localpart) s.deleteRejectedPushers(ctx, rejected, mem.Localpart, mem.Domain)
} }
}() }()
@ -693,8 +693,8 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err
// localPushDevices pushes to the configured devices of a local // localPushDevices pushes to the configured devices of a local
// user. The map keys are [url][format]. // user. The map keys are [url][format].
func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) { func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) {
pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db) pusherDevices, err := util.GetPushDevices(ctx, localpart, serverName, tweaks, s.db)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
@ -791,7 +791,7 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *gomatri
} }
// deleteRejectedPushers deletes the pushers associated with the given devices. // deleteRejectedPushers deletes the pushers associated with the given devices.
func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) { func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string, serverName gomatrixserverlib.ServerName) {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"localpart": localpart, "localpart": localpart,
"app_id0": devices[0].AppID, "app_id0": devices[0].AppID,
@ -799,7 +799,7 @@ func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, dev
}).Warnf("Deleting pushers rejected by the HTTP push gateway") }).Warnf("Deleting pushers rejected by the HTTP push gateway")
for _, d := range devices { for _, d := range devices {
if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart); err != nil { if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart, serverName); err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"localpart": localpart, "localpart": localpart,
}).WithError(err).Errorf("Unable to delete rejected pusher") }).WithError(err).Errorf("Unable to delete rejected pusher")

View file

@ -124,7 +124,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun
return nil return nil
} }
if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, a.DB); err != nil { if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, domain, a.DB); err != nil {
logrus.WithError(err).Error("UserInternalAPI.setFullyRead: NotifyUserCounts failed") logrus.WithError(err).Error("UserInternalAPI.setFullyRead: NotifyUserCounts failed")
return err return err
} }
@ -817,23 +817,23 @@ func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.Perform
} }
} }
if req.Pusher.Kind == "" { if req.Pusher.Kind == "" {
return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart) return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart, req.ServerName)
} }
if req.Pusher.PushKeyTS == 0 { if req.Pusher.PushKeyTS == 0 {
req.Pusher.PushKeyTS = int64(time.Now().Unix()) req.Pusher.PushKeyTS = int64(time.Now().Unix())
} }
return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart) return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart, req.ServerName)
} }
func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error { func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error {
pushers, err := a.DB.GetPushers(ctx, req.Localpart) pushers, err := a.DB.GetPushers(ctx, req.Localpart, req.ServerName)
if err != nil { if err != nil {
return err return err
} }
for i := range pushers { for i := range pushers {
logrus.Warnf("pusher session: %d, req session: %d", pushers[i].SessionID, req.SessionID) logrus.Warnf("pusher session: %d, req session: %d", pushers[i].SessionID, req.SessionID)
if pushers[i].SessionID != req.SessionID { if pushers[i].SessionID != req.SessionID {
err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart) err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart, req.ServerName)
if err != nil { if err != nil {
return err return err
} }
@ -844,7 +844,7 @@ func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.Pe
func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error { func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error {
var err error var err error
res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart) res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart, req.ServerName)
return err return err
} }

View file

@ -107,9 +107,9 @@ type OpenID interface {
} }
type Pusher interface { type Pusher interface {
UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error UpsertPusher(ctx context.Context, p api.Pusher, localpart string, serverName gomatrixserverlib.ServerName) error
GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error) GetPushers(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Pusher, error)
RemovePusher(ctx context.Context, appid, pushkey, localpart string) error RemovePusher(ctx context.Context, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName) error
RemovePushers(ctx context.Context, appid, pushkey string) error RemovePushers(ctx context.Context, appid, pushkey string) error
} }

View file

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
) )
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers // See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
@ -50,22 +51,22 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey); CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
-- For faster retrieving by localpart. -- For faster retrieving by localpart.
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart); CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart, server_name);
-- Pushkey must be unique for a given user and app. -- Pushkey must be unique for a given user and app.
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart); CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart, server_name);
` `
const insertPusherSQL = "" + const insertPusherSQL = "" +
"INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + "INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" +
"ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11" "ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11"
const selectPushersSQL = "" + const selectPushersSQL = "" +
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1" "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2"
const deletePusherSQL = "" + const deletePusherSQL = "" +
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3" "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3 AND server_name = $4"
const deletePushersByAppIdAndPushKeySQL = "" + const deletePushersByAppIdAndPushKeySQL = "" +
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2" "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
@ -96,7 +97,8 @@ type pushersStatements struct {
// Returns nil error success. // Returns nil error success.
func (s *pushersStatements) InsertPusher( func (s *pushersStatements) InsertPusher(
ctx context.Context, txn *sql.Tx, session_id int64, ctx context.Context, txn *sql.Tx, session_id int64,
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data,
localpart string, serverName gomatrixserverlib.ServerName,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
logrus.Debugf("Created pusher %d", session_id) logrus.Debugf("Created pusher %d", session_id)
@ -104,7 +106,8 @@ func (s *pushersStatements) InsertPusher(
} }
func (s *pushersStatements) SelectPushers( func (s *pushersStatements) SelectPushers(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
) ([]api.Pusher, error) { ) ([]api.Pusher, error) {
pushers := []api.Pusher{} pushers := []api.Pusher{}
rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart) rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart)
@ -144,9 +147,10 @@ func (s *pushersStatements) SelectPushers(
// deletePusher removes a single pusher by pushkey and user localpart. // deletePusher removes a single pusher by pushkey and user localpart.
func (s *pushersStatements) DeletePusher( func (s *pushersStatements) DeletePusher(
ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, ctx context.Context, txn *sql.Tx, appid, pushkey,
localpart string, serverName gomatrixserverlib.ServerName,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart) _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart, serverName)
return err return err
} }

View file

@ -768,7 +768,8 @@ func (d *Database) DeleteOldNotifications(ctx context.Context) error {
} }
func (d *Database) UpsertPusher( func (d *Database) UpsertPusher(
ctx context.Context, p api.Pusher, localpart string, ctx context.Context, p api.Pusher,
localpart string, serverName gomatrixserverlib.ServerName,
) error { ) error {
data, err := json.Marshal(p.Data) data, err := json.Marshal(p.Data)
if err != nil { if err != nil {
@ -787,25 +788,26 @@ func (d *Database) UpsertPusher(
p.ProfileTag, p.ProfileTag,
p.Language, p.Language,
string(data), string(data),
localpart) localpart,
serverName)
}) })
} }
// GetPushers returns the pushers matching the given localpart. // GetPushers returns the pushers matching the given localpart.
func (d *Database) GetPushers( func (d *Database) GetPushers(
ctx context.Context, localpart string, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) ([]api.Pusher, error) { ) ([]api.Pusher, error) {
return d.Pushers.SelectPushers(ctx, nil, localpart) return d.Pushers.SelectPushers(ctx, nil, localpart, serverName)
} }
// RemovePusher deletes one pusher // RemovePusher deletes one pusher
// Invoked when `append` is true and `kind` is null in // Invoked when `append` is true and `kind` is null in
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set // https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set
func (d *Database) RemovePusher( func (d *Database) RemovePusher(
ctx context.Context, appid, pushkey, localpart string, ctx context.Context, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName,
) error { ) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart) err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart, serverName)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil return nil
} }

View file

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
) )
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers // See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
@ -50,22 +51,22 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey); CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
-- For faster retrieving by localpart. -- For faster retrieving by localpart.
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart); CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart, server_name);
-- Pushkey must be unique for a given user and app. -- Pushkey must be unique for a given user and app.
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart); CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart, server_name);
` `
const insertPusherSQL = "" + const insertPusherSQL = "" +
"INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + "INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" +
"ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11" "ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11"
const selectPushersSQL = "" + const selectPushersSQL = "" +
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1" "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2"
const deletePusherSQL = "" + const deletePusherSQL = "" +
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3" "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3 AND server_name = $4"
const deletePushersByAppIdAndPushKeySQL = "" + const deletePushersByAppIdAndPushKeySQL = "" +
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2" "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
@ -96,18 +97,20 @@ type pushersStatements struct {
// Returns nil error success. // Returns nil error success.
func (s *pushersStatements) InsertPusher( func (s *pushersStatements) InsertPusher(
ctx context.Context, txn *sql.Tx, session_id int64, ctx context.Context, txn *sql.Tx, session_id int64,
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data,
localpart string, serverName gomatrixserverlib.ServerName,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
logrus.Debugf("Created pusher %d", session_id) logrus.Debugf("Created pusher %d", session_id)
return err return err
} }
func (s *pushersStatements) SelectPushers( func (s *pushersStatements) SelectPushers(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
) ([]api.Pusher, error) { ) ([]api.Pusher, error) {
pushers := []api.Pusher{} pushers := []api.Pusher{}
rows, err := s.selectPushersStmt.QueryContext(ctx, localpart) rows, err := s.selectPushersStmt.QueryContext(ctx, localpart, serverName)
if err != nil { if err != nil {
return pushers, err return pushers, err
@ -144,9 +147,10 @@ func (s *pushersStatements) SelectPushers(
// deletePusher removes a single pusher by pushkey and user localpart. // deletePusher removes a single pusher by pushkey and user localpart.
func (s *pushersStatements) DeletePusher( func (s *pushersStatements) DeletePusher(
ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, ctx context.Context, txn *sql.Tx, appid, pushkey,
localpart string, serverName gomatrixserverlib.ServerName,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart) _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart, serverName)
return err return err
} }

View file

@ -410,7 +410,7 @@ func Test_Profile(t *testing.T) {
func Test_Pusher(t *testing.T) { func Test_Pusher(t *testing.T) {
alice := test.NewUser(t) alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err) assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
@ -432,11 +432,11 @@ func Test_Pusher(t *testing.T) {
ProfileTag: util.RandomString(8), ProfileTag: util.RandomString(8),
Language: util.RandomString(2), Language: util.RandomString(2),
} }
err = db.UpsertPusher(ctx, wantPusher, aliceLocalpart) err = db.UpsertPusher(ctx, wantPusher, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to upsert pusher") assert.NoError(t, err, "unable to upsert pusher")
// check it was actually persisted // check it was actually persisted
gotPushers, err = db.GetPushers(ctx, aliceLocalpart) gotPushers, err = db.GetPushers(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to get pushers") assert.NoError(t, err, "unable to get pushers")
assert.Equal(t, i+1, len(gotPushers)) assert.Equal(t, i+1, len(gotPushers))
assert.Equal(t, wantPusher, gotPushers[i]) assert.Equal(t, wantPusher, gotPushers[i])
@ -444,16 +444,16 @@ func Test_Pusher(t *testing.T) {
} }
// remove single pusher // remove single pusher
err = db.RemovePusher(ctx, appID, pushKeys[0], aliceLocalpart) err = db.RemovePusher(ctx, appID, pushKeys[0], aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to remove pusher") assert.NoError(t, err, "unable to remove pusher")
gotPushers, err := db.GetPushers(ctx, aliceLocalpart) gotPushers, err := db.GetPushers(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to get pushers") assert.NoError(t, err, "unable to get pushers")
assert.Equal(t, 1, len(gotPushers)) assert.Equal(t, 1, len(gotPushers))
// remove last pusher // remove last pusher
err = db.RemovePushers(ctx, appID, pushKeys[1]) err = db.RemovePushers(ctx, appID, pushKeys[1])
assert.NoError(t, err, "unable to remove pusher") assert.NoError(t, err, "unable to remove pusher")
gotPushers, err = db.GetPushers(ctx, aliceLocalpart) gotPushers, err = db.GetPushers(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to get pushers") assert.NoError(t, err, "unable to get pushers")
assert.Equal(t, 0, len(gotPushers)) assert.Equal(t, 0, len(gotPushers))
}) })

View file

@ -99,9 +99,9 @@ type ThreePIDTable interface {
} }
type PusherTable interface { type PusherTable interface {
InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, serverName gomatrixserverlib.ServerName) error
SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error) SelectPushers(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Pusher, error)
DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName) error
DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error
} }

View file

@ -90,7 +90,7 @@ func mustMakeAccountAndDevice(
appServiceID = util.RandomString(16) appServiceID = util.RandomString(16)
} }
_, err := accDB.InsertAccount(ctx, nil, localpart, "localhost", "", appServiceID, accType) _, err := accDB.InsertAccount(ctx, nil, localpart, serverName, "", appServiceID, accType)
if err != nil { if err != nil {
t.Fatalf("unable to create account: %v", err) t.Fatalf("unable to create account: %v", err)
} }

View file

@ -6,6 +6,7 @@ import (
"github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -17,8 +18,8 @@ type PusherDevice struct {
} }
// GetPushDevices pushes to the configured devices of a local user. // GetPushDevices pushes to the configured devices of a local user.
func GetPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) { func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) {
pushers, err := db.GetPushers(ctx, localpart) pushers, err := db.GetPushers(ctx, localpart, serverName)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -8,6 +8,7 @@ import (
"github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -16,8 +17,8 @@ import (
// a single goroutine is started when talking to the Push // a single goroutine is started when talking to the Push
// gateways. There is no way to know when the background goroutine has // gateways. There is no way to know when the background goroutine has
// finished. // finished.
func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, db storage.Database) error { func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.Database) error {
pusherDevices, err := GetPushDevices(ctx, localpart, nil, db) pusherDevices, err := GetPushDevices(ctx, localpart, serverName, nil, db)
if err != nil { if err != nil {
return err return err
} }