Fix some bugs

This commit is contained in:
Neil Alexander 2022-11-07 14:15:43 +00:00
parent c72c58dab9
commit 6ea580d9c9
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
8 changed files with 22 additions and 22 deletions

View file

@ -679,6 +679,7 @@ func handleGuestRegistration(
var devRes userapi.PerformDeviceCreationResponse var devRes userapi.PerformDeviceCreationResponse
err = userAPI.PerformDeviceCreation(req.Context(), &userapi.PerformDeviceCreationRequest{ err = userAPI.PerformDeviceCreation(req.Context(), &userapi.PerformDeviceCreationRequest{
Localpart: res.Account.Localpart, Localpart: res.Account.Localpart,
ServerName: res.Account.ServerName,
DeviceDisplayName: r.InitialDisplayName, DeviceDisplayName: r.InitialDisplayName,
AccessToken: token, AccessToken: token,
IPAddr: req.RemoteAddr, IPAddr: req.RemoteAddr,

View file

@ -492,11 +492,11 @@ func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, er
func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int, roomName string, streamPos uint64) error { func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int, roomName string, streamPos uint64) error {
actions, err := s.evaluatePushRules(ctx, event, mem, roomSize) actions, err := s.evaluatePushRules(ctx, event, mem, roomSize)
if err != nil { if err != nil {
return err return fmt.Errorf("s.evaluatePushRules: %w", err)
} }
a, tweaks, err := pushrules.ActionsToTweaks(actions) a, tweaks, err := pushrules.ActionsToTweaks(actions)
if err != nil { if err != nil {
return err return fmt.Errorf("pushrules.ActionsToTweaks: %w", err)
} }
// TODO: support coalescing. // TODO: support coalescing.
if a != pushrules.NotifyAction && a != pushrules.CoalesceAction { if a != pushrules.NotifyAction && a != pushrules.CoalesceAction {
@ -510,7 +510,7 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr
devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, mem.Domain, tweaks) devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, mem.Domain, tweaks)
if err != nil { if err != nil {
return err return fmt.Errorf("s.localPushDevices: %w", err)
} }
n := &api.Notification{ n := &api.Notification{
@ -528,17 +528,17 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr
TS: gomatrixserverlib.AsTimestamp(time.Now()), TS: gomatrixserverlib.AsTimestamp(time.Now()),
} }
if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), streamPos, tweaks, n); err != nil { if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), streamPos, tweaks, n); err != nil {
return err return fmt.Errorf("s.db.InsertNotification: %w", err)
} }
if err = s.syncProducer.GetAndSendNotificationData(ctx, mem.UserID, event.RoomID()); err != nil { if err = s.syncProducer.GetAndSendNotificationData(ctx, mem.UserID, event.RoomID()); err != nil {
return err return fmt.Errorf("s.syncProducer.GetAndSendNotificationData: %w", err)
} }
// We do this after InsertNotification. Thus, this should always return >=1. // We do this after InsertNotification. Thus, this should always return >=1.
userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, tables.AllNotifications) userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, tables.AllNotifications)
if err != nil { if err != nil {
return err return fmt.Errorf("s.db.GetNotificationCount: %w", err)
} }
log.WithFields(log.Fields{ log.WithFields(log.Fields{
@ -696,7 +696,7 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err
func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) { func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) {
pusherDevices, err := util.GetPushDevices(ctx, localpart, serverName, tweaks, s.db) pusherDevices, err := util.GetPushDevices(ctx, localpart, serverName, tweaks, s.db)
if err != nil { if err != nil {
return nil, "", err return nil, "", fmt.Errorf("util.GetPushDevices: %w", err)
} }
var profileTag string var profileTag string

View file

@ -214,8 +214,8 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
return nil return nil
} }
if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.ServerName, req.Localpart); err != nil { if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, serverName, req.Localpart); err != nil {
return err return fmt.Errorf("a.DB.SetDisplayName: %w", err)
} }
postRegisterJoinRooms(a.Cfg, acc, a.RSAPI) postRegisterJoinRooms(a.Cfg, acc, a.RSAPI)

View file

