From 5713c5715c72953272b7b99fe64feb29bf1fbe6f Mon Sep 17 00:00:00 2001 From: Antonio Cheong Date: Mon, 12 Jun 2023 16:51:26 +0800 Subject: [PATCH 1/3] Update sample link (#3107) Leftover work by f956a8c1d9172f6bbfb9f7515feacd477a0e35f5 Signed-off-by: `Antonio Cheong ` [skip ci] --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 0b9788768..34604eff9 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,7 @@ For a usable federating Dendrite deployment, you will also need: Also recommended are: - A PostgreSQL database engine, which will perform better than SQLite with many users and/or larger rooms -- A reverse proxy server, such as nginx, configured [like this sample](https://github.com/matrix-org/dendrite/blob/master/docs/nginx/monolith-sample.conf) +- A reverse proxy server, such as nginx, configured [like this sample](https://github.com/matrix-org/dendrite/blob/main/docs/nginx/dendrite-sample.conf) The [Federation Tester](https://federationtester.matrix.org) can be used to verify your deployment. From 832ccc32f6a023665e250eee44b5f678e985d50e Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Mon, 12 Jun 2023 12:45:42 +0200 Subject: [PATCH 2/3] Add initial support for storing user room keys (#3098) --- roomserver/storage/interface.go | 16 ++ roomserver/storage/postgres/storage.go | 9 ++ .../storage/postgres/user_room_keys_table.go | 132 ++++++++++++++++ roomserver/storage/shared/storage.go | 146 ++++++++++++++++++ roomserver/storage/shared/storage_test.go | 116 +++++++++++++- roomserver/storage/sqlite3/storage.go | 8 + .../storage/sqlite3/user_room_keys_table.go | 146 ++++++++++++++++++ roomserver/storage/tables/interface.go | 14 ++ .../tables/user_room_keys_table_test.go | 115 ++++++++++++++ roomserver/types/types.go | 5 + 10 files changed, 700 insertions(+), 7 deletions(-) create mode 100644 roomserver/storage/postgres/user_room_keys_table.go create mode 100644 roomserver/storage/sqlite3/user_room_keys_table.go create mode 100644 roomserver/storage/tables/user_room_keys_table_test.go diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 523cc361a..2d27d7999 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -16,6 +16,7 @@ package storage import ( "context" + "crypto/ed25519" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" @@ -27,6 +28,7 @@ import ( ) type Database interface { + UserRoomKeys // Do we support processing input events for more than one room at a time? SupportsConcurrentRoomInputs() bool // RoomInfo returns room information for the given room ID, or nil if there is no room. @@ -194,8 +196,22 @@ type Database interface { ) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error) } +type UserRoomKeys interface { + // InsertUserRoomPrivatePublicKey inserts the given private key as well as the public key for it. This should be used + // when creating keys locally. + InsertUserRoomPrivatePublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) + // InsertUserRoomPublicKey inserts the given public key, this should be used for users NOT local to this server + InsertUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) + // SelectUserRoomPrivateKey selects the private key for the given user and room combination + SelectUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PrivateKey, err error) + // SelectUserIDsForPublicKeys selects all userIDs for the requested senderKeys. Returns a map from roomID -> map from publicKey to userID. + // If a senderKey can't be found, it is omitted in the result. + SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (map[spec.RoomID]map[string]string, error) +} + type RoomDatabase interface { EventDatabase + UserRoomKeys // RoomInfo returns room information for the given room ID, or nil if there is no room. RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error) diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 19cde5410..453ff45da 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -131,6 +131,9 @@ func (d *Database) create(db *sql.DB) error { if err := CreateRedactionsTable(db); err != nil { return err } + if err := CreateUserRoomKeysTable(db); err != nil { + return err + } return nil } @@ -192,6 +195,11 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } + userRoomKeys, err := PrepareUserRoomKeysTable(db) + if err != nil { + return err + } + d.Database = shared.Database{ DB: db, EventDatabase: shared.EventDatabase{ @@ -215,6 +223,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room MembershipTable: membership, PublishedTable: published, Purge: purge, + UserRoomKeyTable: userRoomKeys, } return nil } diff --git a/roomserver/storage/postgres/user_room_keys_table.go b/roomserver/storage/postgres/user_room_keys_table.go new file mode 100644 index 000000000..22f978bf0 --- /dev/null +++ b/roomserver/storage/postgres/user_room_keys_table.go @@ -0,0 +1,132 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "crypto/ed25519" + "database/sql" + "errors" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const userRoomKeysSchema = ` +CREATE TABLE IF NOT EXISTS roomserver_user_room_keys ( + user_nid INTEGER NOT NULL, + room_nid INTEGER NOT NULL, + pseudo_id_key BYTEA NULL, -- may be null for users not local to the server + pseudo_id_pub_key BYTEA NOT NULL, + CONSTRAINT roomserver_user_room_keys_pk PRIMARY KEY (user_nid, room_nid) +); +` + +const insertUserRoomPrivateKeySQL = ` + INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_key, pseudo_id_pub_key) VALUES ($1, $2, $3, $4) + ON CONFLICT ON CONSTRAINT roomserver_user_room_keys_pk DO UPDATE SET pseudo_id_key = roomserver_user_room_keys.pseudo_id_key + RETURNING (pseudo_id_key) +` + +const insertUserRoomPublicKeySQL = ` + INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_pub_key) VALUES ($1, $2, $3) + ON CONFLICT ON CONSTRAINT roomserver_user_room_keys_pk DO UPDATE SET pseudo_id_pub_key = $3 + RETURNING (pseudo_id_pub_key) +` + +const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` + +const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid = ANY($1) AND pseudo_id_pub_key = ANY($2)` + +type userRoomKeysStatements struct { + insertUserRoomPrivateKeyStmt *sql.Stmt + insertUserRoomPublicKeyStmt *sql.Stmt + selectUserRoomKeyStmt *sql.Stmt + selectUserNIDsStmt *sql.Stmt +} + +func CreateUserRoomKeysTable(db *sql.DB) error { + _, err := db.Exec(userRoomKeysSchema) + return err +} + +func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) { + s := &userRoomKeysStatements{} + return s, sqlutil.StatementList{ + {&s.insertUserRoomPrivateKeyStmt, insertUserRoomPrivateKeySQL}, + {&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL}, + {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, + {&s.selectUserNIDsStmt, selectUserNIDsSQL}, + }.Prepare(db) +} + +func (s *userRoomKeysStatements) InsertUserRoomPrivatePublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomPrivateKeyStmt) + err = stmt.QueryRowContext(ctx, userNID, roomNID, key, key.Public()).Scan(&result) + return result, err +} + +func (s *userRoomKeysStatements) InsertUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomPublicKeyStmt) + err = stmt.QueryRowContext(ctx, userNID, roomNID, key).Scan(&result) + return result, err +} + +func (s *userRoomKeysStatements) SelectUserRoomPrivateKey( + ctx context.Context, + txn *sql.Tx, + userNID types.EventStateKeyNID, + roomNID types.RoomNID, +) (ed25519.PrivateKey, error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomKeyStmt) + var result ed25519.PrivateKey + err := stmt.QueryRowContext(ctx, userNID, roomNID).Scan(&result) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return result, err +} + +func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserNIDsStmt) + + roomNIDs := make([]types.RoomNID, 0, len(senderKeys)) + var senders [][]byte + for roomNID := range senderKeys { + roomNIDs = append(roomNIDs, roomNID) + for _, key := range senderKeys[roomNID] { + senders = append(senders, key) + } + } + rows, err := stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(senders)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows") + + result := make(map[string]types.UserRoomKeyPair, len(senders)+len(roomNIDs)) + var publicKey []byte + userRoomKeyPair := types.UserRoomKeyPair{} + for rows.Next() { + if err = rows.Scan(&userRoomKeyPair.EventStateKeyNID, &userRoomKeyPair.RoomNID, &publicKey); err != nil { + return nil, err + } + result[string(publicKey)] = userRoomKeyPair + } + return result, rows.Err() +} diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index f2f842357..cb12b3f57 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -2,14 +2,18 @@ package shared import ( "context" + "crypto/ed25519" "database/sql" "encoding/json" + "errors" "fmt" "sort" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" + "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/internal/caching" @@ -41,6 +45,7 @@ type Database struct { MembershipTable tables.Membership PublishedTable tables.Published Purge tables.Purge + UserRoomKeyTable tables.UserRoomKeys GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error) } @@ -1609,6 +1614,147 @@ func (d *Database) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventS }) } +// InsertUserRoomPrivatePublicKey inserts a new user room key for the given user and room. +// Returns the newly inserted private key or an existing private key. If there is +// an error talking to the database, returns that error. +func (d *Database) InsertUserRoomPrivatePublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) { + uID := userID.String() + stateKeyNIDMap, sErr := d.eventStateKeyNIDs(ctx, nil, []string{uID}) + if sErr != nil { + return nil, sErr + } + stateKeyNID := stateKeyNIDMap[uID] + + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String()) + if rErr != nil { + return rErr + } + if roomInfo == nil { + return eventutil.ErrRoomNoExists{} + } + + var iErr error + result, iErr = d.UserRoomKeyTable.InsertUserRoomPrivatePublicKey(ctx, txn, stateKeyNID, roomInfo.RoomNID, key) + return iErr + }) + return result, err +} + +// InsertUserRoomPublicKey inserts a new user room key for the given user and room. +// Returns the newly inserted public key or an existing public key. If there is +// an error talking to the database, returns that error. +func (d *Database) InsertUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) { + uID := userID.String() + stateKeyNIDMap, sErr := d.eventStateKeyNIDs(ctx, nil, []string{uID}) + if sErr != nil { + return nil, sErr + } + stateKeyNID := stateKeyNIDMap[uID] + + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String()) + if rErr != nil { + return rErr + } + if roomInfo == nil { + return eventutil.ErrRoomNoExists{} + } + + var iErr error + result, iErr = d.UserRoomKeyTable.InsertUserRoomPublicKey(ctx, txn, stateKeyNID, roomInfo.RoomNID, key) + return iErr + }) + return result, err +} + +// SelectUserRoomPrivateKey queries the users room private key. +// If no key exists, returns no key and no error. Otherwise returns +// the key and a database error, if any. +func (d *Database) SelectUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PrivateKey, err error) { + uID := userID.String() + stateKeyNIDMap, sErr := d.eventStateKeyNIDs(ctx, nil, []string{uID}) + if sErr != nil { + return nil, sErr + } + stateKeyNID := stateKeyNIDMap[uID] + + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String()) + if rErr != nil { + return rErr + } + if roomInfo == nil { + return nil + } + + key, sErr = d.UserRoomKeyTable.SelectUserRoomPrivateKey(ctx, txn, stateKeyNID, roomInfo.RoomNID) + if !errors.Is(sErr, sql.ErrNoRows) { + return sErr + } + return nil + }) + return +} + +// SelectUserIDsForPublicKeys returns a map from roomID -> map from senderKey -> userID +func (d *Database) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (result map[spec.RoomID]map[string]string, err error) { + result = make(map[spec.RoomID]map[string]string, len(publicKeys)) + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + + // map all roomIDs to roomNIDs + query := make(map[types.RoomNID][]ed25519.PublicKey) + rooms := make(map[types.RoomNID]spec.RoomID) + for roomID, keys := range publicKeys { + roomNID, ok := d.Cache.GetRoomServerRoomNID(roomID.String()) + if !ok { + roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String()) + if rErr != nil { + return rErr + } + if roomInfo == nil { + logrus.Warnf("missing room info for %s, there will be missing users in the response", roomID.String()) + continue + } + roomNID = roomInfo.RoomNID + } + + query[roomNID] = keys + rooms[roomNID] = roomID + } + + // get the user room key pars + userRoomKeyPairMap, sErr := d.UserRoomKeyTable.BulkSelectUserNIDs(ctx, txn, query) + if sErr != nil { + return sErr + } + nids := make([]types.EventStateKeyNID, 0, len(userRoomKeyPairMap)) + for _, nid := range userRoomKeyPairMap { + nids = append(nids, nid.EventStateKeyNID) + } + // get the userIDs + nidMap, seErr := d.EventStateKeys(ctx, nids) + if seErr != nil { + return seErr + } + + // build the result map (roomID -> map publicKey -> userID) + for publicKey, userRoomKeyPair := range userRoomKeyPairMap { + userID := nidMap[userRoomKeyPair.EventStateKeyNID] + roomID := rooms[userRoomKeyPair.RoomNID] + resMap, exists := result[roomID] + if !exists { + resMap = map[string]string{} + } + resMap[publicKey] = userID + result[roomID] = resMap + } + + return nil + }) + return result, err +} + // FIXME TODO: Remove all this - horrible dupe with roomserver/state. Can't use the original impl because of circular loops // it should live in this package! diff --git a/roomserver/storage/shared/storage_test.go b/roomserver/storage/shared/storage_test.go index 941e84802..4fa451bcc 100644 --- a/roomserver/storage/shared/storage_test.go +++ b/roomserver/storage/shared/storage_test.go @@ -2,11 +2,15 @@ package shared_test import ( "context" + "crypto/ed25519" "testing" "time" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" + ed255192 "golang.org/x/crypto/ed25519" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/postgres" @@ -23,41 +27,62 @@ func mustCreateRoomserverDatabase(t *testing.T, dbType test.DBType) (*shared.Dat connStr, clearDB := test.PrepareDBConnectionString(t, dbType) dbOpts := &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)} - db, err := sqlutil.Open(dbOpts, sqlutil.NewExclusiveWriter()) + writer := sqlutil.NewExclusiveWriter() + db, err := sqlutil.Open(dbOpts, writer) assert.NoError(t, err) var membershipTable tables.Membership var stateKeyTable tables.EventStateKeys + var userRoomKeys tables.UserRoomKeys + var roomsTable tables.Rooms switch dbType { case test.DBTypePostgres: + err = postgres.CreateRoomsTable(db) + assert.NoError(t, err) err = postgres.CreateEventStateKeysTable(db) assert.NoError(t, err) err = postgres.CreateMembershipTable(db) assert.NoError(t, err) + err = postgres.CreateUserRoomKeysTable(db) + assert.NoError(t, err) + roomsTable, err = postgres.PrepareRoomsTable(db) + assert.NoError(t, err) membershipTable, err = postgres.PrepareMembershipTable(db) assert.NoError(t, err) stateKeyTable, err = postgres.PrepareEventStateKeysTable(db) + assert.NoError(t, err) + userRoomKeys, err = postgres.PrepareUserRoomKeysTable(db) case test.DBTypeSQLite: + err = sqlite3.CreateRoomsTable(db) + assert.NoError(t, err) err = sqlite3.CreateEventStateKeysTable(db) assert.NoError(t, err) err = sqlite3.CreateMembershipTable(db) assert.NoError(t, err) + err = sqlite3.CreateUserRoomKeysTable(db) + assert.NoError(t, err) + roomsTable, err = sqlite3.PrepareRoomsTable(db) + assert.NoError(t, err) membershipTable, err = sqlite3.PrepareMembershipTable(db) assert.NoError(t, err) stateKeyTable, err = sqlite3.PrepareEventStateKeysTable(db) + assert.NoError(t, err) + userRoomKeys, err = sqlite3.PrepareUserRoomKeysTable(db) } assert.NoError(t, err) cache := caching.NewRistrettoCache(8*1024*1024, time.Hour, false) - evDb := shared.EventDatabase{EventStateKeysTable: stateKeyTable, Cache: cache} + evDb := shared.EventDatabase{EventStateKeysTable: stateKeyTable, Cache: cache, Writer: writer} return &shared.Database{ - DB: db, - EventDatabase: evDb, - MembershipTable: membershipTable, - Writer: sqlutil.NewExclusiveWriter(), - Cache: cache, + DB: db, + EventDatabase: evDb, + MembershipTable: membershipTable, + UserRoomKeyTable: userRoomKeys, + RoomsTable: roomsTable, + Writer: writer, + Cache: cache, }, func() { clearDB() err = db.Close() @@ -97,3 +122,80 @@ func Test_GetLeftUsers(t *testing.T) { assert.ElementsMatch(t, expectedUserIDs, leftUsers) }) } + +func TestUserRoomKeys(t *testing.T) { + ctx := context.Background() + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + + userID, err := spec.NewUserID(alice.ID, true) + assert.NoError(t, err) + roomID, err := spec.NewRoomID(room.ID) + assert.NoError(t, err) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRoomserverDatabase(t, dbType) + defer close() + + // create a room NID so we can query the room + _, err = db.RoomsTable.InsertRoomNID(ctx, nil, roomID.String(), gomatrixserverlib.RoomVersionV10) + assert.NoError(t, err) + doesNotExist, err := spec.NewRoomID("!doesnotexist:localhost") + assert.NoError(t, err) + _, err = db.RoomsTable.InsertRoomNID(ctx, nil, doesNotExist.String(), gomatrixserverlib.RoomVersionV10) + assert.NoError(t, err) + + _, key, err := ed25519.GenerateKey(nil) + assert.NoError(t, err) + + gotKey, err := db.InsertUserRoomPrivatePublicKey(ctx, *userID, *roomID, key) + assert.NoError(t, err) + assert.Equal(t, gotKey, key) + + // again, this shouldn't result in an error, but return the existing key + _, key2, err := ed25519.GenerateKey(nil) + assert.NoError(t, err) + gotKey, err = db.InsertUserRoomPrivatePublicKey(ctx, *userID, *roomID, key2) + assert.NoError(t, err) + assert.Equal(t, gotKey, key) + + gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *roomID) + assert.NoError(t, err) + assert.Equal(t, key, gotKey) + + // Key doesn't exist, we shouldn't get anything back + assert.NoError(t, err) + gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *doesNotExist) + assert.NoError(t, err) + assert.Nil(t, gotKey) + + queryUserIDs := map[spec.RoomID][]ed25519.PublicKey{ + *roomID: {key.Public().(ed25519.PublicKey)}, + } + + userIDs, err := db.SelectUserIDsForPublicKeys(ctx, queryUserIDs) + assert.NoError(t, err) + wantKeys := map[spec.RoomID]map[string]string{ + *roomID: { + string(key.Public().(ed25519.PublicKey)): userID.String(), + }, + } + assert.Equal(t, wantKeys, userIDs) + + // insert key that came in over federation + var gotPublicKey, key4 ed255192.PublicKey + key4, _, err = ed25519.GenerateKey(nil) + assert.NoError(t, err) + gotPublicKey, err = db.InsertUserRoomPublicKey(context.Background(), *userID, *doesNotExist, key4) + assert.NoError(t, err) + assert.Equal(t, key4, gotPublicKey) + + // test invalid room + reallyDoesNotExist, err := spec.NewRoomID("!reallydoesnotexist:localhost") + assert.NoError(t, err) + _, err = db.InsertUserRoomPublicKey(context.Background(), *userID, *reallyDoesNotExist, key4) + assert.Error(t, err) + _, err = db.InsertUserRoomPrivatePublicKey(context.Background(), *userID, *reallyDoesNotExist, key) + assert.Error(t, err) + }) +} diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 6ab427a84..ef51a5b08 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -138,6 +138,9 @@ func (d *Database) create(db *sql.DB) error { if err := CreateRedactionsTable(db); err != nil { return err } + if err := CreateUserRoomKeysTable(db); err != nil { + return err + } return nil } @@ -199,6 +202,10 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } + userRoomKeys, err := PrepareUserRoomKeysTable(db) + if err != nil { + return err + } d.Database = shared.Database{ DB: db, @@ -224,6 +231,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room PublishedTable: published, GetRoomUpdaterFn: d.GetRoomUpdater, Purge: purge, + UserRoomKeyTable: userRoomKeys, } return nil } diff --git a/roomserver/storage/sqlite3/user_room_keys_table.go b/roomserver/storage/sqlite3/user_room_keys_table.go new file mode 100644 index 000000000..8af57ea0e --- /dev/null +++ b/roomserver/storage/sqlite3/user_room_keys_table.go @@ -0,0 +1,146 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "crypto/ed25519" + "database/sql" + "errors" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const userRoomKeysSchema = ` +CREATE TABLE IF NOT EXISTS roomserver_user_room_keys ( + user_nid INTEGER NOT NULL, + room_nid INTEGER NOT NULL, + pseudo_id_key TEXT NULL, -- may be null for users not local to the server + pseudo_id_pub_key TEXT NOT NULL, + CONSTRAINT roomserver_user_room_keys_pk PRIMARY KEY (user_nid, room_nid) +); +` + +const insertUserRoomKeySQL = ` + INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_key, pseudo_id_pub_key) VALUES ($1, $2, $3, $4) + ON CONFLICT DO UPDATE SET pseudo_id_key = roomserver_user_room_keys.pseudo_id_key + RETURNING (pseudo_id_key) +` + +const insertUserRoomPublicKeySQL = ` + INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_pub_key) VALUES ($1, $2, $3) + ON CONFLICT DO UPDATE SET pseudo_id_pub_key = $3 + RETURNING (pseudo_id_pub_key) +` + +const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` + +const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)` + +type userRoomKeysStatements struct { + insertUserRoomPrivateKeyStmt *sql.Stmt + insertUserRoomPublicKeyStmt *sql.Stmt + selectUserRoomKeyStmt *sql.Stmt + //selectUserNIDsStmt *sql.Stmt //prepared at runtime +} + +func CreateUserRoomKeysTable(db *sql.DB) error { + _, err := db.Exec(userRoomKeysSchema) + return err +} + +func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) { + s := &userRoomKeysStatements{} + return s, sqlutil.StatementList{ + {&s.insertUserRoomPrivateKeyStmt, insertUserRoomKeySQL}, + {&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL}, + {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, + //{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime + }.Prepare(db) +} + +func (s *userRoomKeysStatements) InsertUserRoomPrivatePublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomPrivateKeyStmt) + err = stmt.QueryRowContext(ctx, userNID, roomNID, key, key.Public()).Scan(&result) + return result, err +} + +func (s *userRoomKeysStatements) InsertUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomPublicKeyStmt) + err = stmt.QueryRowContext(ctx, userNID, roomNID, key).Scan(&result) + return result, err +} + +func (s *userRoomKeysStatements) SelectUserRoomPrivateKey( + ctx context.Context, + txn *sql.Tx, + userNID types.EventStateKeyNID, + roomNID types.RoomNID, +) (ed25519.PrivateKey, error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomKeyStmt) + var result ed25519.PrivateKey + err := stmt.QueryRowContext(ctx, userNID, roomNID).Scan(&result) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return result, err +} + +func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) { + + roomNIDs := make([]any, 0, len(senderKeys)) + var senders []any + for roomNID := range senderKeys { + roomNIDs = append(roomNIDs, roomNID) + + for _, key := range senderKeys[roomNID] { + senders = append(senders, []byte(key)) + } + } + + selectSQL := strings.Replace(selectUserNIDsSQL, "($2)", sqlutil.QueryVariadicOffset(len(senders), len(senderKeys)), 1) + selectSQL = strings.Replace(selectSQL, "($1)", sqlutil.QueryVariadic(len(senderKeys)), 1) // replace $1 with the roomNIDs + + selectStmt, err := txn.Prepare(selectSQL) + if err != nil { + return nil, err + } + + params := append(roomNIDs, senders...) + + stmt := sqlutil.TxStmt(txn, selectStmt) + defer internal.CloseAndLogIfError(ctx, stmt, "failed to close statement") + + rows, err := stmt.QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows") + + result := make(map[string]types.UserRoomKeyPair, len(params)) + var publicKey []byte + userRoomKeyPair := types.UserRoomKeyPair{} + for rows.Next() { + if err = rows.Scan(&userRoomKeyPair.EventStateKeyNID, &userRoomKeyPair.RoomNID, &publicKey); err != nil { + return nil, err + } + result[string(publicKey)] = userRoomKeyPair + } + return result, rows.Err() +} diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 333483b32..cd0e51686 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -2,6 +2,7 @@ package tables import ( "context" + "crypto/ed25519" "database/sql" "errors" @@ -184,6 +185,19 @@ type Purge interface { ) error } +type UserRoomKeys interface { + // InsertUserRoomPrivatePublicKey inserts the given private key as well as the public key for it. This should be used + // when creating keys locally. + InsertUserRoomPrivatePublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (ed25519.PrivateKey, error) + // InsertUserRoomPublicKey inserts the given public key, this should be used for users NOT local to this server + InsertUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (ed25519.PublicKey, error) + // SelectUserRoomPrivateKey selects the private key for the given user and room combination + SelectUserRoomPrivateKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PrivateKey, error) + // BulkSelectUserNIDs selects all userIDs for the requested senderKeys. Returns a map from publicKey -> types.UserRoomKeyPair. + // If a senderKey can't be found, it is omitted in the result. + BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) +} + // StrippedEvent represents a stripped event for returning extracted content values. type StrippedEvent struct { RoomID string diff --git a/roomserver/storage/tables/user_room_keys_table_test.go b/roomserver/storage/tables/user_room_keys_table_test.go new file mode 100644 index 000000000..284309481 --- /dev/null +++ b/roomserver/storage/tables/user_room_keys_table_test.go @@ -0,0 +1,115 @@ +package tables_test + +import ( + "context" + "crypto/ed25519" + "database/sql" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/postgres" + "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/stretchr/testify/assert" + ed255192 "golang.org/x/crypto/ed25519" +) + +func mustCreateUserRoomKeysTable(t *testing.T, dbType test.DBType) (tab tables.UserRoomKeys, db *sql.DB, close func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + switch dbType { + case test.DBTypePostgres: + err = postgres.CreateUserRoomKeysTable(db) + assert.NoError(t, err) + tab, err = postgres.PrepareUserRoomKeysTable(db) + case test.DBTypeSQLite: + err = sqlite3.CreateUserRoomKeysTable(db) + assert.NoError(t, err) + tab, err = sqlite3.PrepareUserRoomKeysTable(db) + } + assert.NoError(t, err) + + return tab, db, close +} + +func TestUserRoomKeysTable(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, db, close := mustCreateUserRoomKeysTable(t, dbType) + defer close() + userNID := types.EventStateKeyNID(1) + roomNID := types.RoomNID(1) + _, key, err := ed25519.GenerateKey(nil) + assert.NoError(t, err) + + err = sqlutil.WithTransaction(db, func(txn *sql.Tx) error { + var gotKey, key2, key3 ed25519.PrivateKey + gotKey, err = tab.InsertUserRoomPrivatePublicKey(context.Background(), txn, userNID, roomNID, key) + assert.NoError(t, err) + assert.Equal(t, gotKey, key) + + // again, this shouldn't result in an error, but return the existing key + _, key2, err = ed25519.GenerateKey(nil) + assert.NoError(t, err) + gotKey, err = tab.InsertUserRoomPrivatePublicKey(context.Background(), txn, userNID, roomNID, key2) + assert.NoError(t, err) + assert.Equal(t, gotKey, key) + + // add another user + _, key3, err = ed25519.GenerateKey(nil) + assert.NoError(t, err) + userNID2 := types.EventStateKeyNID(2) + _, err = tab.InsertUserRoomPrivatePublicKey(context.Background(), txn, userNID2, roomNID, key3) + assert.NoError(t, err) + + gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, roomNID) + assert.NoError(t, err) + assert.Equal(t, key, gotKey) + + // try to update an existing key, this should only be done for users NOT on this homeserver + var gotPubKey ed25519.PublicKey + gotPubKey, err = tab.InsertUserRoomPublicKey(context.Background(), txn, userNID, roomNID, key2.Public().(ed25519.PublicKey)) + assert.NoError(t, err) + assert.Equal(t, key2.Public(), gotPubKey) + + // Key doesn't exist + gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, 2) + assert.NoError(t, err) + assert.Nil(t, gotKey) + + // query user NIDs for senderKeys + var gotKeys map[string]types.UserRoomKeyPair + query := map[types.RoomNID][]ed25519.PublicKey{ + roomNID: {key2.Public().(ed25519.PublicKey), key3.Public().(ed25519.PublicKey)}, + types.RoomNID(2): {key.Public().(ed25519.PublicKey), key3.Public().(ed25519.PublicKey)}, // doesn't exist + } + gotKeys, err = tab.BulkSelectUserNIDs(context.Background(), txn, query) + assert.NoError(t, err) + assert.NotNil(t, gotKeys) + + wantKeys := map[string]types.UserRoomKeyPair{ + string(key2.Public().(ed25519.PublicKey)): {RoomNID: roomNID, EventStateKeyNID: userNID}, + string(key3.Public().(ed25519.PublicKey)): {RoomNID: roomNID, EventStateKeyNID: userNID2}, + } + assert.Equal(t, wantKeys, gotKeys) + + // insert key that came in over federation + var gotPublicKey, key4 ed255192.PublicKey + key4, _, err = ed25519.GenerateKey(nil) + assert.NoError(t, err) + gotPublicKey, err = tab.InsertUserRoomPublicKey(context.Background(), txn, userNID, 2, key4) + assert.NoError(t, err) + assert.Equal(t, key4, gotPublicKey) + + return nil + }) + assert.NoError(t, err) + + }) +} diff --git a/roomserver/types/types.go b/roomserver/types/types.go index f57978ad5..45a3e25fc 100644 --- a/roomserver/types/types.go +++ b/roomserver/types/types.go @@ -44,6 +44,11 @@ type EventMetadata struct { RoomNID RoomNID } +type UserRoomKeyPair struct { + RoomNID RoomNID + EventStateKeyNID EventStateKeyNID +} + // StateSnapshotNID is a numeric ID for the state at an event. type StateSnapshotNID int64 From 77d9e4e93dd01f6baa82bd6236850c1007346cac Mon Sep 17 00:00:00 2001 From: devonh Date: Mon, 12 Jun 2023 11:19:25 +0000 Subject: [PATCH 3/3] Cleanup remaining statekey usage for senderIDs (#3106) --- clientapi/routing/account_data.go | 10 +- clientapi/routing/aliases.go | 9 +- clientapi/routing/createroom.go | 1 + clientapi/routing/directory.go | 33 ++-- clientapi/routing/leaveroom.go | 10 +- clientapi/routing/membership.go | 147 ++++++++++++------ clientapi/routing/redaction.go | 34 ++-- clientapi/routing/sendtyping.go | 10 +- clientapi/routing/server_notices.go | 13 +- clientapi/routing/state.go | 53 +++++-- clientapi/routing/upgrade_room.go | 10 +- federationapi/routing/eventauth.go | 2 +- federationapi/routing/events.go | 12 +- federationapi/routing/state.go | 2 +- go.mod | 2 +- go.sum | 4 +- roomserver/api/api.go | 21 +-- roomserver/api/output.go | 6 +- roomserver/api/perform.go | 4 +- roomserver/api/query.go | 20 +-- roomserver/auth/auth.go | 14 +- roomserver/auth/auth_test.go | 12 +- roomserver/internal/helpers/helpers.go | 37 +++-- roomserver/internal/helpers/helpers_test.go | 5 +- roomserver/internal/input/input_events.go | 12 +- roomserver/internal/input/input_membership.go | 21 ++- roomserver/internal/perform/perform_admin.go | 6 +- .../internal/perform/perform_backfill.go | 2 +- .../internal/perform/perform_create_room.go | 15 +- roomserver/internal/perform/perform_invite.go | 8 +- roomserver/internal/perform/perform_join.go | 35 ++--- roomserver/internal/perform/perform_leave.go | 77 ++++----- .../internal/perform/perform_upgrade.go | 116 +++++--------- roomserver/internal/query/query.go | 70 +++++---- roomserver/roomserver_test.go | 19 +-- roomserver/storage/interface.go | 2 +- roomserver/storage/shared/storage.go | 7 +- setup/mscs/msc2836/msc2836.go | 11 +- setup/mscs/msc2836/msc2836_test.go | 6 +- syncapi/consumers/roomserver.go | 29 +++- syncapi/internal/history_visibility.go | 14 +- syncapi/internal/keychange.go | 16 +- syncapi/internal/keychange_test.go | 4 + syncapi/notifier/notifier.go | 45 +++--- syncapi/notifier/notifier_test.go | 22 ++- syncapi/routing/context.go | 18 ++- syncapi/routing/getevent.go | 11 +- syncapi/routing/memberships.go | 13 +- syncapi/routing/messages.go | 6 +- syncapi/routing/relations.go | 11 +- syncapi/routing/search.go | 11 +- syncapi/storage/shared/storage_consumer.go | 16 +- syncapi/storage/shared/storage_sync.go | 4 +- syncapi/streams/stream_invite.go | 11 +- syncapi/streams/stream_pdu.go | 12 +- syncapi/syncapi.go | 2 +- syncapi/synctypes/clientevent.go | 35 ++++- syncapi/synctypes/clientevent_test.go | 6 +- syncapi/types/types.go | 4 +- syncapi/types/types_test.go | 8 +- userapi/consumers/roomserver.go | 36 ++++- userapi/util/notify_test.go | 3 +- 62 files changed, 760 insertions(+), 455 deletions(-) diff --git a/clientapi/routing/account_data.go b/clientapi/routing/account_data.go index 7eacf9cc9..81afc3b13 100644 --- a/clientapi/routing/account_data.go +++ b/clientapi/routing/account_data.go @@ -145,8 +145,16 @@ func SaveReadMarker( userAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI, syncProducer *producers.SyncAPIProducer, device *api.Device, roomID string, ) util.JSONResponse { + deviceUserID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("userID for this device is invalid"), + } + } + // Verify that the user is a member of this room - resErr := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) + resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if resErr != nil { return *resErr } diff --git a/clientapi/routing/aliases.go b/clientapi/routing/aliases.go index f6603be8b..2d6b72d3e 100644 --- a/clientapi/routing/aliases.go +++ b/clientapi/routing/aliases.go @@ -55,9 +55,16 @@ func GetAliases( visibility = content.HistoryVisibility } if visibility != spec.WorldReadable { + deviceUserID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("userID doesn't have power level to change visibility"), + } + } queryReq := api.QueryMembershipForUserRequest{ RoomID: roomID, - UserID: device.UserID, + UserID: *deviceUserID, } var queryRes api.QueryMembershipForUserResponse if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil { diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index 799fc7976..320f236cb 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -224,6 +224,7 @@ func createRoom( PrivateKey: privateKey, EventTime: evTime, } + roomAlias, createRes := rsAPI.PerformCreateRoom(ctx, *userID, *roomID, &req) if createRes != nil { return *createRes diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index 034296f45..f01e24eca 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -314,7 +314,22 @@ func SetVisibility( req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI, dev *userapi.Device, roomID string, ) util.JSONResponse { - resErr := checkMemberInRoom(req.Context(), rsAPI, dev.UserID, roomID) + deviceUserID, err := spec.NewUserID(dev.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("userID for this device is invalid"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown("failed to find senderID for this user"), + } + } + + resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if resErr != nil { return *resErr } @@ -327,7 +342,7 @@ func SetVisibility( }}, } var queryEventsRes roomserverAPI.QueryLatestEventsAndStateResponse - err := rsAPI.QueryLatestEventsAndState(req.Context(), &queryEventsReq, &queryEventsRes) + err = rsAPI.QueryLatestEventsAndState(req.Context(), &queryEventsReq, &queryEventsRes) if err != nil || len(queryEventsRes.StateEvents) == 0 { util.GetLogger(req.Context()).WithError(err).Error("could not query events from room") return util.JSONResponse{ @@ -338,20 +353,6 @@ func SetVisibility( // NOTSPEC: Check if the user's power is greater than power required to change m.room.canonical_alias event power, _ := gomatrixserverlib.NewPowerLevelContentFromEvent(queryEventsRes.StateEvents[0].PDU) - fullUserID, err := spec.NewUserID(dev.UserID, true) - if err != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("userID doesn't have power level to change visibility"), - } - } - senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) - if err != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("userID doesn't have power level to change visibility"), - } - } if power.UserLevel(senderID) < power.EventLevel(spec.MRoomCanonicalAlias, true) { return util.JSONResponse{ Code: http.StatusForbidden, diff --git a/clientapi/routing/leaveroom.go b/clientapi/routing/leaveroom.go index fbf148264..7e8c066eb 100644 --- a/clientapi/routing/leaveroom.go +++ b/clientapi/routing/leaveroom.go @@ -29,10 +29,18 @@ func LeaveRoomByID( rsAPI roomserverAPI.ClientRoomserverAPI, roomID string, ) util.JSONResponse { + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown("device userID is invalid"), + } + } + // Prepare to ask the roomserver to perform the room join. leaveReq := roomserverAPI.PerformLeaveRequest{ RoomID: roomID, - UserID: device.UserID, + Leaver: *userID, } leaveRes := roomserverAPI.PerformLeaveResponse{} diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 78829bec9..03e85edbf 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -57,7 +57,22 @@ func SendBan( } } - errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) + deviceUserID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to ban this user, bad userID"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to ban this user, unknown senderID"), + } + } + + errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if errRes != nil { return *errRes } @@ -66,20 +81,6 @@ func SendBan( if errRes != nil { return *errRes } - fullUserID, err := spec.NewUserID(device.UserID, true) - if err != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("You don't have permission to ban this user, bad userID"), - } - } - senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) - if err != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("You don't have permission to ban this user, unknown senderID"), - } - } allowedToBan := pl.UserLevel(senderID) >= pl.Ban if !allowedToBan { return util.JSONResponse{ @@ -147,7 +148,22 @@ func SendKick( } } - errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) + deviceUserID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"), + } + } + + errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if errRes != nil { return *errRes } @@ -156,20 +172,6 @@ func SendKick( if errRes != nil { return *errRes } - fullUserID, err := spec.NewUserID(device.UserID, true) - if err != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"), - } - } - senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) - if err != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"), - } - } allowedToKick := pl.UserLevel(senderID) >= pl.Kick if !allowedToKick { return util.JSONResponse{ @@ -178,10 +180,17 @@ func SendKick( } } + bodyUserID, err := spec.NewUserID(body.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("body userID is invalid"), + } + } var queryRes roomserverAPI.QueryMembershipForUserResponse err = rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{ RoomID: roomID, - UserID: body.UserID, + UserID: *bodyUserID, }, &queryRes) if err != nil { return util.ErrorResponse(err) @@ -213,15 +222,30 @@ func SendUnban( } } - errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) + deviceUserID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"), + } + } + + errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if errRes != nil { return *errRes } + bodyUserID, err := spec.NewUserID(body.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("body userID is invalid"), + } + } var queryRes roomserverAPI.QueryMembershipForUserResponse - err := rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{ + err = rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{ RoomID: roomID, - UserID: body.UserID, + UserID: *bodyUserID, }, &queryRes) if err != nil { return util.ErrorResponse(err) @@ -272,7 +296,15 @@ func SendInvite( } } - errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) + deviceUserID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"), + } + } + + errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if errRes != nil { return *errRes } @@ -340,17 +372,18 @@ func sendInvite( func buildMembershipEventDirect( ctx context.Context, - targetUserID, reason string, userDisplayName, userAvatarURL string, - sender string, senderDomain spec.ServerName, + targetSenderID spec.SenderID, reason string, userDisplayName, userAvatarURL string, + sender spec.SenderID, senderDomain spec.ServerName, membership, roomID string, isDirect bool, keyID gomatrixserverlib.KeyID, privateKey ed25519.PrivateKey, evTime time.Time, rsAPI roomserverAPI.ClientRoomserverAPI, ) (*types.HeaderedEvent, error) { + targetSenderString := string(targetSenderID) proto := gomatrixserverlib.ProtoEvent{ - SenderID: sender, + SenderID: string(sender), RoomID: roomID, Type: "m.room.member", - StateKey: &targetUserID, + StateKey: &targetSenderString, } content := gomatrixserverlib.MemberContent{ @@ -391,8 +424,25 @@ func buildMembershipEvent( return nil, err } - return buildMembershipEventDirect(ctx, targetUserID, reason, profile.DisplayName, profile.AvatarURL, - device.UserID, device.UserDomain(), membership, roomID, isDirect, identity.KeyID, identity.PrivateKey, evTime, rsAPI) + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return nil, err + } + senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *userID) + if err != nil { + return nil, err + } + + targetID, err := spec.NewUserID(targetUserID, true) + if err != nil { + return nil, err + } + targetSenderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *targetID) + if err != nil { + return nil, err + } + return buildMembershipEventDirect(ctx, targetSenderID, reason, profile.DisplayName, profile.AvatarURL, + senderID, device.UserDomain(), membership, roomID, isDirect, identity.KeyID, identity.PrivateKey, evTime, rsAPI) } // loadProfile lookups the profile of a given user from the database and returns @@ -490,7 +540,7 @@ func checkAndProcessThreepid( return } -func checkMemberInRoom(ctx context.Context, rsAPI roomserverAPI.ClientRoomserverAPI, userID, roomID string) *util.JSONResponse { +func checkMemberInRoom(ctx context.Context, rsAPI roomserverAPI.ClientRoomserverAPI, userID spec.UserID, roomID string) *util.JSONResponse { var membershipRes roomserverAPI.QueryMembershipForUserResponse err := rsAPI.QueryMembershipForUser(ctx, &roomserverAPI.QueryMembershipForUserRequest{ RoomID: roomID, @@ -518,12 +568,21 @@ func SendForget( ) util.JSONResponse { ctx := req.Context() logger := util.GetLogger(ctx).WithField("roomID", roomID).WithField("userID", device.UserID) + + deviceUserID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"), + } + } + var membershipRes roomserverAPI.QueryMembershipForUserResponse membershipReq := roomserverAPI.QueryMembershipForUserRequest{ RoomID: roomID, - UserID: device.UserID, + UserID: *deviceUserID, } - err := rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes) + err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes) if err != nil { logger.WithError(err).Error("QueryMembershipForUser: could not query membership for user") return util.JSONResponse{ diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index 22474fc08..da48e84de 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -47,7 +47,22 @@ func SendRedaction( txnID *string, txnCache *transactions.Cache, ) util.JSONResponse { - resErr := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) + deviceUserID, userIDErr := spec.NewUserID(device.UserID, true) + if userIDErr != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("userID doesn't have power level to redact"), + } + } + senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) + if queryErr != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("userID doesn't have power level to redact"), + } + } + + resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if resErr != nil { return *resErr } @@ -73,25 +88,10 @@ func SendRedaction( } } - fullUserID, userIDErr := spec.NewUserID(device.UserID, true) - if userIDErr != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("userID doesn't have power level to redact"), - } - } - senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID) - if queryErr != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("userID doesn't have power level to redact"), - } - } - // "Users may redact their own events, and any user with a power level greater than or equal // to the redact power level of the room may redact events there" // https://matrix.org/docs/spec/client_server/r0.6.1#put-matrix-client-r0-rooms-roomid-redact-eventid-txnid - allowedToRedact := ev.SenderID() == senderID // TODO: Should replace device.UserID with device...PerRoomKey + allowedToRedact := ev.SenderID() == senderID if !allowedToRedact { plEvent := roomserverAPI.GetStateEvent(req.Context(), rsAPI, roomID, gomatrixserverlib.StateKeyTuple{ EventType: spec.MRoomPowerLevels, diff --git a/clientapi/routing/sendtyping.go b/clientapi/routing/sendtyping.go index c5b29297a..979bced3b 100644 --- a/clientapi/routing/sendtyping.go +++ b/clientapi/routing/sendtyping.go @@ -43,8 +43,16 @@ func SendTyping( } } + deviceUserID, err := spec.NewUserID(userID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("userID doesn't have power level to change visibility"), + } + } + // Verify that the user is a member of this room - resErr := checkMemberInRoom(req.Context(), rsAPI, userID, roomID) + resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if resErr != nil { return *resErr } diff --git a/clientapi/routing/server_notices.go b/clientapi/routing/server_notices.go index 06714ed1f..7006ced46 100644 --- a/clientapi/routing/server_notices.go +++ b/clientapi/routing/server_notices.go @@ -52,6 +52,7 @@ type sendServerNoticeRequest struct { StateKey string `json:"state_key,omitempty"` } +// nolint:gocyclo // SendServerNotice sends a message to a specific user. It can only be invoked by an admin. func SendServerNotice( req *http.Request, @@ -187,9 +188,17 @@ func SendServerNotice( } } else { // we've found a room in common, check the membership + deviceUserID, err := spec.NewUserID(r.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("userID doesn't have power level to change visibility"), + } + } + roomID = commonRooms[0] membershipRes := api.QueryMembershipForUserResponse{} - err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: r.UserID, RoomID: roomID}, &membershipRes) + err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: *deviceUserID, RoomID: roomID}, &membershipRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("unable to query membership for user") return util.JSONResponse{ @@ -234,7 +243,7 @@ func SendServerNotice( ctx, rsAPI, api.KindNew, []*types.HeaderedEvent{ - &types.HeaderedEvent{PDU: e}, + {PDU: e}, }, device.UserDomain(), cfgClient.Matrix.ServerName, diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go index 13f308998..e3a209b6e 100644 --- a/clientapi/routing/state.go +++ b/clientapi/routing/state.go @@ -99,9 +99,17 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a if !worldReadable { // The room isn't world-readable so try to work out based on the // user's membership if we want the latest state or not. - err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{ + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("UserID is invalid") + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown("Device UserID is invalid"), + } + } + err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{ RoomID: roomID, - UserID: device.UserID, + UserID: *userID, }, &membershipRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser") @@ -140,14 +148,11 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a // use the result of the previous QueryLatestEventsAndState response // to find the state event, if provided. for _, ev := range stateRes.StateEvents { - sender := spec.UserID{} - userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID()) - if err == nil && userID != nil { - sender = *userID - } stateEvents = append( stateEvents, - synctypes.ToClientEvent(ev, synctypes.FormatAll, sender), + synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }, ev), ) } } else { @@ -172,9 +177,18 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a if err == nil && userID != nil { sender = *userID } + + sk := ev.StateKey() + if sk != nil && *sk != "" { + skUserID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey())) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } stateEvents = append( stateEvents, - synctypes.ToClientEvent(ev, synctypes.FormatAll, sender), + synctypes.ToClientEvent(ev, synctypes.FormatAll, sender, sk), ) } } @@ -259,11 +273,19 @@ func OnIncomingStateTypeRequest( // membershipRes will only be populated if the room is not world-readable. var membershipRes api.QueryMembershipForUserResponse if !worldReadable { + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("UserID is invalid") + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown("Device UserID is invalid"), + } + } // The room isn't world-readable so try to work out based on the // user's membership if we want the latest state or not. - err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{ + err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{ RoomID: roomID, - UserID: device.UserID, + UserID: *userID, }, &membershipRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser") @@ -344,13 +366,10 @@ func OnIncomingStateTypeRequest( } } - sender := spec.UserID{} - userID, err := rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) - if err == nil && userID != nil { - sender = *userID - } stateEvent := stateEventInStateResp{ - ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll, sender), + ClientEvent: synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }, event), } var res interface{} diff --git a/clientapi/routing/upgrade_room.go b/clientapi/routing/upgrade_room.go index a0b280789..03c0230e6 100644 --- a/clientapi/routing/upgrade_room.go +++ b/clientapi/routing/upgrade_room.go @@ -59,7 +59,15 @@ func UpgradeRoom( } } - newRoomID, err := rsAPI.PerformRoomUpgrade(req.Context(), roomID, device.UserID, gomatrixserverlib.RoomVersion(r.NewVersion)) + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("device UserID is invalid") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + newRoomID, err := rsAPI.PerformRoomUpgrade(req.Context(), roomID, *userID, gomatrixserverlib.RoomVersion(r.NewVersion)) switch e := err.(type) { case nil: case roomserverAPI.ErrNotAllowed: diff --git a/federationapi/routing/eventauth.go b/federationapi/routing/eventauth.go index ca279ac22..c26aa2f15 100644 --- a/federationapi/routing/eventauth.go +++ b/federationapi/routing/eventauth.go @@ -45,7 +45,7 @@ func GetEventAuth( if event.RoomID() != roomID { return util.JSONResponse{Code: http.StatusNotFound, JSON: spec.NotFound("event does not belong to this room")} } - resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID) + resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID()) if resErr != nil { return *resErr } diff --git a/federationapi/routing/events.go b/federationapi/routing/events.go index 196a54db1..d3f0e81c3 100644 --- a/federationapi/routing/events.go +++ b/federationapi/routing/events.go @@ -35,10 +35,6 @@ func GetEvent( eventID string, origin spec.ServerName, ) util.JSONResponse { - err := allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID) - if err != nil { - return *err - } // /_matrix/federation/v1/event/{eventId} doesn't have a roomID, we use an empty string, // which results in `QueryEventsByID` to first get the event and use that to determine the roomID. event, err := fetchEvent(ctx, rsAPI, "", eventID) @@ -46,6 +42,11 @@ func GetEvent( return *err } + err = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID()) + if err != nil { + return *err + } + return util.JSONResponse{Code: http.StatusOK, JSON: gomatrixserverlib.Transaction{ Origin: origin, OriginServerTS: spec.AsTimestamp(time.Now()), @@ -62,8 +63,9 @@ func allowedToSeeEvent( origin spec.ServerName, rsAPI api.FederationRoomserverAPI, eventID string, + roomID string, ) *util.JSONResponse { - allowed, err := rsAPI.QueryServerAllowedToSeeEvent(ctx, origin, eventID) + allowed, err := rsAPI.QueryServerAllowedToSeeEvent(ctx, origin, eventID, roomID) if err != nil { resErr := util.ErrorResponse(err) return &resErr diff --git a/federationapi/routing/state.go b/federationapi/routing/state.go index fa0e9351e..11ad1ebfc 100644 --- a/federationapi/routing/state.go +++ b/federationapi/routing/state.go @@ -116,7 +116,7 @@ func getState( if event.RoomID() != roomID { return nil, nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: spec.NotFound("event does not belong to this room")} } - resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID) + resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID()) if resErr != nil { return nil, nil, resErr } diff --git a/go.mod b/go.mod index 3621428c3..2fbae3148 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d + github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077 github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.16 diff --git a/go.sum b/go.sum index 1ee0261f6..ef8c298ab 100644 --- a/go.sum +++ b/go.sum @@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d h1:MjL8SXRzhO61aXDFL+gA3Bx1SicqLGL9gCWXDv8jkD8= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077 h1:AmKkAUjy9rZA2K+qHXm/O/dPEPnUYfRE2I6SL+Dj+LU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 8c2cbd6b2..bafde91c9 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -34,11 +34,11 @@ func (e ErrNotAllowed) Error() string { type RestrictedJoinAPI interface { CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error) - InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) - RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) + InvitePending(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (bool, error) + RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) QueryRoomInfo(ctx context.Context, roomID spec.RoomID) (*types.RoomInfo, error) QueryServerJoinedToRoom(ctx context.Context, req *QueryServerJoinedToRoomRequest, res *QueryServerJoinedToRoomResponse) error - UserJoinedToRoom(ctx context.Context, roomID types.RoomNID, userID spec.UserID) (bool, error) + UserJoinedToRoom(ctx context.Context, roomID types.RoomNID, senderID spec.SenderID) (bool, error) LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomNID types.RoomNID) ([]gomatrixserverlib.PDU, error) } @@ -191,7 +191,7 @@ type ClientRoomserverAPI interface { PerformCreateRoom(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *PerformCreateRoomRequest) (string, *util.JSONResponse) // PerformRoomUpgrade upgrades a room to a newer version - PerformRoomUpgrade(ctx context.Context, roomID, userID string, roomVersion gomatrixserverlib.RoomVersion) (newRoomID string, err error) + PerformRoomUpgrade(ctx context.Context, roomID string, userID spec.UserID, roomVersion gomatrixserverlib.RoomVersion) (newRoomID string, err error) PerformAdminEvacuateRoom(ctx context.Context, roomID string) (affected []string, err error) PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error) PerformAdminPurgeRoom(ctx context.Context, roomID string) error @@ -228,6 +228,7 @@ type FederationRoomserverAPI interface { // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error + QueryMembershipForSenderID(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, res *QueryMembershipForUserResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error QueryRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error @@ -238,15 +239,13 @@ type FederationRoomserverAPI interface { // Takes lists of PrevEventIDs and AuthEventsIDs and uses them to calculate // the state and auth chain to return. QueryStateAndAuthChain(ctx context.Context, req *QueryStateAndAuthChainRequest, res *QueryStateAndAuthChainResponse) error - // Query if we think we're still in a room. - QueryServerJoinedToRoom(ctx context.Context, req *QueryServerJoinedToRoomRequest, res *QueryServerJoinedToRoomResponse) error QueryPublishedRooms(ctx context.Context, req *QueryPublishedRoomsRequest, res *QueryPublishedRoomsResponse) error // Query missing events for a room from roomserver QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error // Query whether a server is allowed to see an event - QueryServerAllowedToSeeEvent(ctx context.Context, serverName spec.ServerName, eventID string) (allowed bool, err error) + QueryServerAllowedToSeeEvent(ctx context.Context, serverName spec.ServerName, eventID string, roomID string) (allowed bool, err error) QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error - QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error) + QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error HandleInvite(ctx context.Context, event *types.HeaderedEvent) error @@ -254,12 +253,6 @@ type FederationRoomserverAPI interface { // Query a given amount (or less) of events prior to a given set of events. PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error - CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error) - InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) - QueryRoomInfo(ctx context.Context, roomID spec.RoomID) (*types.RoomInfo, error) - UserJoinedToRoom(ctx context.Context, roomID types.RoomNID, userID spec.UserID) (bool, error) - LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomNID types.RoomNID) ([]gomatrixserverlib.PDU, error) - IsKnownRoom(ctx context.Context, roomID spec.RoomID) (bool, error) StateQuerier() gomatrixserverlib.StateQuerier } diff --git a/roomserver/api/output.go b/roomserver/api/output.go index 16b504957..852b64206 100644 --- a/roomserver/api/output.go +++ b/roomserver/api/output.go @@ -215,8 +215,10 @@ type OutputNewInviteEvent struct { type OutputRetireInviteEvent struct { // The ID of the "m.room.member" invite event. EventID string - // The target user ID of the "m.room.member" invite event that was retired. - TargetUserID string + // The room ID of the "m.room.member" invite event. + RoomID string + // The target sender ID of the "m.room.member" invite event that was retired. + TargetSenderID spec.SenderID // Optional event ID of the event that replaced the invite. // This can be empty if the invite was rejected locally and we were unable // to reach the server that originally sent the invite. diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index 6cbaf5b19..b466b7ba8 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -41,8 +41,8 @@ type PerformJoinRequest struct { } type PerformLeaveRequest struct { - RoomID string `json:"room_id"` - UserID string `json:"user_id"` + RoomID string + Leaver spec.UserID } type PerformLeaveResponse struct { diff --git a/roomserver/api/query.go b/roomserver/api/query.go index d79dcebbb..684a5b0e3 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -113,9 +113,9 @@ type QueryEventsByIDResponse struct { // QueryMembershipForUserRequest is a request to QueryMembership type QueryMembershipForUserRequest struct { // ID of the room to fetch membership from - RoomID string `json:"room_id"` + RoomID string // ID of the user for whom membership is requested - UserID string `json:"user_id"` + UserID spec.UserID } // QueryMembershipForUserResponse is a response to QueryMembership @@ -145,7 +145,7 @@ type QueryMembershipsForRoomRequest struct { // Optional - ID of the user sending the request, for checking if the // user is allowed to see the memberships. If not specified then all // room memberships will be returned. - Sender string `json:"sender"` + SenderID spec.SenderID `json:"sender"` } // QueryMembershipsForRoomResponse is a response to QueryMembershipsForRoom @@ -448,11 +448,11 @@ func (rq *JoinRoomQuerier) CurrentStateEvent(ctx context.Context, roomID spec.Ro return rq.Roomserver.CurrentStateEvent(ctx, roomID, eventType, stateKey) } -func (rq *JoinRoomQuerier) InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) { - return rq.Roomserver.InvitePending(ctx, roomID, userID) +func (rq *JoinRoomQuerier) InvitePending(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (bool, error) { + return rq.Roomserver.InvitePending(ctx, roomID, senderID) } -func (rq *JoinRoomQuerier) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) { +func (rq *JoinRoomQuerier) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) { roomInfo, err := rq.Roomserver.QueryRoomInfo(ctx, roomID) if err != nil || roomInfo == nil || roomInfo.IsStub() { return nil, err @@ -468,7 +468,7 @@ func (rq *JoinRoomQuerier) RestrictedRoomJoinInfo(ctx context.Context, roomID sp return nil, fmt.Errorf("InternalServerError: Failed to query room: %w", err) } - userJoinedToRoom, err := rq.Roomserver.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), userID) + userJoinedToRoom, err := rq.Roomserver.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), senderID) if err != nil { util.GetLogger(ctx).WithError(err).Error("rsAPI.UserJoinedToRoom failed") return nil, fmt.Errorf("InternalServerError: %w", err) @@ -492,12 +492,8 @@ type MembershipQuerier struct { } func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) { - req := QueryMembershipForUserRequest{ - RoomID: roomID.String(), - UserID: string(senderID), - } res := QueryMembershipForUserResponse{} - err := mq.Roomserver.QueryMembershipForUser(ctx, &req, &res) + err := mq.Roomserver.QueryMembershipForSenderID(ctx, roomID, senderID, &res) membership := "" if err == nil { diff --git a/roomserver/auth/auth.go b/roomserver/auth/auth.go index b6168d38b..ba10a4332 100644 --- a/roomserver/auth/auth.go +++ b/roomserver/auth/auth.go @@ -13,6 +13,9 @@ package auth import ( + "context" + + "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" ) @@ -22,6 +25,7 @@ import ( // IsServerAllowed returns true if the server is allowed to see events in the room // at this particular state. This function implements https://matrix.org/docs/spec/client_server/r0.6.0#id87 func IsServerAllowed( + ctx context.Context, db storage.RoomDatabase, serverName spec.ServerName, serverCurrentlyInRoom bool, authEvents []gomatrixserverlib.PDU, @@ -37,7 +41,7 @@ func IsServerAllowed( return true } // 2. If the user's membership was join, allow. - joinedUserExists := IsAnyUserOnServerWithMembership(serverName, authEvents, spec.Join) + joinedUserExists := IsAnyUserOnServerWithMembership(ctx, db, serverName, authEvents, spec.Join) if joinedUserExists { return true } @@ -46,7 +50,7 @@ func IsServerAllowed( return true } // 4. If the user's membership was invite, and the history_visibility was set to invited, allow. - invitedUserExists := IsAnyUserOnServerWithMembership(serverName, authEvents, spec.Invite) + invitedUserExists := IsAnyUserOnServerWithMembership(ctx, db, serverName, authEvents, spec.Invite) if invitedUserExists && historyVisibility == gomatrixserverlib.HistoryVisibilityInvited { return true } @@ -70,7 +74,7 @@ func HistoryVisibilityForRoom(authEvents []gomatrixserverlib.PDU) gomatrixserver return visibility } -func IsAnyUserOnServerWithMembership(serverName spec.ServerName, authEvents []gomatrixserverlib.PDU, wantMembership string) bool { +func IsAnyUserOnServerWithMembership(ctx context.Context, db storage.RoomDatabase, serverName spec.ServerName, authEvents []gomatrixserverlib.PDU, wantMembership string) bool { for _, ev := range authEvents { if ev.Type() != spec.MRoomMember { continue @@ -85,12 +89,12 @@ func IsAnyUserOnServerWithMembership(serverName spec.ServerName, authEvents []go continue } - _, domain, err := gomatrixserverlib.SplitID('@', *stateKey) + userID, err := db.GetUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*stateKey)) if err != nil { continue } - if domain == serverName { + if userID.Domain() == serverName { return true } } diff --git a/roomserver/auth/auth_test.go b/roomserver/auth/auth_test.go index e3eea5d8b..192d9e5da 100644 --- a/roomserver/auth/auth_test.go +++ b/roomserver/auth/auth_test.go @@ -1,13 +1,23 @@ package auth import ( + "context" "testing" + "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" ) +type FakeStorageDB struct { + storage.RoomDatabase +} + +func (f *FakeStorageDB) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) +} + func TestIsServerAllowed(t *testing.T) { alice := test.NewUser(t) @@ -77,7 +87,7 @@ func TestIsServerAllowed(t *testing.T) { authEvents = append(authEvents, ev.PDU) } - if got := IsServerAllowed(tt.serverName, tt.serverCurrentlyInRoom, authEvents); got != tt.want { + if got := IsServerAllowed(context.Background(), &FakeStorageDB{}, tt.serverName, tt.serverCurrentlyInRoom, authEvents); got != tt.want { t.Errorf("IsServerAllowed() = %v, want %v", got, tt.want) } }) diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index 95397cd5e..263cb9f85 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "sort" - "strings" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" @@ -55,9 +54,10 @@ func UpdateToInviteMembership( Type: api.OutputTypeRetireInviteEvent, RetireInviteEvent: &api.OutputRetireInviteEvent{ EventID: eventID, + RoomID: add.RoomID(), Membership: spec.Join, RetiredByEventID: add.EventID(), - TargetUserID: *add.StateKey(), + TargetSenderID: spec.SenderID(*add.StateKey()), }, }) } @@ -94,13 +94,13 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam for i := range events { gmslEvents[i] = events[i].PDU } - return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, spec.Join), nil + return auth.IsAnyUserOnServerWithMembership(ctx, db, serverName, gmslEvents, spec.Join), nil } func IsInvitePending( ctx context.Context, db storage.Database, - roomID, userID string, -) (bool, string, string, gomatrixserverlib.PDU, error) { + roomID string, senderID spec.SenderID, +) (bool, spec.SenderID, string, gomatrixserverlib.PDU, error) { // Look up the room NID for the supplied room ID. info, err := db.RoomInfo(ctx, roomID) if err != nil { @@ -111,13 +111,13 @@ func IsInvitePending( } // Look up the state key NID for the supplied user ID. - targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{userID}) + targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{string(senderID)}) if err != nil { return false, "", "", nil, fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err) } - targetUserNID, targetUserFound := targetUserNIDs[userID] + targetUserNID, targetUserFound := targetUserNIDs[string(senderID)] if !targetUserFound { - return false, "", "", nil, fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs) + return false, "", "", nil, fmt.Errorf("missing NID for user %q (%+v)", senderID, targetUserNIDs) } // Let's see if we have an event active for the user in the room. If @@ -156,7 +156,7 @@ func IsInvitePending( event, err := verImpl.NewEventFromTrustedJSON(eventJSON, false) - return true, senderUser, userNIDToEventID[senderUserNIDs[0]], event, err + return true, spec.SenderID(senderUser), userNIDToEventID[senderUserNIDs[0]], event, err } // GetMembershipsAtState filters the state events to @@ -264,7 +264,7 @@ func LoadStateEvents( } func CheckServerAllowedToSeeEvent( - ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName spec.ServerName, isServerInRoom bool, + ctx context.Context, db storage.Database, info *types.RoomInfo, roomID string, eventID string, serverName spec.ServerName, isServerInRoom bool, ) (bool, error) { stateAtEvent, err := db.GetHistoryVisibilityState(ctx, info, eventID, string(serverName)) switch err { @@ -273,7 +273,7 @@ func CheckServerAllowedToSeeEvent( case tables.OptimisationNotSupportedError: // The database engine didn't support this optimisation, so fall back to using // the old and slow method - stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, eventID, serverName) + stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, roomID, eventID, serverName) if err != nil { return false, err } @@ -288,11 +288,11 @@ func CheckServerAllowedToSeeEvent( return false, err } } - return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil + return auth.IsServerAllowed(ctx, db, serverName, isServerInRoom, stateAtEvent), nil } func slowGetHistoryVisibilityState( - ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName spec.ServerName, + ctx context.Context, db storage.Database, info *types.RoomInfo, roomID, eventID string, serverName spec.ServerName, ) ([]gomatrixserverlib.PDU, error) { roomState := state.NewStateResolution(db, info) stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) @@ -319,8 +319,13 @@ func slowGetHistoryVisibilityState( // then we'll filter it out. This does preserve state keys that // are "" since these will contain history visibility etc. for nid, key := range stateKeys { - if key != "" && !strings.HasSuffix(key, ":"+string(serverName)) { - delete(stateKeys, nid) + if key != "" { + userID, err := db.GetUserIDForSender(ctx, roomID, spec.SenderID(key)) + if err == nil && userID != nil { + if userID.Domain() != serverName { + delete(stateKeys, nid) + } + } } } @@ -410,7 +415,7 @@ BFSLoop: // hasn't been seen before. if !visited[pre] { visited[pre] = true - allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, pre, serverName, isServerInRoom) + allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, ev.RoomID(), pre, serverName, isServerInRoom) if err != nil { util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error( "Error checking if allowed to see event", diff --git a/roomserver/internal/helpers/helpers_test.go b/roomserver/internal/helpers/helpers_test.go index f1896277e..1cef83df7 100644 --- a/roomserver/internal/helpers/helpers_test.go +++ b/roomserver/internal/helpers/helpers_test.go @@ -8,6 +8,7 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" "github.com/matrix-org/dendrite/roomserver/types" @@ -58,12 +59,12 @@ func TestIsInvitePendingWithoutNID(t *testing.T) { } // Alice should have no pending invites and should have a NID - pendingInvite, _, _, _, err := IsInvitePending(context.Background(), db, room.ID, alice.ID) + pendingInvite, _, _, _, err := IsInvitePending(context.Background(), db, room.ID, spec.SenderID(alice.ID)) assert.NoError(t, err, "failed to get pending invites") assert.False(t, pendingInvite, "unexpected pending invite") // Bob should have no pending invites and receive a new NID - pendingInvite, _, _, _, err = IsInvitePending(context.Background(), db, room.ID, bob.ID) + pendingInvite, _, _, _, err = IsInvitePending(context.Background(), db, room.ID, spec.SenderID(bob.ID)) assert.NoError(t, err, "failed to get pending invites") assert.False(t, pendingInvite, "unexpected pending invite") }) diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 1f273da01..7bb401632 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -842,17 +842,15 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r continue } - // TODO: pseudoIDs: get userID for room using state key (which is now senderID) - localpart, senderDomain, err := gomatrixserverlib.SplitID('@', *memberEvent.StateKey()) + memberUserID, err := r.Queryer.QueryUserIDForSender(ctx, memberEvent.RoomID(), spec.SenderID(*memberEvent.StateKey())) if err != nil { continue } - // TODO: pseudoIDs: query account by state key (which is now senderID) accountRes := &userAPI.QueryAccountByLocalpartResponse{} if err = r.UserAPI.QueryAccountByLocalpart(ctx, &userAPI.QueryAccountByLocalpartRequest{ - Localpart: localpart, - ServerName: senderDomain, + Localpart: memberUserID.Local(), + ServerName: memberUserID.Domain(), }, accountRes); err != nil { return err } @@ -896,8 +894,8 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r inputEvents = append(inputEvents, api.InputRoomEvent{ Kind: api.KindNew, Event: event, - Origin: senderDomain, - SendAsServer: string(senderDomain), + Origin: memberUserID.Domain(), + SendAsServer: string(memberUserID.Domain()), }) prevEvents = []string{event.EventID()} } diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go index 98d7d13b1..09c65dfe9 100644 --- a/roomserver/internal/input/input_membership.go +++ b/roomserver/internal/input/input_membership.go @@ -18,7 +18,6 @@ import ( "context" "fmt" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/internal" @@ -72,7 +71,7 @@ func (r *Inputer) updateMemberships( if change.addedEventNID != 0 { ae, _ = helpers.EventMap(events).Lookup(change.addedEventNID) } - if updates, err = r.updateMembership(updater, targetUserNID, re, ae, updates); err != nil { + if updates, err = r.updateMembership(ctx, updater, targetUserNID, re, ae, updates); err != nil { return nil, err } } @@ -80,6 +79,7 @@ func (r *Inputer) updateMemberships( } func (r *Inputer) updateMembership( + ctx context.Context, updater *shared.RoomUpdater, targetUserNID types.EventStateKeyNID, remove, add *types.Event, @@ -97,7 +97,7 @@ func (r *Inputer) updateMembership( var targetLocal bool if add != nil { - targetLocal = r.isLocalTarget(add) + targetLocal = r.isLocalTarget(ctx, add) } mu, err := updater.MembershipUpdater(targetUserNID, targetLocal) @@ -136,11 +136,14 @@ func (r *Inputer) updateMembership( } } -func (r *Inputer) isLocalTarget(event *types.Event) bool { +func (r *Inputer) isLocalTarget(ctx context.Context, event *types.Event) bool { isTargetLocalUser := false if statekey := event.StateKey(); statekey != nil { - _, domain, _ := gomatrixserverlib.SplitID('@', *statekey) - isTargetLocalUser = domain == r.ServerName + userID, err := r.Queryer.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*statekey)) + if err != nil || userID == nil { + return isTargetLocalUser + } + isTargetLocalUser = userID.Domain() == r.ServerName } return isTargetLocalUser } @@ -161,9 +164,10 @@ func updateToJoinMembership( Type: api.OutputTypeRetireInviteEvent, RetireInviteEvent: &api.OutputRetireInviteEvent{ EventID: eventID, + RoomID: add.RoomID(), Membership: spec.Join, RetiredByEventID: add.EventID(), - TargetUserID: *add.StateKey(), + TargetSenderID: spec.SenderID(*add.StateKey()), }, }) } @@ -187,9 +191,10 @@ func updateToLeaveMembership( Type: api.OutputTypeRetireInviteEvent, RetireInviteEvent: &api.OutputRetireInviteEvent{ EventID: eventID, + RoomID: add.RoomID(), Membership: newMembership, RetiredByEventID: add.EventID(), - TargetUserID: *add.StateKey(), + TargetSenderID: spec.SenderID(*add.StateKey()), }, }) } diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index eeb1ac406..ec13bff87 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -149,11 +149,11 @@ func (r *Admin) PerformAdminEvacuateUser( ctx context.Context, userID string, ) (affected []string, err error) { - _, domain, err := gomatrixserverlib.SplitID('@', userID) + fullUserID, err := spec.NewUserID(userID, true) if err != nil { return nil, err } - if !r.Cfg.Matrix.IsLocalServerName(domain) { + if !r.Cfg.Matrix.IsLocalServerName(fullUserID.Domain()) { return nil, fmt.Errorf("can only evacuate local users using this endpoint") } @@ -172,7 +172,7 @@ func (r *Admin) PerformAdminEvacuateUser( for _, roomID := range allRooms { leaveReq := &api.PerformLeaveRequest{ RoomID: roomID, - UserID: userID, + Leaver: *fullUserID, } leaveRes := &api.PerformLeaveResponse{} outputEvents, err := r.Leaver.PerformLeave(ctx, leaveReq, leaveRes) diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 388150936..8e87359a3 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -582,7 +582,7 @@ func joinEventsFromHistoryVisibility( } // Can we see events in the room? - canSeeEvents := auth.IsServerAllowed(thisServer, true, events) + canSeeEvents := auth.IsServerAllowed(ctx, db, thisServer, true, events) visibility := auth.HistoryVisibilityForRoom(events) if !canSeeEvents { logrus.Infof("ServersAtEvent history not visible to us: %s", visibility) diff --git a/roomserver/internal/perform/perform_create_room.go b/roomserver/internal/perform/perform_create_room.go index a3ba20f70..475418aa3 100644 --- a/roomserver/internal/perform/perform_create_room.go +++ b/roomserver/internal/perform/perform_create_room.go @@ -63,9 +63,17 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo } } } - createContent["creator"] = userID.String() + senderID, err := c.DB.GetSenderIDForUser(ctx, roomID.String(), userID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + createContent["creator"] = senderID createContent["room_version"] = createRequest.RoomVersion - powerLevelContent := eventutil.InitialPowerLevelsContent(userID.String()) + powerLevelContent := eventutil.InitialPowerLevelsContent(string(senderID)) joinRuleContent := gomatrixserverlib.JoinRuleContent{ JoinRule: spec.Invite, } @@ -121,7 +129,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo } membershipEvent := gomatrixserverlib.FledglingEvent{ Type: spec.MRoomMember, - StateKey: userID.String(), + StateKey: string(senderID), Content: gomatrixserverlib.MemberContent{ Membership: spec.Join, DisplayName: createRequest.UserDisplayName, @@ -270,7 +278,6 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo var builtEvents []*types.HeaderedEvent authEvents := gomatrixserverlib.NewAuthEvents(nil) - senderID, err := c.RSAPI.QuerySenderIDForUser(ctx, roomID.String(), userID) if err != nil { util.GetLogger(ctx).WithError(err).Error("rsapi.QuerySenderIDForUser failed") return "", &util.JSONResponse{ diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 56ee16065..1440daad4 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -134,12 +134,12 @@ func (r *Inviter) PerformInvite( return api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")} } - if event.StateKey() == nil { + if event.StateKey() == nil || *event.StateKey() == "" { return fmt.Errorf("invite must be a state event") } - invitedUser, err := spec.NewUserID(*event.StateKey(), true) - if err != nil { - return spec.InvalidParam("The user ID is invalid") + invitedUser, err := r.RSAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey())) + if err != nil || invitedUser == nil { + return spec.InvalidParam("Could not find the matching senderID for this user") } isTargetLocal := r.Cfg.Matrix.IsLocalServerName(invitedUser.Domain()) diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index d41cc214b..83c3b7c3e 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -162,7 +162,7 @@ func (r *Joiner) performJoinRoomByID( } // Get the domain part of the room ID. - _, domain, err := gomatrixserverlib.SplitID('!', req.RoomIDOrAlias) + roomID, err := spec.NewRoomID(req.RoomIDOrAlias) if err != nil { return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("room ID %q is invalid: %w", req.RoomIDOrAlias, err)} } @@ -170,8 +170,8 @@ func (r *Joiner) performJoinRoomByID( // If the server name in the room ID isn't ours then it's a // possible candidate for finding the room via federation. Add // it to the list of servers to try. - if !r.Cfg.Matrix.IsLocalServerName(domain) { - req.ServerNames = append(req.ServerNames, domain) + if !r.Cfg.Matrix.IsLocalServerName(roomID.Domain()) { + req.ServerNames = append(req.ServerNames, roomID.Domain()) } // Prepare the template for the join event. @@ -203,7 +203,7 @@ func (r *Joiner) performJoinRoomByID( req.Content = map[string]interface{}{} } req.Content["membership"] = spec.Join - if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req); aerr != nil { + if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req, senderID); aerr != nil { return "", "", aerr } else if authorisedVia != "" { req.Content["join_authorised_via_users_server"] = authorisedVia @@ -226,17 +226,17 @@ func (r *Joiner) performJoinRoomByID( // Force a federated join if we're dealing with a pending invite // and we aren't in the room. - isInvitePending, inviteSender, _, inviteEvent, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, req.UserID) + isInvitePending, inviteSender, _, inviteEvent, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, senderID) if err == nil && !serverInRoom && isInvitePending { - _, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender) - if ierr != nil { - return "", "", fmt.Errorf("gomatrixserverlib.SplitID: %w", err) + inviter, queryErr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomIDOrAlias, inviteSender) + if queryErr != nil { + return "", "", fmt.Errorf("r.RSAPI.QueryUserIDForSender: %w", queryErr) } // If we were invited by someone from another server then we can // assume they are in the room so we can join via them. - if !r.Cfg.Matrix.IsLocalServerName(inviterDomain) { - req.ServerNames = append(req.ServerNames, inviterDomain) + if inviter != nil && !r.Cfg.Matrix.IsLocalServerName(inviter.Domain()) { + req.ServerNames = append(req.ServerNames, inviter.Domain()) forceFederatedJoin = true memberEvent := gjson.Parse(string(inviteEvent.JSON())) // only set unsigned if we've got a content.membership, which we _should_ @@ -298,12 +298,8 @@ func (r *Joiner) performJoinRoomByID( // a member of the room. This is best-effort (as in we won't // fail if we can't find the existing membership) because there // is really no harm in just sending another membership event. - membershipReq := &api.QueryMembershipForUserRequest{ - RoomID: req.RoomIDOrAlias, - UserID: userID.String(), - } membershipRes := &api.QueryMembershipForUserResponse{} - _ = r.Queryer.QueryMembershipForUser(ctx, membershipReq, membershipRes) + _ = r.Queryer.QueryMembershipForSenderID(ctx, *roomID, senderID, membershipRes) // If we haven't already joined the room then send an event // into the room changing our membership status. @@ -328,7 +324,7 @@ func (r *Joiner) performJoinRoomByID( // The room doesn't exist locally. If the room ID looks like it should // be ours then this probably means that we've nuked our database at // some point. - if r.Cfg.Matrix.IsLocalServerName(domain) { + if r.Cfg.Matrix.IsLocalServerName(roomID.Domain()) { // If there are no more server names to try then give up here. // Otherwise we'll try a federated join as normal, since it's quite // possible that the room still exists on other servers. @@ -376,15 +372,12 @@ func (r *Joiner) performFederatedJoinRoomByID( func (r *Joiner) populateAuthorisedViaUserForRestrictedJoin( ctx context.Context, joinReq *rsAPI.PerformJoinRequest, + senderID spec.SenderID, ) (string, error) { roomID, err := spec.NewRoomID(joinReq.RoomIDOrAlias) if err != nil { return "", err } - userID, err := spec.NewUserID(joinReq.UserID, true) - if err != nil { - return "", err - } - return r.Queryer.QueryRestrictedJoinAllowed(ctx, *roomID, *userID) + return r.Queryer.QueryRestrictedJoinAllowed(ctx, *roomID, senderID) } diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index 094537f8b..1b23cc1ff 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -53,16 +53,12 @@ func (r *Leaver) PerformLeave( req *api.PerformLeaveRequest, res *api.PerformLeaveResponse, ) ([]api.OutputEvent, error) { - _, domain, err := gomatrixserverlib.SplitID('@', req.UserID) - if err != nil { - return nil, fmt.Errorf("supplied user ID %q in incorrect format", req.UserID) - } - if !r.Cfg.Matrix.IsLocalServerName(domain) { - return nil, fmt.Errorf("user %q does not belong to this homeserver", req.UserID) + if !r.Cfg.Matrix.IsLocalServerName(req.Leaver.Domain()) { + return nil, fmt.Errorf("user %q does not belong to this homeserver", req.Leaver.String()) } logger := logrus.WithContext(ctx).WithFields(logrus.Fields{ "room_id": req.RoomID, - "user_id": req.UserID, + "user_id": req.Leaver.String(), }) logger.Info("User requested to leave join") if strings.HasPrefix(req.RoomID, "!") { @@ -82,21 +78,26 @@ func (r *Leaver) performLeaveRoomByID( req *api.PerformLeaveRequest, res *api.PerformLeaveResponse, // nolint:unparam ) ([]api.OutputEvent, error) { + leaver, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomID, req.Leaver) + if err != nil { + return nil, fmt.Errorf("leaver %s has no matching senderID in this room", req.Leaver.String()) + } + // If there's an invite outstanding for the room then respond to // that. - isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.UserID) + isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, leaver) if err == nil && isInvitePending { - _, senderDomain, serr := gomatrixserverlib.SplitID('@', senderUser) - if serr != nil { - return nil, fmt.Errorf("sender %q is invalid", senderUser) + sender, serr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomID, senderUser) + if serr != nil || sender == nil { + return nil, fmt.Errorf("sender %q has no matching userID", senderUser) } - if !r.Cfg.Matrix.IsLocalServerName(senderDomain) { - return r.performFederatedRejectInvite(ctx, req, res, senderUser, eventID) + if !r.Cfg.Matrix.IsLocalServerName(sender.Domain()) { + return r.performFederatedRejectInvite(ctx, req, res, *sender, eventID, leaver) } // check that this is not a "server notice room" accData := &userapi.QueryAccountDataResponse{} if err = r.UserAPI.QueryAccountData(ctx, &userapi.QueryAccountDataRequest{ - UserID: req.UserID, + UserID: req.Leaver.String(), RoomID: req.RoomID, DataType: "m.tag", }, accData); err != nil { @@ -127,7 +128,7 @@ func (r *Leaver) performLeaveRoomByID( StateToFetch: []gomatrixserverlib.StateKeyTuple{ { EventType: spec.MRoomMember, - StateKey: req.UserID, + StateKey: string(leaver), }, }, } @@ -141,26 +142,18 @@ func (r *Leaver) performLeaveRoomByID( // Now let's see if the user is in the room. if len(latestRes.StateEvents) == 0 { - return nil, fmt.Errorf("user %q is not a member of room %q", req.UserID, req.RoomID) + return nil, fmt.Errorf("user %q is not a member of room %q", req.Leaver.String(), req.RoomID) } membership, err := latestRes.StateEvents[0].Membership() if err != nil { return nil, fmt.Errorf("error getting membership: %w", err) } if membership != spec.Join && membership != spec.Invite { - return nil, fmt.Errorf("user %q is not joined to the room (membership is %q)", req.UserID, membership) + return nil, fmt.Errorf("user %q is not joined to the room (membership is %q)", req.Leaver.String(), membership) } // Prepare the template for the leave event. - fullUserID, err := spec.NewUserID(req.UserID, true) - if err != nil { - return nil, err - } - senderID, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomID, *fullUserID) - if err != nil { - return nil, err - } - senderIDString := string(senderID) + senderIDString := string(leaver) proto := gomatrixserverlib.ProtoEvent{ Type: spec.MRoomMember, SenderID: senderIDString, @@ -175,16 +168,13 @@ func (r *Leaver) performLeaveRoomByID( return nil, fmt.Errorf("eb.SetUnsigned: %w", err) } - // Get the sender domain. - senderDomain := fullUserID.Domain() - // We know that the user is in the room at this point so let's build // a leave event. // TODO: Check what happens if the room exists on the server // but everyone has since left. I suspect it does the wrong thing. var buildRes rsAPI.QueryLatestEventsAndStateResponse - identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) + identity, err := r.Cfg.Matrix.SigningIdentityFor(req.Leaver.Domain()) if err != nil { return nil, fmt.Errorf("SigningIdentityFor: %w", err) } @@ -201,8 +191,8 @@ func (r *Leaver) performLeaveRoomByID( { Kind: api.KindNew, Event: event, - Origin: senderDomain, - SendAsServer: string(senderDomain), + Origin: req.Leaver.Domain(), + SendAsServer: string(req.Leaver.Domain()), }, }, } @@ -219,21 +209,17 @@ func (r *Leaver) performFederatedRejectInvite( ctx context.Context, req *api.PerformLeaveRequest, res *api.PerformLeaveResponse, // nolint:unparam - senderUser, eventID string, + inviteSender spec.UserID, eventID string, + leaver spec.SenderID, ) ([]api.OutputEvent, error) { - _, domain, err := gomatrixserverlib.SplitID('@', senderUser) - if err != nil { - return nil, fmt.Errorf("user ID %q invalid: %w", senderUser, err) - } - // Ask the federation sender to perform a federated leave for us. leaveReq := fsAPI.PerformLeaveRequest{ RoomID: req.RoomID, - UserID: req.UserID, - ServerNames: []spec.ServerName{domain}, + UserID: req.Leaver.String(), + ServerNames: []spec.ServerName{inviteSender.Domain()}, } leaveRes := fsAPI.PerformLeaveResponse{} - if err = r.FSAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil { + if err := r.FSAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil { // failures in PerformLeave should NEVER stop us from telling other components like the // sync API that the invite was withdrawn. Otherwise we can end up with stuck invites. util.GetLogger(ctx).WithError(err).Errorf("failed to PerformLeave, still retiring invite event") @@ -244,7 +230,7 @@ func (r *Leaver) performFederatedRejectInvite( util.GetLogger(ctx).WithError(err).Errorf("failed to get RoomInfo, still retiring invite event") } - updater, err := r.DB.MembershipUpdater(ctx, req.RoomID, req.UserID, true, info.RoomVersion) + updater, err := r.DB.MembershipUpdater(ctx, req.RoomID, string(leaver), true, info.RoomVersion) if err != nil { util.GetLogger(ctx).WithError(err).Errorf("failed to get MembershipUpdater, still retiring invite event") } @@ -267,9 +253,10 @@ func (r *Leaver) performFederatedRejectInvite( { Type: api.OutputTypeRetireInviteEvent, RetireInviteEvent: &api.OutputRetireInviteEvent{ - EventID: eventID, - Membership: "leave", - TargetUserID: req.UserID, + EventID: eventID, + RoomID: req.RoomID, + Membership: "leave", + TargetSenderID: leaver, }, }, }, nil diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go index 5710352bb..1aaa42c94 100644 --- a/roomserver/internal/perform/perform_upgrade.go +++ b/roomserver/internal/perform/perform_upgrade.go @@ -38,19 +38,15 @@ type Upgrader struct { // PerformRoomUpgrade upgrades a room from one version to another func (r *Upgrader) PerformRoomUpgrade( ctx context.Context, - roomID, userID string, roomVersion gomatrixserverlib.RoomVersion, + roomID string, userID spec.UserID, roomVersion gomatrixserverlib.RoomVersion, ) (newRoomID string, err error) { return r.performRoomUpgrade(ctx, roomID, userID, roomVersion) } func (r *Upgrader) performRoomUpgrade( ctx context.Context, - roomID, userID string, roomVersion gomatrixserverlib.RoomVersion, + roomID string, userID spec.UserID, roomVersion gomatrixserverlib.RoomVersion, ) (string, error) { - _, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID) - if err != nil { - return "", api.ErrNotAllowed{Err: fmt.Errorf("error validating the user ID")} - } evTime := time.Now() // Return an immediate error if the room does not exist @@ -58,14 +54,20 @@ func (r *Upgrader) performRoomUpgrade( return "", err } + senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, userID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user") + return "", err + } + // 1. Check if the user is authorized to actually perform the upgrade (can send m.room.tombstone) - if !r.userIsAuthorized(ctx, userID, roomID) { + if !r.userIsAuthorized(ctx, senderID, roomID) { return "", api.ErrNotAllowed{Err: fmt.Errorf("You don't have permission to upgrade the room, power level too low.")} } // TODO (#267): Check room ID doesn't clash with an existing one, and we // probably shouldn't be using pseudo-random strings, maybe GUIDs? - newRoomID := fmt.Sprintf("!%s:%s", util.RandomString(16), userDomain) + newRoomID := fmt.Sprintf("!%s:%s", util.RandomString(16), userID.Domain()) // Get the existing room state for the old room. oldRoomReq := &api.QueryLatestEventsAndStateRequest{ @@ -77,25 +79,25 @@ func (r *Upgrader) performRoomUpgrade( } // Make the tombstone event - tombstoneEvent, pErr := r.makeTombstoneEvent(ctx, evTime, userID, roomID, newRoomID) + tombstoneEvent, pErr := r.makeTombstoneEvent(ctx, evTime, senderID, userID.Domain(), roomID, newRoomID) if pErr != nil { return "", pErr } // Generate the initial events we need to send into the new room. This includes copied state events and bans // as well as the power level events needed to set up the room - eventsToMake, pErr := r.generateInitialEvents(ctx, oldRoomRes, userID, roomID, roomVersion, tombstoneEvent) + eventsToMake, pErr := r.generateInitialEvents(ctx, oldRoomRes, senderID, roomID, roomVersion, tombstoneEvent) if pErr != nil { return "", pErr } // Send the setup events to the new room - if pErr = r.sendInitialEvents(ctx, evTime, userID, userDomain, newRoomID, roomVersion, eventsToMake); pErr != nil { + if pErr = r.sendInitialEvents(ctx, evTime, senderID, userID.Domain(), newRoomID, roomVersion, eventsToMake); pErr != nil { return "", pErr } // 5. Send the tombstone event to the old room - if pErr = r.sendHeaderedEvent(ctx, userDomain, tombstoneEvent, string(userDomain)); pErr != nil { + if pErr = r.sendHeaderedEvent(ctx, userID.Domain(), tombstoneEvent, string(userID.Domain())); pErr != nil { return "", pErr } @@ -105,17 +107,17 @@ func (r *Upgrader) performRoomUpgrade( } // If the old room had a canonical alias event, it should be deleted in the old room - if pErr = r.clearOldCanonicalAliasEvent(ctx, oldRoomRes, evTime, userID, userDomain, roomID); pErr != nil { + if pErr = r.clearOldCanonicalAliasEvent(ctx, oldRoomRes, evTime, senderID, userID.Domain(), roomID); pErr != nil { return "", pErr } // 4. Move local aliases to the new room - if pErr = moveLocalAliases(ctx, roomID, newRoomID, userID, r.URSAPI); pErr != nil { + if pErr = moveLocalAliases(ctx, roomID, newRoomID, senderID, userID, r.URSAPI); pErr != nil { return "", pErr } // 6. Restrict power levels in the old room - if pErr = r.restrictOldRoomPowerLevels(ctx, evTime, userID, userDomain, roomID); pErr != nil { + if pErr = r.restrictOldRoomPowerLevels(ctx, evTime, senderID, userID.Domain(), roomID); pErr != nil { return "", pErr } @@ -130,7 +132,7 @@ func (r *Upgrader) getRoomPowerLevels(ctx context.Context, roomID string) (*goma return oldPowerLevelsEvent.PowerLevels() } -func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.Time, userID string, userDomain spec.ServerName, roomID string) error { +func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.Time, senderID spec.SenderID, userDomain spec.ServerName, roomID string) error { restrictedPowerLevelContent, pErr := r.getRoomPowerLevels(ctx, roomID) if pErr != nil { return pErr @@ -147,7 +149,7 @@ func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.T restrictedPowerLevelContent.EventsDefault = restrictedDefaultPowerLevel restrictedPowerLevelContent.Invite = restrictedDefaultPowerLevel - restrictedPowerLevelsHeadered, resErr := r.makeHeaderedEvent(ctx, evTime, userID, roomID, gomatrixserverlib.FledglingEvent{ + restrictedPowerLevelsHeadered, resErr := r.makeHeaderedEvent(ctx, evTime, senderID, userDomain, roomID, gomatrixserverlib.FledglingEvent{ Type: spec.MRoomPowerLevels, StateKey: "", Content: restrictedPowerLevelContent, @@ -165,7 +167,7 @@ func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.T } func moveLocalAliases(ctx context.Context, - roomID, newRoomID, userID string, + roomID, newRoomID string, senderID spec.SenderID, userID spec.UserID, URSAPI api.RoomserverInternalAPI, ) (err error) { @@ -175,14 +177,6 @@ func moveLocalAliases(ctx context.Context, return fmt.Errorf("Failed to get old room aliases: %w", err) } - fullUserID, err := spec.NewUserID(userID, true) - if err != nil { - return fmt.Errorf("Failed to get userID: %w", err) - } - senderID, err := URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID) - if err != nil { - return fmt.Errorf("Failed to get senderID: %w", err) - } for _, alias := range aliasRes.Aliases { removeAliasReq := api.RemoveRoomAliasRequest{SenderID: senderID, Alias: alias} removeAliasRes := api.RemoveRoomAliasResponse{} @@ -190,7 +184,7 @@ func moveLocalAliases(ctx context.Context, return fmt.Errorf("Failed to remove old room alias: %w", err) } - setAliasReq := api.SetRoomAliasRequest{UserID: userID, Alias: alias, RoomID: newRoomID} + setAliasReq := api.SetRoomAliasRequest{UserID: userID.String(), Alias: alias, RoomID: newRoomID} setAliasRes := api.SetRoomAliasResponse{} if err = URSAPI.SetRoomAlias(ctx, &setAliasReq, &setAliasRes); err != nil { return fmt.Errorf("Failed to set new room alias: %w", err) @@ -199,7 +193,7 @@ func moveLocalAliases(ctx context.Context, return nil } -func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, evTime time.Time, userID string, userDomain spec.ServerName, roomID string) error { +func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, evTime time.Time, senderID spec.SenderID, userDomain spec.ServerName, roomID string) error { for _, event := range oldRoom.StateEvents { if event.Type() != spec.MRoomCanonicalAlias || !event.StateKeyEquals("") { continue @@ -217,7 +211,7 @@ func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api } } - emptyCanonicalAliasEvent, resErr := r.makeHeaderedEvent(ctx, evTime, userID, roomID, gomatrixserverlib.FledglingEvent{ + emptyCanonicalAliasEvent, resErr := r.makeHeaderedEvent(ctx, evTime, senderID, userDomain, roomID, gomatrixserverlib.FledglingEvent{ Type: spec.MRoomCanonicalAlias, Content: map[string]interface{}{}, }) @@ -280,7 +274,7 @@ func (r *Upgrader) validateRoomExists(ctx context.Context, roomID string) error return nil } -func (r *Upgrader) userIsAuthorized(ctx context.Context, userID, roomID string, +func (r *Upgrader) userIsAuthorized(ctx context.Context, senderID spec.SenderID, roomID string, ) bool { plEvent := api.GetStateEvent(ctx, r.URSAPI, roomID, gomatrixserverlib.StateKeyTuple{ EventType: spec.MRoomPowerLevels, @@ -295,26 +289,18 @@ func (r *Upgrader) userIsAuthorized(ctx context.Context, userID, roomID string, } // Check for power level required to send tombstone event (marks the current room as obsolete), // if not found, use the StateDefault power level - fullUserID, err := spec.NewUserID(userID, true) - if err != nil { - return false - } - senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID) - if err != nil { - return false - } return pl.UserLevel(senderID) >= pl.EventLevel("m.room.tombstone", true) } // nolint:gocyclo -func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, userID, roomID string, newVersion gomatrixserverlib.RoomVersion, tombstoneEvent *types.HeaderedEvent) ([]gomatrixserverlib.FledglingEvent, error) { +func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, senderID spec.SenderID, roomID string, newVersion gomatrixserverlib.RoomVersion, tombstoneEvent *types.HeaderedEvent) ([]gomatrixserverlib.FledglingEvent, error) { state := make(map[gomatrixserverlib.StateKeyTuple]*types.HeaderedEvent, len(oldRoom.StateEvents)) for _, event := range oldRoom.StateEvents { if event.StateKey() == nil { // This shouldn't ever happen, but better to be safe than sorry. continue } - if event.Type() == spec.MRoomMember && !event.StateKeyEquals(userID) { + if event.Type() == spec.MRoomMember && !event.StateKeyEquals(string(senderID)) { // With the exception of bans which we do want to copy, we // should ignore membership events that aren't our own, as event auth will // prevent us from being able to create membership events on behalf of other @@ -330,6 +316,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query } } // skip events that rely on a specific user being present + // TODO: What to do here for pseudoIDs? It's checking non-member events for state keys with userIDs. sKey := *event.StateKey() if event.Type() != spec.MRoomMember && len(sKey) > 0 && sKey[:1] == "@" { continue @@ -340,10 +327,10 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query // The following events are ones that we are going to override manually // in the following section. override := map[gomatrixserverlib.StateKeyTuple]struct{}{ - {EventType: spec.MRoomCreate, StateKey: ""}: {}, - {EventType: spec.MRoomMember, StateKey: userID}: {}, - {EventType: spec.MRoomPowerLevels, StateKey: ""}: {}, - {EventType: spec.MRoomJoinRules, StateKey: ""}: {}, + {EventType: spec.MRoomCreate, StateKey: ""}: {}, + {EventType: spec.MRoomMember, StateKey: string(senderID)}: {}, + {EventType: spec.MRoomPowerLevels, StateKey: ""}: {}, + {EventType: spec.MRoomJoinRules, StateKey: ""}: {}, } // The overridden events are essential events that must be present in the @@ -355,7 +342,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query } oldCreateEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomCreate, StateKey: ""}] - oldMembershipEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomMember, StateKey: userID}] + oldMembershipEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomMember, StateKey: string(senderID)}] oldPowerLevelsEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomPowerLevels, StateKey: ""}] oldJoinRulesEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomJoinRules, StateKey: ""}] @@ -364,7 +351,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query // in the create event (such as for the room types MSC). newCreateContent := map[string]interface{}{} _ = json.Unmarshal(oldCreateEvent.Content(), &newCreateContent) - newCreateContent["creator"] = userID + newCreateContent["creator"] = string(senderID) newCreateContent["room_version"] = newVersion newCreateContent["predecessor"] = gomatrixserverlib.PreviousRoom{ EventID: tombstoneEvent.EventID(), @@ -385,7 +372,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query newMembershipContent["membership"] = spec.Join newMembershipEvent := gomatrixserverlib.FledglingEvent{ Type: spec.MRoomMember, - StateKey: userID, + StateKey: string(senderID), Content: newMembershipContent, } @@ -400,14 +387,6 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query return nil, fmt.Errorf("Power level event content was invalid") } - fullUserID, err := spec.NewUserID(userID, true) - if err != nil { - return nil, err - } - senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID) - if err != nil { - return nil, err - } tempPowerLevelsEvent, powerLevelsOverridden := createTemporaryPowerLevels(powerLevelContent, senderID) // Now do the join rules event, same as the create and membership @@ -470,21 +449,13 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query return eventsToMake, nil } -func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, userID string, userDomain spec.ServerName, newRoomID string, newVersion gomatrixserverlib.RoomVersion, eventsToMake []gomatrixserverlib.FledglingEvent) error { +func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, senderID spec.SenderID, userDomain spec.ServerName, newRoomID string, newVersion gomatrixserverlib.RoomVersion, eventsToMake []gomatrixserverlib.FledglingEvent) error { var err error var builtEvents []*types.HeaderedEvent authEvents := gomatrixserverlib.NewAuthEvents(nil) for i, e := range eventsToMake { depth := i + 1 // depth starts at 1 - fullUserID, userIDErr := spec.NewUserID(userID, true) - if userIDErr != nil { - return userIDErr - } - senderID, queryErr := r.URSAPI.QuerySenderIDForUser(ctx, newRoomID, *fullUserID) - if queryErr != nil { - return queryErr - } proto := gomatrixserverlib.ProtoEvent{ SenderID: string(senderID), RoomID: newRoomID, @@ -549,7 +520,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user func (r *Upgrader) makeTombstoneEvent( ctx context.Context, evTime time.Time, - userID, roomID, newRoomID string, + senderID spec.SenderID, senderDomain spec.ServerName, roomID, newRoomID string, ) (*types.HeaderedEvent, error) { content := map[string]interface{}{ "body": "This room has been replaced", @@ -559,30 +530,21 @@ func (r *Upgrader) makeTombstoneEvent( Type: "m.room.tombstone", Content: content, } - return r.makeHeaderedEvent(ctx, evTime, userID, roomID, event) + return r.makeHeaderedEvent(ctx, evTime, senderID, senderDomain, roomID, event) } -func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, userID, roomID string, event gomatrixserverlib.FledglingEvent) (*types.HeaderedEvent, error) { - fullUserID, err := spec.NewUserID(userID, true) - if err != nil { - return nil, err - } - senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID) - if err != nil { - return nil, err - } +func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, senderID spec.SenderID, senderDomain spec.ServerName, roomID string, event gomatrixserverlib.FledglingEvent) (*types.HeaderedEvent, error) { proto := gomatrixserverlib.ProtoEvent{ SenderID: string(senderID), RoomID: roomID, Type: event.Type, StateKey: &event.StateKey, } - err = proto.SetContent(event.Content) + err := proto.SetContent(event.Content) if err != nil { return nil, fmt.Errorf("failed to set new %q event content: %w", proto.Type, err) } // Get the sender domain. - senderDomain := fullUserID.Domain() identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) if err != nil { return nil, fmt.Errorf("failed to get signing identity for %q: %w", senderDomain, err) diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index ae2b7cf57..caea6b526 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -48,7 +48,7 @@ type Queryer struct { Cfg *config.Dendrite } -func (r *Queryer) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) { +func (r *Queryer) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) { roomInfo, err := r.QueryRoomInfo(ctx, roomID) if err != nil || roomInfo == nil || roomInfo.IsStub() { return nil, err @@ -64,7 +64,7 @@ func (r *Queryer) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID return nil, fmt.Errorf("InternalServerError: Failed to query room: %w", err) } - userJoinedToRoom, err := r.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), userID) + userJoinedToRoom, err := r.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), senderID) if err != nil { util.GetLogger(ctx).WithError(err).Error("rsAPI.UserJoinedToRoom failed") return nil, fmt.Errorf("InternalServerError: %w", err) @@ -220,13 +220,14 @@ func (r *Queryer) QueryEventsByID( return nil } -// QueryMembershipForUser implements api.RoomserverInternalAPI -func (r *Queryer) QueryMembershipForUser( +// QueryMembershipForSenderID implements api.RoomserverInternalAPI +func (r *Queryer) QueryMembershipForSenderID( ctx context.Context, - request *api.QueryMembershipForUserRequest, + roomID spec.RoomID, + senderID spec.SenderID, response *api.QueryMembershipForUserResponse, ) error { - info, err := r.DB.RoomInfo(ctx, request.RoomID) + info, err := r.DB.RoomInfo(ctx, roomID.String()) if err != nil { return err } @@ -236,7 +237,7 @@ func (r *Queryer) QueryMembershipForUser( } response.RoomExists = true - membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.UserID) + membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, senderID) if err != nil { return err } @@ -264,6 +265,24 @@ func (r *Queryer) QueryMembershipForUser( return err } +// QueryMembershipForUser implements api.RoomserverInternalAPI +func (r *Queryer) QueryMembershipForUser( + ctx context.Context, + request *api.QueryMembershipForUserRequest, + response *api.QueryMembershipForUserResponse, +) error { + senderID, err := r.DB.GetSenderIDForUser(ctx, request.RoomID, request.UserID) + if err != nil { + return err + } + + roomID, err := spec.NewRoomID(request.RoomID) + if err != nil { + return err + } + return r.QueryMembershipForSenderID(ctx, *roomID, senderID, response) +} + // QueryMembershipAtEvent returns the known memberships at a given event. // If the state before an event is not known, an empty list will be returned // for that event instead. @@ -373,7 +392,7 @@ func (r *Queryer) QueryMembershipsForRoom( // If no sender is specified then we will just return the entire // set of memberships for the room, regardless of whether a specific // user is allowed to see them or not. - if request.Sender == "" { + if request.SenderID == "" { var events []types.Event var eventNIDs []types.EventNID eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.JoinedOnly, request.LocalOnly) @@ -388,18 +407,15 @@ func (r *Queryer) QueryMembershipsForRoom( return fmt.Errorf("r.DB.Events: %w", err) } for _, event := range events { - sender := spec.UserID{} - userID, queryErr := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) - if queryErr == nil && userID != nil { - sender = *userID - } - clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender) + clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return r.QueryUserIDForSender(ctx, roomID, senderID) + }, event) response.JoinEvents = append(response.JoinEvents, clientEvent) } return nil } - membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.Sender) + membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.SenderID) if err != nil { return err } @@ -442,12 +458,9 @@ func (r *Queryer) QueryMembershipsForRoom( } for _, event := range events { - sender := spec.UserID{} - userID, err := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) - if err == nil && userID != nil { - sender = *userID - } - clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender) + clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return r.QueryUserIDForSender(ctx, roomID, senderID) + }, event) response.JoinEvents = append(response.JoinEvents, clientEvent) } @@ -489,6 +502,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent( ctx context.Context, serverName spec.ServerName, eventID string, + roomID string, ) (allowed bool, err error) { events, err := r.DB.EventNIDs(ctx, []string{eventID}) if err != nil { @@ -518,7 +532,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent( } return helpers.CheckServerAllowedToSeeEvent( - ctx, r.DB, info, eventID, serverName, isInRoom, + ctx, r.DB, info, roomID, eventID, serverName, isInRoom, ) } @@ -909,8 +923,8 @@ func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainReq return nil } -func (r *Queryer) InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) { - pending, _, _, _, err := helpers.IsInvitePending(ctx, r.DB, roomID.String(), userID.String()) +func (r *Queryer) InvitePending(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (bool, error) { + pending, _, _, _, err := helpers.IsInvitePending(ctx, r.DB, roomID.String(), senderID) return pending, err } @@ -926,8 +940,8 @@ func (r *Queryer) CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eve return res, err } -func (r *Queryer) UserJoinedToRoom(ctx context.Context, roomNID types.RoomNID, userID spec.UserID) (bool, error) { - _, isIn, _, err := r.DB.GetMembership(ctx, roomNID, userID.String()) +func (r *Queryer) UserJoinedToRoom(ctx context.Context, roomNID types.RoomNID, senderID spec.SenderID) (bool, error) { + _, isIn, _, err := r.DB.GetMembership(ctx, roomNID, senderID) return isIn, err } @@ -957,7 +971,7 @@ func (r *Queryer) LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixse } // nolint:gocyclo -func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error) { +func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) { // Look up if we know anything about the room. If it doesn't exist // or is a stub entry then we can't do anything. roomInfo, err := r.DB.RoomInfo(ctx, roomID.String()) @@ -972,7 +986,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.Ro return "", err } - return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, userID) + return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, senderID) } func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 5e6ba7d4e..90c94bbce 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -722,7 +722,7 @@ func TestQueryRestrictedJoinAllowed(t *testing.T) { roomID, _ := spec.NewRoomID(testRoom.ID) userID, _ := spec.NewUserID(bob.ID, true) - got, err := rsAPI.QueryRestrictedJoinAllowed(processCtx.Context(), *roomID, *userID) + got, err := rsAPI.QueryRestrictedJoinAllowed(processCtx.Context(), *roomID, spec.SenderID(userID.String())) if tc.wantError && err == nil { t.Fatal("expected error, got none") } @@ -821,17 +821,6 @@ func TestUpgrade(t *testing.T) { validateFunc func(t *testing.T, oldRoomID, newRoomID string, rsAPI api.RoomserverInternalAPI) wantNewRoom bool }{ - { - name: "invalid userID", - upgradeUser: "!notvalid:test", - roomFunc: func(rsAPI api.RoomserverInternalAPI) string { - room := test.NewRoom(t, alice) - if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { - t.Errorf("failed to send events: %v", err) - } - return room.ID - }, - }, { name: "invalid roomID", upgradeUser: alice.ID, @@ -1049,7 +1038,11 @@ func TestUpgrade(t *testing.T) { } roomID := tc.roomFunc(rsAPI) - newRoomID, err := rsAPI.PerformRoomUpgrade(processCtx.Context(), roomID, tc.upgradeUser, version.DefaultRoomVersion()) + userID, err := spec.NewUserID(tc.upgradeUser, true) + if err != nil { + t.Fatalf("upgrade userID is invalid") + } + newRoomID, err := rsAPI.PerformRoomUpgrade(processCtx.Context(), roomID, *userID, version.DefaultRoomVersion()) if err != nil && tc.wantNewRoom { t.Fatal(err) } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 2d27d7999..ef4463781 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -131,7 +131,7 @@ type Database interface { // in this room, along a boolean set to true if the user is still in this room, // false if not. // Returns an error if there was a problem talking to the database. - GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom, isRoomForgotten bool, err error) + GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderID spec.SenderID) (membershipEventNID types.EventNID, stillInRoom, isRoomForgotten bool, err error) // Lookup the membership event numeric IDs for all user that are or have // been members of a given room. Only lookup events of "join" membership if // joinOnly is set to true. diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index cb12b3f57..85a1ba7a1 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -490,10 +490,10 @@ func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { }) } -func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom, isRoomforgotten bool, err error) { +func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderID spec.SenderID) (membershipEventNID types.EventNID, stillInRoom, isRoomforgotten bool, err error) { var requestSenderUserNID types.EventStateKeyNID err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - requestSenderUserNID, err = d.assignStateKeyNID(ctx, txn, requestSenderUserID) + requestSenderUserNID, err = d.assignStateKeyNID(ctx, txn, string(requestSenderID)) return err }) if err != nil { @@ -936,6 +936,7 @@ func extractRoomVersionFromCreateEvent(event gomatrixserverlib.PDU) ( return roomVersion, err } +// nolint:gocyclo // MaybeRedactEvent manages the redacted status of events. There's two cases to consider in order to comply with the spec: // "servers should not apply or send redactions to clients until both the redaction event and original event have been seen, and are valid." // https://matrix.org/docs/spec/rooms/v3#authorization-rules-for-events @@ -1014,7 +1015,7 @@ func (d *EventDatabase) MaybeRedactEvent( switch { case powerlevels.UserLevel(redactionEvent.SenderID()) >= powerlevels.Redact: // 1. The power level of the redaction event’s sender is greater than or equal to the redact level. - case sender1Domain == sender2Domain: + case sender1Domain != "" && sender2Domain != "" && sender1Domain == sender2Domain: // 2. The domain of the redaction event’s sender matches that of the original event’s sender. default: ignoreRedaction = true diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index 47eb544ea..d3f1c9dd2 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -154,7 +154,7 @@ type reqCtx struct { rsAPI roomserver.RoomserverInternalAPI db Database req *EventRelationshipRequest - userID string + userID spec.UserID roomVersion gomatrixserverlib.RoomVersion // federated request args @@ -173,10 +173,17 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP JSON: spec.BadJSON(fmt.Sprintf("invalid json: %s", err)), } } + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: 400, + JSON: spec.BadJSON(fmt.Sprintf("invalid json: %s", err)), + } + } rc := reqCtx{ ctx: req.Context(), req: relation, - userID: device.UserID, + userID: *userID, rsAPI: rsAPI, fsAPI: fsAPI, isFederatedRequest: false, diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go index 551d7ad45..e32d6a9f2 100644 --- a/setup/mscs/msc2836/msc2836_test.go +++ b/setup/mscs/msc2836/msc2836_test.go @@ -529,6 +529,10 @@ func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID str return spec.NewUserID(string(senderID), true) } +func (r *testRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { + return spec.SenderID(userID.String()), nil +} + func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver.QueryEventsByIDRequest, res *roomserver.QueryEventsByIDResponse) error { for _, eventID := range req.EventIDs { ev := r.events[eventID] @@ -540,7 +544,7 @@ func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver } func (r *testRoomserverAPI) QueryMembershipForUser(ctx context.Context, req *roomserver.QueryMembershipForUserRequest, res *roomserver.QueryMembershipForUserResponse) error { - rooms := r.userToJoinedRooms[req.UserID] + rooms := r.userToJoinedRooms[req.UserID.String()] for _, roomID := range rooms { if roomID == req.RoomID { res.IsInRoom = true diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 8a2a0b1f6..c5f2db9c8 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -373,7 +373,15 @@ func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *rst // TODO: check that it's a join and not a profile change (means unmarshalling prev_content) if membership == spec.Join { // check it's a local join - if _, _, err := s.cfg.Matrix.SplitLocalID('@', *ev.StateKey()); err != nil { + if ev.StateKey() == nil { + return sp, fmt.Errorf("unexpected nil state_key") + } + + userID, err := s.rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey())) + if err != nil || userID == nil { + return sp, fmt.Errorf("failed getting userID for sender: %w", err) + } + if !s.cfg.Matrix.IsLocalServerName(userID.Domain()) { return sp, nil } @@ -395,9 +403,15 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent( if msg.Event.StateKey() == nil { return } - if _, _, err := s.cfg.Matrix.SplitLocalID('@', *msg.Event.StateKey()); err != nil { + + userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.Event.RoomID(), spec.SenderID(*msg.Event.StateKey())) + if err != nil || userID == nil { return } + if !s.cfg.Matrix.IsLocalServerName(userID.Domain()) { + return + } + pduPos, err := s.db.AddInviteEvent(ctx, msg.Event) if err != nil { sentry.CaptureException(err) @@ -440,7 +454,16 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent( // Notify any active sync requests that the invite has been retired. s.inviteStream.Advance(pduPos) - s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, msg.TargetUserID) + userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.RoomID, msg.TargetSenderID) + if err != nil || userID == nil { + log.WithFields(log.Fields{ + "event_id": msg.EventID, + "sender_id": msg.TargetSenderID, + log.ErrorKey: err, + }).Errorf("failed to find userID for sender") + return + } + s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, userID.String()) } func (s *OutputRoomEventConsumer) onNewPeek( diff --git a/syncapi/internal/history_visibility.go b/syncapi/internal/history_visibility.go index 7449b4647..ab1a7f83d 100644 --- a/syncapi/internal/history_visibility.go +++ b/syncapi/internal/history_visibility.go @@ -134,9 +134,17 @@ func ApplyHistoryVisibilityFilter( } } // NOTSPEC: Always allow user to see their own membership events (spec contains more "rules") - if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(userID) { - eventsFiltered = append(eventsFiltered, ev) - continue + + user, err := spec.NewUserID(userID, true) + if err != nil { + return nil, err + } + senderID, err := rsAPI.QuerySenderIDForUser(ctx, ev.RoomID(), *user) + if err == nil { + if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(string(senderID)) { + eventsFiltered = append(eventsFiltered, ev) + continue + } } // Always allow history evVis events on boundaries. This is done // by setting the effective evVis to the least restrictive diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index ad5935cdc..f4b6ace59 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -169,12 +169,16 @@ func TrackChangedUsers( if err != nil { return nil, nil, err } - for _, state := range stateRes.Rooms { + for roomID, state := range stateRes.Rooms { for tuple, membership := range state { if membership != spec.Join { continue } - queryRes.UserIDsToCount[tuple.StateKey]-- + user, queryErr := rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(tuple.StateKey)) + if queryErr != nil || user == nil { + continue + } + queryRes.UserIDsToCount[user.String()]-- } } @@ -211,14 +215,18 @@ func TrackChangedUsers( if err != nil { return nil, left, err } - for _, state := range stateRes.Rooms { + for roomID, state := range stateRes.Rooms { for tuple, membership := range state { if membership != spec.Join { continue } // new user who we weren't previously sharing rooms with if _, ok := queryRes.UserIDsToCount[tuple.StateKey]; !ok { - changed = append(changed, tuple.StateKey) // changed is returned + user, err := rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(tuple.StateKey)) + if err != nil || user == nil { + continue + } + changed = append(changed, user.String()) // changed is returned } } } diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go index 23c2ecbaa..efa641475 100644 --- a/syncapi/internal/keychange_test.go +++ b/syncapi/internal/keychange_test.go @@ -64,6 +64,10 @@ type mockRoomserverAPI struct { roomIDToJoinedMembers map[string][]string } +func (s *mockRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) +} + // QueryRoomsForUser retrieves a list of room IDs matching the given query. func (s *mockRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error { return nil diff --git a/syncapi/notifier/notifier.go b/syncapi/notifier/notifier.go index f76456859..4ee7c8605 100644 --- a/syncapi/notifier/notifier.go +++ b/syncapi/notifier/notifier.go @@ -20,6 +20,7 @@ import ( "time" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/api" rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" @@ -36,7 +37,8 @@ import ( // the event, but the token has already advanced by the time they fetch it, resulting // in missed events. type Notifier struct { - lock *sync.RWMutex + lock *sync.RWMutex + rsAPI api.SyncRoomserverAPI // A map of RoomID => Set : Must only be accessed by the OnNewEvent goroutine roomIDToJoinedUsers map[string]*userIDSet // A map of RoomID => Set : Must only be accessed by the OnNewEvent goroutine @@ -55,8 +57,9 @@ type Notifier struct { // NewNotifier creates a new notifier set to the given sync position. // In order for this to be of any use, the Notifier needs to be told all rooms and // the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase). -func NewNotifier() *Notifier { +func NewNotifier(rsAPI api.SyncRoomserverAPI) *Notifier { return &Notifier{ + rsAPI: rsAPI, roomIDToJoinedUsers: make(map[string]*userIDSet), roomIDToPeekingDevices: make(map[string]peekingDeviceSet), userDeviceStreams: make(map[string]map[string]*UserDeviceStream), @@ -104,26 +107,32 @@ func (n *Notifier) OnNewEvent( peekingDevicesToNotify := n._peekingDevices(ev.RoomID()) // If this is an invite, also add in the invitee to this list. if ev.Type() == "m.room.member" && ev.StateKey() != nil { - targetUserID := *ev.StateKey() - membership, err := ev.Membership() + targetUserID, err := n.rsAPI.QueryUserIDForSender(context.Background(), ev.RoomID(), spec.SenderID(*ev.StateKey())) if err != nil { log.WithError(err).WithField("event_id", ev.EventID()).Errorf( - "Notifier.OnNewEvent: Failed to unmarshal member event", + "Notifier.OnNewEvent: Failed to find the userID for this event", ) } else { - // Keep the joined user map up-to-date - switch membership { - case spec.Invite: - usersToNotify = append(usersToNotify, targetUserID) - case spec.Join: - // Manually append the new user's ID so they get notified - // along all members in the room - usersToNotify = append(usersToNotify, targetUserID) - n._addJoinedUser(ev.RoomID(), targetUserID) - case spec.Leave: - fallthrough - case spec.Ban: - n._removeJoinedUser(ev.RoomID(), targetUserID) + membership, err := ev.Membership() + if err != nil { + log.WithError(err).WithField("event_id", ev.EventID()).Errorf( + "Notifier.OnNewEvent: Failed to unmarshal member event", + ) + } else { + // Keep the joined user map up-to-date + switch membership { + case spec.Invite: + usersToNotify = append(usersToNotify, targetUserID.String()) + case spec.Join: + // Manually append the new user's ID so they get notified + // along all members in the room + usersToNotify = append(usersToNotify, targetUserID.String()) + n._addJoinedUser(ev.RoomID(), targetUserID.String()) + case spec.Leave: + fallthrough + case spec.Ban: + n._removeJoinedUser(ev.RoomID(), targetUserID.String()) + } } } } diff --git a/syncapi/notifier/notifier_test.go b/syncapi/notifier/notifier_test.go index 36577a0ee..7076f7134 100644 --- a/syncapi/notifier/notifier_test.go +++ b/syncapi/notifier/notifier_test.go @@ -22,9 +22,11 @@ import ( "testing" "time" + "github.com/matrix-org/dendrite/roomserver/api" rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -105,9 +107,15 @@ func mustEqualPositions(t *testing.T, got, want types.StreamingToken) { } } +type TestRoomServer struct{ api.SyncRoomserverAPI } + +func (t *TestRoomServer) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) +} + // Test that the current position is returned if a request is already behind. func TestImmediateNotification(t *testing.T) { - n := NewNotifier() + n := NewNotifier(&TestRoomServer{}) n.SetCurrentPosition(syncPositionBefore) pos, err := waitForEvents(n, newTestSyncRequest(alice, aliceDev, syncPositionVeryOld)) if err != nil { @@ -118,7 +126,7 @@ func TestImmediateNotification(t *testing.T) { // Test that new events to a joined room unblocks the request. func TestNewEventAndJoinedToRoom(t *testing.T) { - n := NewNotifier() + n := NewNotifier(&TestRoomServer{}) n.SetCurrentPosition(syncPositionBefore) n.setUsersJoinedToRooms(map[string][]string{ roomID: {alice, bob}, @@ -144,7 +152,7 @@ func TestNewEventAndJoinedToRoom(t *testing.T) { } func TestCorrectStream(t *testing.T) { - n := NewNotifier() + n := NewNotifier(&TestRoomServer{}) n.SetCurrentPosition(syncPositionBefore) stream := lockedFetchUserStream(n, bob, bobDev) if stream.UserID != bob { @@ -156,7 +164,7 @@ func TestCorrectStream(t *testing.T) { } func TestCorrectStreamWakeup(t *testing.T) { - n := NewNotifier() + n := NewNotifier(&TestRoomServer{}) n.SetCurrentPosition(syncPositionBefore) awoken := make(chan string) @@ -184,7 +192,7 @@ func TestCorrectStreamWakeup(t *testing.T) { // Test that an invite unblocks the request func TestNewInviteEventForUser(t *testing.T) { - n := NewNotifier() + n := NewNotifier(&TestRoomServer{}) n.SetCurrentPosition(syncPositionBefore) n.setUsersJoinedToRooms(map[string][]string{ roomID: {alice, bob}, @@ -241,7 +249,7 @@ func TestEDUWakeup(t *testing.T) { // Test that all blocked requests get woken up on a new event. func TestMultipleRequestWakeup(t *testing.T) { - n := NewNotifier() + n := NewNotifier(&TestRoomServer{}) n.SetCurrentPosition(syncPositionBefore) n.setUsersJoinedToRooms(map[string][]string{ roomID: {alice, bob}, @@ -278,7 +286,7 @@ func TestMultipleRequestWakeup(t *testing.T) { func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { // listen as bob. Make bob leave room. Make alice send event to room. // Make sure alice gets woken up only and not bob as well. - n := NewNotifier() + n := NewNotifier(&TestRoomServer{}) n.SetCurrentPosition(syncPositionBefore) n.setUsersJoinedToRooms(map[string][]string{ roomID: {alice, bob}, diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index 7fb88faaa..55fd3c5a2 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -85,9 +85,16 @@ func Context( *filter.Rooms = append(*filter.Rooms, roomID) } + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("Device UserID is invalid"), + } + } ctx := req.Context() membershipRes := roomserver.QueryMembershipForUserResponse{} - membershipReq := roomserver.QueryMembershipForUserRequest{UserID: device.UserID, RoomID: roomID} + membershipReq := roomserver.QueryMembershipForUserRequest{UserID: *userID, RoomID: roomID} if err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes); err != nil { logrus.WithError(err).Error("unable to query membership") return util.JSONResponse{ @@ -217,12 +224,9 @@ func Context( } } - sender := spec.UserID{} - userID, err := rsAPI.QueryUserIDForSender(ctx, requestedEvent.RoomID(), requestedEvent.SenderID()) - if err == nil && userID != nil { - sender = *userID - } - ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, sender) + ev := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }, requestedEvent) response := ContextRespsonse{ Event: &ev, EventsAfter: eventsAfterClient, diff --git a/syncapi/routing/getevent.go b/syncapi/routing/getevent.go index 63df7e837..de790e5cd 100644 --- a/syncapi/routing/getevent.go +++ b/syncapi/routing/getevent.go @@ -106,8 +106,17 @@ func GetEvent( if err == nil && senderUserID != nil { sender = *senderUserID } + + sk := events[0].StateKey() + if sk != nil && *sk != "" { + skUserID, err := rsAPI.QueryUserIDForSender(ctx, events[0].RoomID(), spec.SenderID(*events[0].StateKey())) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } return util.JSONResponse{ Code: http.StatusOK, - JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender), + JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender, sk), } } diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go index 813167a5e..cf6769ba4 100644 --- a/syncapi/routing/memberships.go +++ b/syncapi/routing/memberships.go @@ -59,14 +59,21 @@ func GetMemberships( syncDB storage.Database, rsAPI api.SyncRoomserverAPI, joinedOnly bool, membership, notMembership *string, at string, ) util.JSONResponse { + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("Device UserID is invalid"), + } + } queryReq := api.QueryMembershipForUserRequest{ RoomID: roomID, - UserID: device.UserID, + UserID: *userID, } var queryRes api.QueryMembershipForUserResponse - if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryMembershipsForRoom failed") + if queryErr := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); queryErr != nil { + util.GetLogger(req.Context()).WithError(queryErr).Error("rsAPI.QueryMembershipsForRoom failed") return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: spec.InternalServerError{}, diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 781fd53e7..6784a27bd 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -296,9 +296,13 @@ func OnIncomingMessagesRequest( } func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api.SyncRoomserverAPI) (resp api.QueryMembershipForUserResponse, err error) { + fullUserID, err := spec.NewUserID(userID, true) + if err != nil { + return resp, err + } req := api.QueryMembershipForUserRequest{ RoomID: roomID, - UserID: userID, + UserID: *fullUserID, } if err := rsAPI.QueryMembershipForUser(ctx, &req, &resp); err != nil { return api.QueryMembershipForUserResponse{}, err diff --git a/syncapi/routing/relations.go b/syncapi/routing/relations.go index f21c684c8..6efa065a9 100644 --- a/syncapi/routing/relations.go +++ b/syncapi/routing/relations.go @@ -119,9 +119,18 @@ func Relations( if err == nil && userID != nil { sender = *userID } + + sk := event.StateKey() + if sk != nil && *sk != "" { + skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey())) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } res.Chunk = append( res.Chunk, - synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender), + synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender, sk), ) } diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go index add50b181..7d9182f47 100644 --- a/syncapi/routing/search.go +++ b/syncapi/routing/search.go @@ -235,6 +235,15 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts if err == nil && userID != nil { sender = *userID } + + sk := event.StateKey() + if sk != nil && *sk != "" { + skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey())) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } results = append(results, Result{ Context: SearchContextResponse{ Start: startToken.String(), @@ -248,7 +257,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts ProfileInfo: profileInfos, }, Rank: eventScore[event.EventID()].Score, - Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender), + Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender, sk), }) roomGroup := groups[event.RoomID()] roomGroup.Results = append(roomGroup.Results, event.EventID()) diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index 5bd3b1f01..799e3d166 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -507,8 +507,20 @@ func (d *Database) CleanSendToDeviceUpdates( // getMembershipFromEvent returns the value of content.membership iff the event is a state event // with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned. -func getMembershipFromEvent(ev gomatrixserverlib.PDU, userID string) (string, string) { - if ev.Type() != "m.room.member" || !ev.StateKeyEquals(userID) { +func getMembershipFromEvent(ctx context.Context, ev gomatrixserverlib.PDU, userID string, rsAPI api.SyncRoomserverAPI) (string, string) { + if ev.StateKey() == nil || *ev.StateKey() == "" { + return "", "" + } + fullUser, err := spec.NewUserID(userID, true) + if err != nil { + return "", "" + } + senderID, err := rsAPI.QuerySenderIDForUser(ctx, ev.RoomID(), *fullUser) + if err != nil { + return "", "" + } + + if ev.Type() != "m.room.member" || !ev.StateKeyEquals(string(senderID)) { return "", "" } membership, err := ev.Membership() diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index df9613850..8e79b71df 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -430,7 +430,7 @@ func (d *DatabaseTransaction) GetStateDeltas( for _, ev := range stateStreamEvents { // Look for our membership in the state events and skip over any // membership events that are not related to us. - membership, prevMembership := getMembershipFromEvent(ev.PDU, userID) + membership, prevMembership := getMembershipFromEvent(ctx, ev.PDU, userID, rsAPI) if membership == "" { continue } @@ -556,7 +556,7 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync( for roomID, stateStreamEvents := range state { for _, ev := range stateStreamEvents { - if membership, _ := getMembershipFromEvent(ev.PDU, userID); membership != "" { + if membership, _ := getMembershipFromEvent(ctx, ev.PDU, userID, rsAPI); membership != "" { if membership != spec.Join { // We've already added full state for all joined rooms above. deltas[roomID] = types.StateDelta{ Membership: membership, diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go index a8b0a7b66..3a5badd92 100644 --- a/syncapi/streams/stream_invite.go +++ b/syncapi/streams/stream_invite.go @@ -70,11 +70,20 @@ func (p *InviteStreamProvider) IncrementalSync( user = *sender } + sk := inviteEvent.StateKey() + if sk != nil && *sk != "" { + skUserID, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey())) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } + // skip ignored user events if _, ok := req.IgnoredUsers.List[user.String()]; ok { continue } - ir := types.NewInviteResponse(inviteEvent, user) + ir := types.NewInviteResponse(inviteEvent, user, sk) req.Response.Rooms.Invite[roomID] = ir } diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index d214980bd..f728d4aea 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -605,13 +605,17 @@ func (p *PDUStreamProvider) lazyLoadMembers( // If this is a gapped incremental sync, we still want this membership isGappedIncremental := limited && incremental // We want this users membership event, keep it in the list - stateKey := *event.StateKey() - if _, ok := timelineUsers[stateKey]; ok || isGappedIncremental || stateKey == device.UserID { + userID := "" + stateKeyUserID, err := p.rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(*event.StateKey())) + if err == nil && stateKeyUserID != nil { + userID = stateKeyUserID.String() + } + if _, ok := timelineUsers[userID]; ok || isGappedIncremental || userID == device.UserID { newStateEvents = append(newStateEvents, event) if !stateFilter.IncludeRedundantMembers { - p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, stateKey, event.EventID()) + p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, userID, event.EventID()) } - delete(timelineUsers, stateKey) + delete(timelineUsers, userID) } } else { newStateEvents = append(newStateEvents, event) diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index ecbe05dd8..64a4af757 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -60,7 +60,7 @@ func AddPublicRoutes( } eduCache := caching.NewTypingCache() - notifier := notifier.NewNotifier() + notifier := notifier.NewNotifier(rsAPI) streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, eduCache, caches, notifier) notifier.SetCurrentPosition(streams.Latest(context.Background())) if err = notifier.Load(context.Background(), syncDB); err != nil { diff --git a/syncapi/synctypes/clientevent.go b/syncapi/synctypes/clientevent.go index 66fb1d01f..358a0c971 100644 --- a/syncapi/synctypes/clientevent.go +++ b/syncapi/synctypes/clientevent.go @@ -55,18 +55,27 @@ func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat, if err == nil && userID != nil { sender = *userID } - evs = append(evs, ToClientEvent(se, format, sender)) + + sk := se.StateKey() + if sk != nil && *sk != "" { + skUserID, err := userIDForSender(se.RoomID(), spec.SenderID(*sk)) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } + evs = append(evs, ToClientEvent(se, format, sender, sk)) } return evs } // ToClientEvent converts a single server event to a client event. -func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID) ClientEvent { +func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID, stateKey *string) ClientEvent { ce := ClientEvent{ Content: spec.RawJSON(se.Content()), Sender: sender.String(), Type: se.Type(), - StateKey: se.StateKey(), + StateKey: stateKey, Unsigned: spec.RawJSON(se.Unsigned()), OriginServerTS: se.OriginServerTS(), EventID: se.EventID(), @@ -77,3 +86,23 @@ func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender sp } return ce } + +// ToClientEvent converts a single server event to a client event. +// It provides default logic for event.SenderID & event.StateKey -> userID conversions. +func ToClientEventDefault(userIDQuery spec.UserIDForSender, event gomatrixserverlib.PDU) ClientEvent { + sender := spec.UserID{} + userID, err := userIDQuery(event.RoomID(), event.SenderID()) + if err == nil && userID != nil { + sender = *userID + } + + sk := event.StateKey() + if sk != nil && *sk != "" { + skUserID, err := userIDQuery(event.RoomID(), spec.SenderID(*event.StateKey())) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } + return ToClientEvent(event, FormatAll, sender, sk) +} diff --git a/syncapi/synctypes/clientevent_test.go b/syncapi/synctypes/clientevent_test.go index 341795081..63c65b2af 100644 --- a/syncapi/synctypes/clientevent_test.go +++ b/syncapi/synctypes/clientevent_test.go @@ -48,7 +48,8 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo if err != nil { t.Fatalf("failed to create userID: %s", err) } - ce := ToClientEvent(ev, FormatAll, *userID) + sk := "" + ce := ToClientEvent(ev, FormatAll, *userID, &sk) if ce.EventID != ev.EventID() { t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID) } @@ -107,7 +108,8 @@ func TestToClientFormatSync(t *testing.T) { if err != nil { t.Fatalf("failed to create userID: %s", err) } - ce := ToClientEvent(ev, FormatSync, *userID) + sk := "" + ce := ToClientEvent(ev, FormatSync, *userID, &sk) if ce.RoomID != "" { t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID) } diff --git a/syncapi/types/types.go b/syncapi/types/types.go index a3dc7f54b..cb3c362d5 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -539,7 +539,7 @@ type InviteResponse struct { } // NewInviteResponse creates an empty response with initialised arrays. -func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID) *InviteResponse { +func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID, stateKey *string) *InviteResponse { res := InviteResponse{} res.InviteState.Events = []json.RawMessage{} @@ -552,7 +552,7 @@ func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID) *InviteRe // Then we'll see if we can create a partial of the invite event itself. // This is needed for clients to work out *who* sent the invite. - inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID) + inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID, stateKey) inviteEvent.Unsigned = nil if ev, err := json.Marshal(inviteEvent); err == nil { res.InviteState.Events = append(res.InviteState.Events, ev) diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go index a79ce5417..c1b7f70bd 100644 --- a/syncapi/types/types_test.go +++ b/syncapi/types/types_test.go @@ -65,8 +65,14 @@ func TestNewInviteResponse(t *testing.T) { if err != nil { t.Fatal(err) } + skUserID, err := spec.NewUserID("@neilalexander:dendrite.neilalexander.dev", true) + if err != nil { + t.Fatal(err) + } + skString := skUserID.String() + sk := &skString - res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender) + res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender, sk) j, err := json.Marshal(res) if err != nil { t.Fatal(err) diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index df507eb26..b2dc477aa 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -306,7 +306,16 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst if queryErr == nil && userID != nil { sender = *userID } - cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, sender) + + sk := event.StateKey() + if sk != nil && *sk != "" { + skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey())) + if queryErr == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } + cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, sender, sk) var member *localMembership member, err = newLocalMembership(&cevent) if err != nil { @@ -539,12 +548,21 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype if err == nil && userID != nil { sender = *userID } + + sk := event.StateKey() + if sk != nil && *sk != "" { + skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey())) + if queryErr == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } n := &api.Notification{ Actions: actions, // UNSPEC: the spec doesn't say this is a ClientEvent, but the // fields seem to match. room_id should be missing, which // matches the behaviour of FormatSync. - Event: synctypes.ToClientEvent(event, synctypes.FormatSync, sender), + Event: synctypes.ToClientEvent(event, synctypes.FormatSync, sender, sk), // TODO: this is per-device, but it's not part of the primary // key. So inserting one notification per profile tag doesn't // make sense. What is this supposed to be? Sytests require it @@ -792,10 +810,20 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes Type: event.Type(), }, } - if mem, err := event.Membership(); err == nil { + if mem, memberErr := event.Membership(); memberErr == nil { req.Notification.Membership = mem } - if event.StateKey() != nil && *event.StateKey() == fmt.Sprintf("@%s:%s", localpart, s.cfg.Matrix.ServerName) { + userID, err := spec.NewUserID(fmt.Sprintf("@%s:%s", localpart, s.cfg.Matrix.ServerName), true) + if err != nil { + logger.WithError(err).Errorf("Failed to convert local user to userID %s", localpart) + return nil, err + } + localSender, err := s.rsAPI.QuerySenderIDForUser(ctx, event.RoomID(), *userID) + if err != nil { + logger.WithError(err).Errorf("Failed to get local user senderID for room %s: %s", userID.String(), event.RoomID()) + return nil, err + } + if event.StateKey() != nil && *event.StateKey() == string(localSender) { req.Notification.UserIsTarget = true } } diff --git a/userapi/util/notify_test.go b/userapi/util/notify_test.go index 27dd373c2..3017069bc 100644 --- a/userapi/util/notify_test.go +++ b/userapi/util/notify_test.go @@ -104,8 +104,9 @@ func TestNotifyUserCountsAsync(t *testing.T) { if err != nil { t.Error(err) } + sk := "" if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{ - Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, *sender), + Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, *sender, &sk), }); err != nil { t.Error(err) }