From dca4af75f10935c08f644ed7ffbde9d0c6701d98 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Sun, 3 Mar 2024 18:51:51 +0100 Subject: [PATCH] Add SelectRoomsWithEventTypeNID and RoomsWithACLs --- roomserver/storage/interface.go | 3 ++ roomserver/storage/postgres/events_table.go | 26 ++++++++++++++ roomserver/storage/shared/storage.go | 20 +++++++++++ roomserver/storage/sqlite3/events_table.go | 26 ++++++++++++++ .../storage/tables/events_table_test.go | 36 +++++++++++++++++++ roomserver/storage/tables/interface.go | 2 ++ 6 files changed, 113 insertions(+) diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 0638252b2..67479587a 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -193,6 +193,9 @@ type Database interface { MaybeRedactEvent( ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, querier api.QuerySenderIDAPI, ) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error) + + // RoomsWithACLs returns all room IDs for rooms with ACLs + RoomsWithACLs(ctx context.Context) ([]string, error) } type UserRoomKeys interface { diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index 1c9cd1599..e5e956259 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -147,6 +147,8 @@ const selectRoomNIDsForEventNIDsSQL = "" + const selectEventRejectedSQL = "" + "SELECT is_rejected FROM roomserver_events WHERE room_nid = $1 AND event_id = $2" +const selectRoomsWithACLsSQL = `select distinct room_nid from roomserver_events where event_type_nid = $1` + type eventStatements struct { insertEventStmt *sql.Stmt selectEventStmt *sql.Stmt @@ -166,6 +168,7 @@ type eventStatements struct { selectMaxEventDepthStmt *sql.Stmt selectRoomNIDsForEventNIDsStmt *sql.Stmt selectEventRejectedStmt *sql.Stmt + selectRoomsWithACLsStmt *sql.Stmt } func CreateEventsTable(db *sql.DB) error { @@ -206,6 +209,7 @@ func PrepareEventsTable(db *sql.DB) (tables.Events, error) { {&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL}, {&s.selectRoomNIDsForEventNIDsStmt, selectRoomNIDsForEventNIDsSQL}, {&s.selectEventRejectedStmt, selectEventRejectedSQL}, + {&s.selectRoomsWithACLsStmt, selectRoomsWithACLsSQL}, }.Prepare(db) } @@ -582,3 +586,25 @@ func (s *eventStatements) SelectEventRejected( err = stmt.QueryRowContext(ctx, roomNID, eventID).Scan(&rejected) return } + +func (s *eventStatements) SelectRoomsWithEventTypeNID( + ctx context.Context, txn *sql.Tx, eventTypeNID types.EventTypeNID, +) ([]types.RoomNID, error) { + stmt := sqlutil.TxStmt(txn, s.selectRoomsWithACLsStmt) + rows, err := stmt.QueryContext(ctx, eventTypeNID) + defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomsWithEventTypeNID: rows.close() failed") + if err != nil { + return nil, err + } + + var roomNIDs []types.RoomNID + var roomNID types.RoomNID + for rows.Next() { + if err := rows.Scan(&roomNID); err != nil { + return nil, err + } + roomNIDs = append(roomNIDs, roomNID) + } + + return roomNIDs, rows.Err() +} diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 682cead6c..217300bb0 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -1630,6 +1630,26 @@ func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil) } +func (d *Database) RoomsWithACLs(ctx context.Context) ([]string, error) { + + eventTypeNID, err := d.GetOrCreateEventTypeNID(ctx, "m.room.server_acl") + if err != nil { + return nil, err + } + + roomNIDs, err := d.EventsTable.SelectRoomsWithEventTypeNID(ctx, nil, eventTypeNID) + if err != nil { + return nil, err + } + + roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, nil, roomNIDs) + if err != nil { + return nil, err + } + + return roomIDs, nil +} + // ForgetRoom sets a users room to forgotten func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error { roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, []string{roomID}) diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 2c269bced..228f18983 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -120,6 +120,8 @@ const selectRoomNIDsForEventNIDsSQL = "" + const selectEventRejectedSQL = "" + "SELECT is_rejected FROM roomserver_events WHERE room_nid = $1 AND event_id = $2" +const selectRoomsWithACLsSQL = `select distinct room_nid from roomserver_events where event_type_nid = $1` + type eventStatements struct { db *sql.DB insertEventStmt *sql.Stmt @@ -135,6 +137,7 @@ type eventStatements struct { bulkSelectStateAtEventAndReferenceStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt selectEventRejectedStmt *sql.Stmt + selectRoomsWithACLsStmt *sql.Stmt //bulkSelectEventNIDStmt *sql.Stmt //bulkSelectUnsentEventNIDStmt *sql.Stmt //selectRoomNIDsForEventNIDsStmt *sql.Stmt @@ -192,6 +195,7 @@ func PrepareEventsTable(db *sql.DB) (tables.Events, error) { //{&s.bulkSelectUnsentEventNIDStmt, bulkSelectUnsentEventNIDSQL}, //{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL}, {&s.selectEventRejectedStmt, selectEventRejectedSQL}, + {&s.selectRoomsWithACLsStmt, selectRoomsWithACLsSQL}, }.Prepare(db) } @@ -682,3 +686,25 @@ func (s *eventStatements) SelectEventRejected( err = stmt.QueryRowContext(ctx, roomNID, eventID).Scan(&rejected) return } + +func (s *eventStatements) SelectRoomsWithEventTypeNID( + ctx context.Context, txn *sql.Tx, eventTypeNID types.EventTypeNID, +) ([]types.RoomNID, error) { + stmt := sqlutil.TxStmt(txn, s.selectRoomsWithACLsStmt) + rows, err := stmt.QueryContext(ctx, eventTypeNID) + defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomsWithEventTypeNID: rows.close() failed") + if err != nil { + return nil, err + } + + var roomNIDs []types.RoomNID + var roomNID types.RoomNID + for rows.Next() { + if err := rows.Scan(&roomNID); err != nil { + return nil, err + } + roomNIDs = append(roomNIDs, roomNID) + } + + return roomNIDs, rows.Err() +} diff --git a/roomserver/storage/tables/events_table_test.go b/roomserver/storage/tables/events_table_test.go index 5ed805648..52aeacc2f 100644 --- a/roomserver/storage/tables/events_table_test.go +++ b/roomserver/storage/tables/events_table_test.go @@ -2,6 +2,7 @@ package tables_test import ( "context" + "fmt" "testing" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -147,3 +148,38 @@ func Test_EventsTable(t *testing.T) { assert.Equal(t, int64(len(room.Events())+1), maxDepth) }) } + +func TestRoomsWithACL(t *testing.T) { + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + eventStateKeys, closeEventStateKeys := mustCreateEventTypesTable(t, dbType) + defer closeEventStateKeys() + + eventsTable, closeEventsTable := mustCreateEventsTable(t, dbType) + defer closeEventsTable() + + ctx := context.Background() + + // insert the m.room.server_acl event type + eventTypeNID, err := eventStateKeys.InsertEventTypeNID(ctx, nil, "m.room.server_acl") + assert.Nil(t, err) + + // Create ACL'd rooms + var wantRoomNIDs []types.RoomNID + for i := 0; i < 10; i++ { + _, _, err = eventsTable.InsertEvent(ctx, nil, types.RoomNID(i), eventTypeNID, types.EmptyStateKeyNID, fmt.Sprintf("$1337+%d", i), nil, 0, false) + assert.Nil(t, err) + wantRoomNIDs = append(wantRoomNIDs, types.RoomNID(i)) + } + + // Create non-ACL'd rooms (eventTypeNID+1) + for i := 10; i < 20; i++ { + _, _, err = eventsTable.InsertEvent(ctx, nil, types.RoomNID(i), eventTypeNID+1, types.EmptyStateKeyNID, fmt.Sprintf("$1337+%d", i), nil, 0, false) + assert.Nil(t, err) + } + + gotRoomNIDs, err := eventsTable.SelectRoomsWithEventTypeNID(ctx, nil, eventTypeNID) + assert.Nil(t, err) + assert.Equal(t, wantRoomNIDs, gotRoomNIDs) + }) +} diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index b3cb31880..0013ddbd2 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -69,6 +69,8 @@ type Events interface { SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) SelectRoomNIDsForEventNIDs(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error) SelectEventRejected(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventID string) (rejected bool, err error) + + SelectRoomsWithEventTypeNID(ctx context.Context, txn *sql.Tx, eventTypeNID types.EventTypeNID) ([]types.RoomNID, error) } type Rooms interface {