@ -78,10 +78,10 @@ const selectDeviceByTokenSQL = "" +
"SELECT session_id, device_id, localpart, server_name FROM userapi_devices WHERE access_token = $1" "SELECT session_id, device_id, localpart, server_name FROM userapi_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" + const selectDeviceByIDSQL = "" +
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2" "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3"
const selectDevicesByLocalpartSQL = "" + const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC"
const updateDeviceNameSQL = "" + const updateDeviceNameSQL = "" +
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"

View file

@ -60,7 +60,7 @@ CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON
const insertPusherSQL = "" + const insertPusherSQL = "" +
"INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + "INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" +
"ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11" "ON CONFLICT (app_id, pushkey, localpart, server_name) DO UPDATE SET session_id = $3, pushkey_ts_ms = $5, kind = $6, app_display_name = $8, device_display_name = $9, profile_tag = $10, lang = $11, data = $12"
const selectPushersSQL = "" + const selectPushersSQL = "" +
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2" "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2"
@ -100,8 +100,7 @@ func (s *pushersStatements) InsertPusher(
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data,
localpart string, serverName gomatrixserverlib.ServerName, localpart string, serverName gomatrixserverlib.ServerName,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
logrus.Debugf("Created pusher %d", session_id)
return err return err
} }
@ -110,7 +109,7 @@ func (s *pushersStatements) SelectPushers(
localpart string, serverName gomatrixserverlib.ServerName, localpart string, serverName gomatrixserverlib.ServerName,
) ([]api.Pusher, error) { ) ([]api.Pusher, error) {
pushers := []api.Pusher{} pushers := []api.Pusher{}
rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart) rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart, serverName)
if err != nil { if err != nil {
return pushers, err return pushers, err

View file

@ -148,7 +148,7 @@ func (d *Database) CreateAccount(
var numLocalpart int64 var numLocalpart int64
numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn, serverName) numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn, serverName)
if err != nil { if err != nil {
return err return fmt.Errorf("d.Accounts.SelectNewNumericLocalpart: %w", err)
} }
localpart = strconv.FormatInt(numLocalpart, 10) localpart = strconv.FormatInt(numLocalpart, 10)
plaintextPassword = "" plaintextPassword = ""
@ -181,15 +181,15 @@ func (d *Database) createAccount(
return nil, sqlutil.ErrUserExists return nil, sqlutil.ErrUserExists
} }
if err = d.Profiles.InsertProfile(ctx, txn, localpart, serverName); err != nil { if err = d.Profiles.InsertProfile(ctx, txn, localpart, serverName); err != nil {
return nil, err return nil, fmt.Errorf("d.Profiles.InsertProfile: %w", err)
} }
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName) pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName)
prbs, err := json.Marshal(pushRuleSets) prbs, err := json.Marshal(pushRuleSets)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("json.Marshal: %w", err)
} }
if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, "", "m.push_rules", json.RawMessage(prbs)); err != nil { if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, "", "m.push_rules", json.RawMessage(prbs)); err != nil {
return nil, err return nil, fmt.Errorf("d.AccountDatas.InsertAccountData: %w", err)
} }
return account, nil return account, nil
} }

View file

@ -60,7 +60,7 @@ CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON
const insertPusherSQL = "" + const insertPusherSQL = "" +
"INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + "INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" +
"ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11" "ON CONFLICT (app_id, pushkey, localpart, server_name) DO UPDATE SET session_id = $3, pushkey_ts_ms = $5, kind = $6, app_display_name = $8, device_display_name = $9, profile_tag = $10, lang = $11, data = $12"
const selectPushersSQL = "" + const selectPushersSQL = "" +
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2" "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2"
@ -101,7 +101,6 @@ func (s *pushersStatements) InsertPusher(
localpart string, serverName gomatrixserverlib.ServerName, localpart string, serverName gomatrixserverlib.ServerName,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
logrus.Debugf("Created pusher %d", session_id)
return err return err
} }

View file

@ -2,6 +2,7 @@ package util
import ( import (
"context" "context"
"fmt"
"github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
@ -21,7 +22,7 @@ type PusherDevice struct {
func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) { func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) {
pushers, err := db.GetPushers(ctx, localpart, serverName) pushers, err := db.GetPushers(ctx, localpart, serverName)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("db.GetPushers: %w", err)
} }
devices := make([]*PusherDevice, 0, len(pushers)) devices := make([]*PusherDevice, 0, len(pushers))