diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 72e406ee8..8b854121e 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -75,6 +75,12 @@ type RoomserverInternalAPI interface { response *QueryLatestEventsAndStateResponse, ) error + QueryStateAndAuthChainIDs( + ctx context.Context, + request *QueryStateAndAuthChainIDsRequest, + response *QueryStateAndAuthChainIDsResponse, + ) error + // Query the state after a list of events in a room from the room server. QueryStateAfterEvents( ctx context.Context, diff --git a/roomserver/api/query.go b/roomserver/api/query.go index c70db65c1..6064d6b62 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -250,6 +250,28 @@ type QueryStateAndAuthChainResponse struct { AuthChainEvents []*gomatrixserverlib.HeaderedEvent `json:"auth_chain_events"` } +// QueryStateAndAuthChainIDsRequest is a request to QueryStateAndAuthChainIDs +type QueryStateAndAuthChainIDsRequest struct { + // The room ID to query the state in. + RoomID string `json:"room_id"` + // The list of prev events for the event. Used to calculate the state at + // the event. + PrevEventIDs []string `json:"prev_event_ids"` +} + +// QueryStateAndAuthChainIDsResponse is a response to QueryStateAndAuthChainIDs +type QueryStateAndAuthChainIDsResponse struct { + // Does the room exist on this roomserver? + // If the room doesn't exist this will be false and StateEvents will be empty. + RoomExists bool `json:"room_exists"` + // The room version of the room. + RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` + // The state and auth chain event IDs that were requested. + // The lists will be in an arbitrary order. + StateEvents []string `json:"state_event_ids"` + AuthChainEvents []string `json:"auth_chain_event_ids"` +} + // QueryRoomVersionCapabilitiesRequest asks for the default room version type QueryRoomVersionCapabilitiesRequest struct{} diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 4af0e6397..a2bb28481 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -526,6 +526,70 @@ func (r *Queryer) QueryStateAndAuthChain( return err } +// QueryStateAndAuthChain implements api.RoomserverInternalAPI +func (r *Queryer) QueryStateAndAuthChainIDs( + ctx context.Context, + request *api.QueryStateAndAuthChainIDsRequest, + response *api.QueryStateAndAuthChainIDsResponse, +) error { + info, err := r.DB.RoomInfo(ctx, request.RoomID) + if err != nil { + return err + } + if info == nil || info.IsStub { + return nil + } + response.RoomExists = true + response.RoomVersion = info.RoomVersion + + roomState := state.NewStateResolution(r.DB, *info) + prevStates, err := r.DB.StateAtEventIDs(ctx, request.PrevEventIDs) + if err != nil { + return fmt.Errorf("r.DB.StateAtEventIDs: %w", err) + } + + eventNIDs := map[types.EventNID]struct{}{} + for _, prevState := range prevStates { + var entries []types.StateEntry + entries, err = roomState.LoadStateAtSnapshot(ctx, prevState.BeforeStateSnapshotNID) + if err != nil { + continue + } + for _, entry := range entries { + eventNIDs[entry.EventNID] = struct{}{} + } + } + var eventNIDsArray types.EventNIDs + for nid := range eventNIDs { + eventNIDsArray = append(eventNIDsArray, nid) + } + + authEventNIDsArray, err := r.DB.AuthEventNIDs(ctx, eventNIDsArray) + if err != nil { + return fmt.Errorf("r.DB.AuthEventNIDs: %w", err) + } + + stateEventIDs, err := r.DB.EventIDs(ctx, eventNIDsArray) + if err != nil { + return fmt.Errorf("r.DB.EventIDs: %w", err) + } + + authEventIDs, err := r.DB.EventIDs(ctx, authEventNIDsArray) + if err != nil { + return fmt.Errorf("r.DB.EventIDs: %w", err) + } + + for _, eventID := range stateEventIDs { + response.StateEvents = append(response.StateEvents, eventID) + } + + for _, eventID := range authEventIDs { + response.AuthChainEvents = append(response.AuthChainEvents, eventID) + } + + return err +} + func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo types.RoomInfo, eventIDs []string) ([]*gomatrixserverlib.Event, error) { roomState := state.NewStateResolution(r.DB, roomInfo) prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs) diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index c25820aac..913e72940 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -87,6 +87,8 @@ type Database interface { // Lookup the event IDs for a batch of event numeric IDs. // Returns an error if the retrieval went wrong. EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) + // AuthEventNIDs returns the auth event NIDs for the given events. + AuthEventNIDs(ctx context.Context, events []types.EventNID) (types.EventNIDs, error) // Look up the latest events in a room in preparation for an update. // The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error. // Returns the latest events in the room and the last eventID sent to the log along with an updater. diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index 88c82083c..365e29ff7 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -134,6 +134,9 @@ const selectMaxEventDepthSQL = "" + const selectRoomNIDsForEventNIDsSQL = "" + "SELECT event_nid, room_nid FROM roomserver_events WHERE event_nid = ANY($1)" +const bulkSelectEventAuthEventNIDsSQL = "" + + "SELECT auth_event_nids FROM roomserver_events WHERE event_nid = ANY($1)" + type eventStatements struct { insertEventStmt *sql.Stmt selectEventStmt *sql.Stmt @@ -150,6 +153,7 @@ type eventStatements struct { bulkSelectEventNIDStmt *sql.Stmt selectMaxEventDepthStmt *sql.Stmt selectRoomNIDsForEventNIDsStmt *sql.Stmt + bulkSelectEventAuthEventNIDsStmt *sql.Stmt } func createEventsTable(db *sql.DB) error { @@ -176,6 +180,7 @@ func prepareEventsTable(db *sql.DB) (tables.Events, error) { {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, {&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL}, {&s.selectRoomNIDsForEventNIDsStmt, selectRoomNIDsForEventNIDsSQL}, + {&s.bulkSelectEventAuthEventNIDsStmt, bulkSelectEventAuthEventNIDsSQL}, }.Prepare(db) } @@ -502,6 +507,28 @@ func (s *eventStatements) SelectRoomNIDsForEventNIDs( return result, nil } +func (s *eventStatements) SelectEventAuthEventNIDs( + ctx context.Context, eventNIDs []types.EventNID, +) (map[types.EventNID][]types.EventNID, error) { + rows, err := s.bulkSelectEventAuthEventNIDsStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRoomNIDsForEventNIDsStmt: rows.close() failed") + result := make(map[types.EventNID][]types.EventNID) + for rows.Next() { + var eventNID types.EventNID + var authEventNIDs pq.Int64Array + if err = rows.Scan(&authEventNIDs); err != nil { + return nil, err + } + for _, a := range authEventNIDs { + result[eventNID] = append(result[eventNID], types.EventNID(a)) + } + } + return result, nil +} + func eventNIDsAsArray(eventNIDs []types.EventNID) pq.Int64Array { nids := make([]int64, len(eventNIDs)) for i := range eventNIDs { diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 9d9434cbb..2ec473c0b 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -292,6 +292,22 @@ func (d *Database) StateEntries( return lists, nil } +func (d *Database) AuthEventNIDs( + ctx context.Context, events []types.EventNID, +) (types.EventNIDs, error) { + entries, err := d.EventsTable.SelectEventAuthEventNIDs( + ctx, events, + ) + if err != nil { + return nil, fmt.Errorf("d.EventsTable.SelectEventAuthEventNIDs: %w", err) + } + var lists types.EventNIDs + for _, nids := range entries { + lists = append(lists, nids...) + } + return lists[:util.SortAndUnique(lists)], nil +} + func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.RoomAliasesTable.InsertRoomAlias(ctx, txn, alias, roomID, creatorUserID) diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index e964770d7..7eb74e708 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -23,6 +23,7 @@ import ( "sort" "strings" + "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" @@ -104,6 +105,9 @@ const selectMaxEventDepthSQL = "" + const selectRoomNIDsForEventNIDsSQL = "" + "SELECT event_nid, room_nid FROM roomserver_events WHERE event_nid IN ($1)" +const bulkSelectEventAuthEventNIDsSQL = "" + + "SELECT auth_event_nids FROM roomserver_events WHERE event_nid IN ($1)" + type eventStatements struct { db *sql.DB insertEventStmt *sql.Stmt @@ -119,6 +123,7 @@ type eventStatements struct { bulkSelectEventIDStmt *sql.Stmt bulkSelectEventNIDStmt *sql.Stmt //selectRoomNIDsForEventNIDsStmt *sql.Stmt + //bulkSelectEventAuthEventNIDsStmt *sql.Stmt } func createEventsTable(db *sql.DB) error { @@ -145,6 +150,7 @@ func prepareEventsTable(db *sql.DB) (tables.Events, error) { {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, //{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL}, + //{&s.bulkSelectEventAuthEventNIDsStmt, bulkSelectEventAuthEventNIDsSQL}, }.Prepare(db) } @@ -571,6 +577,37 @@ func (s *eventStatements) SelectRoomNIDsForEventNIDs( return result, nil } +func (s *eventStatements) SelectEventAuthEventNIDs( + ctx context.Context, eventNIDs []types.EventNID, +) (map[types.EventNID][]types.EventNID, error) { + sqlStr := strings.Replace(bulkSelectEventAuthEventNIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1) + sqlPrep, err := s.db.Prepare(sqlStr) + if err != nil { + return nil, err + } + iEventNIDs := make([]interface{}, len(eventNIDs)) + for i, v := range eventNIDs { + iEventNIDs[i] = v + } + rows, err := sqlPrep.QueryContext(ctx, iEventNIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventAuthEventNIDsStmt: rows.close() failed") + result := make(map[types.EventNID][]types.EventNID) + for rows.Next() { + var eventNID types.EventNID + var authEventNIDs pq.Int64Array + if err = rows.Scan(&authEventNIDs); err != nil { + return nil, err + } + for _, a := range authEventNIDs { + result[eventNID] = append(result[eventNID], types.EventNID(a)) + } + } + return result, nil +} + func eventNIDsAsArray(eventNIDs []types.EventNID) string { b, _ := json.Marshal(eventNIDs) return string(b) diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 4a893663f..abefa9a5f 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -61,6 +61,7 @@ type Events interface { BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) SelectRoomNIDsForEventNIDs(ctx context.Context, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error) + SelectEventAuthEventNIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID][]types.EventNID, error) } type Rooms interface {