Maybe getMembershipsBeforeEventNID and checkServerAllowedToSeeEvent will work now?

This commit is contained in:
Neil Alexander 2020-02-20 12:11:46 +00:00
parent f8e4d5bcb0
commit 229257a052
6 changed files with 124 additions and 6 deletions

View file

@ -94,6 +94,14 @@ type RoomserverQueryAPIDatabase interface {
GetRoomVersionForRoom(
ctx context.Context, roomNID types.RoomNID,
) (state.StateResolutionVersion, error)
// Get the room NID for a given event ID.
RoomNIDForEventID(
ctx context.Context, eventID string,
) (types.RoomNID, error)
// Get the room NID for a given event NID.
RoomNIDForEventNID(
ctx context.Context, eventNID types.EventNID,
) (types.RoomNID, error)
}
// RoomserverQueryAPI is an implementation of api.RoomserverQueryAPI
@ -115,6 +123,7 @@ func (r *RoomserverQueryAPI) QueryLatestEventsAndState(
if roomNID == 0 {
return nil
}
response.RoomExists = true
roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, roomNID)
if err != nil {
return err
@ -123,7 +132,6 @@ func (r *RoomserverQueryAPI) QueryLatestEventsAndState(
if err != nil {
return err
}
response.RoomExists = true
var currentStateSnapshotNID types.StateSnapshotNID
response.LatestEvents, currentStateSnapshotNID, response.Depth, err =
r.DB.LatestEventIDs(ctx, roomNID)
@ -339,11 +347,19 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom(
func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID(
ctx context.Context, eventNID types.EventNID, joinedOnly bool,
) ([]types.Event, error) {
// TODO: get the correct room version
roomState, err := state.GetStateResolutionAlgorithm(state.StateResolutionAlgorithmV1, r.DB)
roomNID, err := r.DB.RoomNIDForEventNID(ctx, eventNID)
if err != nil {
return []types.Event{}, err
return nil, err
}
roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, roomNID)
if err != nil {
return nil, err
}
roomState, err := state.GetStateResolutionAlgorithm(roomVersion, r.DB)
if err != nil {
return nil, err
}
events := []types.Event{}
// Lookup the event NID
eIDs, err := r.DB.EventIDs(ctx, []types.EventNID{eventNID})
@ -445,8 +461,15 @@ func (r *RoomserverQueryAPI) QueryServerAllowedToSeeEvent(
func (r *RoomserverQueryAPI) checkServerAllowedToSeeEvent(
ctx context.Context, eventID string, serverName gomatrixserverlib.ServerName,
) (bool, error) {
// TODO: get the correct room version
roomState, err := state.GetStateResolutionAlgorithm(state.StateResolutionAlgorithmV1, r.DB)
roomNID, err := r.DB.RoomNIDForEventID(ctx, eventID)
if err != nil {
return false, err
}
roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, roomNID)
if err != nil {
return false, err
}
roomState, err := state.GetStateResolutionAlgorithm(roomVersion, r.DB)
if err != nil {
return false, err
}

View file

@ -100,6 +100,12 @@ const updateEventSentToOutputSQL = "" +
const selectEventIDSQL = "" +
"SELECT event_id FROM roomserver_events WHERE event_nid = $1"
const selectRoomNIDForEventIDSQL = "" +
"SELECT room_nid FROM roomserver_events WHERE event_id = $1"
const selectRoomNIDForEventNIDSQL = "" +
"SELECT room_nid FROM roomserver_events WHERE event_nid = $1"
const bulkSelectStateAtEventAndReferenceSQL = "" +
"SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id, reference_sha256" +
" FROM roomserver_events WHERE event_nid = ANY($1)"
@ -125,6 +131,8 @@ type eventStatements struct {
selectEventSentToOutputStmt *sql.Stmt
updateEventSentToOutputStmt *sql.Stmt
selectEventIDStmt *sql.Stmt
selectRoomNIDForEventIDStmt *sql.Stmt
selectRoomNIDForEventNIDStmt *sql.Stmt
bulkSelectStateAtEventAndReferenceStmt *sql.Stmt
bulkSelectEventReferenceStmt *sql.Stmt
bulkSelectEventIDStmt *sql.Stmt
@ -147,6 +155,8 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) {
{&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL},
{&s.selectEventSentToOutputStmt, selectEventSentToOutputSQL},
{&s.selectEventIDStmt, selectEventIDSQL},
{&s.selectRoomNIDForEventIDStmt, selectRoomNIDForEventIDSQL},
{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL},
{&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL},
{&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL},
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
@ -294,6 +304,22 @@ func (s *eventStatements) selectEventID(
return
}
func (s *eventStatements) selectRoomNIDForEventID(
ctx context.Context, txn *sql.Tx, eventID string,
) (roomNID types.RoomNID, err error) {
stmt := common.TxStmt(txn, s.selectRoomNIDForEventIDStmt)
err = stmt.QueryRowContext(ctx, eventID).Scan(&roomNID)
return
}
func (s *eventStatements) selectRoomNIDForEventNID(
ctx context.Context, txn *sql.Tx, eventNID types.EventNID,
) (roomNID types.RoomNID, err error) {
stmt := common.TxStmt(txn, s.selectRoomNIDForEventNIDStmt)
err = stmt.QueryRowContext(ctx, eventNID).Scan(&roomNID)
return
}
func (s *eventStatements) bulkSelectStateAtEventAndReference(
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) ([]types.StateAtEventAndReference, error) {

View file

@ -21,6 +21,7 @@ import (
// Import the postgres database driver.
_ "github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/types"
@ -414,6 +415,26 @@ func (d *Database) RoomNID(ctx context.Context, roomID string) (types.RoomNID, e
return roomNID, err
}
func (d *Database) RoomNIDForEventID(
ctx context.Context, eventID string,
) (out types.RoomNID, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
out, err = d.statements.selectRoomNIDForEventID(ctx, txn, eventID)
return err
})
return
}
func (d *Database) RoomNIDForEventNID(
ctx context.Context, eventNID types.EventNID,
) (out types.RoomNID, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
out, err = d.statements.selectRoomNIDForEventNID(ctx, txn, eventNID)
return err
})
return
}
// LatestEventIDs implements query.RoomserverQueryAPIDatabase
func (d *Database) LatestEventIDs(
ctx context.Context, roomNID types.RoomNID,

View file

@ -80,6 +80,12 @@ const updateEventSentToOutputSQL = "" +
const selectEventIDSQL = "" +
"SELECT event_id FROM roomserver_events WHERE event_nid = $1"
const selectRoomNIDForEventIDSQL = "" +
"SELECT room_nid FROM roomserver_events WHERE event_id = $1"
const selectRoomNIDForEventNIDSQL = "" +
"SELECT room_nid FROM roomserver_events WHERE event_nid = $1"
const bulkSelectStateAtEventAndReferenceSQL = "" +
"SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id, reference_sha256" +
" FROM roomserver_events WHERE event_nid IN ($1)"
@ -107,6 +113,8 @@ type eventStatements struct {
selectEventSentToOutputStmt *sql.Stmt
updateEventSentToOutputStmt *sql.Stmt
selectEventIDStmt *sql.Stmt
selectRoomNIDForEventIDStmt *sql.Stmt
selectRoomNIDForEventNIDStmt *sql.Stmt
bulkSelectStateAtEventAndReferenceStmt *sql.Stmt
bulkSelectEventReferenceStmt *sql.Stmt
bulkSelectEventIDStmt *sql.Stmt
@ -131,6 +139,8 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) {
{&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL},
{&s.selectEventSentToOutputStmt, selectEventSentToOutputSQL},
{&s.selectEventIDStmt, selectEventIDSQL},
{&s.selectRoomNIDForEventIDStmt, selectRoomNIDForEventIDSQL},
{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL},
{&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL},
{&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL},
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
@ -310,6 +320,22 @@ func (s *eventStatements) selectEventID(
return
}
func (s *eventStatements) selectRoomNIDForEventID(
ctx context.Context, txn *sql.Tx, eventID string,
) (roomNID types.RoomNID, err error) {
stmt := common.TxStmt(txn, s.selectRoomNIDForEventIDStmt)
err = stmt.QueryRowContext(ctx, eventID).Scan(&roomNID)
return
}
func (s *eventStatements) selectRoomNIDForEventNID(
ctx context.Context, txn *sql.Tx, eventNID types.EventNID,
) (roomNID types.RoomNID, err error) {
stmt := common.TxStmt(txn, s.selectRoomNIDForEventNIDStmt)
err = stmt.QueryRowContext(ctx, eventNID).Scan(&roomNID)
return
}
func (s *eventStatements) bulkSelectStateAtEventAndReference(
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) ([]types.StateAtEventAndReference, error) {

View file

@ -533,6 +533,26 @@ func (d *Database) RoomNID(ctx context.Context, roomID string) (roomNID types.Ro
return
}
func (d *Database) RoomNIDForEventID(
ctx context.Context, eventID string,
) (out types.RoomNID, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
out, err = d.statements.selectRoomNIDForEventID(ctx, txn, eventID)
return err
})
return
}
func (d *Database) RoomNIDForEventNID(
ctx context.Context, eventNID types.EventNID,
) (out types.RoomNID, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
out, err = d.statements.selectRoomNIDForEventNID(ctx, txn, eventNID)
return err
})
return
}
// LatestEventIDs implements query.RoomserverQueryAPIDatabase
func (d *Database) LatestEventIDs(
ctx context.Context, roomNID types.RoomNID,

View file

@ -38,6 +38,8 @@ type Database interface {
GetLatestEventsForUpdate(ctx context.Context, roomNID types.RoomNID) (types.RoomRecentEventsUpdater, error)
GetTransactionEventID(ctx context.Context, transactionID string, sessionID int64, userID string) (string, error)
RoomNID(ctx context.Context, roomID string) (types.RoomNID, error)
RoomNIDForEventID(ctx context.Context, eventID string) (types.RoomNID, error)
RoomNIDForEventNID(ctx context.Context, eventNID types.EventNID) (types.RoomNID, error)
LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error)
GetInvitesForUser(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (senderUserIDs []types.EventStateKeyNID, err error)
SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error