Maybe getMembershipsBeforeEventNID and checkServerAllowedToSeeEvent will work now?

This commit is contained in:
Neil Alexander 2020-02-20 12:11:46 +00:00
parent f8e4d5bcb0
commit 229257a052
6 changed files with 124 additions and 6 deletions

View file

@ -94,6 +94,14 @@ type RoomserverQueryAPIDatabase interface {
GetRoomVersionForRoom( GetRoomVersionForRoom(
ctx context.Context, roomNID types.RoomNID, ctx context.Context, roomNID types.RoomNID,
) (state.StateResolutionVersion, error) ) (state.StateResolutionVersion, error)
// Get the room NID for a given event ID.
RoomNIDForEventID(
ctx context.Context, eventID string,
) (types.RoomNID, error)
// Get the room NID for a given event NID.
RoomNIDForEventNID(
ctx context.Context, eventNID types.EventNID,
) (types.RoomNID, error)
} }
// RoomserverQueryAPI is an implementation of api.RoomserverQueryAPI // RoomserverQueryAPI is an implementation of api.RoomserverQueryAPI
@ -115,6 +123,7 @@ func (r *RoomserverQueryAPI) QueryLatestEventsAndState(
if roomNID == 0 { if roomNID == 0 {
return nil return nil
} }
response.RoomExists = true
roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, roomNID) roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, roomNID)
if err != nil { if err != nil {
return err return err
@ -123,7 +132,6 @@ func (r *RoomserverQueryAPI) QueryLatestEventsAndState(
if err != nil { if err != nil {
return err return err
} }
response.RoomExists = true
var currentStateSnapshotNID types.StateSnapshotNID var currentStateSnapshotNID types.StateSnapshotNID
response.LatestEvents, currentStateSnapshotNID, response.Depth, err = response.LatestEvents, currentStateSnapshotNID, response.Depth, err =
r.DB.LatestEventIDs(ctx, roomNID) r.DB.LatestEventIDs(ctx, roomNID)
@ -339,11 +347,19 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom(
func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID( func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID(
ctx context.Context, eventNID types.EventNID, joinedOnly bool, ctx context.Context, eventNID types.EventNID, joinedOnly bool,
) ([]types.Event, error) { ) ([]types.Event, error) {
// TODO: get the correct room version roomNID, err := r.DB.RoomNIDForEventNID(ctx, eventNID)
roomState, err := state.GetStateResolutionAlgorithm(state.StateResolutionAlgorithmV1, r.DB)
if err != nil { if err != nil {
return []types.Event{}, err return nil, err
} }
roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, roomNID)
if err != nil {
return nil, err
}
roomState, err := state.GetStateResolutionAlgorithm(roomVersion, r.DB)
if err != nil {
return nil, err
}
events := []types.Event{} events := []types.Event{}
// Lookup the event NID // Lookup the event NID
eIDs, err := r.DB.EventIDs(ctx, []types.EventNID{eventNID}) eIDs, err := r.DB.EventIDs(ctx, []types.EventNID{eventNID})
@ -445,8 +461,15 @@ func (r *RoomserverQueryAPI) QueryServerAllowedToSeeEvent(
func (r *RoomserverQueryAPI) checkServerAllowedToSeeEvent( func (r *RoomserverQueryAPI) checkServerAllowedToSeeEvent(
ctx context.Context, eventID string, serverName gomatrixserverlib.ServerName, ctx context.Context, eventID string, serverName gomatrixserverlib.ServerName,
) (bool, error) { ) (bool, error) {
// TODO: get the correct room version roomNID, err := r.DB.RoomNIDForEventID(ctx, eventID)
roomState, err := state.GetStateResolutionAlgorithm(state.StateResolutionAlgorithmV1, r.DB) if err != nil {
return false, err
}
roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, roomNID)
if err != nil {
return false, err
}
roomState, err := state.GetStateResolutionAlgorithm(roomVersion, r.DB)
if err != nil { if err != nil {
return false, err return false, err
} }

View file

@ -100,6 +100,12 @@ const updateEventSentToOutputSQL = "" +
const selectEventIDSQL = "" + const selectEventIDSQL = "" +
"SELECT event_id FROM roomserver_events WHERE event_nid = $1" "SELECT event_id FROM roomserver_events WHERE event_nid = $1"
const selectRoomNIDForEventIDSQL = "" +
"SELECT room_nid FROM roomserver_events WHERE event_id = $1"
const selectRoomNIDForEventNIDSQL = "" +
"SELECT room_nid FROM roomserver_events WHERE event_nid = $1"
const bulkSelectStateAtEventAndReferenceSQL = "" + const bulkSelectStateAtEventAndReferenceSQL = "" +
"SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id, reference_sha256" + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id, reference_sha256" +
" FROM roomserver_events WHERE event_nid = ANY($1)" " FROM roomserver_events WHERE event_nid = ANY($1)"
@ -125,6 +131,8 @@ type eventStatements struct {
selectEventSentToOutputStmt *sql.Stmt selectEventSentToOutputStmt *sql.Stmt
updateEventSentToOutputStmt *sql.Stmt updateEventSentToOutputStmt *sql.Stmt
selectEventIDStmt *sql.Stmt selectEventIDStmt *sql.Stmt
selectRoomNIDForEventIDStmt *sql.Stmt
selectRoomNIDForEventNIDStmt *sql.Stmt
bulkSelectStateAtEventAndReferenceStmt *sql.Stmt bulkSelectStateAtEventAndReferenceStmt *sql.Stmt
bulkSelectEventReferenceStmt *sql.Stmt bulkSelectEventReferenceStmt *sql.Stmt
bulkSelectEventIDStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt
@ -147,6 +155,8 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) {
{&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL}, {&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL},
{&s.selectEventSentToOutputStmt, selectEventSentToOutputSQL}, {&s.selectEventSentToOutputStmt, selectEventSentToOutputSQL},
{&s.selectEventIDStmt, selectEventIDSQL}, {&s.selectEventIDStmt, selectEventIDSQL},
{&s.selectRoomNIDForEventIDStmt, selectRoomNIDForEventIDSQL},
{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL},
{&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL}, {&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL},
{&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL},
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
@ -294,6 +304,22 @@ func (s *eventStatements) selectEventID(
return return
} }
func (s *eventStatements) selectRoomNIDForEventID(
ctx context.Context, txn *sql.Tx, eventID string,
) (roomNID types.RoomNID, err error) {
stmt := common.TxStmt(txn, s.selectRoomNIDForEventIDStmt)
err = stmt.QueryRowContext(ctx, eventID).Scan(&roomNID)
return
}
func (s *eventStatements) selectRoomNIDForEventNID(
ctx context.Context, txn *sql.Tx, eventNID types.EventNID,
) (roomNID types.RoomNID, err error) {
stmt := common.TxStmt(txn, s.selectRoomNIDForEventNIDStmt)
err = stmt.QueryRowContext(ctx, eventNID).Scan(&roomNID)
return
}
func (s *eventStatements) bulkSelectStateAtEventAndReference( func (s *eventStatements) bulkSelectStateAtEventAndReference(
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) ([]types.StateAtEventAndReference, error) { ) ([]types.StateAtEventAndReference, error) {

View file

@ -21,6 +21,7 @@ import (
// Import the postgres database driver. // Import the postgres database driver.
_ "github.com/lib/pq" _ "github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
@ -414,6 +415,26 @@ func (d *Database) RoomNID(ctx context.Context, roomID string) (types.RoomNID, e
return roomNID, err return roomNID, err
} }
func (d *Database) RoomNIDForEventID(
ctx context.Context, eventID string,
) (out types.RoomNID, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
out, err = d.statements.selectRoomNIDForEventID(ctx, txn, eventID)
return err
})
return
}
func (d *Database) RoomNIDForEventNID(
ctx context.Context, eventNID types.EventNID,
) (out types.RoomNID, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
out, err = d.statements.selectRoomNIDForEventNID(ctx, txn, eventNID)
return err
})
return
}
// LatestEventIDs implements query.RoomserverQueryAPIDatabase // LatestEventIDs implements query.RoomserverQueryAPIDatabase
func (d *Database) LatestEventIDs( func (d *Database) LatestEventIDs(
ctx context.Context, roomNID types.RoomNID, ctx context.Context, roomNID types.RoomNID,

View file

@ -80,6 +80,12 @@ const updateEventSentToOutputSQL = "" +
const selectEventIDSQL = "" + const selectEventIDSQL = "" +
"SELECT event_id FROM roomserver_events WHERE event_nid = $1" "SELECT event_id FROM roomserver_events WHERE event_nid = $1"
const selectRoomNIDForEventIDSQL = "" +
"SELECT room_nid FROM roomserver_events WHERE event_id = $1"
const selectRoomNIDForEventNIDSQL = "" +
"SELECT room_nid FROM roomserver_events WHERE event_nid = $1"
const bulkSelectStateAtEventAndReferenceSQL = "" + const bulkSelectStateAtEventAndReferenceSQL = "" +
"SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id, reference_sha256" + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id, reference_sha256" +
" FROM roomserver_events WHERE event_nid IN ($1)" " FROM roomserver_events WHERE event_nid IN ($1)"
@ -107,6 +113,8 @@ type eventStatements struct {
selectEventSentToOutputStmt *sql.Stmt selectEventSentToOutputStmt *sql.Stmt
updateEventSentToOutputStmt *sql.Stmt updateEventSentToOutputStmt *sql.Stmt
selectEventIDStmt *sql.Stmt selectEventIDStmt *sql.Stmt
selectRoomNIDForEventIDStmt *sql.Stmt
selectRoomNIDForEventNIDStmt *sql.Stmt
bulkSelectStateAtEventAndReferenceStmt *sql.Stmt bulkSelectStateAtEventAndReferenceStmt *sql.Stmt
bulkSelectEventReferenceStmt *sql.Stmt bulkSelectEventReferenceStmt *sql.Stmt
bulkSelectEventIDStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt
@ -131,6 +139,8 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) {
{&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL}, {&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL},
{&s.selectEventSentToOutputStmt, selectEventSentToOutputSQL}, {&s.selectEventSentToOutputStmt, selectEventSentToOutputSQL},
{&s.selectEventIDStmt, selectEventIDSQL}, {&s.selectEventIDStmt, selectEventIDSQL},
{&s.selectRoomNIDForEventIDStmt, selectRoomNIDForEventIDSQL},
{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL},
{&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL}, {&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL},
{&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL},
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
@ -310,6 +320,22 @@ func (s *eventStatements) selectEventID(
return return
} }
func (s *eventStatements) selectRoomNIDForEventID(
ctx context.Context, txn *sql.Tx, eventID string,
) (roomNID types.RoomNID, err error) {
stmt := common.TxStmt(txn, s.selectRoomNIDForEventIDStmt)
err = stmt.QueryRowContext(ctx, eventID).Scan(&roomNID)
return
}
func (s *eventStatements) selectRoomNIDForEventNID(
ctx context.Context, txn *sql.Tx, eventNID types.EventNID,
) (roomNID types.RoomNID, err error) {
stmt := common.TxStmt(txn, s.selectRoomNIDForEventNIDStmt)
err = stmt.QueryRowContext(ctx, eventNID).Scan(&roomNID)
return
}
func (s *eventStatements) bulkSelectStateAtEventAndReference( func (s *eventStatements) bulkSelectStateAtEventAndReference(
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) ([]types.StateAtEventAndReference, error) { ) ([]types.StateAtEventAndReference, error) {

View file

@ -533,6 +533,26 @@ func (d *Database) RoomNID(ctx context.Context, roomID string) (roomNID types.Ro
return return
} }
func (d *Database) RoomNIDForEventID(
ctx context.Context, eventID string,
) (out types.RoomNID, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
out, err = d.statements.selectRoomNIDForEventID(ctx, txn, eventID)
return err
})
return
}
func (d *Database) RoomNIDForEventNID(
ctx context.Context, eventNID types.EventNID,
) (out types.RoomNID, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
out, err = d.statements.selectRoomNIDForEventNID(ctx, txn, eventNID)
return err
})
return
}
// LatestEventIDs implements query.RoomserverQueryAPIDatabase // LatestEventIDs implements query.RoomserverQueryAPIDatabase
func (d *Database) LatestEventIDs( func (d *Database) LatestEventIDs(
ctx context.Context, roomNID types.RoomNID, ctx context.Context, roomNID types.RoomNID,

View file

@ -38,6 +38,8 @@ type Database interface {
GetLatestEventsForUpdate(ctx context.Context, roomNID types.RoomNID) (types.RoomRecentEventsUpdater, error) GetLatestEventsForUpdate(ctx context.Context, roomNID types.RoomNID) (types.RoomRecentEventsUpdater, error)
GetTransactionEventID(ctx context.Context, transactionID string, sessionID int64, userID string) (string, error) GetTransactionEventID(ctx context.Context, transactionID string, sessionID int64, userID string) (string, error)
RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) RoomNID(ctx context.Context, roomID string) (types.RoomNID, error)
RoomNIDForEventID(ctx context.Context, eventID string) (types.RoomNID, error)
RoomNIDForEventNID(ctx context.Context, eventNID types.EventNID) (types.RoomNID, error)
LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error)
GetInvitesForUser(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (senderUserIDs []types.EventStateKeyNID, err error) GetInvitesForUser(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (senderUserIDs []types.EventStateKeyNID, err error)
SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error