Pushers
This commit is contained in:
parent
76ac6dbdf1
commit
c72c58dab9
|
@ -123,8 +123,9 @@ func Password(
|
|||
}
|
||||
|
||||
pushersReq := &api.PerformPusherDeletionRequest{
|
||||
Localpart: localpart,
|
||||
SessionID: device.SessionID,
|
||||
Localpart: localpart,
|
||||
ServerName: domain,
|
||||
SessionID: device.SessionID,
|
||||
}
|
||||
if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("PerformPusherDeletion failed")
|
||||
|
|
|
@ -31,13 +31,14 @@ func GetPushers(
|
|||
userAPI userapi.ClientUserAPI,
|
||||
) util.JSONResponse {
|
||||
var queryRes userapi.QueryPushersResponse
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{
|
||||
Localpart: localpart,
|
||||
Localpart: localpart,
|
||||
ServerName: domain,
|
||||
}, &queryRes)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed")
|
||||
|
@ -59,7 +60,7 @@ func SetPusher(
|
|||
req *http.Request, device *userapi.Device,
|
||||
userAPI userapi.ClientUserAPI,
|
||||
) util.JSONResponse {
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
|
||||
return jsonerror.InternalServerError()
|
||||
|
@ -93,6 +94,7 @@ func SetPusher(
|
|||
|
||||
}
|
||||
body.Localpart = localpart
|
||||
body.ServerName = domain
|
||||
body.SessionID = device.SessionID
|
||||
err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{})
|
||||
if err != nil {
|
||||
|
|
|
@ -519,7 +519,8 @@ const (
|
|||
)
|
||||
|
||||
type QueryPushersRequest struct {
|
||||
Localpart string
|
||||
Localpart string
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
type QueryPushersResponse struct {
|
||||
|
@ -527,14 +528,16 @@ type QueryPushersResponse struct {
|
|||
}
|
||||
|
||||
type PerformPusherSetRequest struct {
|
||||
Pusher // Anonymous field because that's how clientapi unmarshals it.
|
||||
Localpart string
|
||||
Append bool `json:"append"`
|
||||
Pusher // Anonymous field because that's how clientapi unmarshals it.
|
||||
Localpart string
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
Append bool `json:"append"`
|
||||
}
|
||||
|
||||
type PerformPusherDeletionRequest struct {
|
||||
Localpart string
|
||||
SessionID int64
|
||||
Localpart string
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
SessionID int64
|
||||
}
|
||||
|
||||
// Pusher represents a push notification subscriber
|
||||
|
|
|
@ -118,7 +118,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats
|
|||
if !updated {
|
||||
return true
|
||||
}
|
||||
if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil {
|
||||
if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, domain, s.db); err != nil {
|
||||
log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed")
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -508,7 +508,7 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr
|
|||
return nil
|
||||
}
|
||||
|
||||
devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, tweaks)
|
||||
devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, mem.Domain, tweaks)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -589,7 +589,7 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr
|
|||
}
|
||||
|
||||
if len(rejected) > 0 {
|
||||
s.deleteRejectedPushers(ctx, rejected, mem.Localpart)
|
||||
s.deleteRejectedPushers(ctx, rejected, mem.Localpart, mem.Domain)
|
||||
}
|
||||
}()
|
||||
|
||||
|
@ -693,8 +693,8 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err
|
|||
|
||||
// localPushDevices pushes to the configured devices of a local
|
||||
// user. The map keys are [url][format].
|
||||
func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) {
|
||||
pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db)
|
||||
func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) {
|
||||
pusherDevices, err := util.GetPushDevices(ctx, localpart, serverName, tweaks, s.db)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
@ -791,7 +791,7 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *gomatri
|
|||
}
|
||||
|
||||
// deleteRejectedPushers deletes the pushers associated with the given devices.
|
||||
func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) {
|
||||
func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string, serverName gomatrixserverlib.ServerName) {
|
||||
log.WithFields(log.Fields{
|
||||
"localpart": localpart,
|
||||
"app_id0": devices[0].AppID,
|
||||
|
@ -799,7 +799,7 @@ func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, dev
|
|||
}).Warnf("Deleting pushers rejected by the HTTP push gateway")
|
||||
|
||||
for _, d := range devices {
|
||||
if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart); err != nil {
|
||||
if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart, serverName); err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"localpart": localpart,
|
||||
}).WithError(err).Errorf("Unable to delete rejected pusher")
|
||||
|
|
|
@ -124,7 +124,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun
|
|||
return nil
|
||||
}
|
||||
|
||||
if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, a.DB); err != nil {
|
||||
if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, domain, a.DB); err != nil {
|
||||
logrus.WithError(err).Error("UserInternalAPI.setFullyRead: NotifyUserCounts failed")
|
||||
return err
|
||||
}
|
||||
|
@ -817,23 +817,23 @@ func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.Perform
|
|||
}
|
||||
}
|
||||
if req.Pusher.Kind == "" {
|
||||
return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart)
|
||||
return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart, req.ServerName)
|
||||
}
|
||||
if req.Pusher.PushKeyTS == 0 {
|
||||
req.Pusher.PushKeyTS = int64(time.Now().Unix())
|
||||
}
|
||||
return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart)
|
||||
return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart, req.ServerName)
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error {
|
||||
pushers, err := a.DB.GetPushers(ctx, req.Localpart)
|
||||
pushers, err := a.DB.GetPushers(ctx, req.Localpart, req.ServerName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for i := range pushers {
|
||||
logrus.Warnf("pusher session: %d, req session: %d", pushers[i].SessionID, req.SessionID)
|
||||
if pushers[i].SessionID != req.SessionID {
|
||||
err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart)
|
||||
err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart, req.ServerName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -844,7 +844,7 @@ func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.Pe
|
|||
|
||||
func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error {
|
||||
var err error
|
||||
res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart)
|
||||
res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart, req.ServerName)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -107,9 +107,9 @@ type OpenID interface {
|
|||
}
|
||||
|
||||
type Pusher interface {
|
||||
UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error
|
||||
GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error)
|
||||
RemovePusher(ctx context.Context, appid, pushkey, localpart string) error
|
||||
UpsertPusher(ctx context.Context, p api.Pusher, localpart string, serverName gomatrixserverlib.ServerName) error
|
||||
GetPushers(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Pusher, error)
|
||||
RemovePusher(ctx context.Context, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName) error
|
||||
RemovePushers(ctx context.Context, appid, pushkey string) error
|
||||
}
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
|
||||
|
@ -50,22 +51,22 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
|
|||
CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
|
||||
|
||||
-- For faster retrieving by localpart.
|
||||
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart);
|
||||
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart, server_name);
|
||||
|
||||
-- Pushkey must be unique for a given user and app.
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart, server_name);
|
||||
`
|
||||
|
||||
const insertPusherSQL = "" +
|
||||
"INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
|
||||
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" +
|
||||
"INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
|
||||
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" +
|
||||
"ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11"
|
||||
|
||||
const selectPushersSQL = "" +
|
||||
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1"
|
||||
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const deletePusherSQL = "" +
|
||||
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3"
|
||||
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3 AND server_name = $4"
|
||||
|
||||
const deletePushersByAppIdAndPushKeySQL = "" +
|
||||
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
|
||||
|
@ -96,7 +97,8 @@ type pushersStatements struct {
|
|||
// Returns nil error success.
|
||||
func (s *pushersStatements) InsertPusher(
|
||||
ctx context.Context, txn *sql.Tx, session_id int64,
|
||||
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
|
||||
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
|
||||
logrus.Debugf("Created pusher %d", session_id)
|
||||
|
@ -104,7 +106,8 @@ func (s *pushersStatements) InsertPusher(
|
|||
}
|
||||
|
||||
func (s *pushersStatements) SelectPushers(
|
||||
ctx context.Context, txn *sql.Tx, localpart string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) ([]api.Pusher, error) {
|
||||
pushers := []api.Pusher{}
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart)
|
||||
|
@ -144,9 +147,10 @@ func (s *pushersStatements) SelectPushers(
|
|||
|
||||
// deletePusher removes a single pusher by pushkey and user localpart.
|
||||
func (s *pushersStatements) DeletePusher(
|
||||
ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
|
||||
ctx context.Context, txn *sql.Tx, appid, pushkey,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart)
|
||||
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -768,7 +768,8 @@ func (d *Database) DeleteOldNotifications(ctx context.Context) error {
|
|||
}
|
||||
|
||||
func (d *Database) UpsertPusher(
|
||||
ctx context.Context, p api.Pusher, localpart string,
|
||||
ctx context.Context, p api.Pusher,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
data, err := json.Marshal(p.Data)
|
||||
if err != nil {
|
||||
|
@ -787,25 +788,26 @@ func (d *Database) UpsertPusher(
|
|||
p.ProfileTag,
|
||||
p.Language,
|
||||
string(data),
|
||||
localpart)
|
||||
localpart,
|
||||
serverName)
|
||||
})
|
||||
}
|
||||
|
||||
// GetPushers returns the pushers matching the given localpart.
|
||||
func (d *Database) GetPushers(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) ([]api.Pusher, error) {
|
||||
return d.Pushers.SelectPushers(ctx, nil, localpart)
|
||||
return d.Pushers.SelectPushers(ctx, nil, localpart, serverName)
|
||||
}
|
||||
|
||||
// RemovePusher deletes one pusher
|
||||
// Invoked when `append` is true and `kind` is null in
|
||||
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set
|
||||
func (d *Database) RemovePusher(
|
||||
ctx context.Context, appid, pushkey, localpart string,
|
||||
ctx context.Context, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart)
|
||||
err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart, serverName)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
|
||||
|
@ -50,22 +51,22 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
|
|||
CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
|
||||
|
||||
-- For faster retrieving by localpart.
|
||||
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart);
|
||||
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart, server_name);
|
||||
|
||||
-- Pushkey must be unique for a given user and app.
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart, server_name);
|
||||
`
|
||||
|
||||
const insertPusherSQL = "" +
|
||||
"INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
|
||||
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" +
|
||||
"INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
|
||||
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" +
|
||||
"ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11"
|
||||
|
||||
const selectPushersSQL = "" +
|
||||
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1"
|
||||
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const deletePusherSQL = "" +
|
||||
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3"
|
||||
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3 AND server_name = $4"
|
||||
|
||||
const deletePushersByAppIdAndPushKeySQL = "" +
|
||||
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
|
||||
|
@ -96,18 +97,20 @@ type pushersStatements struct {
|
|||
// Returns nil error success.
|
||||
func (s *pushersStatements) InsertPusher(
|
||||
ctx context.Context, txn *sql.Tx, session_id int64,
|
||||
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
|
||||
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
|
||||
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
|
||||
logrus.Debugf("Created pusher %d", session_id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *pushersStatements) SelectPushers(
|
||||
ctx context.Context, txn *sql.Tx, localpart string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) ([]api.Pusher, error) {
|
||||
pushers := []api.Pusher{}
|
||||
rows, err := s.selectPushersStmt.QueryContext(ctx, localpart)
|
||||
rows, err := s.selectPushersStmt.QueryContext(ctx, localpart, serverName)
|
||||
|
||||
if err != nil {
|
||||
return pushers, err
|
||||
|
@ -144,9 +147,10 @@ func (s *pushersStatements) SelectPushers(
|
|||
|
||||
// deletePusher removes a single pusher by pushkey and user localpart.
|
||||
func (s *pushersStatements) DeletePusher(
|
||||
ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
|
||||
ctx context.Context, txn *sql.Tx, appid, pushkey,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart)
|
||||
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -410,7 +410,7 @@ func Test_Profile(t *testing.T) {
|
|||
|
||||
func Test_Pusher(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
|
@ -432,11 +432,11 @@ func Test_Pusher(t *testing.T) {
|
|||
ProfileTag: util.RandomString(8),
|
||||
Language: util.RandomString(2),
|
||||
}
|
||||
err = db.UpsertPusher(ctx, wantPusher, aliceLocalpart)
|
||||
err = db.UpsertPusher(ctx, wantPusher, aliceLocalpart, aliceDomain)
|
||||
assert.NoError(t, err, "unable to upsert pusher")
|
||||
|
||||
// check it was actually persisted
|
||||
gotPushers, err = db.GetPushers(ctx, aliceLocalpart)
|
||||
gotPushers, err = db.GetPushers(ctx, aliceLocalpart, aliceDomain)
|
||||
assert.NoError(t, err, "unable to get pushers")
|
||||
assert.Equal(t, i+1, len(gotPushers))
|
||||
assert.Equal(t, wantPusher, gotPushers[i])
|
||||
|
@ -444,16 +444,16 @@ func Test_Pusher(t *testing.T) {
|
|||
}
|
||||
|
||||
// remove single pusher
|
||||
err = db.RemovePusher(ctx, appID, pushKeys[0], aliceLocalpart)
|
||||
err = db.RemovePusher(ctx, appID, pushKeys[0], aliceLocalpart, aliceDomain)
|
||||
assert.NoError(t, err, "unable to remove pusher")
|
||||
gotPushers, err := db.GetPushers(ctx, aliceLocalpart)
|
||||
gotPushers, err := db.GetPushers(ctx, aliceLocalpart, aliceDomain)
|
||||
assert.NoError(t, err, "unable to get pushers")
|
||||
assert.Equal(t, 1, len(gotPushers))
|
||||
|
||||
// remove last pusher
|
||||
err = db.RemovePushers(ctx, appID, pushKeys[1])
|
||||
assert.NoError(t, err, "unable to remove pusher")
|
||||
gotPushers, err = db.GetPushers(ctx, aliceLocalpart)
|
||||
gotPushers, err = db.GetPushers(ctx, aliceLocalpart, aliceDomain)
|
||||
assert.NoError(t, err, "unable to get pushers")
|
||||
assert.Equal(t, 0, len(gotPushers))
|
||||
})
|
||||
|
|
|
@ -99,9 +99,9 @@ type ThreePIDTable interface {
|
|||
}
|
||||
|
||||
type PusherTable interface {
|
||||
InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error
|
||||
SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error)
|
||||
DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error
|
||||
InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, serverName gomatrixserverlib.ServerName) error
|
||||
SelectPushers(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Pusher, error)
|
||||
DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName) error
|
||||
DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error
|
||||
}
|
||||
|
||||
|
|
|
@ -90,7 +90,7 @@ func mustMakeAccountAndDevice(
|
|||
appServiceID = util.RandomString(16)
|
||||
}
|
||||
|
||||
_, err := accDB.InsertAccount(ctx, nil, localpart, "localhost", "", appServiceID, accType)
|
||||
_, err := accDB.InsertAccount(ctx, nil, localpart, serverName, "", appServiceID, accType)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create account: %v", err)
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal/pushgateway"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
|
@ -17,8 +18,8 @@ type PusherDevice struct {
|
|||
}
|
||||
|
||||
// GetPushDevices pushes to the configured devices of a local user.
|
||||
func GetPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) {
|
||||
pushers, err := db.GetPushers(ctx, localpart)
|
||||
func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) {
|
||||
pushers, err := db.GetPushers(ctx, localpart, serverName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal/pushgateway"
|
||||
"github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
|
@ -16,8 +17,8 @@ import (
|
|||
// a single goroutine is started when talking to the Push
|
||||
// gateways. There is no way to know when the background goroutine has
|
||||
// finished.
|
||||
func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, db storage.Database) error {
|
||||
pusherDevices, err := GetPushDevices(ctx, localpart, nil, db)
|
||||
func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.Database) error {
|
||||
pusherDevices, err := GetPushDevices(ctx, localpart, serverName, nil, db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue