Retrieve room version where known in roomserver

This commit is contained in:
Neil Alexander 2020-03-17 15:53:23 +00:00
parent a66c701b29
commit bc8735ebc6
8 changed files with 87 additions and 14 deletions

View file

@ -72,6 +72,10 @@ type RoomEventDatabase interface {
ctx context.Context, transactionID string,
sessionID int64, userID string,
) (string, error)
// Look up the room version for a given room.
GetRoomVersionForRoom(
ctx context.Context, roomNID types.RoomNID,
) (gomatrixserverlib.RoomVersion, error)
}
// OutputRoomEventWriter has the APIs needed to write an event to the output logs.

View file

@ -253,8 +253,10 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error)
latestEventIDs[i] = u.latest[i].EventID
}
// TODO: Room version here
roomVersion := gomatrixserverlib.RoomVersionV1
roomVersion, err := u.db.GetRoomVersionForRoom(u.ctx, u.roomNID)
if err != nil {
return nil, err
}
ore := api.OutputNewRoomEvent{
Event: u.event.Headered(roomVersion),

View file

@ -93,6 +93,10 @@ type RoomserverQueryAPIDatabase interface {
GetRoomVersionForRoom(
ctx context.Context, roomNID types.RoomNID,
) (gomatrixserverlib.RoomVersion, error)
// Look up the room NID that an event ID appears in.
GetRoomNIDForEventID(
ctx context.Context, eventID string,
) (types.RoomNID, error)
}
// RoomserverQueryAPI is an implementation of api.RoomserverQueryAPI
@ -234,8 +238,15 @@ func (r *RoomserverQueryAPI) QueryEventsByID(
}
for _, event := range events {
// TODO: Room version here
roomVersion := gomatrixserverlib.RoomVersionV1
roomNID, nerr := r.DB.GetRoomNIDForEventID(ctx, event.EventID())
if nerr != nil {
return nerr
}
roomVersion, verr := r.DB.GetRoomVersionForRoom(ctx, roomNID)
if verr != nil {
return verr
}
response.Events = append(response.Events, event.Headered(roomVersion))
}
@ -516,8 +527,15 @@ func (r *RoomserverQueryAPI) QueryMissingEvents(
response.Events = make([]gomatrixserverlib.HeaderedEvent, 0, len(loadedEvents)-len(eventsToFilter))
for _, event := range loadedEvents {
if !eventsToFilter[event.EventID()] {
// TODO: Room version here
roomVersion := gomatrixserverlib.RoomVersionV1
roomNID, nerr := r.DB.GetRoomNIDForEventID(ctx, event.EventID())
if nerr != nil {
return nerr
}
roomVersion, verr := r.DB.GetRoomVersionForRoom(ctx, roomNID)
if verr != nil {
return verr
}
response.Events = append(response.Events, event.Headered(roomVersion))
}
@ -562,8 +580,15 @@ func (r *RoomserverQueryAPI) QueryBackfill(
}
for _, event := range loadedEvents {
// TODO: Room version here
roomVersion := gomatrixserverlib.RoomVersionV1
roomNID, nerr := r.DB.GetRoomNIDForEventID(ctx, event.EventID())
if nerr != nil {
return nerr
}
roomVersion, verr := r.DB.GetRoomVersionForRoom(ctx, roomNID)
if verr != nil {
return verr
}
response.Events = append(response.Events, event.Headered(roomVersion))
}
@ -653,6 +678,11 @@ func (r *RoomserverQueryAPI) QueryStateAndAuthChain(
}
response.RoomExists = true
roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, roomNID)
if err != nil {
return err
}
prevStates, err := r.DB.StateAtEventIDs(ctx, request.PrevEventIDs)
if err != nil {
switch err.(type) {
@ -683,16 +713,10 @@ func (r *RoomserverQueryAPI) QueryStateAndAuthChain(
}
for _, event := range stateEvents {
// TODO: Room version here
roomVersion := gomatrixserverlib.RoomVersionV1
response.StateEvents = append(response.StateEvents, event.Headered(roomVersion))
}
for _, event := range authEvents {
// TODO: Room version here
roomVersion := gomatrixserverlib.RoomVersionV1
response.AuthChainEvents = append(response.AuthChainEvents, event.Headered(roomVersion))
}

View file

@ -38,6 +38,7 @@ type Database interface {
GetInvitesForUser(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (senderUserIDs []types.EventStateKeyNID, err error)
SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error
GetRoomIDForAlias(ctx context.Context, alias string) (string, error)
GetRoomNIDForEventID(ctx context.Context, eventID string) (types.RoomNID, error)
GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error)
GetCreatorIDForAlias(ctx context.Context, alias string) (string, error)
RemoveRoomAlias(ctx context.Context, alias string) error

View file

@ -116,6 +116,9 @@ const bulkSelectEventNIDSQL = "" +
const selectMaxEventDepthSQL = "" +
"SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid = ANY($1)"
const selectRoomNIDForEventIDSQL = "" +
"SELECT room_nid FROM roomserver_events WHERE event_id = $1"
type eventStatements struct {
insertEventStmt *sql.Stmt
selectEventStmt *sql.Stmt
@ -130,6 +133,7 @@ type eventStatements struct {
bulkSelectEventIDStmt *sql.Stmt
bulkSelectEventNIDStmt *sql.Stmt
selectMaxEventDepthStmt *sql.Stmt
selectRoomNIDForEventIDStmt *sql.Stmt
}
func (s *eventStatements) prepare(db *sql.DB) (err error) {
@ -152,6 +156,7 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) {
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
{&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL},
{&s.selectRoomNIDForEventIDStmt, selectRoomNIDForEventIDSQL},
}.prepare(db)
}
@ -424,3 +429,11 @@ func eventNIDsAsArray(eventNIDs []types.EventNID) pq.Int64Array {
}
return nids
}
func (s *eventStatements) selectRoomNIDForEventID(
ctx context.Context, txn *sql.Tx, eventID string,
) (roomNID types.RoomNID, err error) {
selectStmt := common.TxStmt(txn, s.selectRoomNIDForEventIDStmt)
err = selectStmt.QueryRowContext(ctx, eventID).Scan(&roomNID)
return
}

View file

@ -747,6 +747,14 @@ func (d *Database) GetRoomVersionForRoom(
)
}
func (d *Database) GetRoomNIDForEventID(
ctx context.Context, eventID string,
) (types.RoomNID, error) {
return d.statements.selectRoomNIDForEventID(
ctx, nil, eventID,
)
}
type transaction struct {
ctx context.Context
txn *sql.Tx

View file

@ -96,6 +96,9 @@ const bulkSelectEventNIDSQL = "" +
const selectMaxEventDepthSQL = "" +
"SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)"
const selectRoomNIDForEventIDSQL = "" +
"SELECT room_nid FROM roomserver_events WHERE event_id = $1"
type eventStatements struct {
db *sql.DB
insertEventStmt *sql.Stmt
@ -111,6 +114,7 @@ type eventStatements struct {
bulkSelectEventReferenceStmt *sql.Stmt
bulkSelectEventIDStmt *sql.Stmt
bulkSelectEventNIDStmt *sql.Stmt
selectRoomNIDForEventIDStmt *sql.Stmt
}
func (s *eventStatements) prepare(db *sql.DB) (err error) {
@ -134,6 +138,7 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) {
{&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL},
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
{&s.selectRoomNIDForEventIDStmt, selectRoomNIDForEventIDSQL},
}.prepare(db)
}
@ -476,3 +481,11 @@ func eventNIDsAsArray(eventNIDs []types.EventNID) string {
b, _ := json.Marshal(eventNIDs)
return string(b)
}
func (s *eventStatements) selectRoomNIDForEventID(
ctx context.Context, txn *sql.Tx, eventID string,
) (roomNID types.RoomNID, err error) {
selectStmt := common.TxStmt(txn, s.selectRoomNIDForEventIDStmt)
err = selectStmt.QueryRowContext(ctx, eventID).Scan(&roomNID)
return
}

View file

@ -902,6 +902,14 @@ func (d *Database) GetRoomVersionForRoom(
)
}
func (d *Database) GetRoomNIDForEventID(
ctx context.Context, eventID string,
) (types.RoomNID, error) {
return d.statements.selectRoomNIDForEventID(
ctx, nil, eventID,
)
}
type transaction struct {
ctx context.Context
txn *sql.Tx