This commit is contained in:
Neil Alexander 2022-11-07 13:22:35 +00:00
parent 76ac6dbdf1
commit c72c58dab9
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
15 changed files with 87 additions and 69 deletions

View file

@ -124,6 +124,7 @@ func Password(
pushersReq := &api.PerformPusherDeletionRequest{
Localpart: localpart,
ServerName: domain,
SessionID: device.SessionID,
}
if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil {

View file

@ -31,13 +31,14 @@ func GetPushers(
userAPI userapi.ClientUserAPI,
) util.JSONResponse {
var queryRes userapi.QueryPushersResponse
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
return jsonerror.InternalServerError()
}
err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{
Localpart: localpart,
ServerName: domain,
}, &queryRes)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed")
@ -59,7 +60,7 @@ func SetPusher(
req *http.Request, device *userapi.Device,
userAPI userapi.ClientUserAPI,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
return jsonerror.InternalServerError()
@ -93,6 +94,7 @@ func SetPusher(
}
body.Localpart = localpart
body.ServerName = domain
body.SessionID = device.SessionID
err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{})
if err != nil {

View file

@ -520,6 +520,7 @@ const (
type QueryPushersRequest struct {
Localpart string
ServerName gomatrixserverlib.ServerName
}
type QueryPushersResponse struct {
@ -529,11 +530,13 @@ type QueryPushersResponse struct {
type PerformPusherSetRequest struct {
Pusher // Anonymous field because that's how clientapi unmarshals it.
Localpart string
ServerName gomatrixserverlib.ServerName
Append bool `json:"append"`
}
type PerformPusherDeletionRequest struct {
Localpart string
ServerName gomatrixserverlib.ServerName
SessionID int64
}

View file

@ -118,7 +118,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats
if !updated {
return true
}
if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil {
if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, domain, s.db); err != nil {
log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed")
return false
}

View file

@ -508,7 +508,7 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr
return nil
}
devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, tweaks)
devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, mem.Domain, tweaks)
if err != nil {
return err
}
@ -589,7 +589,7 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr
}
if len(rejected) > 0 {
s.deleteRejectedPushers(ctx, rejected, mem.Localpart)
s.deleteRejectedPushers(ctx, rejected, mem.Localpart, mem.Domain)
}
}()
@ -693,8 +693,8 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err
// localPushDevices pushes to the configured devices of a local
// user. The map keys are [url][format].
func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) {
pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db)
func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) {
pusherDevices, err := util.GetPushDevices(ctx, localpart, serverName, tweaks, s.db)
if err != nil {
return nil, "", err
}
@ -791,7 +791,7 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *gomatri
}
// deleteRejectedPushers deletes the pushers associated with the given devices.
func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) {
func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string, serverName gomatrixserverlib.ServerName) {
log.WithFields(log.Fields{
"localpart": localpart,
"app_id0": devices[0].AppID,
@ -799,7 +799,7 @@ func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, dev
}).Warnf("Deleting pushers rejected by the HTTP push gateway")
for _, d := range devices {
if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart); err != nil {
if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart, serverName); err != nil {
log.WithFields(log.Fields{
"localpart": localpart,
}).WithError(err).Errorf("Unable to delete rejected pusher")

View file

@ -124,7 +124,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun
return nil
}
if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, a.DB); err != nil {
if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, domain, a.DB); err != nil {
logrus.WithError(err).Error("UserInternalAPI.setFullyRead: NotifyUserCounts failed")
return err
}
@ -817,23 +817,23 @@ func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.Perform
}
}
if req.Pusher.Kind == "" {
return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart)
return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart, req.ServerName)
}
if req.Pusher.PushKeyTS == 0 {
req.Pusher.PushKeyTS = int64(time.Now().Unix())
}
return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart)
return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart, req.ServerName)
}
func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error {
pushers, err := a.DB.GetPushers(ctx, req.Localpart)
pushers, err := a.DB.GetPushers(ctx, req.Localpart, req.ServerName)
if err != nil {
return err
}
for i := range pushers {
logrus.Warnf("pusher session: %d, req session: %d", pushers[i].SessionID, req.SessionID)
if pushers[i].SessionID != req.SessionID {
err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart)
err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart, req.ServerName)
if err != nil {
return err
}
@ -844,7 +844,7 @@ func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.Pe
func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error {
var err error
res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart)
res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart, req.ServerName)
return err
}

