diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index 307fa17e6..4d01a9f42 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -21,16 +21,15 @@ import ( "io" "io/ioutil" "os" + "regexp" "strings" + "github.com/matrix-org/dendrite/setup/base" "github.com/sirupsen/logrus" - "golang.org/x/crypto/bcrypt" "golang.org/x/term" "github.com/matrix-org/dendrite/setup" - "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" - userdb "github.com/matrix-org/dendrite/userapi/storage" ) const usage = `Usage: %s @@ -56,13 +55,14 @@ Arguments: ` var ( - username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')") - password = flag.String("password", "", "The password to associate with the account (optional, account will be password-less if not specified)") - pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") - pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") - askPass = flag.Bool("ask-pass", false, "Ask for the password to use") - isAdmin = flag.Bool("admin", false, "Create an admin account") - resetPassword = flag.Bool("reset-password", false, "Resets the password for the given username") + username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')") + password = flag.String("password", "", "The password to associate with the account (optional, account will be password-less if not specified)") + pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") + pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") + askPass = flag.Bool("ask-pass", false, "Ask for the password to use") + isAdmin = flag.Bool("admin", false, "Create an admin account") + resetPassword = flag.Bool("reset-password", false, "Resets the password for the given username") + validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) ) func main() { @@ -78,25 +78,21 @@ func main() { os.Exit(1) } + if !validUsernameRegex.MatchString(*username) { + logrus.Warn("Username can only contain characters a-z, 0-9, or '_-./='") + os.Exit(1) + } + pass := getPassword(password, pwdFile, pwdStdin, askPass, os.Stdin) - accountDB, err := userdb.NewDatabase( - &config.DatabaseOptions{ - ConnectionString: cfg.UserAPI.AccountDatabase.ConnectionString, - }, - cfg.Global.ServerName, bcrypt.DefaultCost, - cfg.UserAPI.OpenIDTokenLifetimeMS, - api.DefaultLoginTokenLifetime, - ) - if err != nil { - logrus.Fatalln("Failed to connect to the database:", err.Error()) - } + b := base.NewBaseDendrite(cfg, "create-account") + accountDB := b.CreateAccountsDB() accType := api.AccountTypeUser if *isAdmin { accType = api.AccountTypeAdmin } - + var err error if *resetPassword { err = accountDB.SetPassword(context.Background(), *username, pass) if err != nil { diff --git a/federationapi/consumers/keychange.go b/federationapi/consumers/keychange.go index 22dbc32da..33d716d25 100644 --- a/federationapi/consumers/keychange.go +++ b/federationapi/consumers/keychange.go @@ -203,9 +203,9 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool { return err == nil } -func prevID(streamID int) []int { +func prevID(streamID int64) []int64 { if streamID <= 1 { return nil } - return []int{streamID - 1} + return []int64{streamID - 1} } diff --git a/go.mod b/go.mod index 9957a5b6c..ec4646b6b 100644 --- a/go.mod +++ b/go.mod @@ -40,7 +40,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902 + github.com/matrix-org/gomatrixserverlib v0.0.0-20220310124155-116ed5cc1bfa github.com/matrix-org/pinecone v0.0.0-20220308124038-cfde1f8054c5 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.10 diff --git a/go.sum b/go.sum index 951c0eba9..d6086d6e7 100644 --- a/go.sum +++ b/go.sum @@ -983,8 +983,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902 h1:WHlrE8BYh/hzn1RKwq3YMAlhHivX47jQKAjZFtkJyPE= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220301141554-e124bd7d7902/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220310124155-116ed5cc1bfa h1:anEGvpRn4v6akmxFWqGDobB6csEt3OWmp67pufccimE= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220310124155-116ed5cc1bfa/go.mod h1:+WF5InseAMgi1fTnU46JH39IDpEvLep0fDzx9LDf2Bo= github.com/matrix-org/pinecone v0.0.0-20220308124038-cfde1f8054c5 h1:7viLTiLAA2MtGKY+uf14j6TjfKvvGLAMj/qdm70jJuQ= github.com/matrix-org/pinecone v0.0.0-20220308124038-cfde1f8054c5/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= diff --git a/keyserver/api/api.go b/keyserver/api/api.go index 54eb04f8a..d361c6222 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -70,7 +70,7 @@ type DeviceMessage struct { *DeviceKeys `json:"DeviceKeys,omitempty"` *eduapi.OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"` // A monotonically increasing number which represents device changes for this user. - StreamID int + StreamID int64 DeviceChangeID int64 } @@ -108,7 +108,7 @@ type DeviceKeys struct { } // WithStreamID returns a copy of this device message with the given stream ID -func (k *DeviceKeys) WithStreamID(streamID int) DeviceMessage { +func (k *DeviceKeys) WithStreamID(streamID int64) DeviceMessage { return DeviceMessage{ DeviceKeys: k, StreamID: streamID, @@ -281,7 +281,7 @@ type QueryDeviceMessagesRequest struct { type QueryDeviceMessagesResponse struct { // The latest stream ID - StreamID int + StreamID int64 Devices []DeviceMessage Error *KeyError } diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index 974d0196b..4b2b8c187 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -109,7 +109,7 @@ type DeviceListUpdaterDatabase interface { StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error // PrevIDsExists returns true if all prev IDs exist for this user. - PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) + PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go index ff939355f..0033a5086 100644 --- a/keyserver/internal/device_list_update_test.go +++ b/keyserver/internal/device_list_update_test.go @@ -46,7 +46,7 @@ func (p *mockKeyChangeProducer) ProduceKeyChanges(keys []api.DeviceMessage) erro type mockDeviceListUpdaterDatabase struct { staleUsers map[string]bool - prevIDsExist func(string, []int) bool + prevIDsExist func(string, []int64) bool storedKeys []api.DeviceMessage mu sync.Mutex // protect staleUsers } @@ -101,7 +101,7 @@ func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Contex } // PrevIDsExists returns true if all prev IDs exist for this user. -func (d *mockDeviceListUpdaterDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) { +func (d *mockDeviceListUpdaterDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) { return d.prevIDsExist(userID, prevIDs), nil } @@ -139,7 +139,7 @@ func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrix func TestUpdateHavePrevID(t *testing.T) { db := &mockDeviceListUpdaterDatabase{ staleUsers: make(map[string]bool), - prevIDsExist: func(string, []int) bool { + prevIDsExist: func(string, []int64) bool { return true }, } @@ -151,7 +151,7 @@ func TestUpdateHavePrevID(t *testing.T) { Deleted: false, DeviceID: "FOO", Keys: []byte(`{"key":"value"}`), - PrevID: []int{0}, + PrevID: []int64{0}, StreamID: 1, UserID: "@alice:localhost", } @@ -185,7 +185,7 @@ func TestUpdateHavePrevID(t *testing.T) { func TestUpdateNoPrevID(t *testing.T) { db := &mockDeviceListUpdaterDatabase{ staleUsers: make(map[string]bool), - prevIDsExist: func(string, []int) bool { + prevIDsExist: func(string, []int64) bool { return false }, } @@ -226,7 +226,7 @@ func TestUpdateNoPrevID(t *testing.T) { Deleted: false, DeviceID: "another_device_id", Keys: []byte(`{"key":"value"}`), - PrevID: []int{3}, + PrevID: []int64{3}, StreamID: 4, UserID: remoteUserID, } @@ -268,7 +268,7 @@ func TestDebounce(t *testing.T) { t.Skipf("panic on closed channel on GHA") db := &mockDeviceListUpdaterDatabase{ staleUsers: make(map[string]bool), - prevIDsExist: func(string, []int) bool { + prevIDsExist: func(string, []int64) bool { return true }, } diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 0a8bef95d..cc9d3a616 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -205,7 +205,7 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query } return } - maxStreamID := 0 + maxStreamID := int64(0) for _, m := range msgs { if m.StreamID > maxStreamID { maxStreamID = m.StreamID diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 4dffe695c..16e034776 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -49,7 +49,7 @@ type Database interface { StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error // PrevIDsExists returns true if all prev IDs exist for this user. - PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) + PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go index 628301cf7..ccd20cbd6 100644 --- a/keyserver/storage/postgres/device_keys_table.go +++ b/keyserver/storage/postgres/device_keys_table.go @@ -121,7 +121,7 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { for i, key := range keys { var keyJSONStr string - var streamID int + var streamID int64 var displayName sql.NullString err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) if err != nil && err != sql.ErrNoRows { @@ -138,15 +138,15 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys [] return nil } -func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) { +func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) { // nullable if there are no results - var nullStream sql.NullInt32 + var nullStream sql.NullInt64 err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) if err == sql.ErrNoRows { err = nil } if nullStream.Valid { - streamID = nullStream.Int32 + streamID = nullStream.Int64 } return } @@ -211,7 +211,7 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID } dk.UserID = userID var keyJSON string - var streamID int + var streamID int64 var displayName sql.NullString if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil { return nil, err diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index f2790c8df..03215b93b 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -59,12 +59,8 @@ func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys) } -func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) { - sids := make([]int64, len(prevIDs)) - for i := range prevIDs { - sids[i] = int64(prevIDs[i]) - } - count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, sids) +func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) { + count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, prevIDs) if err != nil { return false, err } @@ -85,7 +81,7 @@ func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceM func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { // work out the latest stream IDs for each user - userIDToStreamID := make(map[string]int) + userIDToStreamID := make(map[string]int64) for _, k := range keys { userIDToStreamID[k.UserID] = 0 } @@ -95,7 +91,7 @@ func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMe if err != nil { return err } - userIDToStreamID[userID] = int(streamID) + userIDToStreamID[userID] = streamID } // set the stream IDs for each key for i := range keys { diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index b461424c6..e77b49b35 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -145,7 +145,7 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID dk.Type = api.TypeDeviceKeyUpdate dk.UserID = userID var keyJSON string - var streamID int + var streamID int64 var displayName sql.NullString if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil { return nil, err @@ -166,7 +166,7 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { for i, key := range keys { var keyJSONStr string - var streamID int + var streamID int64 var displayName sql.NullString err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) if err != nil && err != sql.ErrNoRows { @@ -183,15 +183,15 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys [] return nil } -func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) { +func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) { // nullable if there are no results - var nullStream sql.NullInt32 + var nullStream sql.NullInt64 err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) if err == sql.ErrNoRows { err = nil } if nullStream.Valid { - streamID = nullStream.Int32 + streamID = nullStream.Int64 } return } @@ -204,13 +204,13 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID } query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1) // nullable if there are no results - var count sql.NullInt32 + var count sql.NullInt64 err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count) if err != nil { return 0, err } if count.Valid { - return int(count.Int32), nil + return int(count.Int64), nil } return 0, nil } diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go index 4d5137249..84d2098ad 100644 --- a/keyserver/storage/storage_test.go +++ b/keyserver/storage/storage_test.go @@ -177,7 +177,7 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) { if err != nil { t.Fatalf("DeviceKeysForUser returned error: %s", err) } - wantStreamIDs := map[string]int{ + wantStreamIDs := map[string]int64{ "AAA": 3, "another_device": 2, } diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index cd1719598..f840cd1f3 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -37,7 +37,7 @@ type OneTimeKeys interface { type DeviceKeys interface { SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error - SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) + SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error diff --git a/setup/base/base.go b/setup/base/base.go index ef3b2be29..692a77d5c 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -100,6 +100,7 @@ const ( // The componentName is used for logging purposes, and should be a friendly name // of the compontent running, e.g. "SyncAPI" func NewBaseDendrite(cfg *config.Dendrite, componentName string, options ...BaseDendriteOptions) *BaseDendrite { + platformSanityChecks() useHTTPAPIs := false cacheMetrics := true for _, opt := range options { diff --git a/setup/base/sanity_other.go b/setup/base/sanity_other.go new file mode 100644 index 000000000..48fe6e1f8 --- /dev/null +++ b/setup/base/sanity_other.go @@ -0,0 +1,8 @@ +//go:build !linux && !darwin && !netbsd && !freebsd && !openbsd && !solaris && !dragonfly && !aix +// +build !linux,!darwin,!netbsd,!freebsd,!openbsd,!solaris,!dragonfly,!aix + +package base + +func platformSanityChecks() { + // Nothing to do yet. +} diff --git a/setup/base/sanity_unix.go b/setup/base/sanity_unix.go new file mode 100644 index 000000000..0c1543e0b --- /dev/null +++ b/setup/base/sanity_unix.go @@ -0,0 +1,22 @@ +//go:build linux || darwin || netbsd || freebsd || openbsd || solaris || dragonfly || aix +// +build linux darwin netbsd freebsd openbsd solaris dragonfly aix + +package base + +import ( + "syscall" + + "github.com/sirupsen/logrus" +) + +func platformSanityChecks() { + // Dendrite needs a relatively high number of file descriptors in order + // to function properly, particularly when federating with lots of servers. + // If we run out of file descriptors, we might run into problems accessing + // PostgreSQL amongst other things. Complain at startup if we think the + // number of file descriptors is too low. + var rLimit syscall.Rlimit + if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit); err == nil && rLimit.Cur < 65535 { + logrus.Warnf("IMPORTANT: Process file descriptor limit is currently %d, it is recommended to raise the limit for Dendrite to at least 65535 to avoid issues", rLimit.Cur) + } +} diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index d07fc0c64..2412bc2ae 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -66,7 +66,7 @@ func Context( membershipRes := roomserver.QueryMembershipForUserResponse{} membershipReq := roomserver.QueryMembershipForUserRequest{UserID: device.UserID, RoomID: roomID} if err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes); err != nil { - logrus.WithError(err).Error("unable to fo membership") + logrus.WithError(err).Error("unable to query membership") return jsonerror.InternalServerError() } @@ -158,17 +158,19 @@ func applyLazyLoadMembers(filter *gomatrixserverlib.RoomEventFilter, eventsAfter } newState := []*gomatrixserverlib.HeaderedEvent{} + membershipEvents := []*gomatrixserverlib.HeaderedEvent{} for _, event := range state { if event.Type() != gomatrixserverlib.MRoomMember { newState = append(newState, event) } else { // did the user send an event? if x[event.Sender()] { - newState = append(newState, event) + membershipEvents = append(membershipEvents, event) } } } - return newState + // Add the membershipEvents to the end of the list, to make Sytest happy + return append(newState, membershipEvents...) } func parseRoomEventFilter(req *http.Request) (*gomatrixserverlib.RoomEventFilter, error) { diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index ee649c165..d646a0e41 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -77,6 +77,9 @@ const DeleteRoomStateForRoomSQL = "" + const selectRoomIDsWithMembershipSQL = "" + "SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2" +const selectRoomIDsWithAnyMembershipSQL = "" + + "SELECT DISTINCT room_id, membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1" + const selectCurrentStateSQL = "" + "SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" + " AND ( $2::text[] IS NULL OR sender = ANY($2) )" + @@ -102,14 +105,15 @@ const selectEventsWithEventIDsSQL = "" + " FROM syncapi_current_room_state WHERE event_id = ANY($1)" type currentRoomStateStatements struct { - upsertRoomStateStmt *sql.Stmt - deleteRoomStateByEventIDStmt *sql.Stmt - DeleteRoomStateForRoomStmt *sql.Stmt - selectRoomIDsWithMembershipStmt *sql.Stmt - selectCurrentStateStmt *sql.Stmt - selectJoinedUsersStmt *sql.Stmt - selectEventsWithEventIDsStmt *sql.Stmt - selectStateEventStmt *sql.Stmt + upsertRoomStateStmt *sql.Stmt + deleteRoomStateByEventIDStmt *sql.Stmt + DeleteRoomStateForRoomStmt *sql.Stmt + selectRoomIDsWithMembershipStmt *sql.Stmt + selectRoomIDsWithAnyMembershipStmt *sql.Stmt + selectCurrentStateStmt *sql.Stmt + selectJoinedUsersStmt *sql.Stmt + selectEventsWithEventIDsStmt *sql.Stmt + selectStateEventStmt *sql.Stmt } func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { @@ -130,6 +134,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { return nil, err } + if s.selectRoomIDsWithAnyMembershipStmt, err = db.Prepare(selectRoomIDsWithAnyMembershipSQL); err != nil { + return nil, err + } if s.selectCurrentStateStmt, err = db.Prepare(selectCurrentStateSQL); err != nil { return nil, err } @@ -194,6 +201,31 @@ func (s *currentRoomStateStatements) SelectRoomIDsWithMembership( return result, rows.Err() } +// SelectRoomIDsWithAnyMembership returns a map of all memberships for the given user. +func (s *currentRoomStateStatements) SelectRoomIDsWithAnyMembership( + ctx context.Context, + txn *sql.Tx, + userID string, +) (map[string]string, error) { + stmt := sqlutil.TxStmt(txn, s.selectRoomIDsWithAnyMembershipStmt) + rows, err := stmt.QueryContext(ctx, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsWithAnyMembership: rows.close() failed") + + result := map[string]string{} + for rows.Next() { + var roomID string + var membership string + if err := rows.Scan(&roomID, &membership); err != nil { + return nil, err + } + result[roomID] = membership + } + return result, rows.Err() +} + // SelectCurrentState returns all the current state events for the given room. func (s *currentRoomStateStatements) SelectCurrentState( ctx context.Context, txn *sql.Tx, roomID string, diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index d4cc4f3fb..26689f447 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -119,13 +119,14 @@ const selectStateInRangeSQL = "" + "SELECT event_id, id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids" + " FROM syncapi_output_room_events" + " WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" + - " AND ( $3::text[] IS NULL OR sender = ANY($3) )" + - " AND ( $4::text[] IS NULL OR NOT(sender = ANY($4)) )" + - " AND ( $5::text[] IS NULL OR type LIKE ANY($5) )" + - " AND ( $6::text[] IS NULL OR NOT(type LIKE ANY($6)) )" + - " AND ( $7::bool IS NULL OR contains_url = $7 )" + + " AND room_id = ANY($3)" + + " AND ( $4::text[] IS NULL OR sender = ANY($4) )" + + " AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )" + + " AND ( $6::text[] IS NULL OR type LIKE ANY($6) )" + + " AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" + + " AND ( $8::bool IS NULL OR contains_url = $8 )" + " ORDER BY id ASC" + - " LIMIT $8" + " LIMIT $9" const deleteEventsForRoomSQL = "" + "DELETE FROM syncapi_output_room_events WHERE room_id = $1" @@ -200,12 +201,12 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event // two positions, only the most recent state is returned. func (s *outputRoomEventsStatements) SelectStateInRange( ctx context.Context, txn *sql.Tx, r types.Range, - stateFilter *gomatrixserverlib.StateFilter, + stateFilter *gomatrixserverlib.StateFilter, roomIDs []string, ) (map[string]map[string]bool, map[string]types.StreamEvent, error) { stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt) rows, err := stmt.QueryContext( - ctx, r.Low(), r.High(), + ctx, r.Low(), r.High(), pq.StringArray(roomIDs), pq.StringArray(stateFilter.Senders), pq.StringArray(stateFilter.NotSenders), pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)), diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 87d7c6df7..2c166eef7 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -689,10 +689,26 @@ func (d *Database) GetStateDeltas( var succeeded bool defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err) + // Look up all memberships for the user. We only care about rooms that a + // user has ever interacted with — joined to, kicked/banned from, left. + memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID) + if err != nil { + return nil, nil, err + } + + allRoomIDs := make([]string, 0, len(memberships)) + joinedRoomIDs := make([]string, 0, len(memberships)) + for roomID, membership := range memberships { + allRoomIDs = append(allRoomIDs, roomID) + if membership == gomatrixserverlib.Join { + joinedRoomIDs = append(joinedRoomIDs, roomID) + } + } + var deltas []types.StateDelta // get all the state events ever (i.e. for all available rooms) between these two positions - stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter) + stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs) if err != nil { return nil, nil, err } @@ -760,10 +776,6 @@ func (d *Database) GetStateDeltas( } // Add in currently joined rooms - joinedRoomIDs, err := d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) - if err != nil { - return nil, nil, err - } for _, joinedRoomID := range joinedRoomIDs { deltas = append(deltas, types.StateDelta{ Membership: gomatrixserverlib.Join, @@ -792,6 +804,22 @@ func (d *Database) GetStateDeltasForFullStateSync( var succeeded bool defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err) + // Look up all memberships for the user. We only care about rooms that a + // user has ever interacted with — joined to, kicked/banned from, left. + memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID) + if err != nil { + return nil, nil, err + } + + allRoomIDs := make([]string, 0, len(memberships)) + joinedRoomIDs := make([]string, 0, len(memberships)) + for roomID, membership := range memberships { + allRoomIDs = append(allRoomIDs, roomID) + if membership == gomatrixserverlib.Join { + joinedRoomIDs = append(joinedRoomIDs, roomID) + } + } + // Use a reasonable initial capacity deltas := make(map[string]types.StateDelta) @@ -816,7 +844,7 @@ func (d *Database) GetStateDeltasForFullStateSync( } // Get all the state events ever between these two positions - stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter) + stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs) if err != nil { return nil, nil, err } @@ -842,11 +870,6 @@ func (d *Database) GetStateDeltasForFullStateSync( } } - joinedRoomIDs, err := d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) - if err != nil { - return nil, nil, err - } - // Add full states for all joined rooms for _, joinedRoomID := range joinedRoomIDs { s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, joinedRoomID, stateFilter) diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index c91ca6923..587f9d240 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -66,6 +66,9 @@ const DeleteRoomStateForRoomSQL = "" + const selectRoomIDsWithMembershipSQL = "" + "SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2" +const selectRoomIDsWithAnyMembershipSQL = "" + + "SELECT DISTINCT room_id, membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1" + const selectCurrentStateSQL = "" + "SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" @@ -86,14 +89,15 @@ const selectEventsWithEventIDsSQL = "" + " FROM syncapi_current_room_state WHERE event_id IN ($1)" type currentRoomStateStatements struct { - db *sql.DB - streamIDStatements *streamIDStatements - upsertRoomStateStmt *sql.Stmt - deleteRoomStateByEventIDStmt *sql.Stmt - DeleteRoomStateForRoomStmt *sql.Stmt - selectRoomIDsWithMembershipStmt *sql.Stmt - selectJoinedUsersStmt *sql.Stmt - selectStateEventStmt *sql.Stmt + db *sql.DB + streamIDStatements *streamIDStatements + upsertRoomStateStmt *sql.Stmt + deleteRoomStateByEventIDStmt *sql.Stmt + DeleteRoomStateForRoomStmt *sql.Stmt + selectRoomIDsWithMembershipStmt *sql.Stmt + selectRoomIDsWithAnyMembershipStmt *sql.Stmt + selectJoinedUsersStmt *sql.Stmt + selectStateEventStmt *sql.Stmt } func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) { @@ -117,6 +121,9 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (t if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { return nil, err } + if s.selectRoomIDsWithAnyMembershipStmt, err = db.Prepare(selectRoomIDsWithAnyMembershipSQL); err != nil { + return nil, err + } if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { return nil, err } @@ -175,6 +182,31 @@ func (s *currentRoomStateStatements) SelectRoomIDsWithMembership( return result, nil } +// SelectRoomIDsWithAnyMembership returns a map of all memberships for the given user. +func (s *currentRoomStateStatements) SelectRoomIDsWithAnyMembership( + ctx context.Context, + txn *sql.Tx, + userID string, +) (map[string]string, error) { + stmt := sqlutil.TxStmt(txn, s.selectRoomIDsWithAnyMembershipStmt) + rows, err := stmt.QueryContext(ctx, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsWithAnyMembership: rows.close() failed") + + result := map[string]string{} + for rows.Next() { + var roomID string + var membership string + if err := rows.Scan(&roomID, &membership); err != nil { + return nil, err + } + result[roomID] = membership + } + return result, rows.Err() +} + // CurrentState returns all the current state events for the given room. func (s *currentRoomStateStatements) SelectCurrentState( ctx context.Context, txn *sql.Tx, roomID string, diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 1b256f91a..b9115262e 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -21,6 +21,7 @@ import ( "encoding/json" "fmt" "sort" + "strings" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" @@ -87,6 +88,7 @@ const selectStateInRangeSQL = "" + "SELECT event_id, id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids" + " FROM syncapi_output_room_events" + " WHERE (id > $1 AND id <= $2)" + + " AND room_id IN ($3)" + " AND ((add_state_ids IS NOT NULL AND add_state_ids != '') OR (remove_state_ids IS NOT NULL AND remove_state_ids != ''))" // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters @@ -155,13 +157,17 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event // two positions, only the most recent state is returned. func (s *outputRoomEventsStatements) SelectStateInRange( ctx context.Context, txn *sql.Tx, r types.Range, - stateFilter *gomatrixserverlib.StateFilter, + stateFilter *gomatrixserverlib.StateFilter, roomIDs []string, ) (map[string]map[string]bool, map[string]types.StreamEvent, error) { + stmtSQL := strings.Replace(selectStateInRangeSQL, "($3)", sqlutil.QueryVariadicOffset(len(roomIDs), 2), 1) + inputParams := []interface{}{ + r.Low(), r.High(), + } + for _, roomID := range roomIDs { + inputParams = append(inputParams, roomID) + } stmt, params, err := prepareWithFilters( - s.db, txn, selectStateInRangeSQL, - []interface{}{ - r.Low(), r.High(), - }, + s.db, txn, stmtSQL, inputParams, stateFilter.Senders, stateFilter.NotSenders, stateFilter.Types, stateFilter.NotTypes, nil, stateFilter.Limit, FilterOrderAsc, diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 1ebb42651..9d1078f5f 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -51,7 +51,7 @@ type Peeks interface { } type Events interface { - SelectStateInRange(ctx context.Context, txn *sql.Tx, r types.Range, stateFilter *gomatrixserverlib.StateFilter) (map[string]map[string]bool, map[string]types.StreamEvent, error) + SelectStateInRange(ctx context.Context, txn *sql.Tx, r types.Range, stateFilter *gomatrixserverlib.StateFilter, roomIDs []string) (map[string]map[string]bool, map[string]types.StreamEvent, error) SelectMaxEventID(ctx context.Context, txn *sql.Tx) (id int64, err error) InsertEvent(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, addState, removeState []string, transactionID *api.TransactionID, excludeFromSync bool) (streamPos types.StreamPosition, err error) // SelectRecentEvents returns events between the two stream positions: exclusive of low and inclusive of high. @@ -99,6 +99,8 @@ type CurrentRoomState interface { SelectCurrentState(ctx context.Context, txn *sql.Tx, roomID string, stateFilter *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. SelectRoomIDsWithMembership(ctx context.Context, txn *sql.Tx, userID string, membership string) ([]string, error) + // SelectRoomIDsWithAnyMembership returns a map of all memberships for the given user. + SelectRoomIDsWithAnyMembership(ctx context.Context, txn *sql.Tx, userID string) (map[string]string, error) // SelectJoinedUsers returns a map of room ID to a list of joined user IDs. SelectJoinedUsers(ctx context.Context) (map[string][]string, error) } diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 1486ad3c5..1afcbe750 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -271,22 +271,6 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( return } - // Get the event IDs of the stream events we fetched. There's no point in us - var excludingEventIDs []string - if !wantFullState { - excludingEventIDs = make([]string, 0, len(recentStreamEvents)) - for _, event := range recentStreamEvents { - if event.StateKey() != nil { - excludingEventIDs = append(excludingEventIDs, event.EventID()) - } - } - } - - stateEvents, err := p.DB.CurrentState(ctx, roomID, stateFilter, excludingEventIDs) - if err != nil { - return - } - // TODO FIXME: We don't fully implement history visibility yet. To avoid leaking events which the // user shouldn't see, we check the recent events and remove any prior to the join event of the user // which is equiv to history_visibility: joined @@ -314,6 +298,25 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( limited = false // so clients know not to try to backpaginate } + // Work our way through the timeline events and pick out the event IDs + // of any state events that appear in the timeline. We'll specifically + // exclude them at the next step, so that we don't get duplicate state + // events in both `recentStreamEvents` and `stateEvents`. + var excludingEventIDs []string + if !wantFullState { + excludingEventIDs = make([]string, 0, len(recentStreamEvents)) + for _, event := range recentStreamEvents { + if event.StateKey() != nil { + excludingEventIDs = append(excludingEventIDs, event.EventID()) + } + } + } + + stateEvents, err := p.DB.CurrentState(ctx, roomID, stateFilter, excludingEventIDs) + if err != nil { + return + } + // Retrieve the backward topology position, i.e. the position of the // oldest event in the room's topology. var prevBatch *types.TopologyToken diff --git a/sytest-blacklist b/sytest-blacklist index cee2406e5..6a3b88390 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -24,7 +24,6 @@ Local device key changes get to remote servers with correct prev_id # Flakey Local device key changes appear in /keys/changes -/context/ with lazy_load_members filter works # we don't support groups Remove group category @@ -32,9 +31,10 @@ Remove group role # Flakey AS-ghosted users can use rooms themselves -/context/ with lazy_load_members filter works AS-ghosted users can use rooms via AS Events in rooms with AS-hosted room aliases are sent to AS server +Inviting an AS-hosted user asks the AS server +Accesing an AS-hosted room alias asks the AS server # Flakey, need additional investigation Messages that notify from another user increment notification_count diff --git a/sytest-whitelist b/sytest-whitelist index 6c4745b32..c3a4ad92f 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -515,7 +515,6 @@ AS can create a user with inhibit_login AS can set avatar for ghosted users AS can set displayname for ghosted users Ghost user must register before joining room -Inviting an AS-hosted user asks the AS server Can generate a openid access_token that can be exchanged for information about a user Invalid openid access tokens are rejected Requests to userinfo without access tokens are rejected @@ -661,6 +660,5 @@ Multiple calls to /sync should not cause 500 errors Canonical alias can be set Canonical alias can include alt_aliases Can delete canonical alias -Multiple calls to /sync should not cause 500 errors AS can make room aliases -Accesing an AS-hosted room alias asks the AS server +/context/ with lazy_load_members filter works