From 229257a052dd1f5f3cf53d420b74ac40668a2313 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 20 Feb 2020 12:11:46 +0000 Subject: [PATCH] Maybe getMembershipsBeforeEventNID and checkServerAllowedToSeeEvent will work now? --- roomserver/query/query.go | 35 +++++++++++++++++---- roomserver/storage/postgres/events_table.go | 26 +++++++++++++++ roomserver/storage/postgres/storage.go | 21 +++++++++++++ roomserver/storage/sqlite3/events_table.go | 26 +++++++++++++++ roomserver/storage/sqlite3/storage.go | 20 ++++++++++++ roomserver/storage/storage.go | 2 ++ 6 files changed, 124 insertions(+), 6 deletions(-) diff --git a/roomserver/query/query.go b/roomserver/query/query.go index c1c27a306..f0c167765 100644 --- a/roomserver/query/query.go +++ b/roomserver/query/query.go @@ -94,6 +94,14 @@ type RoomserverQueryAPIDatabase interface { GetRoomVersionForRoom( ctx context.Context, roomNID types.RoomNID, ) (state.StateResolutionVersion, error) + // Get the room NID for a given event ID. + RoomNIDForEventID( + ctx context.Context, eventID string, + ) (types.RoomNID, error) + // Get the room NID for a given event NID. + RoomNIDForEventNID( + ctx context.Context, eventNID types.EventNID, + ) (types.RoomNID, error) } // RoomserverQueryAPI is an implementation of api.RoomserverQueryAPI @@ -115,6 +123,7 @@ func (r *RoomserverQueryAPI) QueryLatestEventsAndState( if roomNID == 0 { return nil } + response.RoomExists = true roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, roomNID) if err != nil { return err @@ -123,7 +132,6 @@ func (r *RoomserverQueryAPI) QueryLatestEventsAndState( if err != nil { return err } - response.RoomExists = true var currentStateSnapshotNID types.StateSnapshotNID response.LatestEvents, currentStateSnapshotNID, response.Depth, err = r.DB.LatestEventIDs(ctx, roomNID) @@ -339,11 +347,19 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom( func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID( ctx context.Context, eventNID types.EventNID, joinedOnly bool, ) ([]types.Event, error) { - // TODO: get the correct room version - roomState, err := state.GetStateResolutionAlgorithm(state.StateResolutionAlgorithmV1, r.DB) + roomNID, err := r.DB.RoomNIDForEventNID(ctx, eventNID) if err != nil { - return []types.Event{}, err + return nil, err } + roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, roomNID) + if err != nil { + return nil, err + } + roomState, err := state.GetStateResolutionAlgorithm(roomVersion, r.DB) + if err != nil { + return nil, err + } + events := []types.Event{} // Lookup the event NID eIDs, err := r.DB.EventIDs(ctx, []types.EventNID{eventNID}) @@ -445,8 +461,15 @@ func (r *RoomserverQueryAPI) QueryServerAllowedToSeeEvent( func (r *RoomserverQueryAPI) checkServerAllowedToSeeEvent( ctx context.Context, eventID string, serverName gomatrixserverlib.ServerName, ) (bool, error) { - // TODO: get the correct room version - roomState, err := state.GetStateResolutionAlgorithm(state.StateResolutionAlgorithmV1, r.DB) + roomNID, err := r.DB.RoomNIDForEventID(ctx, eventID) + if err != nil { + return false, err + } + roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, roomNID) + if err != nil { + return false, err + } + roomState, err := state.GetStateResolutionAlgorithm(roomVersion, r.DB) if err != nil { return false, err } diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index d9b269bc8..acd66d83d 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -100,6 +100,12 @@ const updateEventSentToOutputSQL = "" + const selectEventIDSQL = "" + "SELECT event_id FROM roomserver_events WHERE event_nid = $1" +const selectRoomNIDForEventIDSQL = "" + + "SELECT room_nid FROM roomserver_events WHERE event_id = $1" + +const selectRoomNIDForEventNIDSQL = "" + + "SELECT room_nid FROM roomserver_events WHERE event_nid = $1" + const bulkSelectStateAtEventAndReferenceSQL = "" + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id, reference_sha256" + " FROM roomserver_events WHERE event_nid = ANY($1)" @@ -125,6 +131,8 @@ type eventStatements struct { selectEventSentToOutputStmt *sql.Stmt updateEventSentToOutputStmt *sql.Stmt selectEventIDStmt *sql.Stmt + selectRoomNIDForEventIDStmt *sql.Stmt + selectRoomNIDForEventNIDStmt *sql.Stmt bulkSelectStateAtEventAndReferenceStmt *sql.Stmt bulkSelectEventReferenceStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt @@ -147,6 +155,8 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) { {&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL}, {&s.selectEventSentToOutputStmt, selectEventSentToOutputSQL}, {&s.selectEventIDStmt, selectEventIDSQL}, + {&s.selectRoomNIDForEventIDStmt, selectRoomNIDForEventIDSQL}, + {&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL}, {&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL}, {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, @@ -294,6 +304,22 @@ func (s *eventStatements) selectEventID( return } +func (s *eventStatements) selectRoomNIDForEventID( + ctx context.Context, txn *sql.Tx, eventID string, +) (roomNID types.RoomNID, err error) { + stmt := common.TxStmt(txn, s.selectRoomNIDForEventIDStmt) + err = stmt.QueryRowContext(ctx, eventID).Scan(&roomNID) + return +} + +func (s *eventStatements) selectRoomNIDForEventNID( + ctx context.Context, txn *sql.Tx, eventNID types.EventNID, +) (roomNID types.RoomNID, err error) { + stmt := common.TxStmt(txn, s.selectRoomNIDForEventNIDStmt) + err = stmt.QueryRowContext(ctx, eventNID).Scan(&roomNID) + return +} + func (s *eventStatements) bulkSelectStateAtEventAndReference( ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]types.StateAtEventAndReference, error) { diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 6e6468ac8..785c069c2 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -21,6 +21,7 @@ import ( // Import the postgres database driver. _ "github.com/lib/pq" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/types" @@ -414,6 +415,26 @@ func (d *Database) RoomNID(ctx context.Context, roomID string) (types.RoomNID, e return roomNID, err } +func (d *Database) RoomNIDForEventID( + ctx context.Context, eventID string, +) (out types.RoomNID, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + out, err = d.statements.selectRoomNIDForEventID(ctx, txn, eventID) + return err + }) + return +} + +func (d *Database) RoomNIDForEventNID( + ctx context.Context, eventNID types.EventNID, +) (out types.RoomNID, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + out, err = d.statements.selectRoomNIDForEventNID(ctx, txn, eventNID) + return err + }) + return +} + // LatestEventIDs implements query.RoomserverQueryAPIDatabase func (d *Database) LatestEventIDs( ctx context.Context, roomNID types.RoomNID, diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 4ed1395da..39db84b7e 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -80,6 +80,12 @@ const updateEventSentToOutputSQL = "" + const selectEventIDSQL = "" + "SELECT event_id FROM roomserver_events WHERE event_nid = $1" +const selectRoomNIDForEventIDSQL = "" + + "SELECT room_nid FROM roomserver_events WHERE event_id = $1" + +const selectRoomNIDForEventNIDSQL = "" + + "SELECT room_nid FROM roomserver_events WHERE event_nid = $1" + const bulkSelectStateAtEventAndReferenceSQL = "" + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id, reference_sha256" + " FROM roomserver_events WHERE event_nid IN ($1)" @@ -107,6 +113,8 @@ type eventStatements struct { selectEventSentToOutputStmt *sql.Stmt updateEventSentToOutputStmt *sql.Stmt selectEventIDStmt *sql.Stmt + selectRoomNIDForEventIDStmt *sql.Stmt + selectRoomNIDForEventNIDStmt *sql.Stmt bulkSelectStateAtEventAndReferenceStmt *sql.Stmt bulkSelectEventReferenceStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt @@ -131,6 +139,8 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) { {&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL}, {&s.selectEventSentToOutputStmt, selectEventSentToOutputSQL}, {&s.selectEventIDStmt, selectEventIDSQL}, + {&s.selectRoomNIDForEventIDStmt, selectRoomNIDForEventIDSQL}, + {&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL}, {&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL}, {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, @@ -310,6 +320,22 @@ func (s *eventStatements) selectEventID( return } +func (s *eventStatements) selectRoomNIDForEventID( + ctx context.Context, txn *sql.Tx, eventID string, +) (roomNID types.RoomNID, err error) { + stmt := common.TxStmt(txn, s.selectRoomNIDForEventIDStmt) + err = stmt.QueryRowContext(ctx, eventID).Scan(&roomNID) + return +} + +func (s *eventStatements) selectRoomNIDForEventNID( + ctx context.Context, txn *sql.Tx, eventNID types.EventNID, +) (roomNID types.RoomNID, err error) { + stmt := common.TxStmt(txn, s.selectRoomNIDForEventNIDStmt) + err = stmt.QueryRowContext(ctx, eventNID).Scan(&roomNID) + return +} + func (s *eventStatements) bulkSelectStateAtEventAndReference( ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]types.StateAtEventAndReference, error) { diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index bb1e4aaae..e05cecb14 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -533,6 +533,26 @@ func (d *Database) RoomNID(ctx context.Context, roomID string) (roomNID types.Ro return } +func (d *Database) RoomNIDForEventID( + ctx context.Context, eventID string, +) (out types.RoomNID, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + out, err = d.statements.selectRoomNIDForEventID(ctx, txn, eventID) + return err + }) + return +} + +func (d *Database) RoomNIDForEventNID( + ctx context.Context, eventNID types.EventNID, +) (out types.RoomNID, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + out, err = d.statements.selectRoomNIDForEventNID(ctx, txn, eventNID) + return err + }) + return +} + // LatestEventIDs implements query.RoomserverQueryAPIDatabase func (d *Database) LatestEventIDs( ctx context.Context, roomNID types.RoomNID, diff --git a/roomserver/storage/storage.go b/roomserver/storage/storage.go index a22f4bfda..7b77ebe99 100644 --- a/roomserver/storage/storage.go +++ b/roomserver/storage/storage.go @@ -38,6 +38,8 @@ type Database interface { GetLatestEventsForUpdate(ctx context.Context, roomNID types.RoomNID) (types.RoomRecentEventsUpdater, error) GetTransactionEventID(ctx context.Context, transactionID string, sessionID int64, userID string) (string, error) RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) + RoomNIDForEventID(ctx context.Context, eventID string) (types.RoomNID, error) + RoomNIDForEventNID(ctx context.Context, eventNID types.EventNID) (types.RoomNID, error) LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) GetInvitesForUser(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (senderUserIDs []types.EventStateKeyNID, err error) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error