From 6769c889030d32bd632714f3ea9acc0a75733a5a Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Thu, 3 Sep 2020 14:23:27 +0100 Subject: [PATCH] Implement some current-state-server storage interface functions --- roomserver/internal/api.go | 2 +- roomserver/internal/query/query.go | 2 +- roomserver/storage/interface.go | 2 - .../storage/postgres/membership_table.go | 24 ++++++++ roomserver/storage/postgres/rooms_table.go | 48 +++++++++++++++ roomserver/storage/shared/storage.go | 59 ++++++++++++++++--- .../storage/sqlite3/membership_table.go | 24 ++++++++ roomserver/storage/sqlite3/rooms_table.go | 49 +++++++++++++++ roomserver/storage/tables/interface.go | 3 + 9 files changed, 200 insertions(+), 13 deletions(-) diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 83daf3615..bdea650ea 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -4,10 +4,10 @@ import ( "context" "github.com/Shopify/sarama" - "github.com/matrix-org/dendrite/currentstateserver/acls" fsAPI "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/roomserver/acls" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/input" "github.com/matrix-org/dendrite/roomserver/internal/perform" diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index ff55228d4..f76c93166 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -20,8 +20,8 @@ import ( "fmt" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/currentstateserver/acls" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/roomserver/acls" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/state" diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index faae3f4bf..c4119f7ed 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -151,8 +151,6 @@ type Database interface { // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. // If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) - // Redact a state event - RedactEvent(ctx context.Context, redactedEventID string, redactedBecause gomatrixserverlib.HeaderedEvent) error // JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) // GetKnownUsers searches all users that userID knows about. diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index 13cef638f..0799647e9 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -99,6 +99,9 @@ const updateMembershipSQL = "" + "UPDATE roomserver_membership SET sender_nid = $3, membership_nid = $4, event_nid = $5" + " WHERE room_nid = $1 AND target_nid = $2" +const selectRoomsWithMembershipSQL = "" + + "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2" + type membershipStatements struct { insertMembershipStmt *sql.Stmt selectMembershipForUpdateStmt *sql.Stmt @@ -108,6 +111,7 @@ type membershipStatements struct { selectMembershipsFromRoomStmt *sql.Stmt selectLocalMembershipsFromRoomStmt *sql.Stmt updateMembershipStmt *sql.Stmt + selectRoomsWithMembershipStmt *sql.Stmt } func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) { @@ -126,6 +130,7 @@ func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL}, {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL}, {&s.updateMembershipStmt, updateMembershipSQL}, + {&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL}, }.Prepare(db) } @@ -222,3 +227,22 @@ func (s *membershipStatements) UpdateMembership( ) return err } + +func (s *membershipStatements) SelectRoomsWithMembership( + ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState, +) ([]types.RoomNID, error) { + rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomsWithMembership: rows.close() failed") + var roomNIDs []types.RoomNID + for rows.Next() { + var roomNID types.RoomNID + if err := rows.Scan(&roomNID); err != nil { + return nil, err + } + roomNIDs = append(roomNIDs, roomNID) + } + return roomNIDs, nil +} diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index 13c8e703d..9d359146a 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -21,6 +21,7 @@ import ( "errors" "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" @@ -74,6 +75,12 @@ const selectRoomVersionForRoomNIDSQL = "" + const selectRoomInfoSQL = "" + "SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1" +const selectRoomIDsSQL = "" + + "SELECT room_id FROM roomserver_rooms" + +const bulkSelectRoomIDsSQL = "" + + "SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)" + type roomStatements struct { insertRoomNIDStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt @@ -82,6 +89,8 @@ type roomStatements struct { updateLatestEventNIDsStmt *sql.Stmt selectRoomVersionForRoomNIDStmt *sql.Stmt selectRoomInfoStmt *sql.Stmt + selectRoomIDsStmt *sql.Stmt + bulkSelectRoomIDsStmt *sql.Stmt } func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) { @@ -98,9 +107,27 @@ func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) { {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, {&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL}, {&s.selectRoomInfoStmt, selectRoomInfoSQL}, + {&s.selectRoomIDsStmt, selectRoomIDsSQL}, + {&s.bulkSelectRoomIDsStmt, bulkSelectRoomIDsSQL}, }.Prepare(db) } +func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) { + rows, err := s.selectRoomIDsStmt.QueryContext(ctx) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed") + var roomIDs []string + for rows.Next() { + var roomID string + if err = rows.Scan(&roomID); err != nil { + return nil, err + } + roomIDs = append(roomIDs, roomID) + } + return roomIDs, nil +} func (s *roomStatements) InsertRoomNID( ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion, @@ -197,3 +224,24 @@ func (s *roomStatements) SelectRoomVersionForRoomNID( } return roomVersion, err } + +func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) { + var array pq.Int64Array + for _, nid := range roomNIDs { + array = append(array, int64(nid)) + } + rows, err := s.bulkSelectRoomIDsStmt.QueryContext(ctx, array) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed") + var roomIDs []string + for rows.Next() { + var roomID string + if err = rows.Scan(&roomID); err != nil { + return nil, err + } + roomIDs = append(roomIDs, roomID) + } + return roomIDs, nil +} diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index bb7938116..5c447d66f 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -713,16 +713,62 @@ func (d *Database) loadEvent(ctx context.Context, eventID string) *types.Event { return &evs[0] } -// GetStateEvent returns the state event of a given type for a given room with a given state key +// GetStateEvent returns the current state event of a given type for a given room with a given state key // If no event could be found, returns nil // If there was an issue during the retrieval, returns an error func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) { - return nil, fmt.Errorf("not implemented yet") + /* + roomInfo, err := d.RoomInfo(ctx, roomID) + if err != nil { + return nil, err + } + eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType) + if err != nil { + return nil, err + } + stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, stateKey) + if err != nil { + return nil, err + } + blockNIDs, err := d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, []types.StateSnapshotNID{roomInfo.StateSnapshotNID}) + if err != nil { + return nil, err + } + */ + return nil, nil } // GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) { - return nil, fmt.Errorf("not implemented yet") + var membershipState tables.MembershipState + switch membership { + case "join": + membershipState = tables.MembershipStateJoin + case "invite": + membershipState = tables.MembershipStateInvite + case "leave": + membershipState = tables.MembershipStateLeaveOrBan + case "ban": + membershipState = tables.MembershipStateLeaveOrBan + default: + return nil, fmt.Errorf("GetRoomsByMembership: invalid membership %s", membership) + } + stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID) + if err != nil { + return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err) + } + roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, stateKeyNID, membershipState) + if err != nil { + return nil, err + } + roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, roomNIDs) + if err != nil { + return nil, err + } + if len(roomIDs) != len(roomNIDs) { + return nil, fmt.Errorf("GetRoomsByMembership: missing room IDs, got %d want %d", len(roomIDs), len(roomNIDs)) + } + return roomIDs, nil } // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. @@ -731,11 +777,6 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu return nil, fmt.Errorf("not implemented yet") } -// Redact a state event -func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, redactedBecause gomatrixserverlib.HeaderedEvent) error { - return nil -} - // JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) { return nil, fmt.Errorf("not implemented yet") @@ -748,5 +789,5 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin // GetKnownRooms returns a list of all rooms we know about. func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { - return nil, nil + return d.RoomsTable.SelectRoomIDs(ctx) } diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index b3ee69c00..e850c80bb 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -75,6 +75,9 @@ const updateMembershipSQL = "" + "UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3" + " WHERE room_nid = $4 AND target_nid = $5" +const selectRoomsWithMembershipSQL = "" + + "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2" + type membershipStatements struct { db *sql.DB insertMembershipStmt *sql.Stmt @@ -84,6 +87,7 @@ type membershipStatements struct { selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt selectMembershipsFromRoomStmt *sql.Stmt selectLocalMembershipsFromRoomStmt *sql.Stmt + selectRoomsWithMembershipStmt *sql.Stmt updateMembershipStmt *sql.Stmt } @@ -105,6 +109,7 @@ func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL}, {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL}, {&s.updateMembershipStmt, updateMembershipSQL}, + {&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL}, }.Prepare(db) } @@ -203,3 +208,22 @@ func (s *membershipStatements) UpdateMembership( ) return err } + +func (s *membershipStatements) SelectRoomsWithMembership( + ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState, +) ([]types.RoomNID, error) { + rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomsWithMembership: rows.close() failed") + var roomNIDs []types.RoomNID + for rows.Next() { + var roomNID types.RoomNID + if err := rows.Scan(&roomNID); err != nil { + return nil, err + } + roomNIDs = append(roomNIDs, roomNID) + } + return roomNIDs, nil +} diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index 4c1699d00..daacf86fa 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -21,7 +21,9 @@ import ( "encoding/json" "errors" "fmt" + "strings" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" @@ -64,6 +66,12 @@ const selectRoomVersionForRoomNIDSQL = "" + const selectRoomInfoSQL = "" + "SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1" +const selectRoomIDsSQL = "" + + "SELECT room_id FROM roomserver_rooms" + +const bulkSelectRoomIDsSQL = "" + + "SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)" + type roomStatements struct { db *sql.DB insertRoomNIDStmt *sql.Stmt @@ -73,6 +81,7 @@ type roomStatements struct { updateLatestEventNIDsStmt *sql.Stmt selectRoomVersionForRoomNIDStmt *sql.Stmt selectRoomInfoStmt *sql.Stmt + selectRoomIDsStmt *sql.Stmt } func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) { @@ -91,9 +100,27 @@ func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) { {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, {&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL}, {&s.selectRoomInfoStmt, selectRoomInfoSQL}, + {&s.selectRoomIDsStmt, selectRoomIDsSQL}, }.Prepare(db) } +func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) { + rows, err := s.selectRoomIDsStmt.QueryContext(ctx) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed") + var roomIDs []string + for rows.Next() { + var roomID string + if err = rows.Scan(&roomID); err != nil { + return nil, err + } + roomIDs = append(roomIDs, roomID) + } + return roomIDs, nil +} + func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { var info types.RoomInfo var latestNIDsJSON string @@ -203,3 +230,25 @@ func (s *roomStatements) SelectRoomVersionForRoomNID( } return roomVersion, err } + +func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) { + iRoomNIDs := make([]interface{}, len(roomNIDs)) + for i, v := range roomNIDs { + iRoomNIDs[i] = v + } + sqlQuery := strings.Replace(bulkSelectRoomIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) + rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed") + var roomIDs []string + for rows.Next() { + var roomID string + if err = rows.Scan(&roomID); err != nil { + return nil, err + } + roomIDs = append(roomIDs, roomID) + } + return roomIDs, nil +} diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index c599dd3fe..126c27b57 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -65,6 +65,8 @@ type Rooms interface { UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error SelectRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) + SelectRoomIDs(ctx context.Context) ([]string, error) + BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) } type Transactions interface { @@ -120,6 +122,7 @@ type Membership interface { SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error) SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID) error + SelectRoomsWithMembership(ctx context.Context, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) } type Published interface {