Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/userapitests

This commit is contained in:
Till Faelligen 2022-04-27 14:03:01 +02:00
commit 819d275a1f
26 changed files with 266 additions and 68 deletions

View file

@ -75,7 +75,7 @@ func (r *FederationInternalAPI) PerformJoin(
seenSet := make(map[gomatrixserverlib.ServerName]bool) seenSet := make(map[gomatrixserverlib.ServerName]bool)
var uniqueList []gomatrixserverlib.ServerName var uniqueList []gomatrixserverlib.ServerName
for _, srv := range request.ServerNames { for _, srv := range request.ServerNames {
if seenSet[srv] { if seenSet[srv] || srv == r.cfg.Matrix.ServerName {
continue continue
} }
seenSet[srv] = true seenSet[srv] = true

View file

@ -362,6 +362,13 @@ func (a *KeyInternalAPI) processSelfSignatures(
for targetKeyID, signature := range forTargetUserID { for targetKeyID, signature := range forTargetUserID {
switch sig := signature.CrossSigningBody.(type) { switch sig := signature.CrossSigningBody.(type) {
case *gomatrixserverlib.CrossSigningKey: case *gomatrixserverlib.CrossSigningKey:
for keyID := range sig.Keys {
split := strings.SplitN(string(keyID), ":", 2)
if len(split) > 1 && gomatrixserverlib.KeyID(split[1]) == targetKeyID {
targetKeyID = keyID // contains the ed25519: or other scheme
break
}
}
for originUserID, forOriginUserID := range sig.Signatures { for originUserID, forOriginUserID := range sig.Signatures {
for originKeyID, originSig := range forOriginUserID { for originKeyID, originSig := range forOriginUserID {
if err := a.DB.StoreCrossSigningSigsForTarget( if err := a.DB.StoreCrossSigningSigsForTarget(

View file

@ -33,8 +33,10 @@ CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs (
target_user_id TEXT NOT NULL, target_user_id TEXT NOT NULL,
target_key_id TEXT NOT NULL, target_key_id TEXT NOT NULL,
signature TEXT NOT NULL, signature TEXT NOT NULL,
PRIMARY KEY (origin_user_id, target_user_id, target_key_id) PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id)
); );
CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
` `
const selectCrossSigningSigsForTargetSQL = "" + const selectCrossSigningSigsForTargetSQL = "" +
@ -44,7 +46,7 @@ const selectCrossSigningSigsForTargetSQL = "" +
const upsertCrossSigningSigsForTargetSQL = "" + const upsertCrossSigningSigsForTargetSQL = "" +
"INSERT INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" + "INSERT INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" +
" VALUES($1, $2, $3, $4, $5)" + " VALUES($1, $2, $3, $4, $5)" +
" ON CONFLICT (origin_user_id, target_user_id, target_key_id) DO UPDATE SET (origin_key_id, signature) = ($2, $5)" " ON CONFLICT (origin_user_id, origin_key_id, target_user_id, target_key_id) DO UPDATE SET signature = $5"
const deleteCrossSigningSigsForTargetSQL = "" + const deleteCrossSigningSigsForTargetSQL = "" +
"DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2" "DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2"

View file

@ -0,0 +1,52 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
func LoadFixCrossSigningSignatureIndexes(m *sqlutil.Migrations) {
m.AddMigration(UpFixCrossSigningSignatureIndexes, DownFixCrossSigningSignatureIndexes)
}
func UpFixCrossSigningSignatureIndexes(tx *sql.Tx) error {
_, err := tx.Exec(`
ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey;
ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id);
CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownFixCrossSigningSignatureIndexes(tx *sql.Tx) error {
_, err := tx.Exec(`
ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey;
ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, target_user_id, target_key_id);
DROP INDEX IF EXISTS keyserver_cross_signing_sigs_idx;
`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -54,6 +54,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error)
} }
m := sqlutil.NewMigrations() m := sqlutil.NewMigrations()
deltas.LoadRefactorKeyChanges(m) deltas.LoadRefactorKeyChanges(m)
deltas.LoadFixCrossSigningSignatureIndexes(m)
if err = m.RunDeltas(db, dbProperties); err != nil { if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err return nil, err
} }

View file

@ -33,8 +33,10 @@ CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs (
target_user_id TEXT NOT NULL, target_user_id TEXT NOT NULL,
target_key_id TEXT NOT NULL, target_key_id TEXT NOT NULL,
signature TEXT NOT NULL, signature TEXT NOT NULL,
PRIMARY KEY (origin_user_id, target_user_id, target_key_id) PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id)
); );
CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
` `
const selectCrossSigningSigsForTargetSQL = "" + const selectCrossSigningSigsForTargetSQL = "" +

View file

@ -0,0 +1,76 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
func LoadFixCrossSigningSignatureIndexes(m *sqlutil.Migrations) {
m.AddMigration(UpFixCrossSigningSignatureIndexes, DownFixCrossSigningSignatureIndexes)
}
func UpFixCrossSigningSignatureIndexes(tx *sql.Tx) error {
_, err := tx.Exec(`
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp (
origin_user_id TEXT NOT NULL,
origin_key_id TEXT NOT NULL,
target_user_id TEXT NOT NULL,
target_key_id TEXT NOT NULL,
signature TEXT NOT NULL,
PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id)
);
INSERT INTO keyserver_cross_signing_sigs_tmp (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)
SELECT origin_user_id, origin_key_id, target_user_id, target_key_id, signature FROM keyserver_cross_signing_sigs;
DROP TABLE keyserver_cross_signing_sigs;
ALTER TABLE keyserver_cross_signing_sigs_tmp RENAME TO keyserver_cross_signing_sigs;
CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownFixCrossSigningSignatureIndexes(tx *sql.Tx) error {
_, err := tx.Exec(`
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp (
origin_user_id TEXT NOT NULL,
origin_key_id TEXT NOT NULL,
target_user_id TEXT NOT NULL,
target_key_id TEXT NOT NULL,
signature TEXT NOT NULL,
PRIMARY KEY (origin_user_id, target_user_id, target_key_id)
);
INSERT INTO keyserver_cross_signing_sigs_tmp (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)
SELECT origin_user_id, origin_key_id, target_user_id, target_key_id, signature FROM keyserver_cross_signing_sigs;
DROP TABLE keyserver_cross_signing_sigs;
ALTER TABLE keyserver_cross_signing_sigs_tmp RENAME TO keyserver_cross_signing_sigs;
DELETE INDEX IF EXISTS keyserver_cross_signing_sigs_idx;
`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -53,6 +53,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error)
m := sqlutil.NewMigrations() m := sqlutil.NewMigrations()
deltas.LoadRefactorKeyChanges(m) deltas.LoadRefactorKeyChanges(m)
deltas.LoadFixCrossSigningSignatureIndexes(m)
if err = m.RunDeltas(db, dbProperties); err != nil { if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err return nil, err
} }

View file

@ -88,6 +88,11 @@ func (s *PresenceConsumer) Start() error {
} }
return return
} }
if presence == nil {
presence = &types.PresenceInternal{
UserID: userID,
}
}
deviceRes := api.QueryDevicesResponse{} deviceRes := api.QueryDevicesResponse{}
if err = s.deviceAPI.QueryDevices(s.ctx, &api.QueryDevicesRequest{UserID: userID}, &deviceRes); err != nil { if err = s.deviceAPI.QueryDevices(s.ctx, &api.QueryDevicesRequest{UserID: userID}, &deviceRes); err != nil {
@ -106,7 +111,9 @@ func (s *PresenceConsumer) Start() error {
m.Header.Set(jetstream.UserID, presence.UserID) m.Header.Set(jetstream.UserID, presence.UserID)
m.Header.Set("presence", presence.ClientFields.Presence) m.Header.Set("presence", presence.ClientFields.Presence)
m.Header.Set("status_msg", *presence.ClientFields.StatusMsg) if presence.ClientFields.StatusMsg != nil {
m.Header.Set("status_msg", *presence.ClientFields.StatusMsg)
}
m.Header.Set("last_active_ts", strconv.Itoa(int(presence.LastActiveTS))) m.Header.Set("last_active_ts", strconv.Itoa(int(presence.LastActiveTS)))
if err = msg.RespondMsg(m); err != nil { if err = msg.RespondMsg(m); err != nil {

View file

@ -44,8 +44,8 @@ func GetFilter(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
filter, err := syncDB.GetFilter(req.Context(), localpart, filterID) filter := gomatrixserverlib.DefaultFilter()
if err != nil { if err := syncDB.GetFilter(req.Context(), &filter, localpart, filterID); err != nil {
//TODO better error handling. This error message is *probably* right, //TODO better error handling. This error message is *probably* right,
// but if there are obscure db errors, this will also be returned, // but if there are obscure db errors, this will also be returned,
// even though it is not correct. // even though it is not correct.

View file

@ -81,7 +81,7 @@ type Database interface {
// Returns a map following the format data[roomID] = []dataTypes // Returns a map following the format data[roomID] = []dataTypes
// If no data is retrieved, returns an empty map // If no data is retrieved, returns an empty map
// If there was an issue with the retrieval, returns an error // If there was an issue with the retrieval, returns an error
GetAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataFilterPart *gomatrixserverlib.EventFilter) (map[string][]string, error) GetAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataFilterPart *gomatrixserverlib.EventFilter) (map[string][]string, types.StreamPosition, error)
// UpsertAccountData keeps track of new or updated account data, by saving the type // UpsertAccountData keeps track of new or updated account data, by saving the type
// of the new/updated data, and the user ID and room ID the data is related to (empty) // of the new/updated data, and the user ID and room ID the data is related to (empty)
// room ID means the data isn't specific to any room) // room ID means the data isn't specific to any room)
@ -125,10 +125,10 @@ type Database interface {
// CleanSendToDeviceUpdates removes all send-to-device messages BEFORE the specified // CleanSendToDeviceUpdates removes all send-to-device messages BEFORE the specified
// from position, preventing the send-to-device table from growing indefinitely. // from position, preventing the send-to-device table from growing indefinitely.
CleanSendToDeviceUpdates(ctx context.Context, userID, deviceID string, before types.StreamPosition) (err error) CleanSendToDeviceUpdates(ctx context.Context, userID, deviceID string, before types.StreamPosition) (err error)
// GetFilter looks up the filter associated with a given local user and filter ID. // GetFilter looks up the filter associated with a given local user and filter ID
// Returns a filter structure. Otherwise returns an error if no such filter exists // and populates the target filter. Otherwise returns an error if no such filter exists
// or if there was an error talking to the database. // or if there was an error talking to the database.
GetFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error) GetFilter(ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string) error
// PutFilter puts the passed filter into the database. // PutFilter puts the passed filter into the database.
// Returns the filterID as a string. Otherwise returns an error if something // Returns the filterID as a string. Otherwise returns an error if something
// goes wrong. // goes wrong.

View file

@ -57,7 +57,7 @@ const insertAccountDataSQL = "" +
" RETURNING id" " RETURNING id"
const selectAccountDataInRangeSQL = "" + const selectAccountDataInRangeSQL = "" +
"SELECT room_id, type FROM syncapi_account_data_type" + "SELECT id, room_id, type FROM syncapi_account_data_type" +
" WHERE user_id = $1 AND id > $2 AND id <= $3" + " WHERE user_id = $1 AND id > $2 AND id <= $3" +
" AND ( $4::text[] IS NULL OR type LIKE ANY($4) )" + " AND ( $4::text[] IS NULL OR type LIKE ANY($4) )" +
" AND ( $5::text[] IS NULL OR NOT(type LIKE ANY($5)) )" + " AND ( $5::text[] IS NULL OR NOT(type LIKE ANY($5)) )" +
@ -103,8 +103,9 @@ func (s *accountDataStatements) SelectAccountDataInRange(
userID string, userID string,
r types.Range, r types.Range,
accountDataEventFilter *gomatrixserverlib.EventFilter, accountDataEventFilter *gomatrixserverlib.EventFilter,
) (data map[string][]string, err error) { ) (data map[string][]string, pos types.StreamPosition, err error) {
data = make(map[string][]string) data = make(map[string][]string)
pos = r.Low()
rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High(), rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High(),
pq.StringArray(filterConvertTypeWildcardToSQL(accountDataEventFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(accountDataEventFilter.Types)),
@ -116,11 +117,12 @@ func (s *accountDataStatements) SelectAccountDataInRange(
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed")
for rows.Next() { var dataType string
var dataType string var roomID string
var roomID string var id types.StreamPosition
if err = rows.Scan(&roomID, &dataType); err != nil { for rows.Next() {
if err = rows.Scan(&id, &roomID, &dataType); err != nil {
return return
} }
@ -129,8 +131,11 @@ func (s *accountDataStatements) SelectAccountDataInRange(
} else { } else {
data[roomID] = []string{dataType} data[roomID] = []string{dataType}
} }
if id > pos {
pos = id
}
} }
return data, rows.Err() return data, pos, rows.Err()
} }
func (s *accountDataStatements) SelectMaxAccountDataID( func (s *accountDataStatements) SelectMaxAccountDataID(

View file

@ -73,21 +73,20 @@ func NewPostgresFilterTable(db *sql.DB) (tables.Filter, error) {
} }
func (s *filterStatements) SelectFilter( func (s *filterStatements) SelectFilter(
ctx context.Context, localpart string, filterID string, ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string,
) (*gomatrixserverlib.Filter, error) { ) error {
// Retrieve filter from database (stored as canonical JSON) // Retrieve filter from database (stored as canonical JSON)
var filterData []byte var filterData []byte
err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData) err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
if err != nil { if err != nil {
return nil, err return err
} }
// Unmarshal JSON into Filter struct // Unmarshal JSON into Filter struct
filter := gomatrixserverlib.DefaultFilter() if err = json.Unmarshal(filterData, &target); err != nil {
if err = json.Unmarshal(filterData, &filter); err != nil { return err
return nil, err
} }
return &filter, nil return nil
} }
func (s *filterStatements) InsertFilter( func (s *filterStatements) InsertFilter(

View file

@ -127,6 +127,9 @@ func (p *presenceStatements) GetPresenceForUser(
} }
stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt) stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt)
err := stmt.QueryRowContext(ctx, userID).Scan(&result.Presence, &result.ClientFields.StatusMsg, &result.LastActiveTS) err := stmt.QueryRowContext(ctx, userID).Scan(&result.Presence, &result.ClientFields.StatusMsg, &result.LastActiveTS)
if err == sql.ErrNoRows {
return nil, nil
}
result.ClientFields.Presence = result.Presence.String() result.ClientFields.Presence = result.Presence.String()
return result, err return result, err
} }

View file

@ -265,7 +265,7 @@ func (d *Database) DeletePeeks(
func (d *Database) GetAccountDataInRange( func (d *Database) GetAccountDataInRange(
ctx context.Context, userID string, r types.Range, ctx context.Context, userID string, r types.Range,
accountDataFilterPart *gomatrixserverlib.EventFilter, accountDataFilterPart *gomatrixserverlib.EventFilter,
) (map[string][]string, error) { ) (map[string][]string, types.StreamPosition, error) {
return d.AccountData.SelectAccountDataInRange(ctx, userID, r, accountDataFilterPart) return d.AccountData.SelectAccountDataInRange(ctx, userID, r, accountDataFilterPart)
} }
@ -513,9 +513,9 @@ func (d *Database) StreamToTopologicalPosition(
} }
func (d *Database) GetFilter( func (d *Database) GetFilter(
ctx context.Context, localpart string, filterID string, ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string,
) (*gomatrixserverlib.Filter, error) { ) error {
return d.Filter.SelectFilter(ctx, localpart, filterID) return d.Filter.SelectFilter(ctx, target, localpart, filterID)
} }
func (d *Database) PutFilter( func (d *Database) PutFilter(

View file

@ -43,7 +43,7 @@ const insertAccountDataSQL = "" +
// further parameters are added by prepareWithFilters // further parameters are added by prepareWithFilters
const selectAccountDataInRangeSQL = "" + const selectAccountDataInRangeSQL = "" +
"SELECT room_id, type FROM syncapi_account_data_type" + "SELECT id, room_id, type FROM syncapi_account_data_type" +
" WHERE user_id = $1 AND id > $2 AND id <= $3" " WHERE user_id = $1 AND id > $2 AND id <= $3"
const selectMaxAccountDataIDSQL = "" + const selectMaxAccountDataIDSQL = "" +
@ -95,7 +95,8 @@ func (s *accountDataStatements) SelectAccountDataInRange(
userID string, userID string,
r types.Range, r types.Range,
filter *gomatrixserverlib.EventFilter, filter *gomatrixserverlib.EventFilter,
) (data map[string][]string, err error) { ) (data map[string][]string, pos types.StreamPosition, err error) {
pos = r.Low()
data = make(map[string][]string) data = make(map[string][]string)
stmt, params, err := prepareWithFilters( stmt, params, err := prepareWithFilters(
s.db, nil, selectAccountDataInRangeSQL, s.db, nil, selectAccountDataInRangeSQL,
@ -112,11 +113,12 @@ func (s *accountDataStatements) SelectAccountDataInRange(
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed")
for rows.Next() { var dataType string
var dataType string var roomID string
var roomID string var id types.StreamPosition
if err = rows.Scan(&roomID, &dataType); err != nil { for rows.Next() {
if err = rows.Scan(&id, &roomID, &dataType); err != nil {
return return
} }
@ -125,9 +127,12 @@ func (s *accountDataStatements) SelectAccountDataInRange(
} else { } else {
data[roomID] = []string{dataType} data[roomID] = []string{dataType}
} }
if id > pos {
pos = id
}
} }
return data, nil return data, pos, nil
} }
func (s *accountDataStatements) SelectMaxAccountDataID( func (s *accountDataStatements) SelectMaxAccountDataID(

View file

@ -77,21 +77,20 @@ func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) {
} }
func (s *filterStatements) SelectFilter( func (s *filterStatements) SelectFilter(
ctx context.Context, localpart string, filterID string, ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string,
) (*gomatrixserverlib.Filter, error) { ) error {
// Retrieve filter from database (stored as canonical JSON) // Retrieve filter from database (stored as canonical JSON)
var filterData []byte var filterData []byte
err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData) err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
if err != nil { if err != nil {
return nil, err return err
} }
// Unmarshal JSON into Filter struct // Unmarshal JSON into Filter struct
filter := gomatrixserverlib.DefaultFilter() if err = json.Unmarshal(filterData, &target); err != nil {
if err = json.Unmarshal(filterData, &filter); err != nil { return err
return nil, err
} }
return &filter, nil return nil
} }
func (s *filterStatements) InsertFilter( func (s *filterStatements) InsertFilter(

View file

@ -142,6 +142,9 @@ func (p *presenceStatements) GetPresenceForUser(
} }
stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt) stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt)
err := stmt.QueryRowContext(ctx, userID).Scan(&result.Presence, &result.ClientFields.StatusMsg, &result.LastActiveTS) err := stmt.QueryRowContext(ctx, userID).Scan(&result.Presence, &result.ClientFields.StatusMsg, &result.LastActiveTS)
if err == sql.ErrNoRows {
return nil, nil
}
result.ClientFields.Presence = result.Presence.String() result.ClientFields.Presence = result.Presence.String()
return result, err return result, err
} }

View file

@ -27,7 +27,7 @@ import (
type AccountData interface { type AccountData interface {
InsertAccountData(ctx context.Context, txn *sql.Tx, userID, roomID, dataType string) (pos types.StreamPosition, err error) InsertAccountData(ctx context.Context, txn *sql.Tx, userID, roomID, dataType string) (pos types.StreamPosition, err error)
// SelectAccountDataInRange returns a map of room ID to a list of `dataType`. // SelectAccountDataInRange returns a map of room ID to a list of `dataType`.
SelectAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataEventFilter *gomatrixserverlib.EventFilter) (data map[string][]string, err error) SelectAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataEventFilter *gomatrixserverlib.EventFilter) (data map[string][]string, pos types.StreamPosition, err error)
SelectMaxAccountDataID(ctx context.Context, txn *sql.Tx) (id int64, err error) SelectMaxAccountDataID(ctx context.Context, txn *sql.Tx) (id int64, err error)
} }
@ -157,7 +157,7 @@ type SendToDevice interface {
} }
type Filter interface { type Filter interface {
SelectFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error) SelectFilter(ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string) error
InsertFilter(ctx context.Context, filter *gomatrixserverlib.Filter, localpart string) (filterID string, err error) InsertFilter(ctx context.Context, filter *gomatrixserverlib.Filter, localpart string) (filterID string, err error)
} }

View file

@ -43,7 +43,7 @@ func (p *AccountDataStreamProvider) IncrementalSync(
To: to, To: to,
} }
dataTypes, err := p.DB.GetAccountDataInRange( dataTypes, pos, err := p.DB.GetAccountDataInRange(
ctx, req.Device.UserID, r, &req.Filter.AccountData, ctx, req.Device.UserID, r, &req.Filter.AccountData,
) )
if err != nil { if err != nil {
@ -53,6 +53,12 @@ func (p *AccountDataStreamProvider) IncrementalSync(
// Iterate over the rooms // Iterate over the rooms
for roomID, dataTypes := range dataTypes { for roomID, dataTypes := range dataTypes {
// For a complete sync, make sure we're only including this room if
// that room was present in the joined rooms.
if from == 0 && roomID != "" && !req.IsRoomPresent(roomID) {
continue
}
// Request the missing data from the database // Request the missing data from the database
for _, dataType := range dataTypes { for _, dataType := range dataTypes {
dataReq := userapi.QueryAccountDataRequest{ dataReq := userapi.QueryAccountDataRequest{
@ -95,5 +101,5 @@ func (p *AccountDataStreamProvider) IncrementalSync(
} }
} }
return to return pos
} }

View file

@ -16,7 +16,6 @@ package streams
import ( import (
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"sync" "sync"
@ -80,11 +79,10 @@ func (p *PresenceStreamProvider) IncrementalSync(
if _, ok := presences[roomUsers[i]]; ok { if _, ok := presences[roomUsers[i]]; ok {
continue continue
} }
// Bear in mind that this might return nil, but at least populating
// a nil means that there's a map entry so we won't repeat this call.
presences[roomUsers[i]], err = p.DB.GetPresence(ctx, roomUsers[i]) presences[roomUsers[i]], err = p.DB.GetPresence(ctx, roomUsers[i])
if err != nil { if err != nil {
if err == sql.ErrNoRows {
continue
}
req.Log.WithError(err).Error("unable to query presence for user") req.Log.WithError(err).Error("unable to query presence for user")
return from return from
} }
@ -93,8 +91,10 @@ func (p *PresenceStreamProvider) IncrementalSync(
} }
lastPos := to lastPos := to
for i := range presences { for _, presence := range presences {
presence := presences[i] if presence == nil {
continue
}
// Ignore users we don't share a room with // Ignore users we don't share a room with
if req.Device.UserID != presence.UserID && !p.notifier.IsSharedUser(req.Device.UserID, presence.UserID) { if req.Device.UserID != presence.UserID && !p.notifier.IsSharedUser(req.Device.UserID, presence.UserID) {
continue continue

View file

@ -62,6 +62,12 @@ func (p *ReceiptStreamProvider) IncrementalSync(
} }
for roomID, receipts := range receiptsByRoom { for roomID, receipts := range receiptsByRoom {
// For a complete sync, make sure we're only including this room if
// that room was present in the joined rooms.
if from == 0 && !req.IsRoomPresent(roomID) {
continue
}
jr := *types.NewJoinResponse() jr := *types.NewJoinResponse()
if existing, ok := req.Response.Rooms.Join[roomID]; ok { if existing, ok := req.Response.Rooms.Join[roomID]; ok {
jr = existing jr = existing

View file

@ -18,6 +18,7 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"math"
"net/http" "net/http"
"strconv" "strconv"
"time" "time"
@ -47,6 +48,13 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
} }
// TODO: read from stored filters too // TODO: read from stored filters too
filter := gomatrixserverlib.DefaultFilter() filter := gomatrixserverlib.DefaultFilter()
if since.IsEmpty() {
// Send as much account data down for complete syncs as possible
// by default, otherwise clients do weird things while waiting
// for the rest of the data to trickle down.
filter.AccountData.Limit = math.MaxInt32
filter.Room.AccountData.Limit = math.MaxInt32
}
filterQuery := req.URL.Query().Get("filter") filterQuery := req.URL.Query().Get("filter")
if filterQuery != "" { if filterQuery != "" {
if filterQuery[0] == '{' { if filterQuery[0] == '{' {
@ -61,11 +69,9 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return nil, fmt.Errorf("gomatrixserverlib.SplitID: %w", err) return nil, fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
} }
if f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery); err != nil && err != sql.ErrNoRows { if err := syncDB.GetFilter(req.Context(), &filter, localpart, filterQuery); err != nil && err != sql.ErrNoRows {
util.GetLogger(req.Context()).WithError(err).Error("syncDB.GetFilter failed") util.GetLogger(req.Context()).WithError(err).Error("syncDB.GetFilter failed")
return nil, fmt.Errorf("syncDB.GetFilter: %w", err) return nil, fmt.Errorf("syncDB.GetFilter: %w", err)
} else if f != nil {
filter = *f
} }
} }
} }

View file

@ -127,14 +127,23 @@ func (rp *RequestPool) updatePresence(db storage.Presence, presence string, user
if !ok { // this should almost never happen if !ok { // this should almost never happen
return return
} }
newPresence := types.PresenceInternal{ newPresence := types.PresenceInternal{
ClientFields: types.PresenceClientResponse{
Presence: presenceID.String(),
},
Presence: presenceID, Presence: presenceID,
UserID: userID, UserID: userID,
LastActiveTS: gomatrixserverlib.AsTimestamp(time.Now()), LastActiveTS: gomatrixserverlib.AsTimestamp(time.Now()),
} }
// ensure we also send the current status_msg to federated servers and not nil
dbPresence, err := db.GetPresence(context.Background(), userID)
if err != nil && err != sql.ErrNoRows {
return
}
if dbPresence != nil {
newPresence.ClientFields = dbPresence.ClientFields
}
newPresence.ClientFields.Presence = presenceID.String()
defer rp.presence.Store(userID, newPresence) defer rp.presence.Store(userID, newPresence)
// avoid spamming presence updates when syncing // avoid spamming presence updates when syncing
existingPresence, ok := rp.presence.LoadOrStore(userID, newPresence) existingPresence, ok := rp.presence.LoadOrStore(userID, newPresence)
@ -145,13 +154,7 @@ func (rp *RequestPool) updatePresence(db storage.Presence, presence string, user
} }
} }
// ensure we also send the current status_msg to federated servers and not nil if err := rp.producer.SendPresence(userID, presenceID, newPresence.ClientFields.StatusMsg); err != nil {
dbPresence, err := db.GetPresence(context.Background(), userID)
if err != nil && err != sql.ErrNoRows {
return
}
if err := rp.producer.SendPresence(userID, presenceID, dbPresence.ClientFields.StatusMsg); err != nil {
logrus.WithError(err).Error("Unable to publish presence message from sync") logrus.WithError(err).Error("Unable to publish presence message from sync")
return return
} }

View file

@ -25,6 +25,23 @@ type SyncRequest struct {
IgnoredUsers IgnoredUsers IgnoredUsers IgnoredUsers
} }
func (r *SyncRequest) IsRoomPresent(roomID string) bool {
membership, ok := r.Rooms[roomID]
if !ok {
return false
}
switch membership {
case gomatrixserverlib.Join:
return true
case gomatrixserverlib.Invite:
return true
case gomatrixserverlib.Peek:
return true
default:
return false
}
}
type StreamProvider interface { type StreamProvider interface {
Setup() Setup()

View file

@ -681,8 +681,6 @@ GET /presence/:user_id/status fetches initial status
PUT /presence/:user_id/status updates my presence PUT /presence/:user_id/status updates my presence
Presence change reports an event to myself Presence change reports an event to myself
Existing members see new members' presence Existing members see new members' presence
#Existing members see new member's presence
Newly joined room includes presence in incremental sync
Get presence for newly joined members in incremental sync Get presence for newly joined members in incremental sync
User sees their own presence in a sync User sees their own presence in a sync
User sees updates to presence from other users in the incremental sync. User sees updates to presence from other users in the incremental sync.