View file

@ -107,9 +107,9 @@ type OpenID interface {
}
type Pusher interface {
UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error
GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error)
RemovePusher(ctx context.Context, appid, pushkey, localpart string) error
UpsertPusher(ctx context.Context, p api.Pusher, localpart string, serverName gomatrixserverlib.ServerName) error
GetPushers(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Pusher, error)
RemovePusher(ctx context.Context, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName) error
RemovePushers(ctx context.Context, appid, pushkey string) error
}

View file

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
@ -50,22 +51,22 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
-- For faster retrieving by localpart.
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart);
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart, server_name);
-- Pushkey must be unique for a given user and app.
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart);
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart, server_name);
`
const insertPusherSQL = "" +
"INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" +
"INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" +
"ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11"
const selectPushersSQL = "" +
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1"
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2"
const deletePusherSQL = "" +
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3"
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3 AND server_name = $4"
const deletePushersByAppIdAndPushKeySQL = "" +
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
@ -96,7 +97,8 @@ type pushersStatements struct {
// Returns nil error success.
func (s *pushersStatements) InsertPusher(
ctx context.Context, txn *sql.Tx, session_id int64,
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data,
localpart string, serverName gomatrixserverlib.ServerName,
) error {
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
logrus.Debugf("Created pusher %d", session_id)
@ -104,7 +106,8 @@ func (s *pushersStatements) InsertPusher(
}
func (s *pushersStatements) SelectPushers(
ctx context.Context, txn *sql.Tx, localpart string,
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
) ([]api.Pusher, error) {
pushers := []api.Pusher{}
rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart)
@ -144,9 +147,10 @@ func (s *pushersStatements) SelectPushers(
// deletePusher removes a single pusher by pushkey and user localpart.
func (s *pushersStatements) DeletePusher(
ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
ctx context.Context, txn *sql.Tx, appid, pushkey,
localpart string, serverName gomatrixserverlib.ServerName,
) error {
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart)
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart, serverName)
return err
}

View file

@ -768,7 +768,8 @@ func (d *Database) DeleteOldNotifications(ctx context.Context) error {
}
func (d *Database) UpsertPusher(
ctx context.Context, p api.Pusher, localpart string,
ctx context.Context, p api.Pusher,
localpart string, serverName gomatrixserverlib.ServerName,
) error {
data, err := json.Marshal(p.Data)
if err != nil {
@ -787,25 +788,26 @@ func (d *Database) UpsertPusher(
p.ProfileTag,
p.Language,
string(data),
localpart)
localpart,
serverName)
})
}
// GetPushers returns the pushers matching the given localpart.
func (d *Database) GetPushers(
ctx context.Context, localpart string,
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) ([]api.Pusher, error) {
return d.Pushers.SelectPushers(ctx, nil, localpart)
return d.Pushers.SelectPushers(ctx, nil, localpart, serverName)
}
// RemovePusher deletes one pusher
// Invoked when `append` is true and `kind` is null in
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set
func (d *Database) RemovePusher(
ctx context.Context, appid, pushkey, localpart string,
ctx context.Context, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart)
err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart, serverName)
if err == sql.ErrNoRows {
return nil
}

View file

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
@ -50,22 +51,22 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
-- For faster retrieving by localpart.
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart);
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart, server_name);
-- Pushkey must be unique for a given user and app.
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart);
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart, server_name);
`
const insertPusherSQL = "" +
"INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" +
"INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" +
"ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11"
const selectPushersSQL = "" +
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1"
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2"
const deletePusherSQL = "" +
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3"
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3 AND server_name = $4"
const deletePushersByAppIdAndPushKeySQL = "" +
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
@ -96,18 +97,20 @@ type pushersStatements struct {
// Returns nil error success.
func (s *pushersStatements) InsertPusher(
ctx context.Context, txn *sql.Tx, session_id int64,
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data,
localpart string, serverName gomatrixserverlib.ServerName,
) error {
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
logrus.Debugf("Created pusher %d", session_id)
return err
}
func (s *pushersStatements) SelectPushers(
ctx context.Context, txn *sql.Tx, localpart string,
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
) ([]api.Pusher, error) {
pushers := []api.Pusher{}
rows, err := s.selectPushersStmt.QueryContext(ctx, localpart)
rows, err := s.selectPushersStmt.QueryContext(ctx, localpart, serverName)
if err != nil {
return pushers, err
@ -144,9 +147,10 @@ func (s *pushersStatements) SelectPushers(
// deletePusher removes a single pusher by pushkey and user localpart.
func (s *pushersStatements) DeletePusher(
ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
ctx context.Context, txn *sql.Tx, appid, pushkey,
localpart string, serverName gomatrixserverlib.ServerName,
) error {
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart)
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart, serverName)
return err
}

View file

@ -410,7 +410,7 @@ func Test_Profile(t *testing.T) {
func Test_Pusher(t *testing.T) {
alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
@ -432,11 +432,11 @@ func Test_Pusher(t *testing.T) {
ProfileTag: util.RandomString(8),
Language: util.RandomString(2),
}
err = db.UpsertPusher(ctx, wantPusher, aliceLocalpart)
err = db.UpsertPusher(ctx, wantPusher, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to upsert pusher")
// check it was actually persisted
gotPushers, err = db.GetPushers(ctx, aliceLocalpart)
gotPushers, err = db.GetPushers(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to get pushers")
assert.Equal(t, i+1, len(gotPushers))
assert.Equal(t, wantPusher, gotPushers[i])
@ -444,16 +444,16 @@ func Test_Pusher(t *testing.T) {
}
// remove single pusher
err = db.RemovePusher(ctx, appID, pushKeys[0], aliceLocalpart)
err = db.RemovePusher(ctx, appID, pushKeys[0], aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to remove pusher")
gotPushers, err := db.GetPushers(ctx, aliceLocalpart)
gotPushers, err := db.GetPushers(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to get pushers")
assert.Equal(t, 1, len(gotPushers))
// remove last pusher
err = db.RemovePushers(ctx, appID, pushKeys[1])
assert.NoError(t, err, "unable to remove pusher")
gotPushers, err = db.GetPushers(ctx, aliceLocalpart)
gotPushers, err = db.GetPushers(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to get pushers")
assert.Equal(t, 0, len(gotPushers))
})

View file

@ -99,9 +99,9 @@ type ThreePIDTable interface {
}
type PusherTable interface {
InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error
SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error)
DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error
InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, serverName gomatrixserverlib.ServerName) error
SelectPushers(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Pusher, error)
DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName) error
DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error
}

View file

@ -90,7 +90,7 @@ func mustMakeAccountAndDevice(
appServiceID = util.RandomString(16)
}
_, err := accDB.InsertAccount(ctx, nil, localpart, "localhost", "", appServiceID, accType)
_, err := accDB.InsertAccount(ctx, nil, localpart, serverName, "", appServiceID, accType)
if err != nil {
t.Fatalf("unable to create account: %v", err)
}

View file

@ -6,6 +6,7 @@ import (
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
@ -17,8 +18,8 @@ type PusherDevice struct {
}
// GetPushDevices pushes to the configured devices of a local user.
func GetPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) {
pushers, err := db.GetPushers(ctx, localpart)
func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) {
pushers, err := db.GetPushers(ctx, localpart, serverName)
if err != nil {
return nil, err
}

View file

@ -8,6 +8,7 @@ import (
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
@ -16,8 +17,8 @@ import (
// a single goroutine is started when talking to the Push
// gateways. There is no way to know when the background goroutine has
// finished.
func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, db storage.Database) error {
pusherDevices, err := GetPushDevices(ctx, localpart, nil, db)
func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.Database) error {
pusherDevices, err := GetPushDevices(ctx, localpart, serverName, nil, db)
if err != nil {
return err
}