diff --git a/currentstateserver/storage/sqlite3/current_room_state_table.go b/currentstateserver/storage/sqlite3/current_room_state_table.go index 672aaff26..95185d9a8 100644 --- a/currentstateserver/storage/sqlite3/current_room_state_table.go +++ b/currentstateserver/storage/sqlite3/current_room_state_table.go @@ -60,7 +60,14 @@ const selectStateEventSQL = "" + const selectEventsWithEventIDsSQL = "" + "SELECT headered_event_json FROM currentstate_current_room_state WHERE event_id IN ($1)" +const selectBulkStateContentSQL = "" + + "SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id IN ($1) AND type IN ($2) AND state_key IN ($3)" + +const selectBulkStateContentWildSQL = "" + + "SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id IN ($1) AND type IN ($2)" + type currentRoomStateStatements struct { + db *sql.DB upsertRoomStateStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt selectRoomIDsWithMembershipStmt *sql.Stmt @@ -68,7 +75,9 @@ type currentRoomStateStatements struct { } func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { - s := ¤tRoomStateStatements{} + s := ¤tRoomStateStatements{ + db: db, + } _, err := db.Exec(currentRoomStateSchema) if err != nil { return nil, err @@ -196,5 +205,72 @@ func (s *currentRoomStateStatements) SelectStateEvent( func (s *currentRoomStateStatements) SelectBulkStateContent( ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool, ) ([]tables.StrippedEvent, error) { - return nil, nil + hasWildcards := false + eventTypeSet := make(map[string]bool) + stateKeySet := make(map[string]bool) + var eventTypes []string + var stateKeys []string + for _, tuple := range tuples { + if !eventTypeSet[tuple.EventType] { + eventTypeSet[tuple.EventType] = true + eventTypes = append(eventTypes, tuple.EventType) + } + if !stateKeySet[tuple.StateKey] { + stateKeySet[tuple.StateKey] = true + stateKeys = append(stateKeys, tuple.StateKey) + } + if tuple.StateKey == "*" { + hasWildcards = true + } + } + + iRoomIDs := make([]interface{}, len(roomIDs)) + for i, v := range roomIDs { + iRoomIDs[i] = v + } + iEventTypes := make([]interface{}, len(eventTypes)) + for i, v := range eventTypes { + iEventTypes[i] = v + } + iStateKeys := make([]interface{}, len(stateKeys)) + for i, v := range stateKeys { + iStateKeys[i] = v + } + + var query string + var args []interface{} + if hasWildcards && allowWildcards { + query = strings.Replace(selectBulkStateContentWildSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1) + query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(iEventTypes), len(iRoomIDs)), 1) + args = append(iRoomIDs, iEventTypes...) + } else { + query = strings.Replace(selectBulkStateContentSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1) + query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(iEventTypes), len(iRoomIDs)), 1) + query = strings.Replace(query, "($3)", sqlutil.QueryVariadicOffset(len(iStateKeys), len(iEventTypes)+len(iRoomIDs)), 1) + args = append(iRoomIDs, iEventTypes...) + args = append(args, iStateKeys...) + } + rows, err := s.db.QueryContext(ctx, query, args...) + + if err != nil { + return nil, err + } + strippedEvents := []tables.StrippedEvent{} + defer internal.CloseAndLogIfError(ctx, rows, "SelectBulkStateContent: rows.close() failed") + for rows.Next() { + var roomID string + var eventType string + var stateKey string + var contentVal string + if err = rows.Scan(&roomID, &eventType, &stateKey, &contentVal); err != nil { + return nil, err + } + strippedEvents = append(strippedEvents, tables.StrippedEvent{ + RoomID: roomID, + ContentValue: contentVal, + EventType: eventType, + StateKey: stateKey, + }) + } + return strippedEvents, rows.Err() }